diff --git a/extras/outbounds/acl/compile.go b/extras/outbounds/acl/compile.go index 40d0069..cdd0a94 100644 --- a/extras/outbounds/acl/compile.go +++ b/extras/outbounds/acl/compile.go @@ -41,7 +41,8 @@ type compiledRule[O Outbound] struct { Outbound O HostMatcher hostMatcher Protocol Protocol - Port uint16 + StartPort uint16 + EndPort uint16 HijackAddress net.IP } @@ -49,7 +50,7 @@ func (r *compiledRule[O]) Match(host HostInfo, proto Protocol, port uint16) bool if r.Protocol != ProtocolBoth && r.Protocol != proto { return false } - if r.Port != 0 && r.Port != port { + if r.StartPort != 0 && (port < r.StartPort || port > r.EndPort) { return false } return r.HostMatcher.Match(host) @@ -100,10 +101,9 @@ type GeoLoader interface { // Compile compiles TextRules into a CompiledRuleSet. // Names in the outbounds map MUST be in all lower case. -// geoipFunc is a function that returns the GeoIP database needed by the GeoIP matcher. -// It will be called every time a GeoIP matcher is used during compilation, but won't -// be called if there is no GeoIP rule. We use a function here so that database loading -// is on-demand (only required if used by rules). +// We want on-demand loading of GeoIP/GeoSite databases, so instead of passing the +// databases directly, we use a GeoLoader interface to load them only when needed +// by at least one rule. func Compile[O Outbound](rules []TextRule, outbounds map[string]O, cacheSize int, geoLoader GeoLoader, ) (CompiledRuleSet[O], error) { @@ -117,7 +117,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O, if errStr != "" { return nil, &CompilationError{rule.LineNum, errStr} } - proto, port, ok := parseProtoPort(rule.ProtoPort) + proto, startPort, endPort, ok := parseProtoPort(rule.ProtoPort) if !ok { return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)} } @@ -128,7 +128,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O, return nil, &CompilationError{rule.LineNum, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)} } } - compiledRules[i] = compiledRule[O]{outbound, hm, proto, port, hijackAddress} + compiledRules[i] = compiledRule[O]{outbound, hm, proto, startPort, endPort, hijackAddress} } cache, err := lru.New[string, matchResult[O]](cacheSize) if err != nil { @@ -149,26 +149,26 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O, // [empty] (same as *) // // proto must be either "tcp" or "udp", case-insensitive. -func parseProtoPort(protoPort string) (Protocol, uint16, bool) { +func parseProtoPort(protoPort string) (Protocol, uint16, uint16, bool) { protoPort = strings.ToLower(protoPort) if protoPort == "" || protoPort == "*" || protoPort == "*/*" { - return ProtocolBoth, 0, true + return ProtocolBoth, 0, 0, true } parts := strings.SplitN(protoPort, "/", 2) if len(parts) == 1 { // No port, only protocol switch parts[0] { case "tcp": - return ProtocolTCP, 0, true + return ProtocolTCP, 0, 0, true case "udp": - return ProtocolUDP, 0, true + return ProtocolUDP, 0, 0, true default: - return ProtocolBoth, 0, false + return ProtocolBoth, 0, 0, false } } else { // Both protocol and port var proto Protocol - var port uint16 + var startPort, endPort uint16 switch parts[0] { case "tcp": proto = ProtocolTCP @@ -177,16 +177,35 @@ func parseProtoPort(protoPort string) (Protocol, uint16, bool) { case "*": proto = ProtocolBoth default: - return ProtocolBoth, 0, false + return ProtocolBoth, 0, 0, false } if parts[1] != "*" { - p64, err := strconv.ParseUint(parts[1], 10, 16) - if err != nil { - return ProtocolBoth, 0, false + // We allow either a single port or a range (e.g. "1000-2000") + ports := strings.SplitN(strings.TrimSpace(parts[1]), "-", 2) + if len(ports) == 1 { + p64, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return ProtocolBoth, 0, 0, false + } + startPort = uint16(p64) + endPort = startPort + } else { + p64, err := strconv.ParseUint(ports[0], 10, 16) + if err != nil { + return ProtocolBoth, 0, 0, false + } + startPort = uint16(p64) + p64, err = strconv.ParseUint(ports[1], 10, 16) + if err != nil { + return ProtocolBoth, 0, 0, false + } + endPort = uint16(p64) + if startPort > endPort { + return ProtocolBoth, 0, 0, false + } } - port = uint16(p64) } - return proto, port, true + return proto, startPort, endPort, true } } diff --git a/extras/outbounds/acl/compile_test.go b/extras/outbounds/acl/compile_test.go index 772f8b6..a11c4e8 100644 --- a/extras/outbounds/acl/compile_test.go +++ b/extras/outbounds/acl/compile_test.go @@ -22,7 +22,7 @@ func (l *testGeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) { } func TestCompile(t *testing.T) { - ob1, ob2, ob3, ob4, ob5 := 1, 2, 3, 4, 5 + ob1, ob2, ob3, ob4, ob5, ob6 := 1, 2, 3, 4, 5, 6 rules := []TextRule{ { Outbound: "ob1", @@ -90,6 +90,12 @@ func TestCompile(t *testing.T) { ProtoPort: "*/*", HijackAddress: "", }, + { + Outbound: "ob6", + Address: "all", + ProtoPort: "tcp/6881-6889", + HijackAddress: "", + }, } comp, err := Compile[int](rules, map[string]int{ "ob1": ob1, @@ -97,6 +103,7 @@ func TestCompile(t *testing.T) { "ob3": ob3, "ob4": ob4, "ob5": ob5, + "ob6": ob6, }, 100, &testGeoLoader{}) assert.NoError(t, err) @@ -242,6 +249,15 @@ func TestCompile(t *testing.T) { wantOutbound: 0, // no match default wantIP: nil, }, + { + host: HostInfo{ + IPv4: net.ParseIP("223.1.1.1"), + }, + proto: ProtocolTCP, + port: 6883, + wantOutbound: ob6, // match range port rule 6881-6889 + wantIP: nil, + }, } for _, test := range tests { @@ -249,6 +265,22 @@ func TestCompile(t *testing.T) { assert.Equal(t, test.wantOutbound, gotOutbound) assert.Equal(t, test.wantIP, gotIP) } + + // Test Invalid Port Range Rule + eb1 := 1 + invalidRules := []TextRule{ + { + Outbound: "eb1", + Address: "1.1.2.0/24", + ProtoPort: "*/3-1", + HijackAddress: "", + }, + } + + _, err = Compile[int](invalidRules, map[string]int{ + "eb1": eb1, + }, 100, &testGeoLoader{}) + assert.Error(t, err) } func Test_parseGeoSiteName(t *testing.T) {