Merge pull request #197 from HyNetwork/wip-geoip

ACL GeoIP Country
This commit is contained in:
Toby 2022-01-09 19:47:14 -08:00 committed by GitHub
commit 1e6328936c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 148 additions and 7086 deletions

8
ACL.md
View File

@ -13,14 +13,12 @@ Example:
direct domain evil.corp direct domain evil.corp
proxy domain-suffix google.com proxy domain-suffix google.com
block ip 1.2.3.4 block ip 1.2.3.4
block country cn
hijack cidr 192.168.1.1/24 127.0.0.1 hijack cidr 192.168.1.1/24 127.0.0.1
direct all direct all
``` ```
A real-life ACL example of directly connecting to all China IPs (and its generator Python
script) [can be found here](docs/acl).
Hysteria acts according to the first matching rule in the file for each request. When there is no match, the default Hysteria acts according to the first matching rule in the file for each request. When there is no match, the default
behavior is to proxy all connections. You can override this by adding a rule at the end of the file with the condition behavior is to proxy all connections. You can override this by adding a rule at the end of the file with the condition
`all`. `all`.
@ -35,7 +33,7 @@ behavior is to proxy all connections. You can override this by adding a rule at
`hijack` - hijack the connection to another target address (must be specified in the argument) `hijack` - hijack the connection to another target address (must be specified in the argument)
5 condition types: 6 condition types:
`domain` - match a specific domain (does NOT match subdomains! e.g. `apple.com` will not match `cdn.apple.com`) `domain` - match a specific domain (does NOT match subdomains! e.g. `apple.com` will not match `cdn.apple.com`)
@ -45,6 +43,8 @@ behavior is to proxy all connections. You can override this by adding a rule at
`ip` - IPv4 or IPv6 address `ip` - IPv4 or IPv6 address
`country` - match IP by ISO 3166-1 alpha-2 country code
`all` - match anything (usually placed at the end of the file as a default rule) `all` - match anything (usually placed at the end of the file as a default rule)
For domain requests, Hysteria will try to resolve the domains and match both domain & IP rules. In other words, an IP For domain requests, Hysteria will try to resolve the domains and match both domain & IP rules. In other words, an IP

View File

