Merge pull request #291 from HyNetwork/domain-passthrough

feat: server SOCKS5 outbound domain passthrough
This commit is contained in:
Toby 2022-04-14 00:21:45 -07:00 committed by GitHub
commit 02937081bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 178 additions and 112 deletions

View File

@ -166,7 +166,10 @@ func server(config *serverConfig) {
// ACL // ACL
var aclEngine *acl.Engine var aclEngine *acl.Engine
if len(config.ACL) > 0 { if len(config.ACL) > 0 {
aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultServerTransport.ResolveIPAddr, aclEngine, err = acl.LoadFromFile(config.ACL, func(addr string) (*net.IPAddr, error) {
ipAddr, _, err := transport.DefaultServerTransport.ResolveIPAddr(addr)
return ipAddr, err
},
func() (*geoip2.Reader, error) { func() (*geoip2.Reader, error) {
if len(config.MMDB) > 0 { if len(config.MMDB) > 0 {
return loadMMDBReader(config.MMDB) return loadMMDBReader(config.MMDB)

View File

@ -4,6 +4,7 @@ import (
"bufio" "bufio"
lru "github.com/hashicorp/golang-lru" lru "github.com/hashicorp/golang-lru"
"github.com/oschwald/geoip2-golang" "github.com/oschwald/geoip2-golang"
"github.com/tobyxdd/hysteria/pkg/utils"
"net" "net"
"os" "os"
"strings" "strings"
@ -64,30 +65,31 @@ func LoadFromFile(filename string, resolveIPAddr func(string) (*net.IPAddr, erro
}, nil }, nil
} }
func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, error) { // action, arg, isDomain, resolvedIP, error
ip, zone := parseIPZone(host) func (e *Engine) ResolveAndMatch(host string) (Action, string, bool, *net.IPAddr, error) {
ip, zone := utils.ParseIPZone(host)
if ip == nil { if ip == nil {
// Domain // Domain
ipAddr, err := e.ResolveIPAddr(host) ipAddr, err := e.ResolveIPAddr(host)
if v, ok := e.Cache.Get(host); ok { if v, ok := e.Cache.Get(host); ok {
// Cache hit // Cache hit
ce := v.(cacheEntry) ce := v.(cacheEntry)
return ce.Action, ce.Arg, ipAddr, err return ce.Action, ce.Arg, true, ipAddr, err
} }
for _, entry := range e.Entries { for _, entry := range e.Entries {
if entry.MatchDomain(host) || (ipAddr != nil && entry.MatchIP(ipAddr.IP, e.GeoIPReader)) { if entry.MatchDomain(host) || (ipAddr != nil && entry.MatchIP(ipAddr.IP, e.GeoIPReader)) {
e.Cache.Add(host, cacheEntry{entry.Action, entry.ActionArg}) e.Cache.Add(host, cacheEntry{entry.Action, entry.ActionArg})
return entry.Action, entry.ActionArg, ipAddr, err return entry.Action, entry.ActionArg, true, ipAddr, err
} }
} }
e.Cache.Add(host, cacheEntry{e.DefaultAction, ""}) e.Cache.Add(host, cacheEntry{e.DefaultAction, ""})
return e.DefaultAction, "", ipAddr, err return e.DefaultAction, "", true, ipAddr, err
} else { } else {
// IP // IP
if v, ok := e.Cache.Get(ip.String()); ok { if v, ok := e.Cache.Get(ip.String()); ok {
// Cache hit // Cache hit
ce := v.(cacheEntry) ce := v.(cacheEntry)
return ce.Action, ce.Arg, &net.IPAddr{ return ce.Action, ce.Arg, false, &net.IPAddr{
IP: ip, IP: ip,
Zone: zone, Zone: zone,
}, nil }, nil
@ -95,14 +97,14 @@ func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, erro
for _, entry := range e.Entries { for _, entry := range e.Entries {
if entry.MatchIP(ip, e.GeoIPReader) { if entry.MatchIP(ip, e.GeoIPReader) {
e.Cache.Add(ip.String(), cacheEntry{entry.Action, entry.ActionArg}) e.Cache.Add(ip.String(), cacheEntry{entry.Action, entry.ActionArg})
return entry.Action, entry.ActionArg, &net.IPAddr{ return entry.Action, entry.ActionArg, false, &net.IPAddr{
IP: ip, IP: ip,
Zone: zone, Zone: zone,
}, nil }, nil
} }
} }
e.Cache.Add(ip.String(), cacheEntry{e.DefaultAction, ""}) e.Cache.Add(ip.String(), cacheEntry{e.DefaultAction, ""})
return e.DefaultAction, "", &net.IPAddr{ return e.DefaultAction, "", false, &net.IPAddr{
IP: ip, IP: ip,
Zone: zone, Zone: zone,
}, nil }, nil

View File

@ -1,27 +0,0 @@
package acl
import "net"
func parseIPZone(s string) (net.IP, string) {
s, zone := splitHostZone(s)
return net.ParseIP(s), zone
}
func splitHostZone(s string) (host, zone string) {
if i := last(s, '%'); i > 0 {
host, zone = s[:i], s[i+1:]
} else {
host = s
}
return
}
func last(s string, b byte) int {
i := len(s)
for i--; i >= 0; i-- {
if s[i] == b {
break
}
}
return i
}

View File

@ -154,36 +154,43 @@ func (c *serverClient) handleMessage(msg []byte) {
if ok { if ok {
// Session found, send the message // Session found, send the message
action, arg := acl.ActionDirect, "" action, arg := acl.ActionDirect, ""
var isDomain bool
var ipAddr *net.IPAddr var ipAddr *net.IPAddr
var err error var err error
if c.ACLEngine != nil { if c.ACLEngine != nil {
action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host) action, arg, isDomain, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host)
} else { } else {
ipAddr, err = c.Transport.ResolveIPAddr(dfMsg.Host) ipAddr, isDomain, err = c.Transport.ResolveIPAddr(dfMsg.Host)
} }
if err != nil { if err != nil && !(isDomain && c.Transport.SOCKS5Enabled()) { // Special case for domain requests + SOCKS5 outbound
return return
} }
switch action { switch action {
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
_, _ = conn.WriteToUDP(dfMsg.Data, &net.UDPAddr{ addrEx := &transport.AddrEx{
IP: ipAddr.IP, IPAddr: ipAddr,
Port: int(dfMsg.Port), Port: int(dfMsg.Port),
Zone: ipAddr.Zone, }
}) if isDomain {
addrEx.Domain = dfMsg.Host
}
_, _ = conn.WriteToUDP(dfMsg.Data, addrEx)
if c.UpCounter != nil { if c.UpCounter != nil {
c.UpCounter.Add(float64(len(dfMsg.Data))) c.UpCounter.Add(float64(len(dfMsg.Data)))
} }
case acl.ActionBlock: case acl.ActionBlock:
// Do nothing // Do nothing
case acl.ActionHijack: case acl.ActionHijack:
hijackIPAddr, err := c.Transport.ResolveIPAddr(arg) hijackIPAddr, isDomain, err := c.Transport.ResolveIPAddr(arg)
if err == nil { if err == nil || (isDomain && c.Transport.SOCKS5Enabled()) { // Special case for domain requests + SOCKS5 outbound
_, _ = conn.WriteToUDP(dfMsg.Data, &net.UDPAddr{ addrEx := &transport.AddrEx{
IP: hijackIPAddr.IP, IPAddr: hijackIPAddr,
Port: int(dfMsg.Port), Port: int(dfMsg.Port),
Zone: hijackIPAddr.Zone, }
}) if isDomain {
addrEx.Domain = arg
}
_, _ = conn.WriteToUDP(dfMsg.Data, addrEx)
if c.UpCounter != nil { if c.UpCounter != nil {
c.UpCounter.Add(float64(len(dfMsg.Data))) c.UpCounter.Add(float64(len(dfMsg.Data)))
} }
@ -197,14 +204,15 @@ func (c *serverClient) handleMessage(msg []byte) {
func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) { func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
addrStr := net.JoinHostPort(host, strconv.Itoa(int(port))) addrStr := net.JoinHostPort(host, strconv.Itoa(int(port)))
action, arg := acl.ActionDirect, "" action, arg := acl.ActionDirect, ""
var isDomain bool
var ipAddr *net.IPAddr var ipAddr *net.IPAddr
var err error var err error
if c.ACLEngine != nil { if c.ACLEngine != nil {
action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(host) action, arg, isDomain, ipAddr, err = c.ACLEngine.ResolveAndMatch(host)
} else { } else {
ipAddr, err = c.Transport.ResolveIPAddr(host) ipAddr, isDomain, err = c.Transport.ResolveIPAddr(host)
} }
if err != nil { if err != nil && !(isDomain && c.Transport.SOCKS5Enabled()) { // Special case for domain requests + SOCKS5 outbound
_ = struc.Pack(stream, &serverResponse{ _ = struc.Pack(stream, &serverResponse{
OK: false, OK: false,
Message: "host resolution failure", Message: "host resolution failure",
@ -217,11 +225,14 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
var conn net.Conn // Connection to be piped var conn net.Conn // Connection to be piped
switch action { switch action {
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
conn, err = c.Transport.DialTCP(&net.TCPAddr{ addrEx := &transport.AddrEx{
IP: ipAddr.IP, IPAddr: ipAddr,
Port: int(port), Port: int(port),
Zone: ipAddr.Zone, }
}) if isDomain {
addrEx.Domain = host
}
conn, err = c.Transport.DialTCP(addrEx)
if err != nil { if err != nil {
_ = struc.Pack(stream, &serverResponse{ _ = struc.Pack(stream, &serverResponse{
OK: false, OK: false,
@ -237,8 +248,8 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
}) })
return return
case acl.ActionHijack: case acl.ActionHijack:
hijackIPAddr, err := c.Transport.ResolveIPAddr(arg) hijackIPAddr, isDomain, err := c.Transport.ResolveIPAddr(arg)
if err != nil { if err != nil && !(isDomain && c.Transport.SOCKS5Enabled()) { // Special case for domain requests + SOCKS5 outbound
_ = struc.Pack(stream, &serverResponse{ _ = struc.Pack(stream, &serverResponse{
OK: false, OK: false,
Message: err.Error(), Message: err.Error(),
@ -246,11 +257,14 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) {
c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err) c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err)
return return
} }
conn, err = c.Transport.DialTCP(&net.TCPAddr{ addrEx := &transport.AddrEx{
IP: hijackIPAddr.IP, IPAddr: hijackIPAddr,
Port: int(port), Port: int(port),
Zone: hijackIPAddr.Zone, }
}) if isDomain {
addrEx.Domain = arg
}
conn, err = c.Transport.DialTCP(addrEx)
if err != nil { if err != nil {
_ = struc.Pack(stream, &serverResponse{ _ = struc.Pack(stream, &serverResponse{
OK: false, OK: false,

View File

@ -34,7 +34,7 @@ func NewProxyHTTPServer(hyClient *core.Client, transport *transport.ClientTransp
var ipAddr *net.IPAddr var ipAddr *net.IPAddr
var resErr error var resErr error
if aclEngine != nil { if aclEngine != nil {
action, arg, ipAddr, resErr = aclEngine.ResolveAndMatch(host) action, arg, _, ipAddr, resErr = aclEngine.ResolveAndMatch(host)
// Doesn't always matter if the resolution fails, as we may send it through HyClient // Doesn't always matter if the resolution fails, as we may send it through HyClient
} }
newDialFunc(addr, action, arg) newDialFunc(addr, action, arg)

View File

@ -171,7 +171,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error {
var ipAddr *net.IPAddr var ipAddr *net.IPAddr
var resErr error var resErr error
if s.ACLEngine != nil { if s.ACLEngine != nil {
action, arg, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host) action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host)
// Doesn't always matter if the resolution fails, as we may send it through HyClient // Doesn't always matter if the resolution fails, as we may send it through HyClient
} }
s.TCPRequestFunc(c.RemoteAddr(), addr, action, arg) s.TCPRequestFunc(c.RemoteAddr(), addr, action, arg)
@ -377,7 +377,7 @@ func (s *Server) udpServer(clientConn *net.UDPConn, localRelayConn *net.UDPConn,
var ipAddr *net.IPAddr var ipAddr *net.IPAddr
var resErr error var resErr error
if s.ACLEngine != nil && localRelayConn != nil { if s.ACLEngine != nil && localRelayConn != nil {
action, arg, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host) action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host)
// Doesn't always matter if the resolution fails, as we may send it through HyClient // Doesn't always matter if the resolution fails, as we may send it through HyClient
} }
// Handle according to the action // Handle according to the action

View File

@ -8,7 +8,9 @@ import (
"github.com/tobyxdd/hysteria/pkg/conns/udp" "github.com/tobyxdd/hysteria/pkg/conns/udp"
"github.com/tobyxdd/hysteria/pkg/conns/wechat" "github.com/tobyxdd/hysteria/pkg/conns/wechat"
"github.com/tobyxdd/hysteria/pkg/obfs" "github.com/tobyxdd/hysteria/pkg/obfs"
"github.com/tobyxdd/hysteria/pkg/utils"
"net" "net"
"strconv"
"time" "time"
) )
@ -20,12 +22,51 @@ type ServerTransport struct {
PrefExclusive bool PrefExclusive bool
} }
// AddrEx is like net.TCPAddr or net.UDPAddr, but with additional domain information for SOCKS5.
// At least one of Domain and IPAddr must be non-empty.
type AddrEx struct {
Domain string
IPAddr *net.IPAddr
Port int
}
func (a *AddrEx) String() string {
if a == nil {
return "<nil>"
}
var ip string
if a.IPAddr != nil {
ip = a.IPAddr.String()
}
return net.JoinHostPort(ip, strconv.Itoa(a.Port))
}
type PUDPConn interface { type PUDPConn interface {
ReadFromUDP([]byte) (int, *net.UDPAddr, error) ReadFromUDP([]byte) (int, *net.UDPAddr, error)
WriteToUDP([]byte, *net.UDPAddr) (int, error) WriteToUDP([]byte, *AddrEx) (int, error)
Close() error Close() error
} }
type udpConnPUDPConn struct {
Conn *net.UDPConn
}
func (c *udpConnPUDPConn) ReadFromUDP(bytes []byte) (int, *net.UDPAddr, error) {
return c.Conn.ReadFromUDP(bytes)
}
func (c *udpConnPUDPConn) WriteToUDP(bytes []byte, ex *AddrEx) (int, error) {
return c.Conn.WriteToUDP(bytes, &net.UDPAddr{
IP: ex.IPAddr.IP,
Port: ex.Port,
Zone: ex.IPAddr.Zone,
})
}
func (c *udpConnPUDPConn) Close() error {
return c.Conn.Close()
}
var DefaultServerTransport = &ServerTransport{ var DefaultServerTransport = &ServerTransport{
Dialer: &net.Dialer{ Dialer: &net.Dialer{
Timeout: 8 * time.Second, Timeout: 8 * time.Second,
@ -80,8 +121,8 @@ func (st *ServerTransport) quicPacketConn(proto string, laddr string, obfs obfs.
} }
} }
func (ct *ServerTransport) QUICListen(proto string, listen string, tlsConfig *tls.Config, quicConfig *quic.Config, obfs obfs.Obfuscator) (quic.Listener, error) { func (st *ServerTransport) QUICListen(proto string, listen string, tlsConfig *tls.Config, quicConfig *quic.Config, obfs obfs.Obfuscator) (quic.Listener, error) {
pktConn, err := ct.quicPacketConn(proto, listen, obfs) pktConn, err := st.quicPacketConn(proto, listen, obfs)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -93,19 +134,25 @@ func (ct *ServerTransport) QUICListen(proto string, listen string, tlsConfig *tl
return l, nil return l, nil
} }
func (ct *ServerTransport) ResolveIPAddr(address string) (*net.IPAddr, error) { func (st *ServerTransport) ResolveIPAddr(address string) (*net.IPAddr, bool, error) {
if ct.PrefEnabled { ip, zone := utils.ParseIPZone(address)
return resolveIPAddrWithPreference(address, ct.PrefIPv6, ct.PrefExclusive) if ip != nil {
return &net.IPAddr{IP: ip, Zone: zone}, false, nil
}
if st.PrefEnabled {
ipAddr, err := resolveIPAddrWithPreference(address, st.PrefIPv6, st.PrefExclusive)
return ipAddr, true, err
} else { } else {
return net.ResolveIPAddr("ip", address) ipAddr, err := net.ResolveIPAddr("ip", address)
return ipAddr, true, err
} }
} }
func (ct *ServerTransport) DialTCP(raddr *net.TCPAddr) (*net.TCPConn, error) { func (st *ServerTransport) DialTCP(raddr *AddrEx) (*net.TCPConn, error) {
if ct.SOCKS5Client != nil { if st.SOCKS5Client != nil {
return ct.SOCKS5Client.DialTCP(raddr) return st.SOCKS5Client.DialTCP(raddr)
} else { } else {
conn, err := ct.Dialer.Dial("tcp", raddr.String()) conn, err := st.Dialer.Dial("tcp", raddr.String())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -113,10 +160,20 @@ func (ct *ServerTransport) DialTCP(raddr *net.TCPAddr) (*net.TCPConn, error) {
} }
} }
func (ct *ServerTransport) ListenUDP() (PUDPConn, error) { func (st *ServerTransport) ListenUDP() (PUDPConn, error) {
if ct.SOCKS5Client != nil { if st.SOCKS5Client != nil {
return ct.SOCKS5Client.ListenUDP() return st.SOCKS5Client.ListenUDP()
} else { } else {
return net.ListenUDP("udp", nil) conn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
return &udpConnPUDPConn{
Conn: conn,
}, nil
} }
} }
func (st *ServerTransport) SOCKS5Enabled() bool {
return st.SOCKS5Client != nil
}

View File

@ -73,7 +73,7 @@ func (c *SOCKS5Client) request(conn *net.TCPConn, r *socks5.Request) (*socks5.Re
return reply, nil return reply, nil
} }
func (c *SOCKS5Client) DialTCP(raddr *net.TCPAddr) (*net.TCPConn, error) { func (c *SOCKS5Client) DialTCP(raddr *AddrEx) (*net.TCPConn, error) {
conn, err := net.DialTCP("tcp", nil, c.ServerTCPAddr) conn, err := net.DialTCP("tcp", nil, c.ServerTCPAddr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -87,7 +87,7 @@ func (c *SOCKS5Client) DialTCP(raddr *net.TCPAddr) (*net.TCPConn, error) {
_ = conn.Close() _ = conn.Close()
return nil, err return nil, err
} }
atyp, addr, port, err := addrToSOCKS5Addr(raddr) atyp, addr, port, err := addrExToSOCKS5Addr(raddr)
if err != nil { if err != nil {
_ = conn.Close() _ = conn.Close()
return nil, err return nil, err
@ -191,8 +191,8 @@ func (c *socks5UDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) {
return n, addr, nil return n, addr, nil
} }
func (c *socks5UDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { func (c *socks5UDPConn) WriteToUDP(b []byte, addr *AddrEx) (int, error) {
atyp, dstAddr, dstPort, err := addrToSOCKS5Addr(addr) atyp, dstAddr, dstPort, err := addrExToSOCKS5Addr(addr)
if err != nil { if err != nil {
return 0, err return 0, err
} }
@ -247,30 +247,23 @@ func socks5AddrToUDPAddr(atyp byte, addr []byte, port []byte) (*net.UDPAddr, err
} }
} }
func addrToSOCKS5Addr(addr net.Addr) (byte, []byte, []byte, error) { func addrExToSOCKS5Addr(addr *AddrEx) (byte, []byte, []byte, error) {
var addrIP net.IP sport := make([]byte, 2)
var addrPort int binary.BigEndian.PutUint16(sport, uint16(addr.Port))
if tcpAddr, ok := addr.(*net.TCPAddr); ok { if len(addr.Domain) > 0 {
addrIP = tcpAddr.IP return socks5.ATYPDomain, []byte(addr.Domain), sport, nil
addrPort = tcpAddr.Port
} else if udpAddr, ok := addr.(*net.UDPAddr); ok {
addrIP = udpAddr.IP
addrPort = udpAddr.Port
} else { } else {
return 0, nil, nil, errors.New("unsupported address type")
}
var atyp byte var atyp byte
var saddr, sport []byte var saddr []byte
if ip4 := addrIP.To4(); ip4 != nil { if ip4 := addr.IPAddr.IP.To4(); ip4 != nil {
atyp = socks5.ATYPIPv4 atyp = socks5.ATYPIPv4
saddr = ip4 saddr = ip4
} else if ip6 := addrIP.To16(); ip6 != nil { } else if ip6 := addr.IPAddr.IP.To16(); ip6 != nil {
atyp = socks5.ATYPIPv6 atyp = socks5.ATYPIPv6
saddr = ip6 saddr = ip6
} else { } else {
return 0, nil, nil, errors.New("unsupported address type") return 0, nil, nil, errors.New("unsupported address type")
} }
sport = make([]byte, 2)
binary.BigEndian.PutUint16(sport, uint16(addrPort))
return atyp, saddr, sport, nil return atyp, saddr, sport, nil
}
} }

View File

@ -16,3 +16,27 @@ func SplitHostPort(hostport string) (string, uint16, error) {
} }
return host, uint16(portUint), err return host, uint16(portUint), err
} }
func ParseIPZone(s string) (net.IP, string) {
s, zone := splitHostZone(s)
return net.ParseIP(s), zone
}
func splitHostZone(s string) (host, zone string) {
if i := last(s, '%'); i > 0 {
host, zone = s[:i], s[i+1:]
} else {
host = s
}
return
}
func last(s string, b byte) int {
i := len(s)
for i--; i >= 0; i-- {
if s[i] == b {
break
}
}
return i
}