From 4da73888f490d72c75a1f0a048707661bb5c17c2 Mon Sep 17 00:00:00 2001 From: Toby Date: Mon, 26 Apr 2021 17:34:08 -0700 Subject: [PATCH] Transport WIP --- pkg/acl/engine.go | 7 +++++-- pkg/core/client.go | 8 +++++--- pkg/core/server.go | 10 ++++++---- pkg/core/server_client.go | 16 +++++++++------- pkg/core/transport.go | 14 ++++++++++++++ pkg/http/server.go | 6 +++--- 6 files changed, 42 insertions(+), 19 deletions(-) create mode 100644 pkg/core/transport.go diff --git a/pkg/acl/engine.go b/pkg/acl/engine.go index d3cef8c..b62bd58 100644 --- a/pkg/acl/engine.go +++ b/pkg/acl/engine.go @@ -3,6 +3,7 @@ package acl import ( "bufio" lru "github.com/hashicorp/golang-lru" + "github.com/tobyxdd/hysteria/pkg/core" "net" "os" "strings" @@ -14,6 +15,7 @@ type Engine struct { DefaultAction Action Entries []Entry Cache *lru.ARCCache + Transport core.Transport } type cacheEntry struct { @@ -21,7 +23,7 @@ type cacheEntry struct { Arg string } -func LoadFromFile(filename string) (*Engine, error) { +func LoadFromFile(filename string, transport core.Transport) (*Engine, error) { f, err := os.Open(filename) if err != nil { return nil, err @@ -49,6 +51,7 @@ func LoadFromFile(filename string) (*Engine, error) { DefaultAction: ActionProxy, Entries: entries, Cache: cache, + Transport: transport, }, nil } @@ -56,7 +59,7 @@ func (e *Engine) ResolveAndMatch(host string) (Action, string, *net.IPAddr, erro ip, zone := parseIPZone(host) if ip == nil { // Domain - ipAddr, err := net.ResolveIPAddr("ip", host) + ipAddr, err := e.Transport.OutResolveIPAddr(host) if v, ok := e.Cache.Get(host); ok { // Cache hit ce := v.(cacheEntry) diff --git a/pkg/core/client.go b/pkg/core/client.go index d5ae791..2c60c68 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -23,6 +23,7 @@ var ( type CongestionFactory func(refBPS uint64) congestion.CongestionControl type Client struct { + transport Transport serverAddr string sendBPS, recvBPS uint64 auth []byte @@ -40,9 +41,10 @@ type Client struct { udpSessionMap map[uint32]chan *udpMessage } -func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, +func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, transport Transport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, obfuscator Obfuscator) (*Client, error) { c := &Client{ + transport: transport, serverAddr: serverAddr, sendBPS: sendBPS, recvBPS: recvBPS, @@ -59,11 +61,11 @@ func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig } func (c *Client) connectToServer() error { - serverUDPAddr, err := net.ResolveUDPAddr("udp", c.serverAddr) + serverUDPAddr, err := c.transport.QUICResolveUDPAddr(c.serverAddr) if err != nil { return err } - udpConn, err := net.ListenUDP("udp", nil) + udpConn, err := c.transport.QUICListenUDP(nil) if err != nil { return err } diff --git a/pkg/core/server.go b/pkg/core/server.go index e61a556..9e2364d 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -20,6 +20,7 @@ type UDPRequestFunc func(addr net.Addr, auth []byte, sessionID uint32) type UDPErrorFunc func(addr net.Addr, auth []byte, sessionID uint32, err error) type Server struct { + transport Transport sendBPS, recvBPS uint64 congestionFactory CongestionFactory disableUDP bool @@ -36,15 +37,15 @@ type Server struct { listener quic.Listener } -func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, +func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, transport Transport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine, obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc, udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc, promRegistry *prometheus.Registry) (*Server, error) { - udpAddr, err := net.ResolveUDPAddr("udp", addr) + udpAddr, err := transport.QUICResolveUDPAddr(addr) if err != nil { return nil, err } - udpConn, err := net.ListenUDP("udp", udpAddr) + udpConn, err := transport.QUICListenUDP(udpAddr) if err != nil { return nil, err } @@ -72,6 +73,7 @@ func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, } s := &Server{ listener: listener, + transport: transport, sendBPS: sendBPS, recvBPS: recvBPS, congestionFactory: congestionFactory, @@ -129,7 +131,7 @@ func (s *Server) handleClient(cs quic.Session) { return } // Start accepting streams and messages - sc := newServerClient(cs, auth, s.disableUDP, s.aclEngine, + sc := newServerClient(cs, s.transport, auth, s.disableUDP, s.aclEngine, s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc, s.upCounterVec, s.downCounterVec) sc.Run() _ = cs.CloseWithError(closeErrorCodeGeneric, "") diff --git a/pkg/core/server_client.go b/pkg/core/server_client.go index 7e991f9..094d5f7 100644 --- a/pkg/core/server_client.go +++ b/pkg/core/server_client.go @@ -18,6 +18,7 @@ const udpBufferSize = 65535 type serverClient struct { CS quic.Session + Transport Transport Auth []byte ClientAddr net.Addr DisableUDP bool @@ -34,12 +35,13 @@ type serverClient struct { nextUDPSessionID uint32 } -func newServerClient(cs quic.Session, auth []byte, disableUDP bool, ACLEngine *acl.Engine, +func newServerClient(cs quic.Session, transport Transport, auth []byte, disableUDP bool, ACLEngine *acl.Engine, CTCPRequestFunc TCPRequestFunc, CTCPErrorFunc TCPErrorFunc, CUDPRequestFunc UDPRequestFunc, CUDPErrorFunc UDPErrorFunc, UpCounterVec, DownCounterVec *prometheus.CounterVec) *serverClient { sc := &serverClient{ CS: cs, + Transport: transport, Auth: auth, ClientAddr: cs.RemoteAddr(), DisableUDP: disableUDP, @@ -118,7 +120,7 @@ func (c *serverClient) handleMessage(msg []byte) { if c.ACLEngine != nil { action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(udpMsg.Host) } else { - ipAddr, err = net.ResolveIPAddr("ip", udpMsg.Host) + ipAddr, err = c.Transport.OutResolveIPAddr(udpMsg.Host) } if err != nil { return @@ -137,7 +139,7 @@ func (c *serverClient) handleMessage(msg []byte) { // Do nothing case acl.ActionHijack: hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(udpMsg.Port))) - addr, err := net.ResolveUDPAddr("udp", hijackAddr) + addr, err := c.Transport.OutResolveUDPAddr(hijackAddr) if err == nil { _, _ = conn.WriteToUDP(udpMsg.Data, addr) if c.UpCounter != nil { @@ -158,7 +160,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) { if c.ACLEngine != nil { action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(host) } else { - ipAddr, err = net.ResolveIPAddr("ip", host) + ipAddr, err = c.Transport.OutResolveIPAddr(host) } if err != nil { _ = struc.Pack(stream, &serverResponse{ @@ -173,7 +175,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) { var conn net.Conn // Connection to be piped switch action { case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side - conn, err = net.DialTCP("tcp", nil, &net.TCPAddr{ + conn, err = c.Transport.OutDialTCP(nil, &net.TCPAddr{ IP: ipAddr.IP, Port: int(port), Zone: ipAddr.Zone, @@ -194,7 +196,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) { return case acl.ActionHijack: hijackAddr := net.JoinHostPort(arg, strconv.Itoa(int(port))) - conn, err = net.Dial("tcp", hijackAddr) + conn, err = c.Transport.OutDial("tcp", hijackAddr) if err != nil { _ = struc.Pack(stream, &serverResponse{ OK: false, @@ -234,7 +236,7 @@ func (c *serverClient) handleTCP(stream quic.Stream, host string, port uint16) { func (c *serverClient) handleUDP(stream quic.Stream) { // Like in SOCKS5, the stream here is only used to maintain the UDP session. No need to read anything from it - conn, err := net.ListenUDP("udp", nil) + conn, err := c.Transport.OutListenUDP(nil) if err != nil { _ = struc.Pack(stream, &serverResponse{ OK: false, diff --git a/pkg/core/transport.go b/pkg/core/transport.go new file mode 100644 index 0000000..f1c242c --- /dev/null +++ b/pkg/core/transport.go @@ -0,0 +1,14 @@ +package core + +import "net" + +type Transport interface { + QUICResolveUDPAddr(address string) (*net.UDPAddr, error) + QUICListenUDP(laddr *net.UDPAddr) (*net.UDPConn, error) + + OutResolveIPAddr(address string) (*net.IPAddr, error) + OutResolveUDPAddr(address string) (*net.UDPAddr, error) + OutDial(network, address string) (net.Conn, error) + OutDialTCP(laddr, raddr *net.TCPAddr) (*net.TCPConn, error) + OutListenUDP(laddr *net.UDPAddr) (*net.UDPConn, error) +} diff --git a/pkg/http/server.go b/pkg/http/server.go index 81299ce..6c68085 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -16,7 +16,7 @@ import ( "github.com/tobyxdd/hysteria/pkg/core" ) -func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEngine *acl.Engine, +func NewProxyHTTPServer(hyClient *core.Client, transport core.Transport, idleTimeout time.Duration, aclEngine *acl.Engine, newDialFunc func(reqAddr string, action acl.Action, arg string), basicAuthFunc func(user, password string) bool) (*goproxy.ProxyHttpServer, error) { proxy := goproxy.NewProxyHttpServer() @@ -44,7 +44,7 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng if resErr != nil { return nil, resErr } - return net.DialTCP(network, nil, &net.TCPAddr{ + return transport.OutDialTCP(nil, &net.TCPAddr{ IP: ipAddr.IP, Port: int(port), Zone: ipAddr.Zone, @@ -54,7 +54,7 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng case acl.ActionBlock: return nil, errors.New("blocked by ACL") case acl.ActionHijack: - return net.Dial(network, net.JoinHostPort(arg, strconv.Itoa(int(port)))) + return transport.OutDial(network, net.JoinHostPort(arg, strconv.Itoa(int(port)))) default: return nil, fmt.Errorf("unknown action %d", action) }