mirror of
https://github.com/cmz0228/hysteria-dev.git
synced 2025-09-09 18:01:47 +00:00
feat: WIP ACL
This commit is contained in:
223
extras/outbounds/acl/compile.go
Normal file
223
extras/outbounds/acl/compile.go
Normal file
@@ -0,0 +1,223 @@
|
||||
package acl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
lru "github.com/hashicorp/golang-lru/v2"
|
||||
"github.com/oschwald/geoip2-golang"
|
||||
)
|
||||
|
||||
type protocol int
|
||||
|
||||
const (
|
||||
protocolBoth protocol = iota
|
||||
protocolTCP
|
||||
protocolUDP
|
||||
)
|
||||
|
||||
type Outbound interface {
|
||||
any
|
||||
}
|
||||
|
||||
type HostInfo struct {
|
||||
Name string
|
||||
IPv4 net.IP
|
||||
IPv6 net.IP
|
||||
}
|
||||
|
||||
func (h HostInfo) String() string {
|
||||
return fmt.Sprintf("%s|%s|%s", h.Name, h.IPv4, h.IPv6)
|
||||
}
|
||||
|
||||
type CompiledRuleSet[O Outbound] interface {
|
||||
Match(host HostInfo, proto protocol, port uint16) (O, net.IP)
|
||||
}
|
||||
|
||||
type compiledRule[O Outbound] struct {
|
||||
Outbound O
|
||||
HostMatcher hostMatcher
|
||||
Protocol protocol
|
||||
Port uint16
|
||||
HijackAddress net.IP
|
||||
}
|
||||
|
||||
func (r *compiledRule[O]) Match(host HostInfo, proto protocol, port uint16) bool {
|
||||
if r.Protocol != protocolBoth && r.Protocol != proto {
|
||||
return false
|
||||
}
|
||||
if r.Port != 0 && r.Port != port {
|
||||
return false
|
||||
}
|
||||
return r.HostMatcher.Match(host)
|
||||
}
|
||||
|
||||
type matchResult[O Outbound] struct {
|
||||
Outbound O
|
||||
HijackAddress net.IP
|
||||
}
|
||||
|
||||
type compiledRuleSetImpl[O Outbound] struct {
|
||||
Rules []compiledRule[O]
|
||||
Cache *lru.Cache[string, matchResult[O]] // key: HostInfo.String()
|
||||
}
|
||||
|
||||
func (s *compiledRuleSetImpl[O]) Match(host HostInfo, proto protocol, port uint16) (O, net.IP) {
|
||||
host.Name = strings.ToLower(host.Name) // Normalize host name to lower case
|
||||
key := host.String()
|
||||
if result, ok := s.Cache.Get(key); ok {
|
||||
return result.Outbound, result.HijackAddress
|
||||
}
|
||||
for _, rule := range s.Rules {
|
||||
if rule.Match(host, proto, port) {
|
||||
result := matchResult[O]{rule.Outbound, rule.HijackAddress}
|
||||
s.Cache.Add(key, result)
|
||||
return result.Outbound, result.HijackAddress
|
||||
}
|
||||
}
|
||||
// No match should also be cached
|
||||
var zero O
|
||||
s.Cache.Add(key, matchResult[O]{zero, nil})
|
||||
return zero, nil
|
||||
}
|
||||
|
||||
type CompilationError struct {
|
||||
Index int
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *CompilationError) Error() string {
|
||||
return fmt.Sprintf("error at index %d: %s", e.Index, e.Message)
|
||||
}
|
||||
|
||||
func Compile[O Outbound](rules []TextRule, outbounds map[string]O,
|
||||
cacheSize int, geoipFunc func() *geoip2.Reader,
|
||||
) (CompiledRuleSet[O], error) {
|
||||
compiledRules := make([]compiledRule[O], len(rules))
|
||||
for i, rule := range rules {
|
||||
outbound, ok := outbounds[rule.Outbound]
|
||||
if !ok {
|
||||
return nil, &CompilationError{i, fmt.Sprintf("outbound %s not found", rule.Outbound)}
|
||||
}
|
||||
hm, errStr := compileHostMatcher(rule.Address, geoipFunc)
|
||||
if errStr != "" {
|
||||
return nil, &CompilationError{i, errStr}
|
||||
}
|
||||
proto, port, ok := parseProtoPort(rule.ProtoPort)
|
||||
if !ok {
|
||||
return nil, &CompilationError{i, fmt.Sprintf("invalid protocol/port: %s", rule.ProtoPort)}
|
||||
}
|
||||
var hijackAddress net.IP
|
||||
if rule.HijackAddress != "" {
|
||||
hijackAddress = net.ParseIP(rule.HijackAddress)
|
||||
if hijackAddress == nil {
|
||||
return nil, &CompilationError{i, fmt.Sprintf("invalid hijack address (must be an IP address): %s", rule.HijackAddress)}
|
||||
}
|
||||
}
|
||||
compiledRules[i] = compiledRule[O]{outbound, hm, proto, port, hijackAddress}
|
||||
}
|
||||
cache, err := lru.New[string, matchResult[O]](cacheSize)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &compiledRuleSetImpl[O]{compiledRules, cache}, nil
|
||||
}
|
||||
|
||||
// parseProtoPort parses the protocol and port from a protoPort string.
|
||||
// protoPort must be in one of the following formats:
|
||||
//
|
||||
// proto/port
|
||||
// proto/*
|
||||
// proto
|
||||
// */port
|
||||
// */*
|
||||
// *
|
||||
// [empty] (same as *)
|
||||
//
|
||||
// proto must be either "tcp" or "udp", case-insensitive.
|
||||
func parseProtoPort(protoPort string) (protocol, uint16, bool) {
|
||||
protoPort = strings.ToLower(protoPort)
|
||||
if protoPort == "" || protoPort == "*" || protoPort == "*/*" {
|
||||
return protocolBoth, 0, true
|
||||
}
|
||||
parts := strings.SplitN(protoPort, "/", 2)
|
||||
if len(parts) == 1 {
|
||||
// No port, only protocol
|
||||
switch parts[0] {
|
||||
case "tcp":
|
||||
return protocolTCP, 0, true
|
||||
case "udp":
|
||||
return protocolUDP, 0, true
|
||||
default:
|
||||
return protocolBoth, 0, false
|
||||
}
|
||||
} else {
|
||||
// Both protocol and port
|
||||
var proto protocol
|
||||
var port uint16
|
||||
switch parts[0] {
|
||||
case "tcp":
|
||||
proto = protocolTCP
|
||||
case "udp":
|
||||
proto = protocolUDP
|
||||
case "*":
|
||||
proto = protocolBoth
|
||||
default:
|
||||
return protocolBoth, 0, false
|
||||
}
|
||||
if parts[1] != "*" {
|
||||
p64, err := strconv.ParseUint(parts[1], 10, 16)
|
||||
if err != nil {
|
||||
return protocolBoth, 0, false
|
||||
}
|
||||
port = uint16(p64)
|
||||
}
|
||||
return proto, port, true
|
||||
}
|
||||
}
|
||||
|
||||
func compileHostMatcher(addr string, geoipFunc func() *geoip2.Reader) (hostMatcher, string) {
|
||||
addr = strings.ToLower(addr) // Normalize to lower case
|
||||
if addr == "*" || addr == "all" {
|
||||
// Match all hosts
|
||||
return &allMatcher{}, ""
|
||||
}
|
||||
if strings.HasPrefix(addr, "geoip:") {
|
||||
// GeoIP matcher
|
||||
country := strings.ToUpper(addr[6:])
|
||||
if len(country) != 2 {
|
||||
return nil, fmt.Sprintf("invalid country code: %s", country)
|
||||
}
|
||||
db := geoipFunc()
|
||||
if db == nil {
|
||||
return nil, "failed to load GeoIP database"
|
||||
}
|
||||
return &geoIPMatcher{db, country}, ""
|
||||
}
|
||||
if strings.Contains(addr, "/") {
|
||||
// CIDR matcher
|
||||
_, ipnet, err := net.ParseCIDR(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Sprintf("invalid CIDR address: %s", addr)
|
||||
}
|
||||
return &cidrMatcher{ipnet}, ""
|
||||
}
|
||||
if ip := net.ParseIP(addr); ip != nil {
|
||||
// Single IP matcher
|
||||
return &ipMatcher{ip}, ""
|
||||
}
|
||||
if strings.Contains(addr, "*") {
|
||||
// Wildcard domain matcher
|
||||
return &domainMatcher{
|
||||
Pattern: addr,
|
||||
Wildcard: true,
|
||||
}, ""
|
||||
}
|
||||
// Nothing else matched, treat it as a non-wildcard domain
|
||||
return &domainMatcher{
|
||||
Pattern: addr,
|
||||
Wildcard: false,
|
||||
}, ""
|
||||
}
|
Reference in New Issue
Block a user