diff --git a/app/cmd/server.go b/app/cmd/server.go index 2e15a42..0d3dd15 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -143,9 +143,10 @@ type serverConfigResolver struct { } type serverConfigACL struct { - File string `mapstructure:"file"` - Inline []string `mapstructure:"inline"` - GeoIP string `mapstructure:"geoip"` + File string `mapstructure:"file"` + Inline []string `mapstructure:"inline"` + GeoIP string `mapstructure:"geoip"` + GeoSite string `mapstructure:"geosite"` } type serverConfigOutboundDirect struct { @@ -460,21 +461,22 @@ func (c *serverConfig) fillOutboundConfig(hyConfig *server.Config) error { if c.ACL.File != "" && len(c.ACL.Inline) > 0 { return configError{Field: "acl", Err: errors.New("cannot set both acl.file and acl.inline")} } - gLoader := &utils.GeoIPLoader{ - Filename: c.ACL.GeoIP, - DownloadFunc: geoipDownloadFunc, - DownloadErrFunc: geoipDownloadErrFunc, + gLoader := &utils.GeoLoader{ + GeoIPFilename: c.ACL.GeoIP, + GeoSiteFilename: c.ACL.GeoSite, + DownloadFunc: geoDownloadFunc, + DownloadErrFunc: geoDownloadErrFunc, } if c.ACL.File != "" { hasACL = true - acl, err := outbounds.NewACLEngineFromFile(c.ACL.File, obs, gLoader.Load) + acl, err := outbounds.NewACLEngineFromFile(c.ACL.File, obs, gLoader) if err != nil { return configError{Field: "acl.file", Err: err} } uOb = acl } else if len(c.ACL.Inline) > 0 { hasACL = true - acl, err := outbounds.NewACLEngineFromString(strings.Join(c.ACL.Inline, "\n"), obs, gLoader.Load) + acl, err := outbounds.NewACLEngineFromString(strings.Join(c.ACL.Inline, "\n"), obs, gLoader) if err != nil { return configError{Field: "acl.inline", Err: err} } @@ -764,13 +766,13 @@ func runMasqTCPServer(s *masq.MasqTCPServer, httpAddr, httpsAddr string) { } } -func geoipDownloadFunc(filename, url string) { - logger.Info("downloading GeoIP database", zap.String("filename", filename), zap.String("url", url)) +func geoDownloadFunc(filename, url string) { + logger.Info("downloading database", zap.String("filename", filename), zap.String("url", url)) } -func geoipDownloadErrFunc(err error) { +func geoDownloadErrFunc(err error) { if err != nil { - logger.Error("failed to download GeoIP database", zap.Error(err)) + logger.Error("failed to download database", zap.Error(err)) } } diff --git a/app/cmd/server_test.go b/app/cmd/server_test.go index 611778c..30b088b 100644 --- a/app/cmd/server_test.go +++ b/app/cmd/server_test.go @@ -101,7 +101,8 @@ func TestServerConfig(t *testing.T) { "lmao(ok)", "kek(cringe,boba,tea)", }, - GeoIP: "fake.mmdb", + GeoIP: "some.dat", + GeoSite: "some_site.dat", }, Outbounds: []serverConfigOutboundEntry{ { diff --git a/app/cmd/server_test.yaml b/app/cmd/server_test.yaml index f78b936..fa5418b 100644 --- a/app/cmd/server_test.yaml +++ b/app/cmd/server_test.yaml @@ -75,7 +75,8 @@ acl: inline: - lmao(ok) - kek(cringe,boba,tea) - geoip: fake.mmdb + geoip: some.dat + geosite: some_site.dat outbounds: - name: goodstuff diff --git a/app/go.mod b/app/go.mod index 00fb452..9129a9f 100644 --- a/app/go.mod +++ b/app/go.mod @@ -9,7 +9,6 @@ require ( github.com/caddyserver/certmagic v0.17.2 github.com/mdp/qrterminal/v3 v3.1.1 github.com/mholt/acmez v1.0.4 - github.com/oschwald/geoip2-golang v1.9.0 github.com/spf13/cobra v1.7.0 github.com/spf13/viper v1.15.0 github.com/stretchr/testify v1.8.4 @@ -33,7 +32,6 @@ require ( github.com/miekg/dns v1.1.55 // indirect github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect - github.com/oschwald/maxminddb-golang v1.11.0 // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pelletier/go-toml/v2 v2.0.6 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect @@ -56,6 +54,7 @@ require ( golang.org/x/sys v0.13.0 // indirect golang.org/x/text v0.13.0 // indirect golang.org/x/tools v0.11.1 // indirect + google.golang.org/protobuf v1.28.1 // indirect gopkg.in/ini.v1 v1.67.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect rsc.io/qr v0.2.0 // indirect diff --git a/app/go.sum b/app/go.sum index 9e8ddbc..1470f6c 100644 --- a/app/go.sum +++ b/app/go.sum @@ -102,6 +102,7 @@ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvq github.com/golang/protobuf v1.4.1/go.mod h1:U8fpvMrcmy5pZrNK1lt4xCsGvpyWQ/VVv6QDs8UjoX8= github.com/golang/protobuf v1.4.2/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= @@ -115,6 +116,7 @@ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= @@ -179,10 +181,6 @@ github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= -github.com/oschwald/geoip2-golang v1.9.0 h1:uvD3O6fXAXs+usU+UGExshpdP13GAqp4GBrzN7IgKZc= -github.com/oschwald/geoip2-golang v1.9.0/go.mod h1:BHK6TvDyATVQhKNbQBdrj9eAvuwOMi2zSFXizL3K81Y= -github.com/oschwald/maxminddb-golang v1.11.0 h1:aSXMqYR/EPNjGE8epgqwDay+P30hCBZIveY0WZbAWh0= -github.com/oschwald/maxminddb-golang v1.11.0/go.mod h1:YmVI+H0zh3ySFR3w+oz8PCfglAFj3PuCmui13+P9zDg= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pelletier/go-toml/v2 v2.0.6 h1:nrzqCb7j9cDFj2coyLNLaZuJTLjWjlaz6nvTvIwycIU= @@ -575,6 +573,7 @@ google.golang.org/protobuf v1.23.0/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2 google.golang.org/protobuf v1.23.1-0.20200526195155-81db48ad09cc/go.mod h1:EGpADcykh3NcUnDUJcl1+ZksZNG86OlYog2l/sGQquU= google.golang.org/protobuf v1.24.0/go.mod h1:r/3tXBNzIEhYS9I1OUVjXDlt8tc493IdKGjtUeSXeh4= google.golang.org/protobuf v1.25.0/go.mod h1:9JNX74DMeImyA3h4bdi1ymwjUzf21/xIlbajtzgsN7c= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/app/internal/utils/geoip.go b/app/internal/utils/geoip.go deleted file mode 100644 index 2144ecb..0000000 --- a/app/internal/utils/geoip.go +++ /dev/null @@ -1,70 +0,0 @@ -package utils - -import ( - "io" - "net/http" - "os" - - "github.com/oschwald/geoip2-golang" -) - -const ( - geoipDefaultFilename = "GeoLite2-Country.mmdb" - geoipDownloadURL = "https://git.io/GeoLite2-Country.mmdb" -) - -// GeoIPLoader provides the on-demand GeoIP database loading function required by the ACL engine. -type GeoIPLoader struct { - Filename string - DownloadFunc func(filename, url string) // Called when downloading the GeoIP database. - DownloadErrFunc func(err error) // Called when downloading the GeoIP database succeeds/fails. - - db *geoip2.Reader -} - -func (l *GeoIPLoader) download() error { - resp, err := http.Get(geoipDownloadURL) - if err != nil { - return err - } - defer resp.Body.Close() - - f, err := os.Create(geoipDefaultFilename) - if err != nil { - return err - } - defer f.Close() - - _, err = io.Copy(f, resp.Body) - return err -} - -func (l *GeoIPLoader) Load() *geoip2.Reader { - if l.db == nil { - if l.Filename == "" { - // Filename not specified, try default. - if _, err := os.Stat(geoipDefaultFilename); err == nil { - // Default already exists, just use it. - l.Filename = geoipDefaultFilename - } else if os.IsNotExist(err) { - // Default doesn't exist, download it. - l.DownloadFunc(geoipDefaultFilename, geoipDownloadURL) - err := l.download() - l.DownloadErrFunc(err) - if err != nil { - return nil - } - l.Filename = geoipDefaultFilename - } else { - // Other error - return nil - } - } - db, err := geoip2.Open(l.Filename) - if err != nil { - return nil - } - l.db = db - } - return l.db -} diff --git a/app/internal/utils/geoloader.go b/app/internal/utils/geoloader.go new file mode 100644 index 0000000..ed8b993 --- /dev/null +++ b/app/internal/utils/geoloader.go @@ -0,0 +1,107 @@ +package utils + +import ( + "io" + "net/http" + "os" + + "github.com/apernet/hysteria/extras/outbounds/acl" + "github.com/apernet/hysteria/extras/outbounds/acl/v2geo" +) + +const ( + geoipFilename = "geoip.dat" + geoipURL = "https://cdn.jsdelivr.net/gh/Loyalsoldier/v2ray-rules-dat@release/geoip.dat" + geositeFilename = "geosite.dat" + geositeURL = "https://cdn.jsdelivr.net/gh/Loyalsoldier/v2ray-rules-dat@release/geosite.dat" +) + +var _ acl.GeoLoader = (*GeoLoader)(nil) + +// GeoLoader provides the on-demand GeoIP/GeoSite database +// loading functionality required by the ACL engine. +// Empty filenames = automatic download from built-in URLs. +type GeoLoader struct { + GeoIPFilename string + GeoSiteFilename string + + DownloadFunc func(filename, url string) + DownloadErrFunc func(err error) + + geoipMap map[string]*v2geo.GeoIP + geositeMap map[string]*v2geo.GeoSite +} + +func (l *GeoLoader) download(filename, url string) error { + l.DownloadFunc(filename, url) + + resp, err := http.Get(url) + if err != nil { + l.DownloadErrFunc(err) + return err + } + defer resp.Body.Close() + + f, err := os.Create(filename) + if err != nil { + l.DownloadErrFunc(err) + return err + } + defer f.Close() + + _, err = io.Copy(f, resp.Body) + l.DownloadErrFunc(err) + return err +} + +func (l *GeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) { + if l.geoipMap != nil { + return l.geoipMap, nil + } + autoDL := false + filename := l.GeoIPFilename + if filename == "" { + autoDL = true + filename = geoipFilename + } + m, err := v2geo.LoadGeoIP(filename) + if os.IsNotExist(err) && autoDL { + // It's ok, we will download it. + err = l.download(filename, geoipURL) + if err != nil { + return nil, err + } + m, err = v2geo.LoadGeoIP(filename) + } + if err != nil { + return nil, err + } + l.geoipMap = m + return m, nil +} + +func (l *GeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) { + if l.geositeMap != nil { + return l.geositeMap, nil + } + autoDL := false + filename := l.GeoSiteFilename + if filename == "" { + autoDL = true + filename = geositeFilename + } + m, err := v2geo.LoadGeoSite(filename) + if os.IsNotExist(err) && autoDL { + // It's ok, we will download it. + err = l.download(filename, geositeURL) + if err != nil { + return nil, err + } + m, err = v2geo.LoadGeoSite(filename) + } + if err != nil { + return nil, err + } + l.geositeMap = m + return m, nil +} diff --git a/extras/go.mod b/extras/go.mod index 5fc8cc1..3c67510 100644 --- a/extras/go.mod +++ b/extras/go.mod @@ -7,10 +7,10 @@ require ( github.com/babolivier/go-doh-client v0.0.0-20201028162107-a76cff4cb8b6 github.com/hashicorp/golang-lru/v2 v2.0.5 github.com/miekg/dns v1.1.55 - github.com/oschwald/geoip2-golang v1.9.0 github.com/stretchr/testify v1.8.4 github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301 golang.org/x/crypto v0.14.0 + google.golang.org/protobuf v1.28.1 ) require ( @@ -19,7 +19,6 @@ require ( github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect github.com/onsi/ginkgo/v2 v2.9.5 // indirect - github.com/oschwald/maxminddb-golang v1.11.0 // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/quic-go/qpack v0.4.0 // indirect diff --git a/extras/go.sum b/extras/go.sum index a8cafb3..3803e21 100644 --- a/extras/go.sum +++ b/extras/go.sum @@ -12,8 +12,10 @@ github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ= github.com/go-logr/logr v1.2.4/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI= github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE= @@ -32,10 +34,6 @@ github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q= github.com/onsi/ginkgo/v2 v2.9.5/go.mod h1:tvAoo1QUJwNEU2ITftXTpR7R1RbCzoZUOs3RonqW57k= github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE= github.com/onsi/gomega v1.27.6/go.mod h1:PIQNjfQwkP3aQAH7lf7j87O/5FiNr+ZR8+ipb+qQlhg= -github.com/oschwald/geoip2-golang v1.9.0 h1:uvD3O6fXAXs+usU+UGExshpdP13GAqp4GBrzN7IgKZc= -github.com/oschwald/geoip2-golang v1.9.0/go.mod h1:BHK6TvDyATVQhKNbQBdrj9eAvuwOMi2zSFXizL3K81Y= -github.com/oschwald/maxminddb-golang v1.11.0 h1:aSXMqYR/EPNjGE8epgqwDay+P30hCBZIveY0WZbAWh0= -github.com/oschwald/maxminddb-golang v1.11.0/go.mod h1:YmVI+H0zh3ySFR3w+oz8PCfglAFj3PuCmui13+P9zDg= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= @@ -108,6 +106,8 @@ golang.org/x/tools v0.3.0/go.mod h1:/rWhSS2+zyEVwoJf8YAX6L2f0ntZ7Kn/mGgAWcipA5k= golang.org/x/tools v0.11.1 h1:ojD5zOW8+7dOGzdnNgersm8aPfcDjhMp12UfG93NIMc= golang.org/x/tools v0.11.1/go.mod h1:anzJrxPjNtfgiYQYirP2CPGzGLxrH2u2QBhn6Bf3qY8= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.28.1 h1:d0NfwRgPtno5B1Wa6L2DAG+KivqkdutMf1UhdNx175w= google.golang.org/protobuf v1.28.1/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= diff --git a/extras/outbounds/acl.go b/extras/outbounds/acl.go index d8df63d..a4dd21c 100644 --- a/extras/outbounds/acl.go +++ b/extras/outbounds/acl.go @@ -7,7 +7,6 @@ import ( "strings" "github.com/apernet/hysteria/extras/outbounds/acl" - "github.com/oschwald/geoip2-golang" ) const ( @@ -34,25 +33,25 @@ type OutboundEntry struct { Outbound PluggableOutbound } -func NewACLEngineFromString(rules string, outbounds []OutboundEntry, geoipFunc func() *geoip2.Reader) (PluggableOutbound, error) { +func NewACLEngineFromString(rules string, outbounds []OutboundEntry, geoLoader acl.GeoLoader) (PluggableOutbound, error) { trs, err := acl.ParseTextRules(rules) if err != nil { return nil, err } obMap := outboundsToMap(outbounds) - rs, err := acl.Compile[PluggableOutbound](trs, obMap, aclCacheSize, geoipFunc) + rs, err := acl.Compile[PluggableOutbound](trs, obMap, aclCacheSize, geoLoader) if err != nil { return nil, err } return &aclEngine{rs, obMap["default"]}, nil } -func NewACLEngineFromFile(filename string, outbounds []OutboundEntry, geoipFunc func() *geoip2.Reader) (PluggableOutbound, error) { +func NewACLEngineFromFile(filename string, outbounds []OutboundEntry, geoLoader acl.GeoLoader) (PluggableOutbound, error) { bs, err := os.ReadFile(filename) if err != nil { return nil, err } - return NewACLEngineFromString(string(bs), outbounds, geoipFunc) + return NewACLEngineFromString(string(bs), outbounds, geoLoader) } func outboundsToMap(outbounds []OutboundEntry) map[string]PluggableOutbound { diff --git a/extras/outbounds/acl/GeoLite2-Country.mmdb b/extras/outbounds/acl/GeoLite2-Country.mmdb deleted file mode 100644 index 0eb0e50..0000000 Binary files a/extras/outbounds/acl/GeoLite2-Country.mmdb and /dev/null differ diff --git a/extras/outbounds/acl/compile.go b/extras/outbounds/acl/compile.go index e68f3e0..3fe02e6 100644 --- a/extras/outbounds/acl/compile.go +++ b/extras/outbounds/acl/compile.go @@ -6,8 +6,9 @@ import ( "strconv" "strings" + "github.com/apernet/hysteria/extras/outbounds/acl/v2geo" + lru "github.com/hashicorp/golang-lru/v2" - "github.com/oschwald/geoip2-golang" ) type Protocol int @@ -92,6 +93,11 @@ func (e *CompilationError) Error() string { return fmt.Sprintf("error at line %d: %s", e.LineNum, e.Message) } +type GeoLoader interface { + LoadGeoIP() (map[string]*v2geo.GeoIP, error) + LoadGeoSite() (map[string]*v2geo.GeoSite, error) +} + // Compile compiles TextRules into a CompiledRuleSet. // Names in the outbounds map MUST be in all lower case. // geoipFunc is a function that returns the GeoIP database needed by the GeoIP matcher. @@ -99,7 +105,7 @@ func (e *CompilationError) Error() string { // be called if there is no GeoIP rule. We use a function here so that database loading // is on-demand (only required if used by rules). func Compile[O Outbound](rules []TextRule, outbounds map[string]O, - cacheSize int, geoipFunc func() *geoip2.Reader, + cacheSize int, geoLoader GeoLoader, ) (CompiledRuleSet[O], error) { compiledRules := make([]compiledRule[O], len(rules)) for i, rule := range rules { @@ -107,7 +113,7 @@ func Compile[O Outbound](rules []TextRule, outbounds map[string]O, if !ok { return nil, &CompilationError{rule.LineNum, fmt.Sprintf("outbound %s not found", rule.Outbound)} } - hm, errStr := compileHostMatcher(rule.Address, geoipFunc) + hm, errStr := compileHostMatcher(rule.Address, geoLoader) if errStr != "" { return nil, &CompilationError{rule.LineNum, errStr} } @@ -184,7 +190,7 @@ func parseProtoPort(protoPort string) (Protocol, uint16, bool) { } } -func compileHostMatcher(addr string, geoipFunc func() *geoip2.Reader) (hostMatcher, string) { +func compileHostMatcher(addr string, geoLoader GeoLoader) (hostMatcher, string) { addr = strings.ToLower(addr) // Normalize to lower case if addr == "*" || addr == "all" { // Match all hosts @@ -192,15 +198,43 @@ func compileHostMatcher(addr string, geoipFunc func() *geoip2.Reader) (hostMatch } if strings.HasPrefix(addr, "geoip:") { // GeoIP matcher - country := strings.ToUpper(addr[6:]) - if len(country) != 2 { - return nil, fmt.Sprintf("invalid country code: %s", country) + country := addr[6:] + if len(country) == 0 { + return nil, "empty GeoIP country code" } - db := geoipFunc() - if db == nil { - return nil, "failed to load GeoIP database" + gMap, err := geoLoader.LoadGeoIP() + if err != nil { + return nil, err.Error() } - return &geoipMatcher{db, country}, "" + list, ok := gMap[country] + if !ok || list == nil { + return nil, fmt.Sprintf("GeoIP country code %s not found", country) + } + m, err := newGeoIPMatcher(list) + if err != nil { + return nil, err.Error() + } + return m, "" + } + if strings.HasPrefix(addr, "geosite:") { + // GeoSite matcher + name, attrs := parseGeoSiteName(addr[8:]) + if len(name) == 0 { + return nil, "empty GeoSite name" + } + gMap, err := geoLoader.LoadGeoSite() + if err != nil { + return nil, err.Error() + } + list, ok := gMap[name] + if !ok || list == nil { + return nil, fmt.Sprintf("GeoSite name %s not found", name) + } + m, err := newGeositeMatcher(list, attrs) + if err != nil { + return nil, err.Error() + } + return m, "" } if strings.Contains(addr, "/") { // CIDR matcher @@ -227,3 +261,13 @@ func compileHostMatcher(addr string, geoipFunc func() *geoip2.Reader) (hostMatch Wildcard: false, }, "" } + +func parseGeoSiteName(s string) (string, []string) { + parts := strings.Split(s, "@") + base := strings.TrimSpace(parts[0]) + attrs := parts[1:] + for i := range attrs { + attrs[i] = strings.TrimSpace(attrs[i]) + } + return base, attrs +} diff --git a/extras/outbounds/acl/compile_test.go b/extras/outbounds/acl/compile_test.go index 8f61229..58c10a9 100644 --- a/extras/outbounds/acl/compile_test.go +++ b/extras/outbounds/acl/compile_test.go @@ -4,12 +4,25 @@ import ( "net" "testing" - "github.com/oschwald/geoip2-golang" + "github.com/apernet/hysteria/extras/outbounds/acl/v2geo" + "github.com/stretchr/testify/assert" ) +var _ GeoLoader = (*testGeoLoader)(nil) + +type testGeoLoader struct{} + +func (l *testGeoLoader) LoadGeoIP() (map[string]*v2geo.GeoIP, error) { + return v2geo.LoadGeoIP("v2geo/geoip.dat") +} + +func (l *testGeoLoader) LoadGeoSite() (map[string]*v2geo.GeoSite, error) { + return v2geo.LoadGeoSite("v2geo/geosite.dat") +} + func TestCompile(t *testing.T) { - ob1, ob2, ob3 := 1, 2, 3 + ob1, ob2, ob3, ob4 := 1, 2, 3, 4 rules := []TextRule{ { Outbound: "ob1", @@ -59,12 +72,25 @@ func TestCompile(t *testing.T) { ProtoPort: "*/*", HijackAddress: "", }, + { + Outbound: "ob4", + Address: "geosite:4chan", + ProtoPort: "*/*", + HijackAddress: "", + }, + { + Outbound: "ob4", + Address: "geosite:google @cn", + ProtoPort: "*/*", + HijackAddress: "", + }, } - reader, err := geoip2.Open("GeoLite2-Country.mmdb") - assert.NoError(t, err) - comp, err := Compile[int](rules, map[string]int{"ob1": ob1, "ob2": ob2, "ob3": ob3}, 100, func() *geoip2.Reader { - return reader - }) + comp, err := Compile[int](rules, map[string]int{ + "ob1": ob1, + "ob2": ob2, + "ob3": ob3, + "ob4": ob4, + }, 100, &testGeoLoader{}) assert.NoError(t, err) tests := []struct { @@ -146,6 +172,42 @@ func TestCompile(t *testing.T) { wantOutbound: ob2, wantIP: nil, }, + { + host: HostInfo{ + IPv4: net.ParseIP("175.45.176.73"), + }, + proto: ProtocolTCP, + port: 80, + wantOutbound: 0, // no match default + wantIP: nil, + }, + { + host: HostInfo{ + Name: "boards.4channel.org", + }, + proto: ProtocolTCP, + port: 443, + wantOutbound: ob4, + wantIP: nil, + }, + { + host: HostInfo{ + Name: "gstatic-cn.com", + }, + proto: ProtocolUDP, + port: 9999, + wantOutbound: ob4, + wantIP: nil, + }, + { + host: HostInfo{ + Name: "hoho.waymo.com", + }, + proto: ProtocolUDP, + port: 9999, + wantOutbound: 0, // no match default + wantIP: nil, + }, } for _, test := range tests { @@ -154,3 +216,56 @@ func TestCompile(t *testing.T) { assert.Equal(t, test.wantIP, gotIP) } } + +func Test_parseGeoSiteName(t *testing.T) { + tests := []struct { + name string + s string + want string + want1 []string + }{ + { + name: "no attrs", + s: "pornhub", + want: "pornhub", + want1: []string{}, + }, + { + name: "one attr 1", + s: "xiaomi@cn", + want: "xiaomi", + want1: []string{"cn"}, + }, + { + name: "one attr 2", + s: " google @jp ", + want: "google", + want1: []string{"jp"}, + }, + { + name: "two attrs 1", + s: "netflix@jp@kr", + want: "netflix", + want1: []string{"jp", "kr"}, + }, + { + name: "two attrs 2", + s: "netflix @xixi @haha ", + want: "netflix", + want1: []string{"xixi", "haha"}, + }, + { + name: "empty", + s: "", + want: "", + want1: []string{}, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, got1 := parseGeoSiteName(tt.s) + assert.Equalf(t, tt.want, got, "parseGeoSiteName(%v)", tt.s) + assert.Equalf(t, tt.want1, got1, "parseGeoSiteName(%v)", tt.s) + }) + } +} diff --git a/extras/outbounds/acl/matchers.go b/extras/outbounds/acl/matchers.go index 5f9b3d2..65b2458 100644 --- a/extras/outbounds/acl/matchers.go +++ b/extras/outbounds/acl/matchers.go @@ -2,8 +2,6 @@ package acl import ( "net" - - "github.com/oschwald/geoip2-golang" ) type hostMatcher interface { @@ -55,27 +53,6 @@ func deepMatchRune(str, pattern []rune) bool { return len(str) == 0 && len(pattern) == 0 } -type geoipMatcher struct { - DB *geoip2.Reader - Country string // must be uppercase ISO 3166-1 alpha-2 code -} - -func (m *geoipMatcher) Match(host HostInfo) bool { - if host.IPv4 != nil { - record, err := m.DB.Country(host.IPv4) - if err == nil && record.Country.IsoCode == m.Country { - return true - } - } - if host.IPv6 != nil { - record, err := m.DB.Country(host.IPv6) - if err == nil && record.Country.IsoCode == m.Country { - return true - } - } - return false -} - type allMatcher struct{} func (m *allMatcher) Match(host HostInfo) bool { diff --git a/extras/outbounds/acl/matchers_test.go b/extras/outbounds/acl/matchers_test.go index 871b265..e1b437d 100644 --- a/extras/outbounds/acl/matchers_test.go +++ b/extras/outbounds/acl/matchers_test.go @@ -3,9 +3,6 @@ package acl import ( "net" "testing" - - "github.com/oschwald/geoip2-golang" - "github.com/stretchr/testify/assert" ) func Test_ipMatcher_Match(t *testing.T) { @@ -233,78 +230,3 @@ func Test_domainMatcher_Match(t *testing.T) { }) } } - -func Test_geoipMatcher_Match(t *testing.T) { - db, err := geoip2.Open("GeoLite2-Country.mmdb") - assert.NoError(t, err) - defer db.Close() - - type fields struct { - DB *geoip2.Reader - Country string - } - tests := []struct { - name string - fields fields - host HostInfo - want bool - }{ - { - name: "ipv4 match", - fields: fields{ - DB: db, - Country: "JP", - }, - host: HostInfo{ - IPv4: net.ParseIP("210.140.92.181"), - }, - want: true, - }, - { - name: "ipv6 match", - fields: fields{ - DB: db, - Country: "US", - }, - host: HostInfo{ - IPv6: net.ParseIP("2606:4700::6810:85e5"), - }, - want: true, - }, - { - name: "no match", - fields: fields{ - DB: db, - Country: "AU", - }, - host: HostInfo{ - IPv4: net.ParseIP("210.140.92.181"), - IPv6: net.ParseIP("2606:4700::6810:85e5"), - }, - want: false, - }, - { - name: "both nil", - fields: fields{ - DB: db, - Country: "KR", - }, - host: HostInfo{ - IPv4: nil, - IPv6: nil, - }, - want: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - m := &geoipMatcher{ - DB: tt.fields.DB, - Country: tt.fields.Country, - } - if got := m.Match(tt.host); got != tt.want { - t.Errorf("Match() = %v, want %v", got, tt.want) - } - }) - } -} diff --git a/extras/outbounds/acl/matchers_v2geo.go b/extras/outbounds/acl/matchers_v2geo.go new file mode 100644 index 0000000..df83105 --- /dev/null +++ b/extras/outbounds/acl/matchers_v2geo.go @@ -0,0 +1,213 @@ +package acl + +import ( + "bytes" + "errors" + "net" + "regexp" + "sort" + "strings" + + "github.com/apernet/hysteria/extras/outbounds/acl/v2geo" +) + +var _ hostMatcher = (*geoipMatcher)(nil) + +type geoipMatcher struct { + N4 []*net.IPNet // sorted + N6 []*net.IPNet // sorted + Inverse bool +} + +// matchIP tries to match the given IP address with the corresponding IPNets. +// Note that this function does NOT handle the Inverse flag. +func (m *geoipMatcher) matchIP(ip net.IP) bool { + var n []*net.IPNet + if ip4 := ip.To4(); ip4 != nil { + // N4 stores IPv4 addresses in 4-byte form. + // Make sure we use it here too, otherwise bytes.Compare will fail. + ip = ip4 + n = m.N4 + } else { + n = m.N6 + } + left, right := 0, len(n)-1 + for left <= right { + mid := (left + right) / 2 + if n[mid].Contains(ip) { + return true + } else if bytes.Compare(n[mid].IP, ip) < 0 { + left = mid + 1 + } else { + right = mid - 1 + } + } + return false +} + +func (m *geoipMatcher) Match(host HostInfo) bool { + if host.IPv4 != nil { + if m.matchIP(host.IPv4) { + return !m.Inverse + } + } + if host.IPv6 != nil { + if m.matchIP(host.IPv6) { + return !m.Inverse + } + } + return m.Inverse +} + +func newGeoIPMatcher(list *v2geo.GeoIP) (*geoipMatcher, error) { + n4 := make([]*net.IPNet, 0) + n6 := make([]*net.IPNet, 0) + for _, cidr := range list.Cidr { + if len(cidr.Ip) == 4 { + // IPv4 + n4 = append(n4, &net.IPNet{ + IP: cidr.Ip, + Mask: net.CIDRMask(int(cidr.Prefix), 32), + }) + } else if len(cidr.Ip) == 16 { + // IPv6 + n6 = append(n6, &net.IPNet{ + IP: cidr.Ip, + Mask: net.CIDRMask(int(cidr.Prefix), 128), + }) + } else { + return nil, errors.New("invalid IP length") + } + } + // Sort the IPNets, so we can do binary search later. + sort.Slice(n4, func(i, j int) bool { + return bytes.Compare(n4[i].IP, n4[j].IP) < 0 + }) + sort.Slice(n6, func(i, j int) bool { + return bytes.Compare(n6[i].IP, n6[j].IP) < 0 + }) + return &geoipMatcher{ + N4: n4, + N6: n6, + Inverse: list.InverseMatch, + }, nil +} + +var _ hostMatcher = (*geositeMatcher)(nil) + +type geositeDomainType int + +const ( + geositeDomainPlain geositeDomainType = iota + geositeDomainRegex + geositeDomainRoot + geositeDomainFull +) + +type geositeDomain struct { + Type geositeDomainType + Value string + Regex *regexp.Regexp + Attrs map[string]bool +} + +type geositeMatcher struct { + Domains []geositeDomain + // Attributes are matched using "and" logic - if you have multiple attributes here, + // a domain must have all of those attributes to be considered a match. + Attrs []string +} + +func (m *geositeMatcher) matchDomain(domain geositeDomain, host HostInfo) bool { + // Match attributes first + if len(m.Attrs) > 0 { + if len(domain.Attrs) == 0 { + return false + } + for _, attr := range m.Attrs { + if !domain.Attrs[attr] { + return false + } + } + } + + switch domain.Type { + case geositeDomainPlain: + return strings.Contains(host.Name, domain.Value) + case geositeDomainRegex: + if domain.Regex != nil { + return domain.Regex.MatchString(host.Name) + } + case geositeDomainFull: + return host.Name == domain.Value + case geositeDomainRoot: + if host.Name == domain.Value { + return true + } + return strings.HasSuffix(host.Name, "."+domain.Value) + default: + return false + } + return false +} + +func (m *geositeMatcher) Match(host HostInfo) bool { + for _, domain := range m.Domains { + if m.matchDomain(domain, host) { + return true + } + } + return false +} + +func newGeositeMatcher(list *v2geo.GeoSite, attrs []string) (*geositeMatcher, error) { + domains := make([]geositeDomain, len(list.Domain)) + for i, domain := range list.Domain { + switch domain.Type { + case v2geo.Domain_Plain: + domains[i] = geositeDomain{ + Type: geositeDomainPlain, + Value: domain.Value, + Attrs: domainAttributeToMap(domain.Attribute), + } + case v2geo.Domain_Regex: + regex, err := regexp.Compile(domain.Value) + if err != nil { + return nil, err + } + domains[i] = geositeDomain{ + Type: geositeDomainRegex, + Regex: regex, + Attrs: domainAttributeToMap(domain.Attribute), + } + case v2geo.Domain_Full: + domains[i] = geositeDomain{ + Type: geositeDomainFull, + Value: domain.Value, + Attrs: domainAttributeToMap(domain.Attribute), + } + case v2geo.Domain_RootDomain: + domains[i] = geositeDomain{ + Type: geositeDomainRoot, + Value: domain.Value, + Attrs: domainAttributeToMap(domain.Attribute), + } + default: + return nil, errors.New("unsupported domain type") + } + } + return &geositeMatcher{ + Domains: domains, + Attrs: attrs, + }, nil +} + +func domainAttributeToMap(attrs []*v2geo.Domain_Attribute) map[string]bool { + m := make(map[string]bool) + for _, attr := range attrs { + // Supposedly there are also int attributes, + // but nobody seems to use them, so we treat everything as boolean for now. + m[attr.Key] = true + } + return m +} diff --git a/extras/outbounds/acl/matchers_v2geo_test.go b/extras/outbounds/acl/matchers_v2geo_test.go new file mode 100644 index 0000000..4970baf --- /dev/null +++ b/extras/outbounds/acl/matchers_v2geo_test.go @@ -0,0 +1,141 @@ +package acl + +import ( + "net" + "testing" + + "github.com/apernet/hysteria/extras/outbounds/acl/v2geo" + "github.com/stretchr/testify/assert" +) + +func Test_geoipMatcher_Match(t *testing.T) { + geoipMap, err := v2geo.LoadGeoIP("v2geo/geoip.dat") + assert.NoError(t, err) + m, err := newGeoIPMatcher(geoipMap["us"]) + assert.NoError(t, err) + + tests := []struct { + name string + host HostInfo + want bool + }{ + { + name: "IPv4 match", + host: HostInfo{ + IPv4: net.ParseIP("73.222.1.100"), + }, + want: true, + }, + { + name: "IPv4 no match", + host: HostInfo{ + IPv4: net.ParseIP("123.123.123.123"), + }, + want: false, + }, + { + name: "IPv6 match", + host: HostInfo{ + IPv6: net.ParseIP("2607:f8b0:4005:80c::2004"), + }, + want: true, + }, + { + name: "IPv6 no match", + host: HostInfo{ + IPv6: net.ParseIP("240e:947:6001::1f8"), + }, + want: false, + }, + { + name: "both nil", + host: HostInfo{ + IPv4: nil, + IPv6: nil, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, m.Match(tt.host), "Match(%v)", tt.host) + }) + } +} + +func Test_geositeMatcher_Match(t *testing.T) { + geositeMap, err := v2geo.LoadGeoSite("v2geo/geosite.dat") + assert.NoError(t, err) + m, err := newGeositeMatcher(geositeMap["apple"], nil) + assert.NoError(t, err) + + tests := []struct { + name string + attrs []string + host HostInfo + want bool + }{ + { + name: "subdomain", + attrs: nil, + host: HostInfo{ + Name: "poop.i-book.com", + }, + want: true, + }, + { + name: "subdomain root", + attrs: nil, + host: HostInfo{ + Name: "applepaycash.net", + }, + want: true, + }, + { + name: "full", + attrs: nil, + host: HostInfo{ + Name: "courier-push-apple.com.akadns.net", + }, + want: true, + }, + { + name: "regexp", + attrs: nil, + host: HostInfo{ + Name: "cdn4.apple-mapkit.com", + }, + want: true, + }, + { + name: "attr match", + attrs: []string{"cn"}, + host: HostInfo{ + Name: "bag.itunes.apple.com", + }, + want: true, + }, + { + name: "attr multi no match", + attrs: []string{"cn", "haha"}, + host: HostInfo{ + Name: "bag.itunes.apple.com", + }, + want: false, + }, + { + name: "attr no match", + attrs: []string{"cn"}, + host: HostInfo{ + Name: "mr-apple.com.tw", + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m.Attrs = tt.attrs + assert.Equalf(t, tt.want, m.Match(tt.host), "Match(%v)", tt.host) + }) + } +} diff --git a/extras/outbounds/acl/v2geo/load.go b/extras/outbounds/acl/v2geo/load.go index 4b91f23..2dd918c 100644 --- a/extras/outbounds/acl/v2geo/load.go +++ b/extras/outbounds/acl/v2geo/load.go @@ -7,13 +7,9 @@ import ( "google.golang.org/protobuf/proto" ) -type GeoIPMap map[string]*GeoIP - -type GeoSiteMap map[string]*GeoSite - // LoadGeoIP loads a GeoIP data file and converts it to a map. // The keys of the map (country codes) are all normalized to lowercase. -func LoadGeoIP(filename string) (GeoIPMap, error) { +func LoadGeoIP(filename string) (map[string]*GeoIP, error) { bs, err := os.ReadFile(filename) if err != nil { return nil, err @@ -22,7 +18,7 @@ func LoadGeoIP(filename string) (GeoIPMap, error) { if err := proto.Unmarshal(bs, &list); err != nil { return nil, err } - m := make(GeoIPMap) + m := make(map[string]*GeoIP) for _, entry := range list.Entry { m[strings.ToLower(entry.CountryCode)] = entry } @@ -31,7 +27,7 @@ func LoadGeoIP(filename string) (GeoIPMap, error) { // LoadGeoSite loads a GeoSite data file and converts it to a map. // The keys of the map (site keys) are all normalized to lowercase. -func LoadGeoSite(filename string) (GeoSiteMap, error) { +func LoadGeoSite(filename string) (map[string]*GeoSite, error) { bs, err := os.ReadFile(filename) if err != nil { return nil, err @@ -40,7 +36,7 @@ func LoadGeoSite(filename string) (GeoSiteMap, error) { if err := proto.Unmarshal(bs, &list); err != nil { return nil, err } - m := make(GeoSiteMap) + m := make(map[string]*GeoSite) for _, entry := range list.Entry { m[strings.ToLower(entry.CountryCode)] = entry }