@ -12,13 +12,12 @@ ACL 文件描述如何处理传入请求。服务器和客户端都支持 ACL
direct domain evil.corp direct domain evil.corp
proxy domain-suffix google.com proxy domain-suffix google.com
block ip 1.2.3.4 block ip 1.2.3.4
block country cn
hijack cidr 192.168.1.1/24 127.0.0.1 hijack cidr 192.168.1.1/24 127.0.0.1
direct all direct all
``` ```
一个直连所有中国 IP 的规则和 Python 生成脚本 [在这里](docs/acl)。
Hysteria 根据文件中第一个匹配到规则对每个请求进行操作。当没有匹配时默认的行为是代理连接。可以通过在文件的末尾添加一个规则加上条件 "all" 来设置默认行为。 Hysteria 根据文件中第一个匹配到规则对每个请求进行操作。当没有匹配时默认的行为是代理连接。可以通过在文件的末尾添加一个规则加上条件 "all" 来设置默认行为。
4 种处理方式: 4 种处理方式:
@ -31,7 +30,7 @@ Hysteria 根据文件中第一个匹配到规则对每个请求进行操作。
`hijack` - 把连接劫持到另一个目的地 (必须在参数中指定) `hijack` - 把连接劫持到另一个目的地 (必须在参数中指定)
5 种条件类型: 6 种条件类型:
`domain` - 匹配特定的域名(不匹配子域名!例如:`apple.com` 不匹配 `cdn.apple.com` `domain` - 匹配特定的域名(不匹配子域名!例如:`apple.com` 不匹配 `cdn.apple.com`
@ -41,6 +40,8 @@ Hysteria 根据文件中第一个匹配到规则对每个请求进行操作。
`ip` - IPv4 / IPv6 地址 `ip` - IPv4 / IPv6 地址
`country` - 匹配国家 IPISO 两位字母国家代码
`all` - 匹配所有地址 (通常放在文件尾作为默认规则) `all` - 匹配所有地址 (通常放在文件尾作为默认规则)
对于域名请求Hysteria 将尝试解析域名并同时匹配域名规则和 IP 规则。换句话说IP 规则能覆盖到所有连接,无论客户端是用 IP 还是域名请求。 对于域名请求Hysteria 将尝试解析域名并同时匹配域名规则和 IP 规则。换句话说IP 规则能覆盖到所有连接,无论客户端是用 IP 还是域名请求。

View File

@ -174,6 +174,7 @@ encryption. If you need a proxy, just use our proxy modes.
"down_mbps": 100, // Max download Mbps per client "down_mbps": 100, // Max download Mbps per client
"disable_udp": false, // Disable UDP support "disable_udp": false, // Disable UDP support
"acl": "my_list.acl", // See ACL below "acl": "my_list.acl", // See ACL below
"mmdb": "GeoLite2-Country.mmdb", // MaxMind database for ACL country lookups
"obfs": "AMOGUS", // Obfuscation password "obfs": "AMOGUS", // Obfuscation password
"auth": { // Authentication "auth": { // Authentication
"mode": "password", // Mode, supports "password" "none" and "external" for now "mode": "password", // Mode, supports "password" "none" and "external" for now
@ -318,6 +319,7 @@ hysteria_traffic_uplink_bytes_total{auth="aGFja2VyISE="} 37452
"timeout": 60 // UDP session timeout in seconds "timeout": 60 // UDP session timeout in seconds
}, },
"acl": "my_list.acl", // See ACL below "acl": "my_list.acl", // See ACL below
"mmdb": "GeoLite2-Country.mmdb", // MaxMind database for ACL country lookups
"obfs": "AMOGUS", // Obfuscation password "obfs": "AMOGUS", // Obfuscation password
"auth": "[BASE64]", // Authentication payload in Base64 "auth": "[BASE64]", // Authentication payload in Base64
"auth_str": "yubiyubi", // Authentication payload in string, mutually exclusive with the option above "auth_str": "yubiyubi", // Authentication payload in string, mutually exclusive with the option above

View File

@ -160,6 +160,7 @@ Hysteria 是一个功能丰富的,专为恶劣网络环境进行优化的网
"down_mbps": 100, // 单客户端最大下载速度 "down_mbps": 100, // 单客户端最大下载速度
"disable_udp": false, // 禁用 UDP 支持 "disable_udp": false, // 禁用 UDP 支持
"acl": "my_list.acl", // 见下文 ACL "acl": "my_list.acl", // 见下文 ACL
"mmdb": "GeoLite2-Country.mmdb", // MaxMind IP 库 (ACL)
"obfs": "AMOGUS", // 混淆密码 "obfs": "AMOGUS", // 混淆密码
"auth": { // 验证 "auth": { // 验证
"mode": "password", // 验证模式,暂时只支持 "password" 与 "none" "mode": "password", // 验证模式,暂时只支持 "password" 与 "none"
@ -303,6 +304,7 @@ hysteria_traffic_uplink_bytes_total{auth="aGFja2VyISE="} 37452
"timeout": 60 // UDP 超时秒数 "timeout": 60 // UDP 超时秒数
}, },
"acl": "my_list.acl", // 见下文 ACL "acl": "my_list.acl", // 见下文 ACL
"mmdb": "GeoLite2-Country.mmdb", // MaxMind IP 库 (ACL)
"obfs": "AMOGUS", // 混淆密码 "obfs": "AMOGUS", // 混淆密码
"auth": "[BASE64]", // Base64 验证密钥 "auth": "[BASE64]", // Base64 验证密钥
"auth_str": "yubiyubi", // 字符串验证密钥,和上面的选项二选一 "auth_str": "yubiyubi", // 字符串验证密钥,和上面的选项二选一

View File

