diff --git a/pkg/acl/engine.go b/pkg/acl/engine.go index c74d8b4..4b758a6 100644 --- a/pkg/acl/engine.go +++ b/pkg/acl/engine.go @@ -20,7 +20,13 @@ type Engine struct { GeoIPReader *geoip2.Reader } -type cacheEntry struct { +type cacheKey struct { + Host string + Port uint16 + IsUDP bool +} + +type cacheValue struct { Action Action Arg string } @@ -44,7 +50,7 @@ func LoadFromFile(filename string, resolveIPAddr func(string) (*net.IPAddr, erro if err != nil { return nil, err } - if len(entry.Country) > 0 && geoIPReader == nil { + if _, ok := entry.Matcher.(*countryMatcher); ok && geoIPReader == nil { geoIPReader, err = geoIPLoadFunc() // lazy load GeoIP reader only when needed if err != nil { return nil, err @@ -66,44 +72,69 @@ func LoadFromFile(filename string, resolveIPAddr func(string) (*net.IPAddr, erro } // action, arg, isDomain, resolvedIP, error -func (e *Engine) ResolveAndMatch(host string) (Action, string, bool, *net.IPAddr, error) { +func (e *Engine) ResolveAndMatch(host string, port uint16, isUDP bool) (Action, string, bool, *net.IPAddr, error) { ip, zone := utils.ParseIPZone(host) if ip == nil { // Domain ipAddr, err := e.ResolveIPAddr(host) - if v, ok := e.Cache.Get(host); ok { + if v, ok := e.Cache.Get(cacheKey{host, port, isUDP}); ok { // Cache hit - ce := v.(cacheEntry) + ce := v.(cacheValue) return ce.Action, ce.Arg, true, ipAddr, err } for _, entry := range e.Entries { - if entry.MatchDomain(host) || (ipAddr != nil && entry.MatchIP(ipAddr.IP, e.GeoIPReader)) { - e.Cache.Add(host, cacheEntry{entry.Action, entry.ActionArg}) + mReq := MatchRequest{ + Domain: host, + Port: port, + DB: e.GeoIPReader, + } + if ipAddr != nil { + mReq.IP = ipAddr.IP + } + if isUDP { + mReq.Protocol = ProtocolUDP + } else { + mReq.Protocol = ProtocolTCP + } + if entry.Match(mReq) { + e.Cache.Add(cacheKey{host, port, isUDP}, + cacheValue{entry.Action, entry.ActionArg}) return entry.Action, entry.ActionArg, true, ipAddr, err } } - e.Cache.Add(host, cacheEntry{e.DefaultAction, ""}) + e.Cache.Add(cacheKey{host, port, isUDP}, cacheValue{e.DefaultAction, ""}) return e.DefaultAction, "", true, ipAddr, err } else { // IP - if v, ok := e.Cache.Get(ip.String()); ok { + if v, ok := e.Cache.Get(cacheKey{ip.String(), port, isUDP}); ok { // Cache hit - ce := v.(cacheEntry) + ce := v.(cacheValue) return ce.Action, ce.Arg, false, &net.IPAddr{ IP: ip, Zone: zone, }, nil } for _, entry := range e.Entries { - if entry.MatchIP(ip, e.GeoIPReader) { - e.Cache.Add(ip.String(), cacheEntry{entry.Action, entry.ActionArg}) + mReq := MatchRequest{ + IP: ip, + Port: port, + DB: e.GeoIPReader, + } + if isUDP { + mReq.Protocol = ProtocolUDP + } else { + mReq.Protocol = ProtocolTCP + } + if entry.Match(mReq) { + e.Cache.Add(cacheKey{ip.String(), port, isUDP}, + cacheValue{entry.Action, entry.ActionArg}) return entry.Action, entry.ActionArg, false, &net.IPAddr{ IP: ip, Zone: zone, }, nil } } - e.Cache.Add(ip.String(), cacheEntry{e.DefaultAction, ""}) + e.Cache.Add(cacheKey{ip.String(), port, isUDP}, cacheValue{e.DefaultAction, ""}) return e.DefaultAction, "", false, &net.IPAddr{ IP: ip, Zone: zone, diff --git a/pkg/acl/engine_test.go b/pkg/acl/engine_test.go index f7b8587..4c30884 100644 --- a/pkg/acl/engine_test.go +++ b/pkg/acl/engine_test.go @@ -1,113 +1,153 @@ package acl import ( + "errors" lru "github.com/hashicorp/golang-lru" "net" + "strings" "testing" ) func TestEngine_ResolveAndMatch(t *testing.T) { - cache, _ := lru.NewARC(4) + cache, _ := lru.NewARC(16) e := &Engine{ DefaultAction: ActionDirect, Entries: []Entry{ { - Net: nil, - Domain: "google.com", - Suffix: false, - All: false, Action: ActionProxy, ActionArg: "", + Matcher: &domainMatcher{ + matcherBase: matcherBase{ + Protocol: ProtocolTCP, + Port: 443, + }, + Domain: "google.com", + Suffix: false, + }, }, { - Net: nil, - Domain: "evil.corp", - Suffix: true, - All: false, Action: ActionHijack, ActionArg: "good.org", + Matcher: &domainMatcher{ + matcherBase: matcherBase{}, + Domain: "evil.corp", + Suffix: true, + }, }, { - Net: &net.IPNet{ - IP: net.ParseIP("10.0.0.0"), - Mask: net.CIDRMask(8, 32), - }, - Domain: "", - Suffix: false, - All: false, Action: ActionProxy, ActionArg: "", + Matcher: &netMatcher{ + matcherBase: matcherBase{}, + Net: &net.IPNet{ + IP: net.ParseIP("10.0.0.0"), + Mask: net.CIDRMask(8, 32), + }, + }, }, { - Net: nil, - Domain: "", - Suffix: false, - All: true, Action: ActionBlock, ActionArg: "", + Matcher: &allMatcher{}, }, }, Cache: cache, + ResolveIPAddr: func(s string) (*net.IPAddr, error) { + if strings.Contains(s, "evil.corp") { + return nil, errors.New("resolve error") + } + return net.ResolveIPAddr("ip", s) + }, } tests := []struct { - name string - addr string - want Action - want1 string - wantErr bool + name string + host string + port uint16 + isUDP bool + wantAction Action + wantArg string + wantErr bool }{ { - name: "domain direct", - addr: "google.com", - want: ActionProxy, - want1: "", + name: "domain proxy", + host: "google.com", + port: 443, + isUDP: false, + wantAction: ActionProxy, + wantArg: "", }, { - name: "domain suffix 1", - addr: "evil.corp", - want: ActionHijack, - want1: "good.org", - wantErr: true, + name: "domain block", + host: "google.com", + port: 80, + isUDP: false, + wantAction: ActionBlock, + wantArg: "", }, { - name: "domain suffix 2", - addr: "notevil.corp", - want: ActionBlock, - want1: "", - wantErr: true, + name: "domain suffix 1", + host: "evil.corp", + port: 8899, + isUDP: true, + wantAction: ActionHijack, + wantArg: "good.org", + wantErr: true, }, { - name: "domain suffix 3", - addr: "im.real.evil.corp", - want: ActionHijack, - want1: "good.org", - wantErr: true, + name: "domain suffix 2", + host: "notevil.corp", + port: 22, + isUDP: false, + wantAction: ActionBlock, + wantArg: "", + wantErr: true, }, { - name: "ip match", - addr: "10.2.3.4", - want: ActionProxy, - want1: "", + name: "domain suffix 3", + host: "im.real.evil.corp", + port: 443, + isUDP: true, + wantAction: ActionHijack, + wantArg: "good.org", + wantErr: true, }, { - name: "ip mismatch", - addr: "100.5.6.0", - want: ActionBlock, - want1: "", + name: "ip match", + host: "10.2.3.4", + port: 80, + isUDP: false, + wantAction: ActionProxy, + wantArg: "", + }, + { + name: "ip mismatch", + host: "100.5.6.0", + port: 1234, + isUDP: false, + wantAction: ActionBlock, + wantArg: "", + }, + { + name: "domain proxy cache", + host: "google.com", + port: 443, + isUDP: false, + wantAction: ActionProxy, + wantArg: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, got1, _, err := e.ResolveAndMatch(tt.addr) + gotAction, gotArg, _, _, err := e.ResolveAndMatch(tt.host, tt.port, tt.isUDP) if (err != nil) != tt.wantErr { t.Errorf("ResolveAndMatch() error = %v, wantErr %v", err, tt.wantErr) return } - if got != tt.want { - t.Errorf("ResolveAndMatch() got = %v, want %v", got, tt.want) + if gotAction != tt.wantAction { + t.Errorf("ResolveAndMatch() gotAction = %v, wantAction %v", gotAction, tt.wantAction) } - if got1 != tt.want1 { - t.Errorf("ResolveAndMatch() got1 = %v, want %v", got1, tt.want1) + if gotArg != tt.wantArg { + t.Errorf("ResolveAndMatch() gotArg = %v, wantAction %v", gotArg, tt.wantArg) } }) } diff --git a/pkg/acl/entry.go b/pkg/acl/entry.go index 637e5ec..db23128 100644 --- a/pkg/acl/entry.go +++ b/pkg/acl/entry.go @@ -5,10 +5,12 @@ import ( "fmt" "github.com/oschwald/geoip2-golang" "net" + "strconv" "strings" ) type Action byte +type Protocol byte const ( ActionDirect = Action(iota) @@ -17,78 +19,134 @@ const ( ActionHijack ) +const ( + ProtocolAll = Protocol(iota) + ProtocolTCP + ProtocolUDP +) + type Entry struct { - Net *net.IPNet - Domain string - Suffix bool - Country string - All bool Action Action ActionArg string + Matcher Matcher } -func (e Entry) MatchDomain(domain string) bool { - if e.All { - return true - } - if len(e.Domain) > 0 && len(domain) > 0 { - ld := strings.ToLower(domain) - if e.Suffix { - return e.Domain == ld || strings.HasSuffix(ld, "."+e.Domain) - } else { - return e.Domain == ld - } - } - return false +type MatchRequest struct { + IP net.IP + Domain string + + Protocol Protocol + Port uint16 + + DB *geoip2.Reader } -func (e Entry) MatchIP(ip net.IP, db *geoip2.Reader) bool { - if e.All { - return true +type Matcher interface { + Match(MatchRequest) bool +} + +type matcherBase struct { + Protocol Protocol + Port uint16 // 0 for all ports +} + +func (m *matcherBase) MatchProtocolPort(p Protocol, port uint16) bool { + return (m.Protocol == ProtocolAll || m.Protocol == p) && (m.Port == 0 || m.Port == port) +} + +func parseProtocolPort(s string) (Protocol, uint16, error) { + if len(s) == 0 || s == "*" { + return ProtocolAll, 0, nil } - if ip == nil { + parts := strings.Split(s, "/") + if len(parts) != 2 { + return ProtocolAll, 0, errors.New("invalid protocol/port syntax") + } + protocol := ProtocolAll + switch parts[0] { + case "tcp": + protocol = ProtocolTCP + case "udp": + protocol = ProtocolUDP + case "*": + protocol = ProtocolAll + default: + return ProtocolAll, 0, errors.New("invalid protocol") + } + if parts[1] == "*" { + return protocol, 0, nil + } + port, err := strconv.ParseUint(parts[1], 10, 16) + if err != nil { + return ProtocolAll, 0, errors.New("invalid port") + } + return protocol, uint16(port), nil +} + +type netMatcher struct { + matcherBase + Net *net.IPNet +} + +func (m *netMatcher) Match(r MatchRequest) bool { + if r.IP == nil { return false } - if e.Net != nil { - return e.Net.Contains(ip) - } - if len(e.Country) > 0 && db != nil { - country, err := db.Country(ip) - if err != nil { - return false - } - return country.Country.IsoCode == e.Country - } - return false + return m.Net.Contains(r.IP) && m.MatchProtocolPort(r.Protocol, r.Port) +} + +type domainMatcher struct { + matcherBase + Domain string + Suffix bool +} + +func (m *domainMatcher) Match(r MatchRequest) bool { + if len(r.Domain) == 0 { + return false + } + domain := strings.ToLower(r.Domain) + return (m.Domain == domain || (m.Suffix && strings.HasSuffix(domain, "."+m.Domain))) && + m.MatchProtocolPort(r.Protocol, r.Port) +} + +type countryMatcher struct { + matcherBase + Country string // ISO 3166-1 alpha-2 country code, upper case +} + +func (m *countryMatcher) Match(r MatchRequest) bool { + if r.IP == nil || r.DB == nil { + return false + } + c, err := r.DB.Country(r.IP) + if err != nil { + return false + } + return c.Country.IsoCode == m.Country && m.MatchProtocolPort(r.Protocol, r.Port) +} + +type allMatcher struct { + matcherBase +} + +func (m *allMatcher) Match(r MatchRequest) bool { + return m.MatchProtocolPort(r.Protocol, r.Port) +} + +func (e Entry) Match(r MatchRequest) bool { + return e.Matcher.Match(r) } -// Format: action cond_type cond arg -// Examples: -// proxy domain-suffix google.com -// block ip 8.8.8.8 -// hijack cidr 192.168.1.1/24 127.0.0.1 func ParseEntry(s string) (Entry, error) { fields := strings.Fields(s) if len(fields) < 2 { - return Entry{}, fmt.Errorf("expecting at least 2 fields, got %d", len(fields)) + return Entry{}, fmt.Errorf("expected at least 2 fields, got %d", len(fields)) } - args := fields[1:] - if len(args) == 1 { - // Make sure there are at least 2 args - args = append(args, "") - } - ipNet, domain, suffix, country, all, err := parseCond(args[0], args[1]) - if err != nil { - return Entry{}, err - } - e := Entry{ - Net: ipNet, - Domain: domain, - Suffix: suffix, - Country: country, - All: all, - } - switch strings.ToLower(fields[0]) { + e := Entry{} + action := fields[0] + conds := fields[1:] + switch strings.ToLower(action) { case "direct": e.Action = ActionDirect case "proxy": @@ -96,59 +154,159 @@ func ParseEntry(s string) (Entry, error) { case "block": e.Action = ActionBlock case "hijack": - if len(args) < 3 { - return Entry{}, fmt.Errorf("no hijack destination for %s %s", args[0], args[1]) + if len(conds) < 2 { + return Entry{}, fmt.Errorf("hijack requires at least 3 fields, got %d", len(fields)) } e.Action = ActionHijack - e.ActionArg = args[2] + e.ActionArg = conds[len(conds)-1] + conds = conds[:len(conds)-1] default: return Entry{}, fmt.Errorf("invalid action %s", fields[0]) } + m, err := condsToMatcher(conds) + if err != nil { + return Entry{}, err + } + e.Matcher = m return e, nil } -func parseCond(typ, cond string) (*net.IPNet, string, bool, string, bool, error) { +func condsToMatcher(conds []string) (Matcher, error) { + if len(conds) < 1 { + return nil, errors.New("no condition specified") + } + typ, args := conds[0], conds[1:] switch strings.ToLower(typ) { case "domain": - if len(cond) == 0 { - return nil, "", false, "", false, errors.New("empty domain") + // domain + if len(args) == 0 || len(args) > 2 { + return nil, fmt.Errorf("invalid number of arguments for domain: %d, expected 1 or 2", len(args)) } - return nil, strings.ToLower(cond), false, "", false, nil + mb := matcherBase{} + if len(args) == 2 { + protocol, port, err := parseProtocolPort(args[1]) + if err != nil { + return nil, err + } + mb.Protocol = protocol + mb.Port = port + } + return &domainMatcher{ + matcherBase: mb, + Domain: args[0], + Suffix: false, + }, nil case "domain-suffix": - if len(cond) == 0 { - return nil, "", false, "", false, errors.New("empty domain suffix") + // domain-suffix + if len(args) == 0 || len(args) > 2 { + return nil, fmt.Errorf("invalid number of arguments for domain-suffix: %d, expected 1 or 2", len(args)) } - return nil, strings.ToLower(cond), true, "", false, nil + mb := matcherBase{} + if len(args) == 2 { + protocol, port, err := parseProtocolPort(args[1]) + if err != nil { + return nil, err + } + mb.Protocol = protocol + mb.Port = port + } + return &domainMatcher{ + matcherBase: mb, + Domain: args[0], + Suffix: true, + }, nil case "cidr": - _, ipNet, err := net.ParseCIDR(cond) + // cidr + if len(args) == 0 || len(args) > 2 { + return nil, fmt.Errorf("invalid number of arguments for cidr: %d, expected 1 or 2", len(args)) + } + mb := matcherBase{} + if len(args) == 2 { + protocol, port, err := parseProtocolPort(args[1]) + if err != nil { + return nil, err + } + mb.Protocol = protocol + mb.Port = port + } + _, ipNet, err := net.ParseCIDR(args[0]) if err != nil { - return nil, "", false, "", false, err + return nil, err } - return ipNet, "", false, "", false, nil + return &netMatcher{ + matcherBase: mb, + Net: ipNet, + }, nil case "ip": - ip := net.ParseIP(cond) - if ip == nil { - return nil, "", false, "", false, fmt.Errorf("invalid ip %s", cond) + // ip + if len(args) == 0 || len(args) > 2 { + return nil, fmt.Errorf("invalid number of arguments for ip: %d, expected 1 or 2", len(args)) } + mb := matcherBase{} + if len(args) == 2 { + protocol, port, err := parseProtocolPort(args[1]) + if err != nil { + return nil, err + } + mb.Protocol = protocol + mb.Port = port + } + ip := net.ParseIP(args[0]) + if ip == nil { + return nil, fmt.Errorf("invalid ip: %s", args[0]) + } + var ipNet *net.IPNet if ip.To4() != nil { - return &net.IPNet{ + ipNet = &net.IPNet{ IP: ip, Mask: net.CIDRMask(32, 32), - }, "", false, "", false, nil + } } else { - return &net.IPNet{ + ipNet = &net.IPNet{ IP: ip, Mask: net.CIDRMask(128, 128), - }, "", false, "", false, nil + } } + return &netMatcher{ + matcherBase: mb, + Net: ipNet, + }, nil case "country": - if len(cond) == 0 { - return nil, "", false, "", false, errors.New("empty country") + // country + if len(args) == 0 || len(args) > 2 { + return nil, fmt.Errorf("invalid number of arguments for country: %d, expected 1 or 2", len(args)) } - return nil, "", false, strings.ToUpper(cond), false, nil + mb := matcherBase{} + if len(args) == 2 { + protocol, port, err := parseProtocolPort(args[1]) + if err != nil { + return nil, err + } + mb.Protocol = protocol + mb.Port = port + } + return &countryMatcher{ + matcherBase: mb, + Country: strings.ToUpper(args[0]), + }, nil case "all": - return nil, "", false, "", true, nil + // all + if len(args) > 1 { + return nil, fmt.Errorf("invalid number of arguments for all: %d, expected 0 or 1", len(args)) + } + mb := matcherBase{} + if len(args) == 1 { + protocol, port, err := parseProtocolPort(args[0]) + if err != nil { + return nil, err + } + mb.Protocol = protocol + mb.Port = port + } + return &allMatcher{ + matcherBase: mb, + }, nil default: - return nil, "", false, "", false, fmt.Errorf("invalid condition type %s", typ) + return nil, fmt.Errorf("invalid condition type: %s", typ) } } diff --git a/pkg/acl/entry_test.go b/pkg/acl/entry_test.go index f4f551b..37b8807 100644 --- a/pkg/acl/entry_test.go +++ b/pkg/acl/entry_test.go @@ -7,7 +7,7 @@ import ( ) func TestParseEntry(t *testing.T) { - _, ok4ipnet, _ := net.ParseCIDR("8.8.8.0/24") + _, ok3net, _ := net.ParseCIDR("8.8.8.0/24") type args struct { s string @@ -20,28 +20,45 @@ func TestParseEntry(t *testing.T) { }{ {name: "empty", args: args{""}, want: Entry{}, wantErr: true}, {name: "ok 1", args: args{"direct domain-suffix google.com"}, - want: Entry{nil, "google.com", true, "", false, ActionDirect, ""}, + want: Entry{ActionDirect, "", &domainMatcher{ + matcherBase: matcherBase{}, + Domain: "google.com", + Suffix: true, + }}, wantErr: false}, - {name: "ok 2", args: args{"proxy ip 8.8.8.8"}, - want: Entry{&net.IPNet{net.ParseIP("8.8.8.8"), net.CIDRMask(32, 32)}, - "", false, "", false, ActionProxy, ""}, wantErr: false}, - {name: "ok 3", args: args{"hijack domain mad.bad 127.0.0.1"}, - want: Entry{nil, "mad.bad", false, "", false, ActionHijack, "127.0.0.1"}, + {name: "ok 2", args: args{"proxy domain shithole"}, + want: Entry{ActionProxy, "", &domainMatcher{ + matcherBase: matcherBase{}, + Domain: "shithole", + Suffix: false, + }}, wantErr: false}, - {name: "ok 4", args: args{"block cidr 8.8.8.0/24"}, - want: Entry{ok4ipnet, "", false, "", false, ActionBlock, ""}, + {name: "ok 3", args: args{"block cidr 8.8.8.0/24 */53"}, + want: Entry{ActionBlock, "", &netMatcher{ + matcherBase: matcherBase{ProtocolAll, 53}, + Net: ok3net, + }}, wantErr: false}, - {name: "ok 5", args: args{"block all"}, - want: Entry{nil, "", false, "", true, ActionBlock, ""}, + {name: "ok 4", args: args{"hijack all udp/* udpblackhole.net"}, + want: Entry{ActionHijack, "udpblackhole.net", &allMatcher{ + matcherBase: matcherBase{ProtocolUDP, 0}, + }}, wantErr: false}, - {name: "ok 6", args: args{"block country cn"}, - want: Entry{nil, "", false, "CN", false, ActionBlock, ""}, - wantErr: false}, - {name: "invalid 1", args: args{"proxy domain"}, want: Entry{}, wantErr: true}, - {name: "invalid 2", args: args{"proxy dom google.com"}, want: Entry{}, wantErr: true}, - {name: "invalid 3", args: args{"hijack ip 1.1.1.1"}, want: Entry{}, wantErr: true}, - {name: "invalid 4", args: args{"direct cidr"}, want: Entry{}, wantErr: true}, - {name: "invalid 5", args: args{"oxy ip 8.8.8.8"}, want: Entry{}, wantErr: true}, + {name: "err 1", args: args{"what the heck"}, + want: Entry{}, + wantErr: true}, + {name: "err 2", args: args{"proxy sucks ass"}, + want: Entry{}, + wantErr: true}, + {name: "err 3", args: args{"block ip 999.999.999.999"}, + want: Entry{}, + wantErr: true}, + {name: "err 4", args: args{"hijack domain google.com"}, + want: Entry{}, + wantErr: true}, + {name: "err 5", args: args{"hijack domain google.com bing.com 123"}, + want: Entry{}, + wantErr: true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -51,7 +68,7 @@ func TestParseEntry(t *testing.T) { return } if !reflect.DeepEqual(got, tt.want) { - t.Errorf("ParseEntry() got = %v, want %v", got, tt.want) + t.Errorf("ParseEntry() got = %v, wantAction %v", got, tt.want) } }) } diff --git a/pkg/core/server_client.go b/pkg/core/server_client.go index f69e667..6f1a19e 100644 --- a/pkg/core/server_client.go +++ b/pkg/core/server_client.go @@ -158,7 +158,7 @@ func (c *serverClient) handleMessage(msg []byte) { var ipAddr *net.IPAddr var err error if c.ACLEngine != nil { - action, arg, isDomain, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host) + action, arg, isDomain, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host, dfMsg.Port, true) } else { ipAddr, isDomain, err = c.Transport.ResolveIPAddr(dfMsg.Host) } @@ -208,7 +208,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) { var ipAddr *net.IPAddr var err error if c.ACLEngine != nil { - action, arg, isDomain, ipAddr, err = c.ACLEngine.ResolveAndMatch(host) + action, arg, isDomain, ipAddr, err = c.ACLEngine.ResolveAndMatch(host, port, false) } else { ipAddr, isDomain, err = c.Transport.ResolveIPAddr(host) } diff --git a/pkg/http/server.go b/pkg/http/server.go index 650f20a..5afef7a 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -34,7 +34,7 @@ func NewProxyHTTPServer(hyClient *core.Client, transport *transport.ClientTransp var ipAddr *net.IPAddr var resErr error if aclEngine != nil { - action, arg, _, ipAddr, resErr = aclEngine.ResolveAndMatch(host) + action, arg, _, ipAddr, resErr = aclEngine.ResolveAndMatch(host, port, false) // Doesn't always matter if the resolution fails, as we may send it through HyClient } newDialFunc(addr, action, arg) diff --git a/pkg/socks5/server.go b/pkg/socks5/server.go index c679455..3c50c4b 100644 --- a/pkg/socks5/server.go +++ b/pkg/socks5/server.go @@ -171,7 +171,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error { var ipAddr *net.IPAddr var resErr error if s.ACLEngine != nil { - action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host) + action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host, port, false) // Doesn't always matter if the resolution fails, as we may send it through HyClient } s.TCPRequestFunc(c.RemoteAddr(), addr, action, arg) @@ -377,7 +377,7 @@ func (s *Server) udpServer(clientConn *net.UDPConn, localRelayConn *net.UDPConn, var ipAddr *net.IPAddr var resErr error if s.ACLEngine != nil && localRelayConn != nil { - action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host) + action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host, port, true) // Doesn't always matter if the resolution fails, as we may send it through HyClient } // Handle according to the action