diff --git a/extras/go.mod b/extras/go.mod index 4143c22..90db4aa 100644 --- a/extras/go.mod +++ b/extras/go.mod @@ -5,7 +5,9 @@ go 1.20 require ( github.com/apernet/hysteria/core v0.0.0-00010101000000-000000000000 github.com/babolivier/go-doh-client v0.0.0-20201028162107-a76cff4cb8b6 + github.com/hashicorp/golang-lru/v2 v2.0.5 github.com/miekg/dns v1.1.55 + github.com/oschwald/geoip2-golang v1.9.0 github.com/stretchr/testify v1.8.4 golang.org/x/crypto v0.11.0 ) @@ -17,6 +19,7 @@ require ( github.com/golang/mock v1.6.0 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect + github.com/oschwald/maxminddb-golang v1.11.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect github.com/quic-go/qtls-go1-20 v0.3.1 // indirect diff --git a/extras/go.sum b/extras/go.sum index 6fba549..cfcd0c5 100644 --- a/extras/go.sum +++ b/extras/go.sum @@ -17,6 +17,8 @@ github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38/go.mod h1:kpwsk12EmLew5upagYY7GY0pfYCcupk39gWOCRROcvE= +github.com/hashicorp/golang-lru/v2 v2.0.5 h1:wW7h1TG88eUIJ2i69gaE3uNVtEPIagzhGvHgwfx2Vm4= +github.com/hashicorp/golang-lru/v2 v2.0.5/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/miekg/dns v1.1.55 h1:GoQ4hpsj0nFLYe+bWiCToyrBEJXkQfOOIvFGFy0lEgo= @@ -25,6 +27,10 @@ github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e h1:fD57ERR4JtEqsWb github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= +github.com/oschwald/geoip2-golang v1.9.0 h1:uvD3O6fXAXs+usU+UGExshpdP13GAqp4GBrzN7IgKZc= +github.com/oschwald/geoip2-golang v1.9.0/go.mod h1:BHK6TvDyATVQhKNbQBdrj9eAvuwOMi2zSFXizL3K81Y= +github.com/oschwald/maxminddb-golang v1.11.0 h1:aSXMqYR/EPNjGE8epgqwDay+P30hCBZIveY0WZbAWh0= +github.com/oschwald/maxminddb-golang v1.11.0/go.mod h1:YmVI+H0zh3ySFR3w+oz8PCfglAFj3PuCmui13+P9zDg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/quic-go/qpack v0.4.0 h1:Cr9BXA1sQS2SmDUWjSofMPNKmvF6IiIfDRmgU0w1ZCo= diff --git a/extras/outbounds/acl/GeoLite2-Country.mmdb b/extras/outbounds/acl/GeoLite2-Country.mmdb new file mode 100644 index 0000000..0eb0e50 Binary files /dev/null and b/extras/outbounds/acl/GeoLite2-Country.mmdb differ diff --git a/extras/outbounds/acl/compile.go b/extras/outbounds/acl/compile.go new file mode 100644 index 0000000..359add1 --- /dev/null +++ b/extras/outbounds/acl/compile.go @@ -0,0 +1,223 @@ +package acl + +import ( + "fmt" + "net" + "strconv" + "strings" + + lru "github.com/hashicorp/golang-lru/v2" + "github.com/oschwald/geoip2-golang" +) + +type protocol int + +const ( + protocolBoth protocol = iota + protocolTCP + protocolUDP +) + +type Outbound interface { + any +} + +type HostInfo struct { + Name string + IPv4 net.IP + IPv6 net.IP +} + +func (h HostInfo) String() string { + return fmt.Sprintf("%s|%s|%s", h.Name, h.IPv4, h.IPv6) +} + +type CompiledRuleSet[O Outbound] interface { + Match(host HostInfo, proto protocol, port uint16) (O, net.IP) +} + +type compiledRule[O Outbound] struct { + Outbound O + HostMatcher hostMatcher + Protocol protocol + Port uint16 + HijackAddress net.IP +} + +func (r *compiledRule[O]) Match(host HostInfo, proto protocol, port uint16) bool { + if r.Protocol != protocolBoth && r.Protocol != proto { + return false + } + if r.Port != 0 && r.Port != port { + return false + } + return r.HostMatcher.Match(host) +} + +type matchResult[O Outbound] struct { + Outbound O + HijackAddress net.IP +} + +type compiledRuleSetImpl[O Outbound] struct { + Rules []compiledRule[O] + Cache *lru.Cache[string, matchResult[O]] // key: HostInfo.String() +} + +func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto protocol, port uint16) (O, net.IP) { + host.Name = strings.ToLower(host.Name) // Normalize host name to lower case + key := host.String() + if result, ok := s.Cache.Get(key); ok { + return result.Outbound, result.HijackAddress + } + for _, rule := range s.Rules { + if rule.Match(host, proto, port) { + result := matchResult[O]{rule.Outbound, rule.HijackAddress} + s.Cache.Add(key, result) + return result.Outbound, result.HijackAddress + } + } + // No match should also be cached + var zero O + s.Cache.Add(key, matchResult[O]{zero, nil}) + return zero, nil +} + +type CompilationError struct { + Index int + Message string +} + +func (e *CompilationError) Error() string { + return fmt.Sprintf("error at index %d: %s", e.Index, e.Message) +} + +func Compile[O Outbound](rules []TextRule, outbounds map[string]O, + cacheSize int, geoipFunc func() *geoip2.Reader, +) (CompiledRuleSet[O], error) { + compiledRules := make([]compiledRule[O], len(rules)) + for i, rule := range rules { + outbound, ok := outbounds[rule.Outbound] + if !ok { + return nil, &CompilationError{i, fmt.Sprintf("outbound %s not found", rule.Outbound)} + } + hm, errStr := compileHostMatcher(rule.Address, geoipFunc) + if errStr != "" { + return nil, &CompilationError{i, errStr} + } + proto, port, ok := parseProtoPort(rule.ProtoPort) + if !ok { + return nil, &CompilationError{i, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)} + } + var hijackAddress net.IP + if rule.HijackAddress != "" { + hijackAddress = net.ParseIP(rule.HijackAddress) + if hijackAddress == nil { + return nil, &CompilationError{i, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)} + } + } + compiledRules[i] = compiledRule[O]{outbound, hm, proto, port, hijackAddress} + } + cache, err := lru.New[string, matchResult[O]](cacheSize) + if err != nil { + return nil, err + } + return &compiledRuleSetImpl[O]{compiledRules, cache}, nil +} + +// parseProtoPort parses the protocol and port from a protoPort string. +// protoPort must be in one of the following formats: +// +// proto/port +// proto/* +// proto +// */port +// */* +// * +// [empty] (same as *) +// +// proto must be either "tcp" or "udp", case-insensitive. +func parseProtoPort(protoPort string) (protocol, uint16, bool) { + protoPort = strings.ToLower(protoPort) + if protoPort == "" || protoPort == "*" || protoPort == "*/*" { + return protocolBoth, 0, true + } + parts := strings.SplitN(protoPort, "/", 2) + if len(parts) == 1 { + // No port, only protocol + switch parts[0] { + case "tcp": + return protocolTCP, 0, true + case "udp": + return protocolUDP, 0, true + default: + return protocolBoth, 0, false + } + } else { + // Both protocol and port + var proto protocol + var port uint16 + switch parts[0] { + case "tcp": + proto = protocolTCP + case "udp": + proto = protocolUDP + case "*": + proto = protocolBoth + default: + return protocolBoth, 0, false + } + if parts[1] != "*" { + p64, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return protocolBoth, 0, false + } + port = uint16(p64) + } + return proto, port, true + } +} + +func compileHostMatcher(addr string, geoipFunc func() *geoip2.Reader) (hostMatcher, string) { + addr = strings.ToLower(addr) // Normalize to lower case + if addr == "*" || addr == "all" { + // Match all hosts + return &allMatcher{}, "" + } + if strings.HasPrefix(addr, "geoip:") { + // GeoIP matcher + country := strings.ToUpper(addr[6:]) + if len(country) != 2 { + return nil, fmt.Sprintf("invalid country code: %s", country) + } + db := geoipFunc() + if db == nil { + return nil, "failed to load GeoIP database" + } + return &geoIPMatcher{db, country}, "" + } + if strings.Contains(addr, "/") { + // CIDR matcher + _, ipnet, err := net.ParseCIDR(addr) + if err != nil { + return nil, fmt.Sprintf("invalid CIDR address: %s", addr) + } + return &cidrMatcher{ipnet}, "" + } + if ip := net.ParseIP(addr); ip != nil { + // Single IP matcher + return &ipMatcher{ip}, "" + } + if strings.Contains(addr, "*") { + // Wildcard domain matcher + return &domainMatcher{ + Pattern: addr, + Wildcard: true, + }, "" + } + // Nothing else matched, treat it as a non-wildcard domain + return &domainMatcher{ + Pattern: addr, + Wildcard: false, + }, "" +} diff --git a/extras/outbounds/acl/compile_test.go b/extras/outbounds/acl/compile_test.go new file mode 100644 index 0000000..6f41112 --- /dev/null +++ b/extras/outbounds/acl/compile_test.go @@ -0,0 +1,156 @@ +package acl + +import ( + "net" + "testing" + + "github.com/oschwald/geoip2-golang" + "github.com/stretchr/testify/assert" +) + +func TestCompile(t *testing.T) { + ob1, ob2, ob3 := 1, 2, 3 + rules := []TextRule{ + { + Outbound: "ob1", + Address: "1.2.3.4", + ProtoPort: "", + HijackAddress: "", + }, + { + Outbound: "ob2", + Address: "8.8.8.0/24", + ProtoPort: "*", + HijackAddress: "1.1.1.1", + }, + { + Outbound: "ob3", + Address: "all", + ProtoPort: "udp/443", + HijackAddress: "", + }, + { + Outbound: "ob1", + Address: "2606:4700::6810:85e5", + ProtoPort: "tcp", + HijackAddress: "2606:4700::6810:85e6", + }, + { + Outbound: "ob2", + Address: "2606:4700::/44", + ProtoPort: "*/8888", + HijackAddress: "", + }, + { + Outbound: "ob3", + Address: "*.v2ex.com", + ProtoPort: "udp", + HijackAddress: "", + }, + { + Outbound: "ob1", + Address: "crap.v2ex.com", + ProtoPort: "tcp/80", + HijackAddress: "2.2.2.2", + }, + { + Outbound: "ob2", + Address: "geoip:JP", + ProtoPort: "*/*", + HijackAddress: "", + }, + } + reader, err := geoip2.Open("GeoLite2-Country.mmdb") + assert.NoError(t, err) + comp, err := Compile[int](rules, map[string]int{"ob1": ob1, "ob2": ob2, "ob3": ob3}, 100, func() *geoip2.Reader { + return reader + }) + assert.NoError(t, err) + + tests := []struct { + host HostInfo + proto protocol + port uint16 + wantOutbound int + wantIP net.IP + }{ + { + host: HostInfo{ + IPv4: net.ParseIP("1.2.3.4"), + }, + proto: protocolTCP, + port: 1234, + wantOutbound: ob1, + wantIP: nil, + }, + { + host: HostInfo{ + IPv4: net.ParseIP("8.8.8.4"), + }, + proto: protocolUDP, + port: 5353, + wantOutbound: ob2, + wantIP: net.ParseIP("1.1.1.1"), + }, + { + host: HostInfo{ + Name: "lean.delicious.com", + }, + proto: protocolUDP, + port: 443, + wantOutbound: ob3, + wantIP: nil, + }, + { + host: HostInfo{ + IPv6: net.ParseIP("2606:4700::6810:85e5"), + }, + proto: protocolTCP, + port: 80, + wantOutbound: ob1, + wantIP: net.ParseIP("2606:4700::6810:85e6"), + }, + { + host: HostInfo{ + IPv6: net.ParseIP("2606:4700:0:0:0:0:0:1"), + }, + proto: protocolUDP, + port: 8888, + wantOutbound: ob2, + wantIP: nil, + }, + { + host: HostInfo{ + Name: "www.v2ex.com", + }, + proto: protocolUDP, + port: 1234, + wantOutbound: ob3, + wantIP: nil, + }, + { + host: HostInfo{ + Name: "crap.v2ex.com", + }, + proto: protocolTCP, + port: 80, + wantOutbound: ob1, + wantIP: net.ParseIP("2.2.2.2"), + }, + { + host: HostInfo{ + IPv4: net.ParseIP("210.140.92.187"), + }, + proto: protocolTCP, + port: 25, + wantOutbound: ob2, + wantIP: nil, + }, + } + + for _, test := range tests { + gotOutbound, gotIP := comp.Match(test.host, test.proto, test.port) + assert.Equal(t, test.wantOutbound, gotOutbound) + assert.Equal(t, test.wantIP, gotIP) + } +} diff --git a/extras/outbounds/acl/matchers.go b/extras/outbounds/acl/matchers.go new file mode 100644 index 0000000..33ec53a --- /dev/null +++ b/extras/outbounds/acl/matchers.go @@ -0,0 +1,83 @@ +package acl + +import ( + "net" + + "github.com/oschwald/geoip2-golang" +) + +type hostMatcher interface { + Match(HostInfo) bool +} + +type ipMatcher struct { + IP net.IP +} + +func (m *ipMatcher) Match(host HostInfo) bool { + return m.IP.Equal(host.IPv4) || m.IP.Equal(host.IPv6) +} + +type cidrMatcher struct { + IPNet *net.IPNet +} + +func (m *cidrMatcher) Match(host HostInfo) bool { + return m.IPNet.Contains(host.IPv4) || m.IPNet.Contains(host.IPv6) +} + +type domainMatcher struct { + Pattern string + Wildcard bool +} + +func (m *domainMatcher) Match(host HostInfo) bool { + if m.Wildcard { + return deepMatchRune([]rune(host.Name), []rune(m.Pattern)) + } + return m.Pattern == host.Name +} + +func deepMatchRune(str, pattern []rune) bool { + for len(pattern) > 0 { + switch pattern[0] { + default: + if len(str) == 0 || str[0] != pattern[0] { + return false + } + case '*': + return deepMatchRune(str, pattern[1:]) || + (len(str) > 0 && deepMatchRune(str[1:], pattern)) + } + str = str[1:] + pattern = pattern[1:] + } + return len(str) == 0 && len(pattern) == 0 +} + +type geoIPMatcher struct { + DB *geoip2.Reader + Country string // must be uppercase ISO 3166-1 alpha-2 code +} + +func (m *geoIPMatcher) Match(host HostInfo) bool { + if host.IPv4 != nil { + record, err := m.DB.Country(host.IPv4) + if err == nil && record.Country.IsoCode == m.Country { + return true + } + } + if host.IPv6 != nil { + record, err := m.DB.Country(host.IPv6) + if err == nil && record.Country.IsoCode == m.Country { + return true + } + } + return false +} + +type allMatcher struct{} + +func (m *allMatcher) Match(host HostInfo) bool { + return true +} diff --git a/extras/outbounds/acl/matchers_test.go b/extras/outbounds/acl/matchers_test.go new file mode 100644 index 0000000..5ec3884 --- /dev/null +++ b/extras/outbounds/acl/matchers_test.go @@ -0,0 +1,310 @@ +package acl + +import ( + "net" + "testing" + + "github.com/oschwald/geoip2-golang" + "github.com/stretchr/testify/assert" +) + +func Test_ipMatcher_Match(t *testing.T) { + tests := []struct { + name string + IP net.IP + host HostInfo + want bool + }{ + { + name: "ipv4 match", + IP: net.IPv4(127, 0, 0, 1), + host: HostInfo{ + IPv4: net.IPv4(127, 0, 0, 1), + IPv6: nil, + }, + want: true, + }, + { + name: "ipv6 match", + IP: net.IPv6loopback, + host: HostInfo{ + IPv4: nil, + IPv6: net.IPv6loopback, + }, + want: true, + }, + { + name: "no match", + IP: net.IPv4(127, 0, 0, 1), + host: HostInfo{ + IPv4: net.IPv4(127, 0, 0, 2), + IPv6: net.IPv6loopback, + }, + want: false, + }, + { + name: "both nil", + IP: net.IPv4(127, 0, 0, 1), + host: HostInfo{ + IPv4: nil, + IPv6: nil, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &ipMatcher{ + IP: tt.IP, + } + if got := m.Match(tt.host); got != tt.want { + t.Errorf("Match() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_cidrMatcher_Match(t *testing.T) { + _, cidr1, _ := net.ParseCIDR("192.168.1.0/24") + _, cidr2, _ := net.ParseCIDR("::1/128") + _, cidr3, _ := net.ParseCIDR("0.0.0.0/0") + _, cidr4, _ := net.ParseCIDR("::/0") + + tests := []struct { + name string + IPNet *net.IPNet + host HostInfo + want bool + }{ + { + name: "ipv4 match", + IPNet: cidr1, + host: HostInfo{ + IPv4: net.ParseIP("192.168.1.100"), + IPv6: net.ParseIP("::1"), + }, + want: true, + }, + { + name: "ipv6 match", + IPNet: cidr2, + host: HostInfo{ + IPv4: net.ParseIP("10.0.0.1"), + IPv6: net.ParseIP("::1"), + }, + want: true, + }, + { + name: "no match", + IPNet: cidr1, + host: HostInfo{ + IPv4: net.ParseIP("10.0.0.1"), + IPv6: net.ParseIP("2001:db8::2:1"), + }, + want: false, + }, + { + name: "ipv4 broad", + IPNet: cidr3, + host: HostInfo{ + IPv4: net.ParseIP("10.0.0.1"), + IPv6: net.ParseIP("::1"), + }, + want: true, + }, + { + name: "ipv6 broad", + IPNet: cidr4, + host: HostInfo{ + IPv4: net.ParseIP("10.0.0.1"), + IPv6: net.ParseIP("2001:db8::2:1"), + }, + want: true, + }, + { + name: "both nil", + IPNet: cidr1, + host: HostInfo{ + IPv4: nil, + IPv6: nil, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &cidrMatcher{ + IPNet: tt.IPNet, + } + if got := m.Match(tt.host); got != tt.want { + t.Errorf("Match() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_domainMatcher_Match(t *testing.T) { + type fields struct { + Pattern string + Wildcard bool + } + tests := []struct { + name string + fields fields + host HostInfo + want bool + }{ + { + name: "non-wildcard match", + fields: fields{ + Pattern: "example.com", + Wildcard: false, + }, + host: HostInfo{ + Name: "example.com", + }, + want: true, + }, + { + name: "non-wildcard no match", + fields: fields{ + Pattern: "example.com", + Wildcard: false, + }, + host: HostInfo{ + Name: "example.org", + }, + want: false, + }, + { + name: "wildcard match 1", + fields: fields{ + Pattern: "*.example.com", + Wildcard: true, + }, + host: HostInfo{ + Name: "www.example.com", + }, + want: true, + }, + { + name: "wildcard match 2", + fields: fields{ + Pattern: "example*.com", + Wildcard: true, + }, + host: HostInfo{ + Name: "example2.com", + }, + want: true, + }, + { + name: "wildcard no match", + fields: fields{ + Pattern: "*.example.com", + Wildcard: true, + }, + host: HostInfo{ + Name: "example.com", + }, + want: false, + }, + { + name: "empty", + fields: fields{ + Pattern: "*.example.com", + Wildcard: true, + }, + host: HostInfo{ + Name: "", + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &domainMatcher{ + Pattern: tt.fields.Pattern, + Wildcard: tt.fields.Wildcard, + } + if got := m.Match(tt.host); got != tt.want { + t.Errorf("Match() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_geoIPMatcher_Match(t *testing.T) { + db, err := geoip2.Open("GeoLite2-Country.mmdb") + assert.NoError(t, err) + defer db.Close() + + type fields struct { + DB *geoip2.Reader + Country string + } + tests := []struct { + name string + fields fields + host HostInfo + want bool + }{ + { + name: "ipv4 match", + fields: fields{ + DB: db, + Country: "JP", + }, + host: HostInfo{ + IPv4: net.ParseIP("210.140.92.181"), + }, + want: true, + }, + { + name: "ipv6 match", + fields: fields{ + DB: db, + Country: "US", + }, + host: HostInfo{ + IPv6: net.ParseIP("2606:4700::6810:85e5"), + }, + want: true, + }, + { + name: "no match", + fields: fields{ + DB: db, + Country: "AU", + }, + host: HostInfo{ + IPv4: net.ParseIP("210.140.92.181"), + IPv6: net.ParseIP("2606:4700::6810:85e5"), + }, + want: false, + }, + { + name: "both nil", + fields: fields{ + DB: db, + Country: "KR", + }, + host: HostInfo{ + IPv4: nil, + IPv6: nil, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &geoIPMatcher{ + DB: tt.fields.DB, + Country: tt.fields.Country, + } + if got := m.Match(tt.host); got != tt.want { + t.Errorf("Match() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/extras/outbounds/acl/parse.go b/extras/outbounds/acl/parse.go new file mode 100644 index 0000000..f9b4eef --- /dev/null +++ b/extras/outbounds/acl/parse.go @@ -0,0 +1,79 @@ +package acl + +import ( + "fmt" + "os" + "regexp" + "strings" +) + +var linePattern = regexp.MustCompile(`^(\w+)\s*\(([^,]+)(?:,([^,]+))?(?:,([^,]+))?\)$`) + +type InvalidSyntaxError struct { + Line string + LineNum int +} + +func (e *InvalidSyntaxError) Error() string { + return fmt.Sprintf("invalid syntax at line %d: %s", e.LineNum, e.Line) +} + +// TextRule is the struct representation of a (non-comment) line parsed from an ACL file. +// A line can be parsed into a TextRule as long as it matches one of the following patterns: +// +// outbound(address) +// outbound(address,protoPort) +// outbound(address,protoPort,hijackAddress) +// +// It does not check whether any of the fields is valid - it's up to the compiler to do so. +type TextRule struct { + Outbound string + Address string + ProtoPort string + HijackAddress string +} + +func parseLine(line string) *TextRule { + matches := linePattern.FindStringSubmatch(line) + if matches == nil { + return nil + } + return &TextRule{ + Outbound: matches[1], + Address: strings.TrimSpace(matches[2]), + ProtoPort: strings.TrimSpace(matches[3]), + HijackAddress: strings.TrimSpace(matches[4]), + } +} + +func ParseTextRules(text string) ([]TextRule, error) { + rules := make([]TextRule, 0) + lineNum := 0 + for _, line := range strings.Split(text, "\n") { + lineNum++ + // Remove comments + if i := strings.Index(line, "#"); i >= 0 { + line = line[:i] + } + line = strings.TrimSpace(line) + // Skip empty lines + if len(line) == 0 { + continue + } + // Parse line + rule := parseLine(line) + if rule == nil { + return nil, &InvalidSyntaxError{line, lineNum} + } + rules = append(rules, *rule) + } + return rules, nil +} + +func ParseTextRulesFile(filename string) ([]TextRule, error) { + bs, err := os.ReadFile(filename) + if err != nil { + return nil, err + } + return ParseTextRules(string(bs)) +} diff --git a/extras/outbounds/acl/parse_test.go b/extras/outbounds/acl/parse_test.go new file mode 100644 index 0000000..250e8a1 --- /dev/null +++ b/extras/outbounds/acl/parse_test.go @@ -0,0 +1,70 @@ +package acl + +import ( + "reflect" + "testing" +) + +func TestParseTextRules(t *testing.T) { + tests := []struct { + name string + text string + want []TextRule + wantErr bool + }{ + { + name: "empty", + text: "", + want: []TextRule{}, + wantErr: false, + }, + { + name: "ok", + text: ` +# just a comment + # another comment +direct(1.1.1.1) +direct(8.8.8.0/24) +reject(all, udp/443) # inline comment + reject(geoip:cn) + reject(*.v2ex.com) +my_custom_outbound1(9.9.9.9,*, 8.8.8.8) # bebop +my_custom_outbound2(all) +`, + want: []TextRule{ + {Outbound: "direct", Address: "1.1.1.1"}, + {Outbound: "direct", Address: "8.8.8.0/24"}, + {Outbound: "reject", Address: "all", ProtoPort: "udp/443"}, + {Outbound: "reject", Address: "geoip:cn"}, + {Outbound: "reject", Address: "*.v2ex.com"}, + {Outbound: "my_custom_outbound1", Address: "9.9.9.9", ProtoPort: "*", HijackAddress: "8.8.8.8"}, + {Outbound: "my_custom_outbound2", Address: "all"}, + }, + wantErr: false, + }, + { + name: "fail 1", + text: `boom()`, + want: nil, + wantErr: true, + }, + { + name: "fail 2", + text: `lol(1,1,1,1)`, + want: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseTextRules(tt.text) + if (err != nil) != tt.wantErr { + t.Errorf("ParseTextRules() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("ParseTextRules() got = %v, want %v", got, tt.want) + } + }) + } +}