@ -3,6 +3,7 @@ package main
import ( import (
"crypto/tls" "crypto/tls"
"crypto/x509" "crypto/x509"
"github.com/oschwald/geoip2-golang"
"io" "io"
"io/ioutil" "io/ioutil"
"net" "net"
@ -93,7 +94,13 @@ func client(config *clientConfig) {
var aclEngine *acl.Engine var aclEngine *acl.Engine
if len(config.ACL) > 0 { if len(config.ACL) > 0 {
var err error var err error
aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultTransport) aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultTransport, func() (*geoip2.Reader, error) {
if len(config.MMDB) > 0 {
return loadMMDBReader(config.MMDB)
} else {
return loadMMDBReader(DefaultMMDBFilename)
}
})
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"error": err, "error": err,

View File

@ -16,6 +16,8 @@ const (
DefaultMaxIncomingStreams = 1024 DefaultMaxIncomingStreams = 1024
DefaultALPN = "hysteria" DefaultALPN = "hysteria"
DefaultMMDBFilename = "GeoLite2-Country.mmdb"
) )
type serverConfig struct { type serverConfig struct {
@ -36,6 +38,7 @@ type serverConfig struct {
DownMbps int `json:"down_mbps"` DownMbps int `json:"down_mbps"`
DisableUDP bool `json:"disable_udp"` DisableUDP bool `json:"disable_udp"`
ACL string `json:"acl"` ACL string `json:"acl"`
MMDB string `json:"mmdb"`
Obfs string `json:"obfs"` Obfs string `json:"obfs"`
Auth struct { Auth struct {
Mode string `json:"mode"` Mode string `json:"mode"`
@ -137,6 +140,7 @@ type clientConfig struct {
Timeout int `json:"timeout"` Timeout int `json:"timeout"`
} `json:"tproxy_udp"` } `json:"tproxy_udp"`
ACL string `json:"acl"` ACL string `json:"acl"`
MMDB string `json:"mmdb"`
Obfs string `json:"obfs"` Obfs string `json:"obfs"`
Auth []byte `json:"auth"` Auth []byte `json:"auth"`
AuthString string `json:"auth_str"` AuthString string `json:"auth_str"`

51
cmd/mmdb.go Normal file
View File

@ -0,0 +1,51 @@
package main
import (
"github.com/oschwald/geoip2-golang"
"github.com/sirupsen/logrus"
"io"
"net/http"
"os"
)
const (
mmdbURL = "https://github.com/P3TERX/GeoLite.mmdb/raw/download/GeoLite2-Country.mmdb"
)
func downloadMMDB(filename string) error {
resp, err := http.Get(mmdbURL)
if err != nil {
return err
}
defer resp.Body.Close()
file, err := os.Create(filename)
if err != nil {
return err
}
defer file.Close()
_, err = io.Copy(file, resp.Body)
return err
}
func loadMMDBReader(filename string) (*geoip2.Reader, error) {
if _, err := os.Stat(filename); err != nil {
if os.IsNotExist(err) {
logrus.Info("GeoLite2 database not found, downloading...")
if err := downloadMMDB(filename); err != nil {
return nil, err
}
logrus.WithFields(logrus.Fields{
"file": filename,
}).Info("GeoLite2 database downloaded")
return geoip2.Open(filename)
} else {
// some other error
return nil, err
}
} else {
// file exists, just open it
return geoip2.Open(filename)
}
}

View File

@ -4,6 +4,7 @@ import (
"crypto/tls" "crypto/tls"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/lucas-clemente/quic-go/congestion" "github.com/lucas-clemente/quic-go/congestion"
"github.com/oschwald/geoip2-golang"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promhttp" "github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -139,7 +140,13 @@ func server(config *serverConfig) {
// ACL // ACL
var aclEngine *acl.Engine var aclEngine *acl.Engine
if len(config.ACL) > 0 { if len(config.ACL) > 0 {
aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultTransport) aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultTransport, func() (*geoip2.Reader, error) {
if len(config.MMDB) > 0 {
return loadMMDBReader(config.MMDB)
} else {
return loadMMDBReader(DefaultMMDBFilename)
}
})
if err != nil { if err != nil {
logrus.WithFields(logrus.Fields{ logrus.WithFields(logrus.Fields{
"error": err, "error": err,

File diff suppressed because it is too large Load Diff

View File

@ -1,20 +0,0 @@
#! /usr/bin/env python3
import urllib.request
from itertools import chain
from datetime import date
data_ipv4 = urllib.request.urlopen(
'http://www.ipdeny.com/ipblocks/data/aggregated/cn-aggregated.zone')
data_ipv6 = urllib.request.urlopen(
'http://www.ipdeny.com/ipv6/ipaddresses/aggregated/cn-aggregated.zone')
data = chain(data_ipv4, data_ipv6)
with open('chnroutes.acl', 'w') as out:
out.write('# chnroutes\n# Generated on %s\n\n' %
date.today().strftime("%B %d, %Y"))
for l in data:
ls = str(l, 'UTF8').strip()
if ls:
out.write('direct cidr %s\n' % ls)

1
go.mod
View File

@ -16,6 +16,7 @@ require (
github.com/hashicorp/golang-lru v0.5.4 github.com/hashicorp/golang-lru v0.5.4
github.com/lucas-clemente/quic-go v0.24.0 github.com/lucas-clemente/quic-go v0.24.0
github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40
github.com/oschwald/geoip2-golang v1.5.0 // indirect
github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect
github.com/prometheus/client_golang v1.11.0 github.com/prometheus/client_golang v1.11.0
github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect

6
go.sum
View File

@ -173,6 +173,10 @@ github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1y
github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak= github.com/onsi/gomega v1.13.0 h1:7lLHu94wT9Ij0o6EWWclhu0aOh32VxhkwEJvzuWPeak=
github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY= github.com/onsi/gomega v1.13.0/go.mod h1:lRk9szgn8TxENtWd0Tp4c3wjlRfMTMH27I+3Je41yGY=
github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8= github.com/openzipkin/zipkin-go v0.1.1/go.mod h1:NtoC/o8u3JlF1lSlyPNswIbeQH9bJTmOf0Erfk+hxe8=
github.com/oschwald/geoip2-golang v1.5.0 h1:igg2yQIrrcRccB1ytFXqBfOHCjXWIoMv85lVJ1ONZzw=
github.com/oschwald/geoip2-golang v1.5.0/go.mod h1:xdvYt5xQzB8ORWFqPnqMwZpCpgNagttWdoZLlJQzg7s=
github.com/oschwald/maxminddb-golang v1.8.0 h1:Uh/DSnGoxsyp/KYbY1AuP0tYEwfs0sCph9p/UMXK/Hk=
github.com/oschwald/maxminddb-golang v1.8.0/go.mod h1:RXZtst0N6+FY/3qCNmZMBApR19cdQj43/NM9VkrNAis=
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
@ -248,6 +252,7 @@ github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXf
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA=
@ -352,6 +357,7 @@ golang.org/x/sys v0.0.0-20190904154756-749cb33beabd/go.mod h1:h1NjWce9XRLGQEsW7w
golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191120155948-bd437916bb0e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191224085550-c709ea063b76/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200106162015-b016eb3dc98e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=

View File

@ -3,6 +3,7 @@ package acl
import ( import (
"bufio" "bufio"
lru "github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
"github.com/oschwald/geoip2-golang"
"github.com/tobyxdd/hysteria/pkg/transport" "github.com/tobyxdd/hysteria/pkg/transport"
"net" "net"
"os" "os"
@ -16,6 +17,7 @@ type Engine struct {
Entries []Entry Entries []Entry
Cache *lru.ARCCache Cache *lru.ARCCache
Transport transport.Transport Transport transport.Transport
GeoIPReader *geoip2.Reader
} }
type cacheEntry struct { type cacheEntry struct {
@ -23,7 +25,7 @@ type cacheEntry struct {
Arg string Arg string
} }
func LoadFromFile(filename string, transport transport.Transport) (*Engine, error) { func LoadFromFile(filename string, transport transport.Transport, geoIPLoadFunc func() (*geoip2.Reader, error)) (*Engine, error) {
f, err := os.Open(filename) f, err := os.Open(filename)
if err != nil { if err != nil {
return nil, err return nil, err
@ -31,6 +33,7 @@ func LoadFromFile(filename string, transport transport.Transport) (*Engine, erro
defer f.Close() defer f.Close()
scanner := bufio.NewScanner(f) scanner := bufio.NewScanner(f)
entries := make([]Entry, 0, 1024) entries := make([]Entry, 0, 1024)
var geoIPReader *geoip2.Reader
for scanner.Scan() { for scanner.Scan() {
line := strings.TrimSpace(scanner.Text()) line := strings.TrimSpace(scanner.Text())
if len(line) == 0 || strings.HasPrefix(line, "#") { if len(line) == 0 || strings.HasPrefix(line, "#") {
@ -41,6 +44,12 @@ func LoadFromFile(filename string, transport transport.Transport) (*Engine, erro
if err != nil { if err != nil {
return nil, err return nil, err
} }
if len(entry.Country) > 0 && geoIPReader == nil {
geoIPReader, err = geoIPLoadFunc() // lazy load GeoIP reader only when needed
if err != nil {
return nil, err
}
}
entries = append(entries, entry) entries = append(entries, entry)
} }
cache, err := lru.NewARC(entryCacheSize) cache, err := lru.NewARC(entryCacheSize)
@ -52,6 +61,7 @@ func LoadFromFile(filename string, transport transport.Transport) (*Engine, erro
Entries: entries, Entries: entries,
Cache: cache, Cache: cache,
Transport: transport, Transport: transport,
GeoIPReader: geoIPReader,
}, nil }, nil
} }
@ -66,7 +76,7 @@ func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, erro
return ce.Action, ce.Arg, ipAddr, err return ce.Action, ce.Arg, ipAddr, err
} }
for _, entry := range e.Entries { for _, entry := range e.Entries {
if entry.MatchDomain(host) || (ipAddr != nil && entry.MatchIP(ipAddr.IP)) { if entry.MatchDomain(host) || (ipAddr != nil && entry.MatchIP(ipAddr.IP, e.GeoIPReader)) {
e.Cache.Add(host, cacheEntry{entry.Action, entry.ActionArg}) e.Cache.Add(host, cacheEntry{entry.Action, entry.ActionArg})
return entry.Action, entry.ActionArg, ipAddr, err return entry.Action, entry.ActionArg, ipAddr, err
} }
@ -84,7 +94,7 @@ func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, erro
}, nil }, nil
} }
for _, entry := range e.Entries { for _, entry := range e.Entries {
if entry.MatchIP(ip) { if entry.MatchIP(ip, e.GeoIPReader) {
e.Cache.Add(ip.String(), cacheEntry{entry.Action, entry.ActionArg}) e.Cache.Add(ip.String(), cacheEntry{entry.Action, entry.ActionArg})
return entry.Action, entry.ActionArg, &net.IPAddr{ return entry.Action, entry.ActionArg, &net.IPAddr{
IP: ip, IP: ip,

View File

@ -3,6 +3,7 @@ package acl
import ( import (
"errors" "errors"
"fmt" "fmt"
"github.com/oschwald/geoip2-golang"
"net" "net"
"strings" "strings"
) )
@ -20,6 +21,7 @@ type Entry struct {
Net *net.IPNet Net *net.IPNet
Domain string Domain string
Suffix bool Suffix bool
Country string
All bool All bool
Action Action Action Action
ActionArg string ActionArg string
@ -40,13 +42,23 @@ func (e Entry) MatchDomain(domain string) bool {
return false return false
} }
func (e Entry) MatchIP(ip net.IP) bool { func (e Entry) MatchIP(ip net.IP, db *geoip2.Reader) bool {
if e.All { if e.All {
return true return true
} }
if e.Net != nil && ip != nil { if ip == nil {
return false
}
if e.Net != nil {
return e.Net.Contains(ip) return e.Net.Contains(ip)
} }
if len(e.Country) > 0 && db != nil {
country, err := db.Country(ip)
if err != nil {
return false
}
return country.Country.IsoCode == e.Country
}
return false return false
} }
@ -65,7 +77,7 @@ func ParseEntry(s string) (Entry, error) {
// Make sure there are at least 2 args // Make sure there are at least 2 args
args = append(args, "") args = append(args, "")
} }
ipNet, domain, suffix, all, err := parseCond(args[0], args[1]) ipNet, domain, suffix, country, all, err := parseCond(args[0], args[1])
if err != nil { if err != nil {
return Entry{}, err return Entry{}, err
} }
@ -73,6 +85,7 @@ func ParseEntry(s string) (Entry, error) {
Net: ipNet, Net: ipNet,
Domain: domain, Domain: domain,
Suffix: suffix, Suffix: suffix,
Country: country,
All: all, All: all,
} }
switch strings.ToLower(fields[0]) { switch strings.ToLower(fields[0]) {
@ -94,43 +107,48 @@ func ParseEntry(s string) (Entry, error) {
return e, nil return e, nil
} }
func parseCond(typ, cond string) (*net.IPNet, string, bool, bool, error) { func parseCond(typ, cond string) (*net.IPNet, string, bool, string, bool, error) {
switch strings.ToLower(typ) { switch strings.ToLower(typ) {
case "domain": case "domain":
if len(cond) == 0 { if len(cond) == 0 {
return nil, "", false, false, errors.New("empty domain") return nil, "", false, "", false, errors.New("empty domain")
} }
return nil, strings.ToLower(cond), false, false, nil return nil, strings.ToLower(cond), false, "", false, nil
case "domain-suffix": case "domain-suffix":
if len(cond) == 0 { if len(cond) == 0 {
return nil, "", false, false, errors.New("empty domain suffix") return nil, "", false, "", false, errors.New("empty domain suffix")
} }
return nil, strings.ToLower(cond), true, false, nil return nil, strings.ToLower(cond), true, "", false, nil
case "cidr": case "cidr":
_, ipNet, err := net.ParseCIDR(cond) _, ipNet, err := net.ParseCIDR(cond)
if err != nil { if err != nil {
return nil, "", false, false, err return nil, "", false, "", false, err
} }
return ipNet, "", false, false, nil return ipNet, "", false, "", false, nil
case "ip": case "ip":
ip := net.ParseIP(cond) ip := net.ParseIP(cond)
if ip == nil { if ip == nil {
return nil, "", false, false, fmt.Errorf("invalid ip %s", cond) return nil, "", false, "", false, fmt.Errorf("invalid ip %s", cond)
} }
if ip.To4() != nil { if ip.To4() != nil {
return &net.IPNet{ return &net.IPNet{
IP: ip, IP: ip,
Mask: net.CIDRMask(32, 32), Mask: net.CIDRMask(32, 32),
}, "", false, false, nil }, "", false, "", false, nil
} else { } else {
return &net.IPNet{ return &net.IPNet{
IP: ip, IP: ip,
Mask: net.CIDRMask(128, 128), Mask: net.CIDRMask(128, 128),
}, "", false, false, nil }, "", false, "", false, nil
} }
case "country":
if len(cond) == 0 {
return nil, "", false, "", false, errors.New("empty country")
}
return nil, "", false, strings.ToUpper(cond), false, nil
case "all": case "all":
return nil, "", false, true, nil return nil, "", false, "", true, nil
default: default:
return nil, "", false, false, fmt.Errorf("invalid condition type %s", typ) return nil, "", false, "", false, fmt.Errorf("invalid condition type %s", typ)
} }
} }

View File

@ -20,19 +20,22 @@ func TestParseEntry(t *testing.T) {
}{ }{
{name: "empty", args: args{""}, want: Entry{}, wantErr: true}, {name: "empty", args: args{""}, want: Entry{}, wantErr: true},
{name: "ok 1", args: args{"direct domain-suffix google.com"}, {name: "ok 1", args: args{"direct domain-suffix google.com"},
want: Entry{nil, "google.com", true, false, ActionDirect, ""}, want: Entry{nil, "google.com", true, "", false, ActionDirect, ""},
wantErr: false}, wantErr: false},
{name: "ok 2", args: args{"proxy ip 8.8.8.8"}, {name: "ok 2", args: args{"proxy ip 8.8.8.8"},
want: Entry{&net.IPNet{net.ParseIP("8.8.8.8"), net.CIDRMask(32, 32)}, want: Entry{&net.IPNet{net.ParseIP("8.8.8.8"), net.CIDRMask(32, 32)},
"", false, false, ActionProxy, ""}, wantErr: false}, "", false, "", false, ActionProxy, ""}, wantErr: false},
{name: "ok 3", args: args{"hijack domain mad.bad 127.0.0.1"}, {name: "ok 3", args: args{"hijack domain mad.bad 127.0.0.1"},
want: Entry{nil, "mad.bad", false, false, ActionHijack, "127.0.0.1"}, want: Entry{nil, "mad.bad", false, "", false, ActionHijack, "127.0.0.1"},
wantErr: false}, wantErr: false},
{name: "ok 4", args: args{"block cidr 8.8.8.0/24"}, {name: "ok 4", args: args{"block cidr 8.8.8.0/24"},
want: Entry{ok4ipnet, "", false, false, ActionBlock, ""}, want: Entry{ok4ipnet, "", false, "", false, ActionBlock, ""},
wantErr: false}, wantErr: false},
{name: "ok 5", args: args{"block all"}, {name: "ok 5", args: args{"block all"},
want: Entry{nil, "", false, true, ActionBlock, ""}, want: Entry{nil, "", false, "", true, ActionBlock, ""},
wantErr: false},
{name: "ok 6", args: args{"block country cn"},
want: Entry{nil, "", false, "CN", false, ActionBlock, ""},
wantErr: false}, wantErr: false},
{name: "invalid 1", args: args{"proxy domain"}, want: Entry{}, wantErr: true}, {name: "invalid 1", args: args{"proxy domain"}, want: Entry{}, wantErr: true},
{name: "invalid 2", args: args{"proxy dom google.com"}, want: Entry{}, wantErr: true}, {name: "invalid 2", args: args{"proxy dom google.com"}, want: Entry{}, wantErr: true},