diff --git a/pkg/acl/engine.go b/pkg/acl/engine.go index 47e3290..d3cef8c 100644 --- a/pkg/acl/engine.go +++ b/pkg/acl/engine.go @@ -52,39 +52,47 @@ func LoadFromFile(filename string) (*Engine, error) { }, nil } -func (e *Engine) Lookup(domain string, ip net.IP) (Action, string) { - if len(domain) > 0 { +func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, error) { + ip, zone := parseIPZone(host) + if ip == nil { // Domain - if v, ok := e.Cache.Get(domain); ok { + ipAddr, err := net.ResolveIPAddr("ip", host) + if v, ok := e.Cache.Get(host); ok { // Cache hit ce := v.(cacheEntry) - return ce.Action, ce.Arg + return ce.Action, ce.Arg, ipAddr, err } - ips, _ := net.LookupIP(domain) for _, entry := range e.Entries { - if entry.MatchDomain(domain) || (len(ips) > 0 && entry.MatchIPs(ips)) { - e.Cache.Add(domain, cacheEntry{entry.Action, entry.ActionArg}) - return entry.Action, entry.ActionArg + if entry.MatchDomain(host) || (ipAddr != nil && entry.MatchIP(ipAddr.IP)) { + e.Cache.Add(host, cacheEntry{entry.Action, entry.ActionArg}) + return entry.Action, entry.ActionArg, ipAddr, err } } - e.Cache.Add(domain, cacheEntry{e.DefaultAction, ""}) - return e.DefaultAction, "" - } else if ip != nil { + e.Cache.Add(host, cacheEntry{e.DefaultAction, ""}) + return e.DefaultAction, "", ipAddr, err + } else { // IP if v, ok := e.Cache.Get(ip.String()); ok { // Cache hit ce := v.(cacheEntry) - return ce.Action, ce.Arg + return ce.Action, ce.Arg, &net.IPAddr{ + IP: ip, + Zone: zone, + }, nil } for _, entry := range e.Entries { if entry.MatchIP(ip) { e.Cache.Add(ip.String(), cacheEntry{entry.Action, entry.ActionArg}) - return entry.Action, entry.ActionArg + return entry.Action, entry.ActionArg, &net.IPAddr{ + IP: ip, + Zone: zone, + }, nil } } e.Cache.Add(ip.String(), cacheEntry{e.DefaultAction, ""}) - return e.DefaultAction, "" - } else { - return e.DefaultAction, "" + return e.DefaultAction, "", &net.IPAddr{ + IP: ip, + Zone: zone, + }, nil } } diff --git a/pkg/acl/engine_test.go b/pkg/acl/engine_test.go index b382416..f7b8587 100644 --- a/pkg/acl/engine_test.go +++ b/pkg/acl/engine_test.go @@ -6,7 +6,7 @@ import ( "testing" ) -func TestEngine_Lookup(t *testing.T) { +func TestEngine_ResolveAndMatch(t *testing.T) { cache, _ := lru.NewARC(4) e := &Engine{ DefaultAction: ActionDirect, @@ -49,61 +49,65 @@ func TestEngine_Lookup(t *testing.T) { }, Cache: cache, } - type args struct { - domain string - ip net.IP - } tests := []struct { - name string - args args - want Action - want1 string + name string + addr string + want Action + want1 string + wantErr bool }{ { name: "domain direct", - args: args{"google.com", nil}, + addr: "google.com", want: ActionProxy, want1: "", }, { - name: "domain suffix 1", - args: args{"evil.corp", nil}, - want: ActionHijack, - want1: "good.org", + name: "domain suffix 1", + addr: "evil.corp", + want: ActionHijack, + want1: "good.org", + wantErr: true, }, { - name: "domain suffix 2", - args: args{"notevil.corp", nil}, - want: ActionBlock, - want1: "", + name: "domain suffix 2", + addr: "notevil.corp", + want: ActionBlock, + want1: "", + wantErr: true, }, { - name: "domain suffix 3", - args: args{"im.real.evil.corp", nil}, - want: ActionHijack, - want1: "good.org", + name: "domain suffix 3", + addr: "im.real.evil.corp", + want: ActionHijack, + want1: "good.org", + wantErr: true, }, { name: "ip match", - args: args{"", net.ParseIP("10.2.3.4")}, + addr: "10.2.3.4", want: ActionProxy, want1: "", }, { name: "ip mismatch", - args: args{"", net.ParseIP("100.5.6.0")}, + addr: "100.5.6.0", want: ActionBlock, want1: "", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, got1 := e.Lookup(tt.args.domain, tt.args.ip) + got, got1, _, err := e.ResolveAndMatch(tt.addr) + if (err != nil) != tt.wantErr { + t.Errorf("ResolveAndMatch() error = %v, wantErr %v", err, tt.wantErr) + return + } if got != tt.want { - t.Errorf("Lookup() got = %v, want %v", got, tt.want) + t.Errorf("ResolveAndMatch() got = %v, want %v", got, tt.want) } if got1 != tt.want1 { - t.Errorf("Lookup() got1 = %v, want %v", got1, tt.want1) + t.Errorf("ResolveAndMatch() got1 = %v, want %v", got1, tt.want1) } }) } diff --git a/pkg/acl/entry.go b/pkg/acl/entry.go index 18f7275..d6fb731 100644 --- a/pkg/acl/entry.go +++ b/pkg/acl/entry.go @@ -50,20 +50,6 @@ func (e Entry) MatchIP(ip net.IP) bool { return false } -func (e Entry) MatchIPs(ips []net.IP) bool { - if e.All { - return true - } - if e.Net != nil && len(ips) > 0 { - for _, ip := range ips { - if e.Net.Contains(ip) { - return true - } - } - } - return false -} - // Format: action cond_type cond arg // Examples: // proxy domain-suffix google.com diff --git a/pkg/core/client.go b/pkg/core/client.go index b222dda..604583d 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -10,6 +10,7 @@ import ( "github.com/lucas-clemente/quic-go/congestion" "github.com/lunixbochs/struc" "net" + "strconv" "sync" "time" ) @@ -187,14 +188,19 @@ func (c *Client) openStreamWithReconnect() (quic.Session, quic.Stream, error) { } func (c *Client) DialTCP(addr string) (net.Conn, error) { + host, port, err := splitHostPort(addr) + if err != nil { + return nil, err + } session, stream, err := c.openStreamWithReconnect() if err != nil { return nil, err } // Send request err = struc.Pack(stream, &clientRequest{ - UDP: false, - Address: addr, + UDP: false, + Host: host, + Port: port, }) if err != nil { _ = stream.Close() @@ -349,14 +355,19 @@ func (c *quicPktConn) ReadFrom() ([]byte, string, error) { // Closed return nil, "", ErrClosed } - return msg.Data, msg.Address, nil + return msg.Data, net.JoinHostPort(msg.Host, strconv.Itoa(int(msg.Port))), nil } func (c *quicPktConn) WriteTo(p []byte, addr string) error { + host, port, err := splitHostPort(addr) + if err != nil { + return err + } var msgBuf bytes.Buffer _ = struc.Pack(&msgBuf, &udpMessage{ SessionID: c.UDPSessionID, - Address: addr, + Host: host, + Port: port, Data: p, }) return c.Session.SendMessage(msgBuf.Bytes()) @@ -366,3 +377,15 @@ func (c *quicPktConn) Close() error { c.CloseFunc() return c.Stream.Close() } + +func splitHostPort(hostport string) (string, uint16, error) { + host, port, err := net.SplitHostPort(hostport) + if err != nil { + return "", 0, err + } + portUint, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return "", 0, err + } + return host, uint16(portUint), err +} diff --git a/pkg/core/protocol.go b/pkg/core/protocol.go index 0a71ba3..d7ccc6b 100644 --- a/pkg/core/protocol.go +++ b/pkg/core/protocol.go @@ -32,9 +32,10 @@ type serverHello struct { } type clientRequest struct { - UDP bool - AddressLen uint16 `struc:"sizeof=Address"` - Address string + UDP bool + HostLen uint16 `struc:"sizeof=Host"` + Host string + Port uint16 } type serverResponse struct { @@ -45,9 +46,10 @@ type serverResponse struct { } type udpMessage struct { - SessionID uint32 - AddressLen uint16 `struc:"sizeof=Address"` - Address string - DataLen uint16 `struc:"sizeof=Data"` - Data []byte + SessionID uint32 + HostLen uint16 `struc:"sizeof=Host"` + Host string + Port uint16 + DataLen uint16 `struc:"sizeof=Data"` + Data []byte } diff --git a/pkg/core/server_client.go b/pkg/core/server_client.go index 5dd0497..7e991f9 100644 --- a/pkg/core/server_client.go +++ b/pkg/core/server_client.go @@ -10,6 +10,7 @@ import ( "github.com/tobyxdd/hysteria/pkg/acl" "github.com/tobyxdd/hysteria/pkg/utils" "net" + "strconv" "sync" ) @@ -88,7 +89,7 @@ func (c *serverClient) handleStream(stream quic.Stream) { } if !req.UDP { // TCP connection - c.handleTCP(stream, req.Address) + c.handleTCP(stream, req.Host, req.Port) } else if !c.DisableUDP { // UDP connection c.handleUDP(stream) @@ -112,32 +113,30 @@ func (c *serverClient) handleMessage(msg []byte) { c.udpSessionMutex.RUnlock() if ok { // Session found, send the message - host, port, err := net.SplitHostPort(udpMsg.Address) + action, arg := acl.ActionDirect, "" + var ipAddr *net.IPAddr + if c.ACLEngine != nil { + action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(udpMsg.Host) + } else { + ipAddr, err = net.ResolveIPAddr("ip", udpMsg.Host) + } if err != nil { return } - action, arg := acl.ActionDirect, "" - if c.ACLEngine != nil { - ip := net.ParseIP(host) - if ip != nil { - // IP request, clear host for ACL engine - host = "" - } - action, arg = c.ACLEngine.Lookup(host, ip) - } switch action { case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side - addr, err := net.ResolveUDPAddr("udp", udpMsg.Address) - if err == nil { - _, _ = conn.WriteToUDP(udpMsg.Data, addr) - if c.UpCounter != nil { - c.UpCounter.Add(float64(len(udpMsg.Data))) - } + _, _ = conn.WriteToUDP(udpMsg.Data, &net.UDPAddr{ + IP: ipAddr.IP, + Port: int(udpMsg.Port), + Zone: ipAddr.Zone, + }) + if c.UpCounter != nil { + c.UpCounter.Add(float64(len(udpMsg.Data))) } case acl.ActionBlock: // Do nothing case acl.ActionHijack: - hijackAddr := net.JoinHostPort(arg, port) + hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(udpMsg.Port))) addr, err := net.ResolveUDPAddr("udp", hijackAddr) if err == nil { _, _ = conn.WriteToUDP(udpMsg.Data, addr) @@ -151,37 +150,40 @@ func (c *serverClient) handleMessage(msg []byte) { } } -func (c *serverClient) handleTCP(stream quic.Stream, reqAddr string) { - host, port, err := net.SplitHostPort(reqAddr) +func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) { + addrStr := net.JoinHostPort(host, strconv.Itoa(int(port))) + action, arg := acl.ActionDirect, "" + var ipAddr *net.IPAddr + var err error + if c.ACLEngine != nil { + action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(host) + } else { + ipAddr, err = net.ResolveIPAddr("ip", host) + } if err != nil { _ = struc.Pack(stream, &serverResponse{ OK: false, - Message: "invalid address", + Message: "host resolution failure", }) - c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err) + c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err) return } - action, arg := acl.ActionDirect, "" - if c.ACLEngine != nil { - ip := net.ParseIP(host) - if ip != nil { - // IP request, clear host for ACL engine - host = "" - } - action, arg = c.ACLEngine.Lookup(host, ip) - } - c.CTCPRequestFunc(c.ClientAddr, c.Auth, reqAddr, action, arg) + c.CTCPRequestFunc(c.ClientAddr, c.Auth, addrStr, action, arg) var conn net.Conn // Connection to be piped switch action { case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side - conn, err = net.DialTimeout("tcp", reqAddr, dialTimeout) + conn, err = net.DialTCP("tcp", nil, &net.TCPAddr{ + IP: ipAddr.IP, + Port: int(port), + Zone: ipAddr.Zone, + }) if err != nil { _ = struc.Pack(stream, &serverResponse{ OK: false, Message: err.Error(), }) - c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err) + c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err) return } case acl.ActionBlock: @@ -191,14 +193,14 @@ func (c *serverClient) handleTCP(stream quic.Stream, reqAddr string) { }) return case acl.ActionHijack: - hijackAddr := net.JoinHostPort(arg, port) - conn, err = net.DialTimeout("tcp", hijackAddr, dialTimeout) + hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(port))) + conn, err = net.Dial("tcp", hijackAddr) if err != nil { _ = struc.Pack(stream, &serverResponse{ OK: false, Message: err.Error(), }) - c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err) + c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err) return } default: @@ -227,7 +229,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, reqAddr string) { } else { err = utils.Pipe2Way(stream, conn, nil) } - c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err) + c.CTCPErrorFunc(c.ClientAddr, c.Auth, addrStr, err) } func (c *serverClient) handleUDP(stream quic.Stream) { @@ -268,7 +270,8 @@ func (c *serverClient) handleUDP(stream quic.Stream) { var msgBuf bytes.Buffer _ = struc.Pack(&msgBuf, &udpMessage{ SessionID: id, - Address: rAddr.String(), + Host: rAddr.IP.String(), + Port: uint16(rAddr.Port), Data: buf[:n], }) _ = c.CS.SendMessage(msgBuf.Bytes()) diff --git a/pkg/http/server.go b/pkg/http/server.go index f0733f5..f9c8dbe 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "net/http" + "strconv" "time" "github.com/elazarl/goproxy/ext/auth" @@ -27,20 +28,30 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng if err != nil { return nil, err } + portUint, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return nil, err + } // ACL action, arg := acl.ActionProxy, "" + var ipAddr *net.IPAddr + var resErr error if aclEngine != nil { - ip := net.ParseIP(host) - if ip != nil { - host = "" - } - action, arg = aclEngine.Lookup(host, ip) + action, arg, ipAddr, resErr = aclEngine.ResolveAndMatch(host) + // Doesn't always matter if the resolution fails, as we may send it through HyClient } newDialFunc(addr, action, arg) // Handle according to the action switch action { case acl.ActionDirect: - return net.Dial(network, addr) + if resErr != nil { + return nil, resErr + } + return net.DialTCP(network, nil, &net.TCPAddr{ + IP: ipAddr.IP, + Port: int(portUint), + Zone: ipAddr.Zone, + }) case acl.ActionProxy: return hyClient.DialTCP(addr) case acl.ActionBlock: diff --git a/pkg/socks5/server.go b/pkg/socks5/server.go index 7be34c2..99771a4 100644 --- a/pkg/socks5/server.go +++ b/pkg/socks5/server.go @@ -162,10 +162,13 @@ func (s *Server) handle(c *net.TCPConn, r *socks5.Request) error { } func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error { - domain, ip, port, addr := parseRequestAddress(r) + host, port, addr := parseRequestAddress(r) action, arg := acl.ActionProxy, "" + var ipAddr *net.IPAddr + var resErr error if s.ACLEngine != nil { - action, arg = s.ACLEngine.Lookup(domain, ip) + action, arg, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host) + // Doesn't always matter if the resolution fails, as we may send it through HyClient } s.TCPRequestFunc(c.RemoteAddr(), addr, action, arg) var closeErr error @@ -175,7 +178,16 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error { // Handle according to the action switch action { case acl.ActionDirect: - rc, err := net.Dial("tcp", addr) + if resErr != nil { + _ = sendReply(c, socks5.RepHostUnreachable) + closeErr = resErr + return resErr + } + rc, err := net.DialTCP("tcp", nil, &net.TCPAddr{ + IP: ipAddr.IP, + Port: int(port), + Zone: ipAddr.Zone, + }) if err != nil { _ = sendReply(c, socks5.RepHostUnreachable) closeErr = err @@ -201,7 +213,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error { closeErr = errors.New("blocked in ACL") return nil case acl.ActionHijack: - rc, err := net.Dial("tcp", net.JoinHostPort(arg, port)) + rc, err := net.Dial("tcp", net.JoinHostPort(arg, strconv.Itoa(int(port)))) if err != nil { _ = sendReply(c, socks5.RepHostUnreachable) closeErr = err @@ -299,13 +311,15 @@ func (s *Server) udpServer(clientConn *net.UDPConn, localRelayConn *net.UDPConn, // Start remote to local go func() { for { - bs, _, err := hyUDP.ReadFrom() + bs, from, err := hyUDP.ReadFrom() if err != nil { break } - // RFC 1928 is very ambiguous on how to properly use DST.ADDR and DST.PORT in reply packets - // So we just fill in zeros for now. Works fine for all the SOCKS5 clients I tested - d := socks5.NewDatagram(socks5.ATYPIPv4, []byte{0x00, 0x00, 0x00, 0x00}, []byte{0x00, 0x00}, bs) + atyp, addr, port, err := socks5.ParseAddress(from) + if err != nil { + continue + } + d := socks5.NewDatagram(atyp, addr, port, bs) _, _ = clientConn.WriteToUDP(d.Bytes(), clientAddr) } }() @@ -329,24 +343,31 @@ func (s *Server) udpServer(clientConn *net.UDPConn, localRelayConn *net.UDPConn, // Not our client, bye continue } - domain, ip, port, addr := parseDatagramRequestAddress(d) + host, port, addr := parseDatagramRequestAddress(d) action, arg := acl.ActionProxy, "" + var ipAddr *net.IPAddr + var resErr error if s.ACLEngine != nil && localRelayConn != nil { - action, arg = s.ACLEngine.Lookup(domain, ip) + action, arg, ipAddr, resErr = s.ACLEngine.ResolveAndMatch(host) + // Doesn't always matter if the resolution fails, as we may send it through HyClient } // Handle according to the action switch action { case acl.ActionDirect: - rAddr, err := net.ResolveUDPAddr("udp", addr) - if err == nil { - _, _ = localRelayConn.WriteToUDP(d.Data, rAddr) + if resErr != nil { + return } + _, _ = localRelayConn.WriteToUDP(d.Data, &net.UDPAddr{ + IP: ipAddr.IP, + Port: int(port), + Zone: ipAddr.Zone, + }) case acl.ActionProxy: _ = hyUDP.WriteTo(d.Data, addr) case acl.ActionBlock: // Do nothing case acl.ActionHijack: - hijackAddr := net.JoinHostPort(arg, port) + hijackAddr := net.JoinHostPort(arg, net.JoinHostPort(arg, strconv.Itoa(int(port)))) rAddr, err := net.ResolveUDPAddr("udp", hijackAddr) if err == nil { _, _ = localRelayConn.WriteToUDP(d.Data, rAddr) @@ -363,22 +384,24 @@ func sendReply(conn *net.TCPConn, rep byte) error { return err } -func parseRequestAddress(r *socks5.Request) (domain string, ip net.IP, port string, addr string) { - p := strconv.Itoa(int(binary.BigEndian.Uint16(r.DstPort))) +func parseRequestAddress(r *socks5.Request) (host string, port uint16, addr string) { + p := binary.BigEndian.Uint16(r.DstPort) if r.Atyp == socks5.ATYPDomain { d := string(r.DstAddr[1:]) - return d, nil, p, net.JoinHostPort(d, p) + return d, p, net.JoinHostPort(d, strconv.Itoa(int(p))) } else { - return "", r.DstAddr, p, net.JoinHostPort(net.IP(r.DstAddr).String(), p) + ipStr := net.IP(r.DstAddr).String() + return ipStr, p, net.JoinHostPort(ipStr, strconv.Itoa(int(p))) } } -func parseDatagramRequestAddress(r *socks5.Datagram) (domain string, ip net.IP, port string, addr string) { - p := strconv.Itoa(int(binary.BigEndian.Uint16(r.DstPort))) +func parseDatagramRequestAddress(r *socks5.Datagram) (host string, port uint16, addr string) { + p := binary.BigEndian.Uint16(r.DstPort) if r.Atyp == socks5.ATYPDomain { d := string(r.DstAddr[1:]) - return d, nil, p, net.JoinHostPort(d, p) + return d, p, net.JoinHostPort(d, strconv.Itoa(int(p))) } else { - return "", r.DstAddr, p, net.JoinHostPort(net.IP(r.DstAddr).String(), p) + ipStr := net.IP(r.DstAddr).String() + return ipStr, p, net.JoinHostPort(ipStr, strconv.Itoa(int(p))) } }