feat: ACL country support

This commit is contained in:
Toby 2022-01-09 18:11:52 -08:00
parent c3b76a5b44
commit 89452dd9c5
9 changed files with 136 additions and 29 deletions

View File

@ -3,6 +3,7 @@ package main
import (
"crypto/tls"
"crypto/x509"
"github.com/oschwald/geoip2-golang"
"io"
"io/ioutil"
"net"
@ -93,7 +94,13 @@ func client(config *clientConfig) {
var aclEngine *acl.Engine
if len(config.ACL) > 0 {
var err error
aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultTransport)
aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultTransport, func() (*geoip2.Reader, error) {
if len(config.MMDB) > 0 {
return loadMMDBReader(config.MMDB)
} else {
return loadMMDBReader(DefaultMMDBFilename)
}
})
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,

View File

@ -16,6 +16,8 @@ const (
DefaultMaxIncomingStreams = 1024
DefaultALPN = "hysteria"
DefaultMMDBFilename = "GeoLite2-Country.mmdb"
)
type serverConfig struct {
@ -36,6 +38,7 @@ type serverConfig struct {
DownMbps int `json:"down_mbps"`
DisableUDP bool `json:"disable_udp"`
ACL string `json:"acl"`
MMDB string `json:"mmdb"`
Obfs string `json:"obfs"`
Auth struct {
Mode string `json:"mode"`
@ -137,6 +140,7 @@ type clientConfig struct {
Timeout int `json:"timeout"`
} `json:"tproxy_udp"`
ACL string `json:"acl"`
MMDB string `json:"mmdb"`
Obfs string `json:"obfs"`
Auth []byte `json:"auth"`
AuthString string `json:"auth_str"`

51
cmd/mmdb.go Normal file
View File

@ -0,0 +1,51 @@
package main
import (
"github.com/oschwald/geoip2-golang"
"github.com/sirupsen/logrus"
"io"
"net/http"
"os"
)
const (
mmdbURL = "https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-Country.mmdb"
)
func downloadMMDB(filename string) error {
resp, err := http.Get(mmdbURL)
if err != nil {
return err
}
defer resp.Body.Close()
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
_, err = io.Copy(file, resp.Body)
return err
}
func loadMMDBReader(filename string) (*geoip2.Reader, error) {
if _, err := os.Stat(filename); err != nil {
if os.IsNotExist(err) {
logrus.Info("GeoLite2 database not found, downloading...")
if err := downloadMMDB(filename); err != nil {
return nil, err
}
logrus.WithFields(logrus.Fields{
"file": filename,
}).Info("GeoLite2 database downloaded")
return geoip2.Open(filename)
} else {
// some other error
return nil, err
}
} else {
// file exists, just open it
return geoip2.Open(filename)
}
}

View File

@ -4,6 +4,7 @@ import (
"crypto/tls"
"github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion"
"github.com/oschwald/geoip2-golang"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sirupsen/logrus"
@ -139,7 +140,13 @@ func server(config *serverConfig) {
// ACL
var aclEngine *acl.Engine
if len(config.ACL) > 0 {
aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultTransport)
aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultTransport, func() (*geoip2.Reader, error) {
if len(config.MMDB) > 0 {
return loadMMDBReader(config.MMDB)
} else {
return loadMMDBReader(DefaultMMDBFilename)
}
})
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,

1
go.mod
View File

@ -16,6 +16,7 @@ require (
github.com/hashicorp/golang-lru v0.5.4
github.com/lucas-clemente/quic-go v0.24.0
github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40
github.com/oschwald/geoip2-golang v1.5.0 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/prometheus/client_golang v1.11.0
github.com/russross/blackfriday/v2 v2.1.0 // indirect

6
go.sum
View File

@ -173,6 +173,10 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y
github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak=
github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY=
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
github.com/oschwald/geoip2-golang v1.5.0 h1:igg2yQIrrcRccB1ytFXqBfOHCjXWIoMv85lVJ1ONZzw=
github.com/oschwald/geoip2-golang v1.5.0/go.mod h1:xdvYt5xQzB8ORWFqPnqMwZpCpgNagttWdoZLlJQzg7s=
github.com/oschwald/maxminddb-golang v1.8.0 h1:Uh/DSnGoxsyp/KYbY1AuP0tYEwfs0sCph9p/UMXK/Hk=
github.com/oschwald/maxminddb-golang v1.8.0/go.mod h1:RXZtst0N6+FY/3qCNmZMBApR19cdQj43/NM9VkrNAis=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -248,6 +252,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
@ -352,6 +357,7 @@ golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191224085550-c709ea063b76/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View File

@ -3,6 +3,7 @@ package acl
import (
"bufio"
lru "github.com/hashicorp/golang-lru"
"github.com/oschwald/geoip2-golang"
"github.com/tobyxdd/hysteria/pkg/transport"
"net"
"os"
@ -16,6 +17,7 @@ type Engine struct {
Entries []Entry
Cache *lru.ARCCache
Transport transport.Transport
GeoIPReader *geoip2.Reader
}
type cacheEntry struct {
@ -23,7 +25,7 @@ type cacheEntry struct {
Arg string
}
func LoadFromFile(filename string, transport transport.Transport) (*Engine, error) {
func LoadFromFile(filename string, transport transport.Transport, geoIPLoadFunc func() (*geoip2.Reader, error)) (*Engine, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
@ -31,6 +33,7 @@ func LoadFromFile(filename string, transport transport.Transport) (*Engine, erro
defer f.Close()
scanner := bufio.NewScanner(f)
entries := make([]Entry, 0, 1024)
var geoIPReader *geoip2.Reader
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if len(line) == 0 || strings.HasPrefix(line, "#") {
@ -41,6 +44,12 @@ func LoadFromFile(filename string, transport transport.Transport) (*Engine, erro
if err != nil {
return nil, err
}
if len(entry.Country) > 0 && geoIPReader == nil {
geoIPReader, err = geoIPLoadFunc() // lazy load GeoIP reader only when needed
if err != nil {
return nil, err
}
}
entries = append(entries, entry)
}
cache, err := lru.NewARC(entryCacheSize)
@ -52,6 +61,7 @@ func LoadFromFile(filename string, transport transport.Transport) (*Engine, erro
Entries: entries,
Cache: cache,
Transport: transport,
GeoIPReader: geoIPReader,
}, nil
}
@ -66,7 +76,7 @@ func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, erro
return ce.Action, ce.Arg, ipAddr, err
}
for _, entry := range e.Entries {
if entry.MatchDomain(host) || (ipAddr != nil && entry.MatchIP(ipAddr.IP)) {
if entry.MatchDomain(host) || (ipAddr != nil && entry.MatchIP(ipAddr.IP, e.GeoIPReader)) {
e.Cache.Add(host, cacheEntry{entry.Action, entry.ActionArg})
return entry.Action, entry.ActionArg, ipAddr, err
}
@ -84,7 +94,7 @@ func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, erro
}, nil
}
for _, entry := range e.Entries {
if entry.MatchIP(ip) {
if entry.MatchIP(ip, e.GeoIPReader) {
e.Cache.Add(ip.String(), cacheEntry{entry.Action, entry.ActionArg})
return entry.Action, entry.ActionArg, &net.IPAddr{
IP: ip,

View File

@ -3,6 +3,7 @@ package acl
import (
"errors"
"fmt"
"github.com/oschwald/geoip2-golang"
"net"
"strings"
)
@ -20,6 +21,7 @@ type Entry struct {
Net *net.IPNet
Domain string
Suffix bool
Country string
All bool
Action Action
ActionArg string
@ -40,13 +42,23 @@ func (e Entry) MatchDomain(domain string) bool {
return false
}
func (e Entry) MatchIP(ip net.IP) bool {
func (e Entry) MatchIP(ip net.IP, db *geoip2.Reader) bool {
if e.All {
return true
}
if e.Net != nil && ip != nil {
if ip == nil {
return false
}
if e.Net != nil {
return e.Net.Contains(ip)
}
if len(e.Country) > 0 && db != nil {
country, err := db.Country(ip)
if err != nil {
return false
}
return country.Country.IsoCode == e.Country
}
return false
}
@ -65,15 +77,16 @@ func ParseEntry(s string) (Entry, error) {
// Make sure there are at least 2 args
args = append(args, "")
}
ipNet, domain, suffix, all, err := parseCond(args[0], args[1])
ipNet, domain, suffix, country, all, err := parseCond(args[0], args[1])
if err != nil {
return Entry{}, err
}
e := Entry{
Net: ipNet,
Domain: domain,
Suffix: suffix,
All: all,
Net: ipNet,
Domain: domain,
Suffix: suffix,
Country: country,
All: all,
}
switch strings.ToLower(fields[0]) {
case "direct":
@ -94,43 +107,48 @@ func ParseEntry(s string) (Entry, error) {
return e, nil
}
func parseCond(typ, cond string) (*net.IPNet, string, bool, bool, error) {
func parseCond(typ, cond string) (*net.IPNet, string, bool, string, bool, error) {
switch strings.ToLower(typ) {
case "domain":
if len(cond) == 0 {
return nil, "", false, false, errors.New("empty domain")
return nil, "", false, "", false, errors.New("empty domain")
}
return nil, strings.ToLower(cond), false, false, nil
return nil, strings.ToLower(cond), false, "", false, nil
case "domain-suffix":
if len(cond) == 0 {
return nil, "", false, false, errors.New("empty domain suffix")
return nil, "", false, "", false, errors.New("empty domain suffix")
}
return nil, strings.ToLower(cond), true, false, nil
return nil, strings.ToLower(cond), true, "", false, nil
case "cidr":
_, ipNet, err := net.ParseCIDR(cond)
if err != nil {
return nil, "", false, false, err
return nil, "", false, "", false, err
}
return ipNet, "", false, false, nil
return ipNet, "", false, "", false, nil
case "ip":
ip := net.ParseIP(cond)
if ip == nil {
return nil, "", false, false, fmt.Errorf("invalid ip %s", cond)
return nil, "", false, "", false, fmt.Errorf("invalid ip %s", cond)
}
if ip.To4() != nil {
return &net.IPNet{
IP: ip,
Mask: net.CIDRMask(32, 32),
}, "", false, false, nil
}, "", false, "", false, nil
} else {
return &net.IPNet{
IP: ip,
Mask: net.CIDRMask(128, 128),
}, "", false, false, nil
}, "", false, "", false, nil
}
case "country":
if len(cond) == 0 {
return nil, "", false, "", false, errors.New("empty country")
}
return nil, "", false, strings.ToUpper(cond), false, nil
case "all":
return nil, "", false, true, nil
return nil, "", false, "", true, nil
default:
return nil, "", false, false, fmt.Errorf("invalid condition type %s", typ)
return nil, "", false, "", false, fmt.Errorf("invalid condition type %s", typ)
}
}

View File

@ -20,19 +20,22 @@ func TestParseEntry(t *testing.T) {
}{
{name: "empty", args: args{""}, want: Entry{}, wantErr: true},
{name: "ok 1", args: args{"direct domain-suffix google.com"},
want: Entry{nil, "google.com", true, false, ActionDirect, ""},
want: Entry{nil, "google.com", true, "", false, ActionDirect, ""},
wantErr: false},
{name: "ok 2", args: args{"proxy ip 8.8.8.8"},
want: Entry{&net.IPNet{net.ParseIP("8.8.8.8"), net.CIDRMask(32, 32)},
"", false, false, ActionProxy, ""}, wantErr: false},
"", false, "", false, ActionProxy, ""}, wantErr: false},
{name: "ok 3", args: args{"hijack domain mad.bad 127.0.0.1"},
want: Entry{nil, "mad.bad", false, false, ActionHijack, "127.0.0.1"},
want: Entry{nil, "mad.bad", false, "", false, ActionHijack, "127.0.0.1"},
wantErr: false},
{name: "ok 4", args: args{"block cidr 8.8.8.0/24"},
want: Entry{ok4ipnet, "", false, false, ActionBlock, ""},
want: Entry{ok4ipnet, "", false, "", false, ActionBlock, ""},
wantErr: false},
{name: "ok 5", args: args{"block all"},
want: Entry{nil, "", false, true, ActionBlock, ""},
want: Entry{nil, "", false, "", true, ActionBlock, ""},
wantErr: false},
{name: "ok 6", args: args{"block country cn"},
want: Entry{nil, "", false, "CN", false, ActionBlock, ""},
wantErr: false},
{name: "invalid 1", args: args{"proxy domain"}, want: Entry{}, wantErr: true},
{name: "invalid 2", args: args{"proxy dom google.com"}, want: Entry{}, wantErr: true},