From fba6cf7a1c85020de3211a5a570012e5e76cc726 Mon Sep 17 00:00:00 2001 From: Toby Date: Thu, 14 Apr 2022 00:11:44 -0700 Subject: [PATCH] feat: server SOCKS5 outbound domain passthrough --- cmd/server.go | 5 ++- pkg/acl/engine.go | 18 ++++---- pkg/acl/ip.go | 27 ------------ pkg/core/server_client.go | 74 +++++++++++++++++++-------------- pkg/http/server.go | 2 +- pkg/socks5/server.go | 4 +- pkg/transport/server.go | 87 ++++++++++++++++++++++++++++++++------- pkg/transport/socks5.go | 49 ++++++++++------------ pkg/utils/misc.go | 24 +++++++++++ 9 files changed, 178 insertions(+), 112 deletions(-) delete mode 100644 pkg/acl/ip.go diff --git a/cmd/server.go b/cmd/server.go index 2d36772..a738e1d 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -166,7 +166,10 @@ func server(config *serverConfig) { // ACL var aclEngine *acl.Engine if len(config.ACL) > 0 { - aclEngine, err = acl.LoadFromFile(config.ACL, transport.DefaultServerTransport.ResolveIPAddr, + aclEngine, err = acl.LoadFromFile(config.ACL, func(addr string) (*net.IPAddr, error) { + ipAddr, _, err := transport.DefaultServerTransport.ResolveIPAddr(addr) + return ipAddr, err + }, func() (*geoip2.Reader, error) { if len(config.MMDB) > 0 { return loadMMDBReader(config.MMDB) diff --git a/pkg/acl/engine.go b/pkg/acl/engine.go index 151df18..c74d8b4 100644 --- a/pkg/acl/engine.go +++ b/pkg/acl/engine.go @@ -4,6 +4,7 @@ import ( "bufio" lru "github.com/hashicorp/golang-lru" "github.com/oschwald/geoip2-golang" + "github.com/tobyxdd/hysteria/pkg/utils" "net" "os" "strings" @@ -64,30 +65,31 @@ func LoadFromFile(filename string, resolveIPAddr func(string) (*net.IPAddr, erro }, nil } -func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, error) { - ip, zone := parseIPZone(host) +// action, arg, isDomain, resolvedIP, error +func (e *Engine) ResolveAndMatch(host string) (Action, string, bool, *net.IPAddr, error) { + ip, zone := utils.ParseIPZone(host) if ip == nil { // Domain ipAddr, err := e.ResolveIPAddr(host) if v, ok := e.Cache.Get(host); ok { // Cache hit ce := v.(cacheEntry) - return ce.Action, ce.Arg, ipAddr, err + return ce.Action, ce.Arg, true, ipAddr, err } for _, entry := range e.Entries { if entry.MatchDomain(host) || (ipAddr != nil && entry.MatchIP(ipAddr.IP, e.GeoIPReader)) { e.Cache.Add(host, cacheEntry{entry.Action, entry.ActionArg}) - return entry.Action, entry.ActionArg, ipAddr, err + return entry.Action, entry.ActionArg, true, ipAddr, err } } e.Cache.Add(host, cacheEntry{e.DefaultAction, ""}) - return e.DefaultAction, "", ipAddr, err + return e.DefaultAction, "", true, ipAddr, err } else { // IP if v, ok := e.Cache.Get(ip.String()); ok { // Cache hit ce := v.(cacheEntry) - return ce.Action, ce.Arg, &net.IPAddr{ + return ce.Action, ce.Arg, false, &net.IPAddr{ IP: ip, Zone: zone, }, nil @@ -95,14 +97,14 @@ func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, erro for _, entry := range e.Entries { if entry.MatchIP(ip, e.GeoIPReader) { e.Cache.Add(ip.String(), cacheEntry{entry.Action, entry.ActionArg}) - return entry.Action, entry.ActionArg, &net.IPAddr{ + return entry.Action, entry.ActionArg, false, &net.IPAddr{ IP: ip, Zone: zone, }, nil } } e.Cache.Add(ip.String(), cacheEntry{e.DefaultAction, ""}) - return e.DefaultAction, "", &net.IPAddr{ + return e.DefaultAction, "", false, &net.IPAddr{ IP: ip, Zone: zone, }, nil diff --git a/pkg/acl/ip.go b/pkg/acl/ip.go deleted file mode 100644 index fe31923..0000000 --- a/pkg/acl/ip.go +++ /dev/null @@ -1,27 +0,0 @@ -package acl - -import "net" - -func parseIPZone(s string) (net.IP, string) { - s, zone := splitHostZone(s) - return net.ParseIP(s), zone -} - -func splitHostZone(s string) (host, zone string) { - if i := last(s, '%'); i > 0 { - host, zone = s[:i], s[i+1:] - } else { - host = s - } - return -} - -func last(s string, b byte) int { - i := len(s) - for i--; i >= 0; i-- { - if s[i] == b { - break - } - } - return i -} diff --git a/pkg/core/server_client.go b/pkg/core/server_client.go index 73a1ba3..98da1b1 100644 --- a/pkg/core/server_client.go +++ b/pkg/core/server_client.go @@ -154,36 +154,43 @@ func (c *serverClient) handleMessage(msg []byte) { if ok { // Session found, send the message action, arg := acl.ActionDirect, "" + var isDomain bool var ipAddr *net.IPAddr var err error if c.ACLEngine != nil { - action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host) + action, arg, isDomain, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host) } else { - ipAddr, err = c.Transport.ResolveIPAddr(dfMsg.Host) + ipAddr, isDomain, err = c.Transport.ResolveIPAddr(dfMsg.Host) } - if err != nil { + if err != nil && !(isDomain && c.Transport.SOCKS5Enabled()) { // Special case for domain requests + SOCKS5 outbound return } switch action { case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side - _, _ = conn.WriteToUDP(dfMsg.Data, &net.UDPAddr{ - IP: ipAddr.IP, - Port: int(dfMsg.Port), - Zone: ipAddr.Zone, - }) + addrEx := &transport.AddrEx{ + IPAddr: ipAddr, + Port: int(dfMsg.Port), + } + if isDomain { + addrEx.Domain = dfMsg.Host + } + _, _ = conn.WriteToUDP(dfMsg.Data, addrEx) if c.UpCounter != nil { c.UpCounter.Add(float64(len(dfMsg.Data))) } case acl.ActionBlock: // Do nothing case acl.ActionHijack: - hijackIPAddr, err := c.Transport.ResolveIPAddr(arg) - if err == nil { - _, _ = conn.WriteToUDP(dfMsg.Data, &net.UDPAddr{ - IP: hijackIPAddr.IP, - Port: int(dfMsg.Port), - Zone: hijackIPAddr.Zone, - }) + hijackIPAddr, isDomain, err := c.Transport.ResolveIPAddr(arg) + if err == nil || (isDomain && c.Transport.SOCKS5Enabled()) { // Special case for domain requests + SOCKS5 outbound + addrEx := &transport.AddrEx{ + IPAddr: hijackIPAddr, + Port: int(dfMsg.Port), + } + if isDomain { + addrEx.Domain = arg + } + _, _ = conn.WriteToUDP(dfMsg.Data, addrEx) if c.UpCounter != nil { c.UpCounter.Add(float64(len(dfMsg.Data))) } @@ -197,14 +204,15 @@ func (c *serverClient) handleMessage(msg []byte) { func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) { addrStr := net.JoinHostPort(host, strconv.Itoa(int(port))) action, arg := acl.ActionDirect, "" + var isDomain bool var ipAddr *net.IPAddr var err error if c.ACLEngine != nil { - action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(host) + action, arg, isDomain, ipAddr, err = c.ACLEngine.ResolveAndMatch(host) } else { - ipAddr, err = c.Transport.ResolveIPAddr(host) + ipAddr, isDomain, err = c.Transport.ResolveIPAddr(host) } - if err != nil { + if err != nil && !(isDomain && c.Transport.SOCKS5Enabled()) { // Special case for domain requests + SOCKS5 outbound _ = struc.Pack(stream, &serverResponse{ OK: false, Message: "host resolution failure", @@ -217,11 +225,14 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) { var conn net.Conn // Connection to be piped switch action { case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side - conn, err = c.Transport.DialTCP(&net.TCPAddr{ - IP: ipAddr.IP, - Port: int(port), - Zone: ipAddr.Zone, - }) + addrEx := &transport.AddrEx{ + IPAddr: ipAddr, + Port: int(port), + } + if isDomain { + addrEx.Domain = host + } + conn, err = c.Transport.DialTCP(addrEx) if err != nil { _ = struc.Pack(stream, &serverResponse{ OK: false, @@ -237,8 +248,8 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) { }) return case acl.ActionHijack: - hijackIPAddr, err := c.Transport.ResolveIPAddr(arg) - if err != nil { + hijackIPAddr, isDomain, err := c.Transport.ResolveIPAddr(arg) + if err != nil && !(isDomain && c.Transport.SOCKS5Enabled()) { // Special case for domain requests + SOCKS5 outbound _ = struc.Pack(stream, &serverResponse{ OK: false, Message: err.Error(), @@ -246,11 +257,14 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) { c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err) return } - conn, err = c.Transport.DialTCP(&net.TCPAddr{ - IP: hijackIPAddr.IP, - Port: int(port), - Zone: hijackIPAddr.Zone, - }) + addrEx := &transport.AddrEx{ + IPAddr: hijackIPAddr, + Port: int(port), + } + if isDomain { + addrEx.Domain = arg + } + conn, err = c.Transport.DialTCP(addrEx) if err != nil { _ = struc.Pack(stream, &serverResponse{ OK: false, diff --git a/pkg/http/server.go b/pkg/http/server.go index b4d899e..650f20a 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -34,7 +34,7 @@ func NewProxyHTTPServer(hyClient *core.Client, transport *transport.ClientTransp var ipAddr *net.IPAddr var resErr error if aclEngine != nil { - action, arg, ipAddr, resErr = aclEngine.ResolveAndMatch(host) + action, arg, _, ipAddr, resErr = aclEngine.ResolveAndMatch(host) // Doesn't always matter if the resolution fails, as we may send it through HyClient } newDialFunc(addr, action, arg) diff --git a/pkg/socks5/server.go b/pkg/socks5/server.go index 795032d..c679455 100644 --- a/pkg/socks5/server.go +++ b/pkg/socks5/server.go @@ -171,7 +171,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error { var ipAddr *net.IPAddr var resErr error if s.ACLEngine != nil { - action, arg, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host) + action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host) // Doesn't always matter if the resolution fails, as we may send it through HyClient } s.TCPRequestFunc(c.RemoteAddr(), addr, action, arg) @@ -377,7 +377,7 @@ func (s *Server) udpServer(clientConn *net.UDPConn, localRelayConn *net.UDPConn, var ipAddr *net.IPAddr var resErr error if s.ACLEngine != nil && localRelayConn != nil { - action, arg, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host) + action, arg, _, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host) // Doesn't always matter if the resolution fails, as we may send it through HyClient } // Handle according to the action diff --git a/pkg/transport/server.go b/pkg/transport/server.go index bf03e0f..c771dfc 100644 --- a/pkg/transport/server.go +++ b/pkg/transport/server.go @@ -8,7 +8,9 @@ import ( "github.com/tobyxdd/hysteria/pkg/conns/udp" "github.com/tobyxdd/hysteria/pkg/conns/wechat" "github.com/tobyxdd/hysteria/pkg/obfs" + "github.com/tobyxdd/hysteria/pkg/utils" "net" + "strconv" "time" ) @@ -20,12 +22,51 @@ type ServerTransport struct { PrefExclusive bool } +// AddrEx is like net.TCPAddr or net.UDPAddr, but with additional domain information for SOCKS5. +// At least one of Domain and IPAddr must be non-empty. +type AddrEx struct { + Domain string + IPAddr *net.IPAddr + Port int +} + +func (a *AddrEx) String() string { + if a == nil { + return "" + } + var ip string + if a.IPAddr != nil { + ip = a.IPAddr.String() + } + return net.JoinHostPort(ip, strconv.Itoa(a.Port)) +} + type PUDPConn interface { ReadFromUDP([]byte) (int, *net.UDPAddr, error) - WriteToUDP([]byte, *net.UDPAddr) (int, error) + WriteToUDP([]byte, *AddrEx) (int, error) Close() error } +type udpConnPUDPConn struct { + Conn *net.UDPConn +} + +func (c *udpConnPUDPConn) ReadFromUDP(bytes []byte) (int, *net.UDPAddr, error) { + return c.Conn.ReadFromUDP(bytes) +} + +func (c *udpConnPUDPConn) WriteToUDP(bytes []byte, ex *AddrEx) (int, error) { + return c.Conn.WriteToUDP(bytes, &net.UDPAddr{ + IP: ex.IPAddr.IP, + Port: ex.Port, + Zone: ex.IPAddr.Zone, + }) +} + +func (c *udpConnPUDPConn) Close() error { + return c.Conn.Close() +} + var DefaultServerTransport = &ServerTransport{ Dialer: &net.Dialer{ Timeout: 8 * time.Second, @@ -80,8 +121,8 @@ func (st *ServerTransport) quicPacketConn(proto string, laddr string, obfs obfs. } } -func (ct *ServerTransport) QUICListen(proto string, listen string, tlsConfig *tls.Config, quicConfig *quic.Config, obfs obfs.Obfuscator) (quic.Listener, error) { - pktConn, err := ct.quicPacketConn(proto, listen, obfs) +func (st *ServerTransport) QUICListen(proto string, listen string, tlsConfig *tls.Config, quicConfig *quic.Config, obfs obfs.Obfuscator) (quic.Listener, error) { + pktConn, err := st.quicPacketConn(proto, listen, obfs) if err != nil { return nil, err } @@ -93,19 +134,25 @@ func (ct *ServerTransport) QUICListen(proto string, listen string, tlsConfig *tl return l, nil } -func (ct *ServerTransport) ResolveIPAddr(address string) (*net.IPAddr, error) { - if ct.PrefEnabled { - return resolveIPAddrWithPreference(address, ct.PrefIPv6, ct.PrefExclusive) +func (st *ServerTransport) ResolveIPAddr(address string) (*net.IPAddr, bool, error) { + ip, zone := utils.ParseIPZone(address) + if ip != nil { + return &net.IPAddr{IP: ip, Zone: zone}, false, nil + } + if st.PrefEnabled { + ipAddr, err := resolveIPAddrWithPreference(address, st.PrefIPv6, st.PrefExclusive) + return ipAddr, true, err } else { - return net.ResolveIPAddr("ip", address) + ipAddr, err := net.ResolveIPAddr("ip", address) + return ipAddr, true, err } } -func (ct *ServerTransport) DialTCP(raddr *net.TCPAddr) (*net.TCPConn, error) { - if ct.SOCKS5Client != nil { - return ct.SOCKS5Client.DialTCP(raddr) +func (st *ServerTransport) DialTCP(raddr *AddrEx) (*net.TCPConn, error) { + if st.SOCKS5Client != nil { + return st.SOCKS5Client.DialTCP(raddr) } else { - conn, err := ct.Dialer.Dial("tcp", raddr.String()) + conn, err := st.Dialer.Dial("tcp", raddr.String()) if err != nil { return nil, err } @@ -113,10 +160,20 @@ func (ct *ServerTransport) DialTCP(raddr *net.TCPAddr) (*net.TCPConn, error) { } } -func (ct *ServerTransport) ListenUDP() (PUDPConn, error) { - if ct.SOCKS5Client != nil { - return ct.SOCKS5Client.ListenUDP() +func (st *ServerTransport) ListenUDP() (PUDPConn, error) { + if st.SOCKS5Client != nil { + return st.SOCKS5Client.ListenUDP() } else { - return net.ListenUDP("udp", nil) + conn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, err + } + return &udpConnPUDPConn{ + Conn: conn, + }, nil } } + +func (st *ServerTransport) SOCKS5Enabled() bool { + return st.SOCKS5Client != nil +} diff --git a/pkg/transport/socks5.go b/pkg/transport/socks5.go index b7261ea..50b4388 100644 --- a/pkg/transport/socks5.go +++ b/pkg/transport/socks5.go @@ -73,7 +73,7 @@ func (c *SOCKS5Client) request(conn *net.TCPConn, r *socks5.Request) (*socks5.Re return reply, nil } -func (c *SOCKS5Client) DialTCP(raddr *net.TCPAddr) (*net.TCPConn, error) { +func (c *SOCKS5Client) DialTCP(raddr *AddrEx) (*net.TCPConn, error) { conn, err := net.DialTCP("tcp", nil, c.ServerTCPAddr) if err != nil { return nil, err @@ -87,7 +87,7 @@ func (c *SOCKS5Client) DialTCP(raddr *net.TCPAddr) (*net.TCPConn, error) { _ = conn.Close() return nil, err } - atyp, addr, port, err := addrToSOCKS5Addr(raddr) + atyp, addr, port, err := addrExToSOCKS5Addr(raddr) if err != nil { _ = conn.Close() return nil, err @@ -191,8 +191,8 @@ func (c *socks5UDPConn) ReadFromUDP(b []byte) (int, *net.UDPAddr, error) { return n, addr, nil } -func (c *socks5UDPConn) WriteToUDP(b []byte, addr *net.UDPAddr) (int, error) { - atyp, dstAddr, dstPort, err := addrToSOCKS5Addr(addr) +func (c *socks5UDPConn) WriteToUDP(b []byte, addr *AddrEx) (int, error) { + atyp, dstAddr, dstPort, err := addrExToSOCKS5Addr(addr) if err != nil { return 0, err } @@ -247,30 +247,23 @@ func socks5AddrToUDPAddr(atyp byte, addr []byte, port []byte) (*net.UDPAddr, err } } -func addrToSOCKS5Addr(addr net.Addr) (byte, []byte, []byte, error) { - var addrIP net.IP - var addrPort int - if tcpAddr, ok := addr.(*net.TCPAddr); ok { - addrIP = tcpAddr.IP - addrPort = tcpAddr.Port - } else if udpAddr, ok := addr.(*net.UDPAddr); ok { - addrIP = udpAddr.IP - addrPort = udpAddr.Port +func addrExToSOCKS5Addr(addr *AddrEx) (byte, []byte, []byte, error) { + sport := make([]byte, 2) + binary.BigEndian.PutUint16(sport, uint16(addr.Port)) + if len(addr.Domain) > 0 { + return socks5.ATYPDomain, []byte(addr.Domain), sport, nil } else { - return 0, nil, nil, errors.New("unsupported address type") + var atyp byte + var saddr []byte + if ip4 := addr.IPAddr.IP.To4(); ip4 != nil { + atyp = socks5.ATYPIPv4 + saddr = ip4 + } else if ip6 := addr.IPAddr.IP.To16(); ip6 != nil { + atyp = socks5.ATYPIPv6 + saddr = ip6 + } else { + return 0, nil, nil, errors.New("unsupported address type") + } + return atyp, saddr, sport, nil } - var atyp byte - var saddr, sport []byte - if ip4 := addrIP.To4(); ip4 != nil { - atyp = socks5.ATYPIPv4 - saddr = ip4 - } else if ip6 := addrIP.To16(); ip6 != nil { - atyp = socks5.ATYPIPv6 - saddr = ip6 - } else { - return 0, nil, nil, errors.New("unsupported address type") - } - sport = make([]byte, 2) - binary.BigEndian.PutUint16(sport, uint16(addrPort)) - return atyp, saddr, sport, nil } diff --git a/pkg/utils/misc.go b/pkg/utils/misc.go index 13db0e7..29c7cf0 100644 --- a/pkg/utils/misc.go +++ b/pkg/utils/misc.go @@ -16,3 +16,27 @@ func SplitHostPort(hostport string) (string, uint16, error) { } return host, uint16(portUint), err } + +func ParseIPZone(s string) (net.IP, string) { + s, zone := splitHostZone(s) + return net.ParseIP(s), zone +} + +func splitHostZone(s string) (host, zone string) { + if i := last(s, '%'); i > 0 { + host, zone = s[:i], s[i+1:] + } else { + host = s + } + return +} + +func last(s string, b byte) int { + i := len(s) + for i--; i >= 0; i-- { + if s[i] == b { + break + } + } + return i +}