From 4bb59829607c08ed350f164ce9d8f4d9cc77193e Mon Sep 17 00:00:00 2001 From: Toby Date: Sat, 27 Mar 2021 16:51:15 -0700 Subject: [PATCH] Implemented UDP for both server & client --- cmd/client.go | 1 + cmd/server.go | 26 +++- pkg/core/client.go | 154 +++++++++++++++++++-- pkg/core/client_udp_test.go | 58 ++++++++ pkg/core/protocol.go | 8 ++ pkg/core/server.go | 112 ++------------- pkg/core/server_client.go | 269 ++++++++++++++++++++++++++++++++++++ pkg/http/server.go | 8 +- 8 files changed, 524 insertions(+), 112 deletions(-) create mode 100644 pkg/core/client_udp_test.go create mode 100644 pkg/core/server_client.go diff --git a/cmd/client.go b/cmd/client.go index cf52a65..d52b1dd 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -50,6 +50,7 @@ func client(config *clientConfig) { MaxStreamReceiveWindow: config.ReceiveWindowConn, MaxConnectionReceiveWindow: config.ReceiveWindow, KeepAlive: true, + EnableDatagrams: true, } if quicConfig.MaxStreamReceiveWindow == 0 { quicConfig.MaxStreamReceiveWindow = DefaultMaxReceiveStreamFlowControlWindow diff --git a/cmd/server.go b/cmd/server.go index d302a50..fe18332 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -36,6 +36,7 @@ func server(config *serverConfig) { MaxConnectionReceiveWindow: config.ReceiveWindowClient, MaxIncomingStreams: int64(config.MaxConnClient), KeepAlive: true, + EnableDatagrams: true, } if quicConfig.MaxStreamReceiveWindow == 0 { quicConfig.MaxStreamReceiveWindow = DefaultMaxReceiveStreamFlowControlWindow @@ -96,7 +97,8 @@ func server(config *serverConfig) { uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, func(refBPS uint64) congestion.CongestionControl { return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) - }, aclEngine, obfuscator, authFunc, tcpRequestFunc, tcpErrorFunc) + }, config.DisableUDP, aclEngine, obfuscator, authFunc, + tcpRequestFunc, tcpErrorFunc, udpRequestFunc, udpErrorFunc) if err != nil { logrus.WithField("error", err).Fatal("Failed to initialize server") } @@ -130,6 +132,28 @@ func tcpErrorFunc(addr net.Addr, auth []byte, reqAddr string, err error) { } } +func udpRequestFunc(addr net.Addr, auth []byte, sessionID uint32) { + logrus.WithFields(logrus.Fields{ + "src": addr.String(), + "session": sessionID, + }).Debug("UDP request") +} + +func udpErrorFunc(addr net.Addr, auth []byte, sessionID uint32, err error) { + if err != io.EOF { + logrus.WithFields(logrus.Fields{ + "src": addr.String(), + "session": sessionID, + "error": err, + }).Info("UDP error") + } else { + logrus.WithFields(logrus.Fields{ + "src": addr.String(), + "session": sessionID, + }).Debug("UDP EOF") + } +} + func actionToString(action acl.Action, arg string) string { switch action { case acl.ActionDirect: diff --git a/pkg/core/client.go b/pkg/core/client.go index adee661..8f9fbbd 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -1,6 +1,7 @@ package core import ( + "bytes" "context" "crypto/tls" "errors" @@ -14,7 +15,7 @@ import ( ) var ( - ErrClosed = errors.New("client closed") + ErrClosed = errors.New("closed") ) type CongestionFactory func(refBPS uint64) congestion.CongestionControl @@ -32,6 +33,9 @@ type Client struct { quicSession quic.Session reconnectMutex sync.Mutex closed bool + + udpSessionMutex sync.RWMutex + udpSessionMap map[uint32]chan *udpMessage } func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, @@ -90,6 +94,8 @@ func (c *Client) connectToServer() error { return fmt.Errorf("auth error: %s", msg) } // All good + c.udpSessionMap = make(map[uint32]chan *udpMessage) + go c.handleMessage(qs) c.quicSession = qs return nil } @@ -119,34 +125,59 @@ func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (bool, return true, sh.Message, nil } -func (c *Client) openStreamWithReconnect() (quic.Stream, net.Addr, net.Addr, error) { +func (c *Client) handleMessage(qs quic.Session) { + for { + msg, err := qs.ReceiveMessage() + if err != nil { + break + } + var udpMsg udpMessage + err = struc.Unpack(bytes.NewBuffer(msg), &udpMsg) + if err != nil { + continue + } + c.udpSessionMutex.RLock() + ch, ok := c.udpSessionMap[udpMsg.SessionID] + if ok { + select { + case ch <- &udpMsg: + // OK + default: + // Silently drop the message when the channel is full + } + } + c.udpSessionMutex.RUnlock() + } +} + +func (c *Client) openStreamWithReconnect() (quic.Session, quic.Stream, error) { c.reconnectMutex.Lock() defer c.reconnectMutex.Unlock() if c.closed { - return nil, nil, nil, ErrClosed + return nil, nil, ErrClosed } stream, err := c.quicSession.OpenStream() if err == nil { // All good - return stream, c.quicSession.LocalAddr(), c.quicSession.RemoteAddr(), nil + return c.quicSession, stream, nil } // Something is wrong if nErr, ok := err.(net.Error); ok && nErr.Temporary() { // Temporary error, just return - return nil, nil, nil, err + return nil, nil, err } // Permanent error, need to reconnect if err := c.connectToServer(); err != nil { // Still error, oops - return nil, nil, nil, err + return nil, nil, err } // We are not going to try again even if it still fails the second time stream, err = c.quicSession.OpenStream() - return stream, c.quicSession.LocalAddr(), c.quicSession.RemoteAddr(), err + return c.quicSession, stream, nil } func (c *Client) DialTCP(addr string) (net.Conn, error) { - stream, localAddr, remoteAddr, err := c.openStreamWithReconnect() + session, stream, err := c.openStreamWithReconnect() if err != nil { return nil, err } @@ -172,11 +203,64 @@ func (c *Client) DialTCP(addr string) (net.Conn, error) { } return &quicConn{ Orig: stream, - PseudoLocalAddr: localAddr, - PseudoRemoteAddr: remoteAddr, + PseudoLocalAddr: session.LocalAddr(), + PseudoRemoteAddr: session.RemoteAddr(), }, nil } +func (c *Client) DialUDP() (UDPConn, error) { + session, stream, err := c.openStreamWithReconnect() + if err != nil { + return nil, err + } + // Send request + err = struc.Pack(stream, &clientRequest{ + UDP: true, + }) + if err != nil { + _ = stream.Close() + return nil, err + } + // Read response + var sr serverResponse + err = struc.Unpack(stream, &sr) + if err != nil { + _ = stream.Close() + return nil, err + } + if !sr.OK { + _ = stream.Close() + return nil, fmt.Errorf("connection rejected: %s", sr.Message) + } + + // Create a session in the map + c.udpSessionMutex.Lock() + nCh := make(chan *udpMessage, 1024) + // Store the current session map for CloseFunc below + // to ensures that we are adding and removing sessions on the same map, + // as reconnecting will reassign the map + sessionMap := c.udpSessionMap + sessionMap[sr.UDPSessionID] = nCh + c.udpSessionMutex.Unlock() + + pktConn := &quicPktConn{ + Session: session, + Stream: stream, + CloseFunc: func() { + c.udpSessionMutex.Lock() + if ch, ok := sessionMap[sr.UDPSessionID]; ok { + close(ch) + delete(sessionMap, sr.UDPSessionID) + } + c.udpSessionMutex.Unlock() + }, + UDPSessionID: sr.UDPSessionID, + MsgCh: nCh, + } + go pktConn.Hold() + return pktConn, nil +} + func (c *Client) Close() error { c.reconnectMutex.Lock() defer c.reconnectMutex.Unlock() @@ -222,3 +306,53 @@ func (w *quicConn) SetReadDeadline(t time.Time) error { func (w *quicConn) SetWriteDeadline(t time.Time) error { return w.Orig.SetWriteDeadline(t) } + +type UDPConn interface { + ReadFrom() ([]byte, string, error) + WriteTo([]byte, string) error + Close() error +} + +type quicPktConn struct { + Session quic.Session + Stream quic.Stream + CloseFunc func() + UDPSessionID uint32 + MsgCh <-chan *udpMessage +} + +func (c *quicPktConn) Hold() { + // Hold the stream until it's closed + buf := make([]byte, 1024) + for { + _, err := c.Stream.Read(buf) + if err != nil { + break + } + } + _ = c.Close() +} + +func (c *quicPktConn) ReadFrom() ([]byte, string, error) { + msg := <-c.MsgCh + if msg == nil { + // Closed + return nil, "", ErrClosed + } + return msg.Data, msg.Address, nil +} + +func (c *quicPktConn) WriteTo(p []byte, addr string) error { + var msgBuf bytes.Buffer + _ = struc.Pack(&msgBuf, &udpMessage{ + SessionID: c.UDPSessionID, + Address: addr, + Data: p, + }) + return c.Session.SendMessage(msgBuf.Bytes()) +} + +func (c *quicPktConn) Close() error { + c.CloseFunc() + return c.Stream.Close() +} diff --git a/pkg/core/client_udp_test.go b/pkg/core/client_udp_test.go new file mode 100644 index 0000000..ba6ffee --- /dev/null +++ b/pkg/core/client_udp_test.go @@ -0,0 +1,58 @@ +package core + +import ( + "crypto/tls" + "github.com/lucas-clemente/quic-go" + "testing" +) + +func TestClientUDP(t *testing.T) { + client, err := NewClient("toby.moe:36713", nil, &tls.Config{ + NextProtos: []string{"hysteria"}, + MinVersion: tls.VersionTLS13, + }, &quic.Config{ + EnableDatagrams: true, + }, 125000, 125000, nil, nil) + if err != nil { + t.Fatal(err) + } + conn, err := client.DialUDP() + if err != nil { + t.Fatal("conn DialUDP", err) + } + t.Run("8.8.8.8", func(t *testing.T) { + dnsReq := []byte{0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x62, 0x61, 0x69, 0x64, 0x75, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01} + err := conn.WriteTo(dnsReq, "8.8.8.8:53") + if err != nil { + t.Error("WriteTo", err) + } + buf, _, err := conn.ReadFrom() + if err != nil { + t.Error("ReadFrom", err) + } + if buf[0] != dnsReq[0] || buf[1] != dnsReq[1] { + t.Error("invalid response") + } + }) + t.Run("1.1.1.1", func(t *testing.T) { + dnsReq := []byte{0x66, 0x77, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x62, 0x61, 0x69, 0x64, 0x75, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01} + err := conn.WriteTo(dnsReq, "1.1.1.1:53") + if err != nil { + t.Error("WriteTo", err) + } + buf, _, err := conn.ReadFrom() + if err != nil { + t.Error("ReadFrom", err) + } + if buf[0] != dnsReq[0] || buf[1] != dnsReq[1] { + t.Error("invalid response") + } + }) + t.Run("Close", func(t *testing.T) { + _ = conn.Close() + _, _, err := conn.ReadFrom() + if err != ErrClosed { + t.Error("closed conn not returning the correct error") + } + }) +} diff --git a/pkg/core/protocol.go b/pkg/core/protocol.go index 83e9b8c..9299044 100644 --- a/pkg/core/protocol.go +++ b/pkg/core/protocol.go @@ -43,3 +43,11 @@ type serverResponse struct { MessageLen uint16 `struc:"sizeof=Message"` Message string } + +type udpMessage struct { + SessionID uint32 + AddressLen uint16 `struc:"sizeof=Address"` + Address string + DataLen uint16 `struc:"sizeof=Data"` + Data []byte +} diff --git a/pkg/core/server.go b/pkg/core/server.go index 49ed004..8651d1b 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -7,7 +7,6 @@ import ( "github.com/lucas-clemente/quic-go" "github.com/lunixbochs/struc" "github.com/tobyxdd/hysteria/pkg/acl" - "github.com/tobyxdd/hysteria/pkg/utils" "net" "time" ) @@ -17,22 +16,28 @@ const dialTimeout = 10 * time.Second type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) type TCPRequestFunc func(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string) type TCPErrorFunc func(addr net.Addr, auth []byte, reqAddr string, err error) +type UDPRequestFunc func(addr net.Addr, auth []byte, sessionID uint32) +type UDPErrorFunc func(addr net.Addr, auth []byte, sessionID uint32, err error) type Server struct { sendBPS, recvBPS uint64 congestionFactory CongestionFactory + disableUDP bool aclEngine *acl.Engine authFunc AuthFunc tcpRequestFunc TCPRequestFunc tcpErrorFunc TCPErrorFunc + udpRequestFunc UDPRequestFunc + udpErrorFunc UDPErrorFunc listener quic.Listener } func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, - sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, aclEngine *acl.Engine, - obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc) (*Server, error) { + sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine, + obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc, + udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc) (*Server, error) { packetConn, err := net.ListenPacket("udp", addr) if err != nil { return nil, err @@ -53,10 +58,13 @@ func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, sendBPS: sendBPS, recvBPS: recvBPS, congestionFactory: congestionFactory, + disableUDP: disableUDP, aclEngine: aclEngine, authFunc: authFunc, tcpRequestFunc: tcpRequestFunc, tcpErrorFunc: tcpErrorFunc, + udpRequestFunc: udpRequestFunc, + udpErrorFunc: udpErrorFunc, } return s, nil } @@ -94,14 +102,10 @@ func (s *Server) handleClient(cs quic.Session) { _ = cs.CloseWithError(closeErrorCodeAuth, "auth error") return } - // Start accepting streams - for { - stream, err := cs.AcceptStream(context.Background()) - if err != nil { - break - } - go s.handleStream(cs.RemoteAddr(), auth, stream) - } + // Start accepting streams and messages + sc := newServerClient(cs, auth, s.disableUDP, s.aclEngine, + s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc) + sc.Run() _ = cs.CloseWithError(closeErrorCodeGeneric, "") } @@ -143,89 +147,3 @@ func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byt } return ch.Auth, ok, nil } - -func (s *Server) handleStream(remoteAddr net.Addr, auth []byte, stream quic.Stream) { - defer stream.Close() - // Read request - var req clientRequest - err := struc.Unpack(stream, &req) - if err != nil { - return - } - if !req.UDP { - // TCP connection - s.handleTCP(remoteAddr, auth, stream, req.Address) - } else { - // UDP connection - // TODO - } -} - -func (s *Server) handleTCP(remoteAddr net.Addr, auth []byte, stream quic.Stream, reqAddr string) { - host, port, err := net.SplitHostPort(reqAddr) - if err != nil { - _ = struc.Pack(stream, &serverResponse{ - OK: false, - Message: "invalid address", - }) - s.tcpErrorFunc(remoteAddr, auth, reqAddr, err) - return - } - ip := net.ParseIP(host) - if ip != nil { - // IP request, clear host for ACL engine - host = "" - } - action, arg := acl.ActionDirect, "" - if s.aclEngine != nil { - action, arg = s.aclEngine.Lookup(host, ip) - } - s.tcpRequestFunc(remoteAddr, auth, reqAddr, action, arg) - - var conn net.Conn // Connection to be piped - switch action { - case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side - conn, err = net.DialTimeout("tcp", reqAddr, dialTimeout) - if err != nil { - _ = struc.Pack(stream, &serverResponse{ - OK: false, - Message: err.Error(), - }) - s.tcpErrorFunc(remoteAddr, auth, reqAddr, err) - return - } - case acl.ActionBlock: - _ = struc.Pack(stream, &serverResponse{ - OK: false, - Message: "blocked by ACL", - }) - return - case acl.ActionHijack: - hijackAddr := net.JoinHostPort(arg, port) - conn, err = net.DialTimeout("tcp", hijackAddr, dialTimeout) - if err != nil { - _ = struc.Pack(stream, &serverResponse{ - OK: false, - Message: err.Error(), - }) - s.tcpErrorFunc(remoteAddr, auth, reqAddr, err) - return - } - default: - _ = struc.Pack(stream, &serverResponse{ - OK: false, - Message: "ACL error", - }) - return - } - // So far so good if we reach here - defer conn.Close() - err = struc.Pack(stream, &serverResponse{ - OK: true, - }) - if err != nil { - return - } - err = utils.Pipe2Way(stream, conn) - s.tcpErrorFunc(remoteAddr, auth, reqAddr, err) -} diff --git a/pkg/core/server_client.go b/pkg/core/server_client.go new file mode 100644 index 0000000..0663099 --- /dev/null +++ b/pkg/core/server_client.go @@ -0,0 +1,269 @@ +package core + +import ( + "bytes" + "context" + "github.com/lucas-clemente/quic-go" + "github.com/lunixbochs/struc" + "github.com/tobyxdd/hysteria/pkg/acl" + "github.com/tobyxdd/hysteria/pkg/utils" + "net" + "sync" +) + +const udpBufferSize = 65535 + +type serverClient struct { + CS quic.Session + Auth []byte + ClientAddr net.Addr + DisableUDP bool + ACLEngine *acl.Engine + CTCPRequestFunc TCPRequestFunc + CTCPErrorFunc TCPErrorFunc + CUDPRequestFunc UDPRequestFunc + CUDPErrorFunc UDPErrorFunc + + udpSessionMutex sync.RWMutex + udpSessionMap map[uint32]*net.UDPConn + nextUDPSessionID uint32 +} + +func newServerClient(cs quic.Session, auth []byte, disableUDP bool, ACLEngine *acl.Engine, + CTCPRequestFunc TCPRequestFunc, CTCPErrorFunc TCPErrorFunc, + CUDPRequestFunc UDPRequestFunc, CUDPErrorFunc UDPErrorFunc) *serverClient { + return &serverClient{ + CS: cs, + Auth: auth, + ClientAddr: cs.RemoteAddr(), + DisableUDP: disableUDP, + ACLEngine: ACLEngine, + CTCPRequestFunc: CTCPRequestFunc, + CTCPErrorFunc: CTCPErrorFunc, + CUDPRequestFunc: CUDPRequestFunc, + CUDPErrorFunc: CUDPErrorFunc, + udpSessionMap: make(map[uint32]*net.UDPConn), + } +} + +func (c *serverClient) Run() { + if !c.DisableUDP { + go func() { + for { + msg, err := c.CS.ReceiveMessage() + if err != nil { + break + } + c.handleMessage(msg) + } + }() + } + for { + stream, err := c.CS.AcceptStream(context.Background()) + if err != nil { + break + } + go c.handleStream(stream) + } +} + +func (c *serverClient) handleStream(stream quic.Stream) { + defer stream.Close() + // Read request + var req clientRequest + err := struc.Unpack(stream, &req) + if err != nil { + return + } + if !req.UDP { + // TCP connection + c.handleTCP(stream, req.Address) + } else if !c.DisableUDP { + // UDP connection + c.handleUDP(stream) + } else { + // UDP disabled + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: "UDP disabled", + }) + } +} + +func (c *serverClient) handleMessage(msg []byte) { + var udpMsg udpMessage + err := struc.Unpack(bytes.NewBuffer(msg), &udpMsg) + if err != nil { + return + } + c.udpSessionMutex.RLock() + conn, ok := c.udpSessionMap[udpMsg.SessionID] + c.udpSessionMutex.RUnlock() + if ok { + // Session found, send the message + host, port, err := net.SplitHostPort(udpMsg.Address) + if err != nil { + return + } + action, arg := acl.ActionDirect, "" + if c.ACLEngine != nil { + ip := net.ParseIP(host) + if ip != nil { + // IP request, clear host for ACL engine + host = "" + } + action, arg = c.ACLEngine.Lookup(host, ip) + } + switch action { + case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side + addr, err := net.ResolveUDPAddr("udp", udpMsg.Address) + if err == nil { + _, _ = conn.WriteToUDP(udpMsg.Data, addr) + } + case acl.ActionBlock: + // Do nothing + case acl.ActionHijack: + hijackAddr := net.JoinHostPort(arg, port) + addr, err := net.ResolveUDPAddr("udp", hijackAddr) + if err == nil { + _, _ = conn.WriteToUDP(udpMsg.Data, addr) + } + default: + // Do nothing + } + } +} + +func (c *serverClient) handleTCP(stream quic.Stream, reqAddr string) { + host, port, err := net.SplitHostPort(reqAddr) + if err != nil { + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: "invalid address", + }) + c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err) + return + } + action, arg := acl.ActionDirect, "" + if c.ACLEngine != nil { + ip := net.ParseIP(host) + if ip != nil { + // IP request, clear host for ACL engine + host = "" + } + action, arg = c.ACLEngine.Lookup(host, ip) + } + c.CTCPRequestFunc(c.ClientAddr, c.Auth, reqAddr, action, arg) + + var conn net.Conn // Connection to be piped + switch action { + case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side + conn, err = net.DialTimeout("tcp", reqAddr, dialTimeout) + if err != nil { + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: err.Error(), + }) + c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err) + return + } + case acl.ActionBlock: + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: "blocked by ACL", + }) + return + case acl.ActionHijack: + hijackAddr := net.JoinHostPort(arg, port) + conn, err = net.DialTimeout("tcp", hijackAddr, dialTimeout) + if err != nil { + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: err.Error(), + }) + c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err) + return + } + default: + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: "ACL error", + }) + return + } + // So far so good if we reach here + defer conn.Close() + err = struc.Pack(stream, &serverResponse{ + OK: true, + }) + if err != nil { + return + } + err = utils.Pipe2Way(stream, conn) + c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err) +} + +func (c *serverClient) handleUDP(stream quic.Stream) { + // Like in SOCKS5, the stream here is only used to maintain the UDP session. No need to read anything from it + conn, err := net.ListenUDP("udp", nil) + if err != nil { + _ = struc.Pack(stream, &serverResponse{ + OK: false, + Message: "UDP initialization failed", + }) + c.CUDPErrorFunc(c.ClientAddr, c.Auth, 0, err) + return + } + defer conn.Close() + + var id uint32 + c.udpSessionMutex.Lock() + id = c.nextUDPSessionID + c.udpSessionMap[id] = conn + c.nextUDPSessionID += 1 + c.udpSessionMutex.Unlock() + + err = struc.Pack(stream, &serverResponse{ + OK: true, + UDPSessionID: id, + }) + if err != nil { + return + } + c.CUDPRequestFunc(c.ClientAddr, c.Auth, id) + + // Receive UDP packets, send them to the client + go func() { + buf := make([]byte, udpBufferSize) + for { + n, rAddr, err := conn.ReadFromUDP(buf) + if n > 0 { + var msgBuf bytes.Buffer + _ = struc.Pack(&msgBuf, &udpMessage{ + SessionID: id, + Address: rAddr.String(), + Data: buf[:n], + }) + _ = c.CS.SendMessage(msgBuf.Bytes()) + } + if err != nil { + break + } + } + }() + + // Hold the stream until it's closed by the client + buf := make([]byte, 1024) + for { + _, err = stream.Read(buf) + if err != nil { + break + } + } + c.CUDPErrorFunc(c.ClientAddr, c.Auth, id, err) + + // Remove the session + c.udpSessionMutex.Lock() + delete(c.udpSessionMap, id) + c.udpSessionMutex.Unlock() +} diff --git a/pkg/http/server.go b/pkg/http/server.go index ae38639..f0733f5 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -27,13 +27,13 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng if err != nil { return nil, err } - ip := net.ParseIP(host) - if ip != nil { - host = "" - } // ACL action, arg := acl.ActionProxy, "" if aclEngine != nil { + ip := net.ParseIP(host) + if ip != nil { + host = "" + } action, arg = aclEngine.Lookup(host, ip) } newDialFunc(addr, action, arg)