diff --git a/cmd/client.go b/cmd/client.go index 4fe1d12..088a865 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -3,6 +3,7 @@ package main import ( "crypto/tls" "crypto/x509" + "github.com/oschwald/geoip2-golang" "io" "io/ioutil" "net" @@ -93,7 +94,13 @@ func client(config *clientConfig) { var aclEngine *acl.Engine if len(config.ACL) > 0 { var err error - aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultTransport) + aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultTransport, func() (*geoip2.Reader, error) { + if len(config.MMDB) > 0 { + return loadMMDBReader(config.MMDB) + } else { + return loadMMDBReader(DefaultMMDBFilename) + } + }) if err != nil { logrus.WithFields(logrus.Fields{ "error": err, diff --git a/cmd/config.go b/cmd/config.go index 24b538b..742e4d3 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -16,6 +16,8 @@ const ( DefaultMaxIncomingStreams = 1024 DefaultALPN = "hysteria" + + DefaultMMDBFilename = "GeoLite2-Country.mmdb" ) type serverConfig struct { @@ -36,6 +38,7 @@ type serverConfig struct { DownMbps int `json:"down_mbps"` DisableUDP bool `json:"disable_udp"` ACL string `json:"acl"` + MMDB string `json:"mmdb"` Obfs string `json:"obfs"` Auth struct { Mode string `json:"mode"` @@ -137,6 +140,7 @@ type clientConfig struct { Timeout int `json:"timeout"` } `json:"tproxy_udp"` ACL string `json:"acl"` + MMDB string `json:"mmdb"` Obfs string `json:"obfs"` Auth []byte `json:"auth"` AuthString string `json:"auth_str"` diff --git a/cmd/mmdb.go b/cmd/mmdb.go new file mode 100644 index 0000000..b512fd6 --- /dev/null +++ b/cmd/mmdb.go @@ -0,0 +1,51 @@ +package main + +import ( + "github.com/oschwald/geoip2-golang" + "github.com/sirupsen/logrus" + "io" + "net/http" + "os" +) + +const ( + mmdbURL = "https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-Country.mmdb" +) + +func downloadMMDB(filename string) error { + resp, err := http.Get(mmdbURL) + if err != nil { + return err + } + defer resp.Body.Close() + + file, err := os.Create(filename) + if err != nil { + return err + } + defer file.Close() + + _, err = io.Copy(file, resp.Body) + return err +} + +func loadMMDBReader(filename string) (*geoip2.Reader, error) { + if _, err := os.Stat(filename); err != nil { + if os.IsNotExist(err) { + logrus.Info("GeoLite2 database not found, downloading...") + if err := downloadMMDB(filename); err != nil { + return nil, err + } + logrus.WithFields(logrus.Fields{ + "file": filename, + }).Info("GeoLite2 database downloaded") + return geoip2.Open(filename) + } else { + // some other error + return nil, err + } + } else { + // file exists, just open it + return geoip2.Open(filename) + } +} diff --git a/cmd/server.go b/cmd/server.go index b977adb..a659391 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/congestion" + "github.com/oschwald/geoip2-golang" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/sirupsen/logrus" @@ -139,7 +140,13 @@ func server(config *serverConfig) { // ACL var aclEngine *acl.Engine if len(config.ACL) > 0 { - aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultTransport) + aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultTransport, func() (*geoip2.Reader, error) { + if len(config.MMDB) > 0 { + return loadMMDBReader(config.MMDB) + } else { + return loadMMDBReader(DefaultMMDBFilename) + } + }) if err != nil { logrus.WithFields(logrus.Fields{ "error": err, diff --git a/go.mod b/go.mod index 0008a8e..a9b64bc 100644 --- a/go.mod +++ b/go.mod @@ -16,6 +16,7 @@ require ( github.com/hashicorp/golang-lru v0.5.4 github.com/lucas-clemente/quic-go v0.24.0 github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 + github.com/oschwald/geoip2-golang v1.5.0 // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/prometheus/client_golang v1.11.0 github.com/russross/blackfriday/v2 v2.1.0 // indirect diff --git a/go.sum b/go.sum index 56f7ee4..72cdb4f 100644 --- a/go.sum +++ b/go.sum @@ -173,6 +173,10 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak= github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= +github.com/oschwald/geoip2-golang v1.5.0 h1:igg2yQIrrcRccB1ytFXqBfOHCjXWIoMv85lVJ1ONZzw= +github.com/oschwald/geoip2-golang v1.5.0/go.mod h1:xdvYt5xQzB8ORWFqPnqMwZpCpgNagttWdoZLlJQzg7s= +github.com/oschwald/maxminddb-golang v1.8.0 h1:Uh/DSnGoxsyp/KYbY1AuP0tYEwfs0sCph9p/UMXK/Hk= +github.com/oschwald/maxminddb-golang v1.8.0/go.mod h1:RXZtst0N6+FY/3qCNmZMBApR19cdQj43/NM9VkrNAis= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= @@ -248,6 +252,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= @@ -352,6 +357,7 @@ golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191224085550-c709ea063b76/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/pkg/acl/engine.go b/pkg/acl/engine.go index 5fbaa40..ead693b 100644 --- a/pkg/acl/engine.go +++ b/pkg/acl/engine.go @@ -3,6 +3,7 @@ package acl import ( "bufio" lru "github.com/hashicorp/golang-lru" + "github.com/oschwald/geoip2-golang" "github.com/tobyxdd/hysteria/pkg/transport" "net" "os" @@ -16,6 +17,7 @@ type Engine struct { Entries []Entry Cache *lru.ARCCache Transport transport.Transport + GeoIPReader *geoip2.Reader } type cacheEntry struct { @@ -23,7 +25,7 @@ type cacheEntry struct { Arg string } -func LoadFromFile(filename string, transport transport.Transport) (*Engine, error) { +func LoadFromFile(filename string, transport transport.Transport, geoIPLoadFunc func() (*geoip2.Reader, error)) (*Engine, error) { f, err := os.Open(filename) if err != nil { return nil, err @@ -31,6 +33,7 @@ func LoadFromFile(filename string, transport transport.Transport) (*Engine, erro defer f.Close() scanner := bufio.NewScanner(f) entries := make([]Entry, 0, 1024) + var geoIPReader *geoip2.Reader for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) if len(line) == 0 || strings.HasPrefix(line, "#") { @@ -41,6 +44,12 @@ func LoadFromFile(filename string, transport transport.Transport) (*Engine, erro if err != nil { return nil, err } + if len(entry.Country) > 0 && geoIPReader == nil { + geoIPReader, err = geoIPLoadFunc() // lazy load GeoIP reader only when needed + if err != nil { + return nil, err + } + } entries = append(entries, entry) } cache, err := lru.NewARC(entryCacheSize) @@ -52,6 +61,7 @@ func LoadFromFile(filename string, transport transport.Transport) (*Engine, erro Entries: entries, Cache: cache, Transport: transport, + GeoIPReader: geoIPReader, }, nil } @@ -66,7 +76,7 @@ func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, erro return ce.Action, ce.Arg, ipAddr, err } for _, entry := range e.Entries { - if entry.MatchDomain(host) || (ipAddr != nil && entry.MatchIP(ipAddr.IP)) { + if entry.MatchDomain(host) || (ipAddr != nil && entry.MatchIP(ipAddr.IP, e.GeoIPReader)) { e.Cache.Add(host, cacheEntry{entry.Action, entry.ActionArg}) return entry.Action, entry.ActionArg, ipAddr, err } @@ -84,7 +94,7 @@ func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, erro }, nil } for _, entry := range e.Entries { - if entry.MatchIP(ip) { + if entry.MatchIP(ip, e.GeoIPReader) { e.Cache.Add(ip.String(), cacheEntry{entry.Action, entry.ActionArg}) return entry.Action, entry.ActionArg, &net.IPAddr{ IP: ip, diff --git a/pkg/acl/entry.go b/pkg/acl/entry.go index d6fb731..637e5ec 100644 --- a/pkg/acl/entry.go +++ b/pkg/acl/entry.go @@ -3,6 +3,7 @@ package acl import ( "errors" "fmt" + "github.com/oschwald/geoip2-golang" "net" "strings" ) @@ -20,6 +21,7 @@ type Entry struct { Net *net.IPNet Domain string Suffix bool + Country string All bool Action Action ActionArg string @@ -40,13 +42,23 @@ func (e Entry) MatchDomain(domain string) bool { return false } -func (e Entry) MatchIP(ip net.IP) bool { +func (e Entry) MatchIP(ip net.IP, db *geoip2.Reader) bool { if e.All { return true } - if e.Net != nil && ip != nil { + if ip == nil { + return false + } + if e.Net != nil { return e.Net.Contains(ip) } + if len(e.Country) > 0 && db != nil { + country, err := db.Country(ip) + if err != nil { + return false + } + return country.Country.IsoCode == e.Country + } return false } @@ -65,15 +77,16 @@ func ParseEntry(s string) (Entry, error) { // Make sure there are at least 2 args args = append(args, "") } - ipNet, domain, suffix, all, err := parseCond(args[0], args[1]) + ipNet, domain, suffix, country, all, err := parseCond(args[0], args[1]) if err != nil { return Entry{}, err } e := Entry{ - Net: ipNet, - Domain: domain, - Suffix: suffix, - All: all, + Net: ipNet, + Domain: domain, + Suffix: suffix, + Country: country, + All: all, } switch strings.ToLower(fields[0]) { case "direct": @@ -94,43 +107,48 @@ func ParseEntry(s string) (Entry, error) { return e, nil } -func parseCond(typ, cond string) (*net.IPNet, string, bool, bool, error) { +func parseCond(typ, cond string) (*net.IPNet, string, bool, string, bool, error) { switch strings.ToLower(typ) { case "domain": if len(cond) == 0 { - return nil, "", false, false, errors.New("empty domain") + return nil, "", false, "", false, errors.New("empty domain") } - return nil, strings.ToLower(cond), false, false, nil + return nil, strings.ToLower(cond), false, "", false, nil case "domain-suffix": if len(cond) == 0 { - return nil, "", false, false, errors.New("empty domain suffix") + return nil, "", false, "", false, errors.New("empty domain suffix") } - return nil, strings.ToLower(cond), true, false, nil + return nil, strings.ToLower(cond), true, "", false, nil case "cidr": _, ipNet, err := net.ParseCIDR(cond) if err != nil { - return nil, "", false, false, err + return nil, "", false, "", false, err } - return ipNet, "", false, false, nil + return ipNet, "", false, "", false, nil case "ip": ip := net.ParseIP(cond) if ip == nil { - return nil, "", false, false, fmt.Errorf("invalid ip %s", cond) + return nil, "", false, "", false, fmt.Errorf("invalid ip %s", cond) } if ip.To4() != nil { return &net.IPNet{ IP: ip, Mask: net.CIDRMask(32, 32), - }, "", false, false, nil + }, "", false, "", false, nil } else { return &net.IPNet{ IP: ip, Mask: net.CIDRMask(128, 128), - }, "", false, false, nil + }, "", false, "", false, nil } + case "country": + if len(cond) == 0 { + return nil, "", false, "", false, errors.New("empty country") + } + return nil, "", false, strings.ToUpper(cond), false, nil case "all": - return nil, "", false, true, nil + return nil, "", false, "", true, nil default: - return nil, "", false, false, fmt.Errorf("invalid condition type %s", typ) + return nil, "", false, "", false, fmt.Errorf("invalid condition type %s", typ) } } diff --git a/pkg/acl/entry_test.go b/pkg/acl/entry_test.go index 326046b..f4f551b 100644 --- a/pkg/acl/entry_test.go +++ b/pkg/acl/entry_test.go @@ -20,19 +20,22 @@ func TestParseEntry(t *testing.T) { }{ {name: "empty", args: args{""}, want: Entry{}, wantErr: true}, {name: "ok 1", args: args{"direct domain-suffix google.com"}, - want: Entry{nil, "google.com", true, false, ActionDirect, ""}, + want: Entry{nil, "google.com", true, "", false, ActionDirect, ""}, wantErr: false}, {name: "ok 2", args: args{"proxy ip 8.8.8.8"}, want: Entry{&net.IPNet{net.ParseIP("8.8.8.8"), net.CIDRMask(32, 32)}, - "", false, false, ActionProxy, ""}, wantErr: false}, + "", false, "", false, ActionProxy, ""}, wantErr: false}, {name: "ok 3", args: args{"hijack domain mad.bad 127.0.0.1"}, - want: Entry{nil, "mad.bad", false, false, ActionHijack, "127.0.0.1"}, + want: Entry{nil, "mad.bad", false, "", false, ActionHijack, "127.0.0.1"}, wantErr: false}, {name: "ok 4", args: args{"block cidr 8.8.8.0/24"}, - want: Entry{ok4ipnet, "", false, false, ActionBlock, ""}, + want: Entry{ok4ipnet, "", false, "", false, ActionBlock, ""}, wantErr: false}, {name: "ok 5", args: args{"block all"}, - want: Entry{nil, "", false, true, ActionBlock, ""}, + want: Entry{nil, "", false, "", true, ActionBlock, ""}, + wantErr: false}, + {name: "ok 6", args: args{"block country cn"}, + want: Entry{nil, "", false, "CN", false, ActionBlock, ""}, wantErr: false}, {name: "invalid 1", args: args{"proxy domain"}, want: Entry{}, wantErr: true}, {name: "invalid 2", args: args{"proxy dom google.com"}, want: Entry{}, wantErr: true},