ACL protocol & port support

This commit is contained in:
Toby 2022-05-11 17:26:39 -07:00
parent 7a0977023e
commit e9974b0398
7 changed files with 422 additions and 176 deletions

View File

@ -20,7 +20,13 @@ type Engine struct {
GeoIPReader *geoip2.Reader GeoIPReader *geoip2.Reader
} }
type cacheEntry struct { type cacheKey struct {
Host string
Port uint16
IsUDP bool
}
type cacheValue struct {
Action Action Action Action
Arg string Arg string
} }
@ -44,7 +50,7 @@ func LoadFromFile(filename string, resolveIPAddr func(string) (*net.IPAddr, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(entry.Country) > 0 && geoIPReader == nil { if _, ok := entry.Matcher.(*countryMatcher); ok && geoIPReader == nil {
geoIPReader, err = geoIPLoadFunc() // lazy load GeoIP reader only when needed geoIPReader, err = geoIPLoadFunc() // lazy load GeoIP reader only when needed
if err != nil { if err != nil {
return nil, err return nil, err
@ -66,44 +72,69 @@ func LoadFromFile(filename string, resolveIPAddr func(string) (*net.IPAddr, erro
} }
// action, arg, isDomain, resolvedIP, error // action, arg, isDomain, resolvedIP, error
func (e *Engine) ResolveAndMatch(host string) (Action, string, bool, *net.IPAddr, error) { func (e *Engine) ResolveAndMatch(host string, port uint16, isUDP bool) (Action, string, bool, *net.IPAddr, error) {
ip, zone := utils.ParseIPZone(host) ip, zone := utils.ParseIPZone(host)
if ip == nil { if ip == nil {
// Domain // Domain
ipAddr, err := e.ResolveIPAddr(host) ipAddr, err := e.ResolveIPAddr(host)
if v, ok := e.Cache.Get(host); ok { if v, ok := e.Cache.Get(cacheKey{host, port, isUDP}); ok {
// Cache hit // Cache hit
ce := v.(cacheEntry) ce := v.(cacheValue)
return ce.Action, ce.Arg, true, ipAddr, err return ce.Action, ce.Arg, true, ipAddr, err
} }
for _, entry := range e.Entries { for _, entry := range e.Entries {
if entry.MatchDomain(host) || (ipAddr != nil && entry.MatchIP(ipAddr.IP, e.GeoIPReader)) { mReq := MatchRequest{
e.Cache.Add(host, cacheEntry{entry.Action, entry.ActionArg}) Domain: host,
Port: port,
DB: e.GeoIPReader,
}
if ipAddr != nil {
mReq.IP = ipAddr.IP
}
if isUDP {
mReq.Protocol = ProtocolUDP
} else {
mReq.Protocol = ProtocolTCP
}
if entry.Match(mReq) {
e.Cache.Add(cacheKey{host, port, isUDP},
cacheValue{entry.Action, entry.ActionArg})
return entry.Action, entry.ActionArg, true, ipAddr, err return entry.Action, entry.ActionArg, true, ipAddr, err
} }
} }
e.Cache.Add(host, cacheEntry{e.DefaultAction, ""}) e.Cache.Add(cacheKey{host, port, isUDP}, cacheValue{e.DefaultAction, ""})
return e.DefaultAction, "", true, ipAddr, err return e.DefaultAction, "", true, ipAddr, err
} else { } else {
// IP // IP
if v, ok := e.Cache.Get(ip.String()); ok { if v, ok := e.Cache.Get(cacheKey{ip.String(), port, isUDP}); ok {
// Cache hit // Cache hit
ce := v.(cacheEntry) ce := v.(cacheValue)
return ce.Action, ce.Arg, false, &net.IPAddr{ return ce.Action, ce.Arg, false, &net.IPAddr{
IP: ip, IP: ip,
Zone: zone, Zone: zone,
}, nil }, nil
} }
for _, entry := range e.Entries { for _, entry := range e.Entries {
if entry.MatchIP(ip, e.GeoIPReader) { mReq := MatchRequest{
e.Cache.Add(ip.String(), cacheEntry{entry.Action, entry.ActionArg}) IP: ip,
Port: port,
DB: e.GeoIPReader,
}
if isUDP {
mReq.Protocol = ProtocolUDP
} else {
mReq.Protocol = ProtocolTCP
}
if entry.Match(mReq) {
e.Cache.Add(cacheKey{ip.String(), port, isUDP},
cacheValue{entry.Action, entry.ActionArg})
return entry.Action, entry.ActionArg, false, &net.IPAddr{ return entry.Action, entry.ActionArg, false, &net.IPAddr{
IP: ip, IP: ip,
Zone: zone, Zone: zone,
}, nil }, nil
} }
} }
e.Cache.Add(ip.String(), cacheEntry{e.DefaultAction, ""}) e.Cache.Add(cacheKey{ip.String(), port, isUDP}, cacheValue{e.DefaultAction, ""})
return e.DefaultAction, "", false, &net.IPAddr{ return e.DefaultAction, "", false, &net.IPAddr{
IP: ip, IP: ip,
Zone: zone, Zone: zone,

View File

@ -1,113 +1,153 @@
package acl package acl
import ( import (
"errors"
lru "github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
"net" "net"
"strings"
"testing" "testing"
) )
func TestEngine_ResolveAndMatch(t *testing.T) { func TestEngine_ResolveAndMatch(t *testing.T) {
cache, _ := lru.NewARC(4) cache, _ := lru.NewARC(16)
e := &Engine{ e := &Engine{
DefaultAction: ActionDirect, DefaultAction: ActionDirect,
Entries: []Entry{ Entries: []Entry{
{ {
Net: nil,
Domain: "google.com",
Suffix: false,
All: false,
Action: ActionProxy, Action: ActionProxy,
ActionArg: "", ActionArg: "",
Matcher: &domainMatcher{
matcherBase: matcherBase{
Protocol: ProtocolTCP,
Port: 443,
},
Domain: "google.com",
Suffix: false,
},
}, },
{ {
Net: nil,
Domain: "evil.corp",
Suffix: true,
All: false,
Action: ActionHijack, Action: ActionHijack,
ActionArg: "good.org", ActionArg: "good.org",
Matcher: &domainMatcher{
matcherBase: matcherBase{},
Domain: "evil.corp",
Suffix: true,
},
}, },
{ {
Net: &net.IPNet{
IP: net.ParseIP("10.0.0.0"),
Mask: net.CIDRMask(8, 32),
},
Domain: "",
Suffix: false,
All: false,
Action: ActionProxy, Action: ActionProxy,
ActionArg: "", ActionArg: "",
Matcher: &netMatcher{
matcherBase: matcherBase{},
Net: &net.IPNet{
IP: net.ParseIP("10.0.0.0"),
Mask: net.CIDRMask(8, 32),
},
},
}, },
{ {
Net: nil,
Domain: "",
Suffix: false,
All: true,
Action: ActionBlock, Action: ActionBlock,
ActionArg: "", ActionArg: "",
Matcher: &allMatcher{},
}, },
}, },
Cache: cache, Cache: cache,
ResolveIPAddr: func(s string) (*net.IPAddr, error) {
if strings.Contains(s, "evil.corp") {
return nil, errors.New("resolve error")
}
return net.ResolveIPAddr("ip", s)
},
} }
tests := []struct { tests := []struct {
name string name string
addr string host string
want Action port uint16
want1 string isUDP bool
wantErr bool wantAction Action
wantArg string
wantErr bool
}{ }{
{ {
name: "domain direct", name: "domain proxy",
addr: "google.com", host: "google.com",
want: ActionProxy, port: 443,
want1: "", isUDP: false,
wantAction: ActionProxy,
wantArg: "",
}, },
{ {
name: "domain suffix 1", name: "domain block",
addr: "evil.corp", host: "google.com",
want: ActionHijack, port: 80,
want1: "good.org", isUDP: false,
wantErr: true, wantAction: ActionBlock,
wantArg: "",
}, },
{ {
name: "domain suffix 2", name: "domain suffix 1",
addr: "notevil.corp", host: "evil.corp",
want: ActionBlock, port: 8899,
want1: "", isUDP: true,
wantErr: true, wantAction: ActionHijack,
wantArg: "good.org",
wantErr: true,
}, },
{ {
name: "domain suffix 3", name: "domain suffix 2",
addr: "im.real.evil.corp", host: "notevil.corp",
want: ActionHijack, port: 22,
want1: "good.org", isUDP: false,
wantErr: true, wantAction: ActionBlock,
wantArg: "",
wantErr: true,
}, },
{ {
name: "ip match", name: "domain suffix 3",
addr: "10.2.3.4", host: "im.real.evil.corp",
want: ActionProxy, port: 443,
want1: "", isUDP: true,
wantAction: ActionHijack,
wantArg: "good.org",
wantErr: true,
}, },
{ {
name: "ip mismatch", name: "ip match",
addr: "100.5.6.0", host: "10.2.3.4",
want: ActionBlock, port: 80,
want1: "", isUDP: false,
wantAction: ActionProxy,
wantArg: "",
},
{
name: "ip mismatch",
host: "100.5.6.0",
port: 1234,
isUDP: false,
wantAction: ActionBlock,
wantArg: "",
},
{
name: "domain proxy cache",
host: "google.com",
port: 443,
isUDP: false,
wantAction: ActionProxy,
wantArg: "",
}, },
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, got1, _, err := e.ResolveAndMatch(tt.addr) gotAction, gotArg, _, _, err := e.ResolveAndMatch(tt.host, tt.port, tt.isUDP)
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("ResolveAndMatch() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("ResolveAndMatch() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if got != tt.want { if gotAction != tt.wantAction {
t.Errorf("ResolveAndMatch() got = %v, want %v", got, tt.want) t.Errorf("ResolveAndMatch() gotAction = %v, wantAction %v", gotAction, tt.wantAction)
} }
if got1 != tt.want1 { if gotArg != tt.wantArg {
t.Errorf("ResolveAndMatch() got1 = %v, want %v", got1, tt.want1) t.Errorf("ResolveAndMatch() gotArg = %v, wantAction %v", gotArg, tt.wantArg)
} }
}) })
} }

View File

@ -5,10 +5,12 @@ import (
"fmt" "fmt"
"github.com/oschwald/geoip2-golang" "github.com/oschwald/geoip2-golang"
"net" "net"
"strconv"
"strings" "strings"
) )
type Action byte type Action byte
type Protocol byte
const ( const (
ActionDirect = Action(iota) ActionDirect = Action(iota)
@ -17,78 +19,134 @@ const (
ActionHijack ActionHijack
) )
const (
ProtocolAll = Protocol(iota)
ProtocolTCP
ProtocolUDP
)
type Entry struct { type Entry struct {
Net *net.IPNet
Domain string
Suffix bool
Country string
All bool
Action Action Action Action
ActionArg string ActionArg string
Matcher Matcher
} }
func (e Entry) MatchDomain(domain string) bool { type MatchRequest struct {
if e.All { IP net.IP
return true Domain string
}
if len(e.Domain) > 0 && len(domain) > 0 { Protocol Protocol
ld := strings.ToLower(domain) Port uint16
if e.Suffix {
return e.Domain == ld || strings.HasSuffix(ld, "."+e.Domain) DB *geoip2.Reader
} else {
return e.Domain == ld
}
}
return false
} }
func (e Entry) MatchIP(ip net.IP, db *geoip2.Reader) bool { type Matcher interface {
if e.All { Match(MatchRequest) bool
return true }
type matcherBase struct {
Protocol Protocol
Port uint16 // 0 for all ports
}
func (m *matcherBase) MatchProtocolPort(p Protocol, port uint16) bool {
return (m.Protocol == ProtocolAll || m.Protocol == p) && (m.Port == 0 || m.Port == port)
}
func parseProtocolPort(s string) (Protocol, uint16, error) {
if len(s) == 0 || s == "*" {
return ProtocolAll, 0, nil
} }
if ip == nil { parts := strings.Split(s, "/")
if len(parts) != 2 {
return ProtocolAll, 0, errors.New("invalid protocol/port syntax")
}
protocol := ProtocolAll
switch parts[0] {
case "tcp":
protocol = ProtocolTCP
case "udp":
protocol = ProtocolUDP
case "*":
protocol = ProtocolAll
default:
return ProtocolAll, 0, errors.New("invalid protocol")
}
if parts[1] == "*" {
return protocol, 0, nil
}
port, err := strconv.ParseUint(parts[1], 10, 16)
if err != nil {
return ProtocolAll, 0, errors.New("invalid port")
}
return protocol, uint16(port), nil
}
type netMatcher struct {
matcherBase
Net *net.IPNet
}
func (m *netMatcher) Match(r MatchRequest) bool {
if r.IP == nil {
return false return false
} }
if e.Net != nil { return m.Net.Contains(r.IP) && m.MatchProtocolPort(r.Protocol, r.Port)
return e.Net.Contains(ip) }
}
if len(e.Country) > 0 && db != nil { type domainMatcher struct {
country, err := db.Country(ip) matcherBase
if err != nil { Domain string
return false Suffix bool
} }
return country.Country.IsoCode == e.Country
} func (m *domainMatcher) Match(r MatchRequest) bool {
return false if len(r.Domain) == 0 {
return false
}
domain := strings.ToLower(r.Domain)
return (m.Domain == domain || (m.Suffix && strings.HasSuffix(domain, "."+m.Domain))) &&
m.MatchProtocolPort(r.Protocol, r.Port)
}
type countryMatcher struct {
matcherBase
Country string // ISO 3166-1 alpha-2 country code, upper case
}
func (m *countryMatcher) Match(r MatchRequest) bool {
if r.IP == nil || r.DB == nil {
return false
}
c, err := r.DB.Country(r.IP)
if err != nil {
return false
}
return c.Country.IsoCode == m.Country && m.MatchProtocolPort(r.Protocol, r.Port)
}
type allMatcher struct {
matcherBase
}
func (m *allMatcher) Match(r MatchRequest) bool {
return m.MatchProtocolPort(r.Protocol, r.Port)
}
func (e Entry) Match(r MatchRequest) bool {
return e.Matcher.Match(r)
} }
// Format: action cond_type cond arg
// Examples:
// proxy domain-suffix google.com
// block ip 8.8.8.8
// hijack cidr 192.168.1.1/24 127.0.0.1
func ParseEntry(s string) (Entry, error) { func ParseEntry(s string) (Entry, error) {
fields := strings.Fields(s) fields := strings.Fields(s)
if len(fields) < 2 { if len(fields) < 2 {
return Entry{}, fmt.Errorf("expecting at least 2 fields, got %d", len(fields)) return Entry{}, fmt.Errorf("expected at least 2 fields, got %d", len(fields))
} }
args := fields[1:] e := Entry{}
if len(args) == 1 { action := fields[0]
// Make sure there are at least 2 args conds := fields[1:]
args = append(args, "") switch strings.ToLower(action) {
}
ipNet, domain, suffix, country, all, err := parseCond(args[0], args[1])
if err != nil {
return Entry{}, err
}
e := Entry{
Net: ipNet,
Domain: domain,
Suffix: suffix,
Country: country,
All: all,
}
switch strings.ToLower(fields[0]) {
case "direct": case "direct":
e.Action = ActionDirect e.Action = ActionDirect
case "proxy": case "proxy":
@ -96,59 +154,159 @@ func ParseEntry(s string) (Entry, error) {
case "block": case "block":
e.Action = ActionBlock e.Action = ActionBlock
case "hijack": case "hijack":
if len(args) < 3 { if len(conds) < 2 {
return Entry{}, fmt.Errorf("no hijack destination for %s %s", args[0], args[1]) return Entry{}, fmt.Errorf("hijack requires at least 3 fields, got %d", len(fields))
} }
e.Action = ActionHijack e.Action = ActionHijack
e.ActionArg = args[2] e.ActionArg = conds[len(conds)-1]
conds = conds[:len(conds)-1]
default: default:
return Entry{}, fmt.Errorf("invalid action %s", fields[0]) return Entry{}, fmt.Errorf("invalid action %s", fields[0])
} }
m, err := condsToMatcher(conds)
if err != nil {
return Entry{}, err
}
e.Matcher = m
return e, nil return e, nil
} }
func parseCond(typ, cond string) (*net.IPNet, string, bool, string, bool, error) { func condsToMatcher(conds []string) (Matcher, error) {
if len(conds) < 1 {
return nil, errors.New("no condition specified")
}
typ, args := conds[0], conds[1:]
switch strings.ToLower(typ) { switch strings.ToLower(typ) {
case "domain": case "domain":
if len(cond) == 0 { // domain <domain> <optional: protocol/port>
return nil, "", false, "", false, errors.New("empty domain") if len(args) == 0 || len(args) > 2 {
return nil, fmt.Errorf("invalid number of arguments for domain: %d, expected 1 or 2", len(args))
} }
return nil, strings.ToLower(cond), false, "", false, nil mb := matcherBase{}
if len(args) == 2 {
protocol, port, err := parseProtocolPort(args[1])
if err != nil {
return nil, err
}
mb.Protocol = protocol
mb.Port = port
}
return &domainMatcher{
matcherBase: mb,
Domain: args[0],
Suffix: false,
}, nil
case "domain-suffix": case "domain-suffix":
if len(cond) == 0 { // domain-suffix <domain> <optional: protocol/port>
return nil, "", false, "", false, errors.New("empty domain suffix") if len(args) == 0 || len(args) > 2 {
return nil, fmt.Errorf("invalid number of arguments for domain-suffix: %d, expected 1 or 2", len(args))
} }
return nil, strings.ToLower(cond), true, "", false, nil mb := matcherBase{}
if len(args) == 2 {
protocol, port, err := parseProtocolPort(args[1])
if err != nil {
return nil, err
}
mb.Protocol = protocol
mb.Port = port
}
return &domainMatcher{
matcherBase: mb,
Domain: args[0],
Suffix: true,
}, nil
case "cidr": case "cidr":
_, ipNet, err := net.ParseCIDR(cond) // cidr <cidr> <optional: protocol/port>
if len(args) == 0 || len(args) > 2 {
return nil, fmt.Errorf("invalid number of arguments for cidr: %d, expected 1 or 2", len(args))
}
mb := matcherBase{}
if len(args) == 2 {
protocol, port, err := parseProtocolPort(args[1])
if err != nil {
return nil, err
}
mb.Protocol = protocol
mb.Port = port
}
_, ipNet, err := net.ParseCIDR(args[0])
if err != nil { if err != nil {
return nil, "", false, "", false, err return nil, err
} }
return ipNet, "", false, "", false, nil return &netMatcher{
matcherBase: mb,
Net: ipNet,
}, nil
case "ip": case "ip":
ip := net.ParseIP(cond) // ip <ip> <optional: protocol/port>
if ip == nil { if len(args) == 0 || len(args) > 2 {
return nil, "", false, "", false, fmt.Errorf("invalid ip %s", cond) return nil, fmt.Errorf("invalid number of arguments for ip: %d, expected 1 or 2", len(args))
} }
mb := matcherBase{}
if len(args) == 2 {
protocol, port, err := parseProtocolPort(args[1])
if err != nil {
return nil, err
}
mb.Protocol = protocol
mb.Port = port
}
ip := net.ParseIP(args[0])
if ip == nil {
return nil, fmt.Errorf("invalid ip: %s", args[0])
}
var ipNet *net.IPNet
if ip.To4() != nil { if ip.To4() != nil {
return &net.IPNet{ ipNet = &net.IPNet{
IP: ip, IP: ip,
Mask: net.CIDRMask(32, 32), Mask: net.CIDRMask(32, 32),
}, "", false, "", false, nil }
} else { } else {
return &net.IPNet{ ipNet = &net.IPNet{
IP: ip, IP: ip,
Mask: net.CIDRMask(128, 128), Mask: net.CIDRMask(128, 128),
}, "", false, "", false, nil }
} }
return &netMatcher{
matcherBase: mb,
Net: ipNet,
}, nil
case "country": case "country":
if len(cond) == 0 { // country <country> <optional: protocol/port>
return nil, "", false, "", false, errors.New("empty country") if len(args) == 0 || len(args) > 2 {
return nil, fmt.Errorf("invalid number of arguments for country: %d, expected 1 or 2", len(args))
} }
return nil, "", false, strings.ToUpper(cond), false, nil mb := matcherBase{}
if len(args) == 2 {
protocol, port, err := parseProtocolPort(args[1])
if err != nil {
return nil, err
}
mb.Protocol = protocol
mb.Port = port
}
return &countryMatcher{
matcherBase: mb,
Country: strings.ToUpper(args[0]),
}, nil
case "all": case "all":
return nil, "", false, "", true, nil // all <optional: protocol/port>
if len(args) > 1 {
return nil, fmt.Errorf("invalid number of arguments for all: %d, expected 0 or 1", len(args))
}
mb := matcherBase{}
if len(args) == 1 {
protocol, port, err := parseProtocolPort(args[0])
if err != nil {
return nil, err
}
mb.Protocol = protocol
mb.Port = port
}
return &allMatcher{
matcherBase: mb,
}, nil
default: default:
return nil, "", false, "", false, fmt.Errorf("invalid condition type %s", typ) return nil, fmt.Errorf("invalid condition type: %s", typ)
} }
} }

View File

@ -7,7 +7,7 @@ import (
) )
func TestParseEntry(t *testing.T) { func TestParseEntry(t *testing.T) {
_, ok4ipnet, _ := net.ParseCIDR("8.8.8.0/24") _, ok3net, _ := net.ParseCIDR("8.8.8.0/24")
type args struct { type args struct {
s string s string
@ -20,28 +20,45 @@ func TestParseEntry(t *testing.T) {
}{ }{
{name: "empty", args: args{""}, want: Entry{}, wantErr: true}, {name: "empty", args: args{""}, want: Entry{}, wantErr: true},
{name: "ok 1", args: args{"direct domain-suffix google.com"}, {name: "ok 1", args: args{"direct domain-suffix google.com"},
want: Entry{nil, "google.com", true, "", false, ActionDirect, ""}, want: Entry{ActionDirect, "", &domainMatcher{
matcherBase: matcherBase{},
Domain: "google.com",
Suffix: true,
}},
wantErr: false}, wantErr: false},
{name: "ok 2", args: args{"proxy ip 8.8.8.8"}, {name: "ok 2", args: args{"proxy domain shithole"},
want: Entry{&net.IPNet{net.ParseIP("8.8.8.8"), net.CIDRMask(32, 32)}, want: Entry{ActionProxy, "", &domainMatcher{
"", false, "", false, ActionProxy, ""}, wantErr: false}, matcherBase: matcherBase{},
{name: "ok 3", args: args{"hijack domain mad.bad 127.0.0.1"}, Domain: "shithole",
want: Entry{nil, "mad.bad", false, "", false, ActionHijack, "127.0.0.1"}, Suffix: false,
}},
wantErr: false}, wantErr: false},
{name: "ok 4", args: args{"block cidr 8.8.8.0/24"}, {name: "ok 3", args: args{"block cidr 8.8.8.0/24 */53"},
want: Entry{ok4ipnet, "", false, "", false, ActionBlock, ""}, want: Entry{ActionBlock, "", &netMatcher{
matcherBase: matcherBase{ProtocolAll, 53},
Net: ok3net,
}},
wantErr: false}, wantErr: false},
{name: "ok 5", args: args{"block all"}, {name: "ok 4", args: args{"hijack all udp/* udpblackhole.net"},
want: Entry{nil, "", false, "", true, ActionBlock, ""}, want: Entry{ActionHijack, "udpblackhole.net", &allMatcher{
matcherBase: matcherBase{ProtocolUDP, 0},
}},
wantErr: false}, wantErr: false},
{name: "ok 6", args: args{"block country cn"}, {name: "err 1", args: args{"what the heck"},
want: Entry{nil, "", false, "CN", false, ActionBlock, ""}, want: Entry{},
wantErr: false}, wantErr: true},
{name: "invalid 1", args: args{"proxy domain"}, want: Entry{}, wantErr: true}, {name: "err 2", args: args{"proxy sucks ass"},
{name: "invalid 2", args: args{"proxy dom google.com"}, want: Entry{}, wantErr: true}, want: Entry{},
{name: "invalid 3", args: args{"hijack ip 1.1.1.1"}, want: Entry{}, wantErr: true}, wantErr: true},
{name: "invalid 4", args: args{"direct cidr"}, want: Entry{}, wantErr: true}, {name: "err 3", args: args{"block ip 999.999.999.999"},
{name: "invalid 5", args: args{"oxy ip 8.8.8.8"}, want: Entry{}, wantErr: true}, want: Entry{},
wantErr: true},
{name: "err 4", args: args{"hijack domain google.com"},
want: Entry{},
wantErr: true},
{name: "err 5", args: args{"hijack domain google.com bing.com 123"},
want: Entry{},
wantErr: true},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -51,7 +68,7 @@ func TestParseEntry(t *testing.T) {
return return
} }
if !reflect.DeepEqual(got, tt.want) { if !reflect.DeepEqual(got, tt.want) {
t.Errorf("ParseEntry() got = %v, want %v", got, tt.want) t.Errorf("ParseEntry() got = %v, wantAction %v", got, tt.want)
} }
}) })
} }

View File

@ -158,7 +158,7 @@ func (c *serverClient) handleMessage(msg []byte) {
var ipAddr *net.IPAddr var ipAddr *net.IPAddr
var err error var err error
if c.ACLEngine != nil { if c.ACLEngine != nil {
action, arg, isDomain, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host) action, arg, isDomain, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host, dfMsg.Port, true)
} else { } else {
ipAddr, isDomain, err = c.Transport.ResolveIPAddr(dfMsg.Host) ipAddr, isDomain, err = c.Transport.ResolveIPAddr(dfMsg.Host)
} }
@ -208,7 +208,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
var ipAddr *net.IPAddr var ipAddr *net.IPAddr
var err error var err error
if c.ACLEngine != nil { if c.ACLEngine != nil {
action, arg, isDomain, ipAddr, err = c.ACLEngine.ResolveAndMatch(host) action, arg, isDomain, ipAddr, err = c.ACLEngine.ResolveAndMatch(host, port, false)
} else { } else {
ipAddr, isDomain, err = c.Transport.ResolveIPAddr(host) ipAddr, isDomain, err = c.Transport.ResolveIPAddr(host)
} }

View File

@ -34,7 +34,7 @@ func NewProxyHTTPServer(hyClient *core.Client, transport *transport.ClientTransp
var ipAddr *net.IPAddr var ipAddr *net.IPAddr
var resErr error var resErr error
if aclEngine != nil { if aclEngine != nil {
action, arg, _, ipAddr, resErr = aclEngine.ResolveAndMatch(host) action, arg, _, ipAddr, resErr = aclEngine.ResolveAndMatch(host, port, false)
// Doesn't always matter if the resolution fails, as we may send it through HyClient // Doesn't always matter if the resolution fails, as we may send it through HyClient
} }
newDialFunc(addr, action, arg) newDialFunc(addr, action, arg)

View File

@ -171,7 +171,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error {
var ipAddr *net.IPAddr var ipAddr *net.IPAddr
var resErr error var resErr error
if s.ACLEngine != nil { if s.ACLEngine != nil {
action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host) action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host, port, false)
// Doesn't always matter if the resolution fails, as we may send it through HyClient // Doesn't always matter if the resolution fails, as we may send it through HyClient
} }
s.TCPRequestFunc(c.RemoteAddr(), addr, action, arg) s.TCPRequestFunc(c.RemoteAddr(), addr, action, arg)
@ -377,7 +377,7 @@ func (s *Server) udpServer(clientConn *net.UDPConn, localRelayConn *net.UDPConn,
var ipAddr *net.IPAddr var ipAddr *net.IPAddr
var resErr error var resErr error
if s.ACLEngine != nil && localRelayConn != nil { if s.ACLEngine != nil && localRelayConn != nil {
action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host) action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host, port, true)
// Doesn't always matter if the resolution fails, as we may send it through HyClient // Doesn't always matter if the resolution fails, as we may send it through HyClient
} }
// Handle according to the action // Handle according to the action