From cbedb27f0fc08079286705b3bc7562e2b3f57578 Mon Sep 17 00:00:00 2001 From: Toby Date: Mon, 24 Jul 2023 16:32:25 -0700 Subject: [PATCH] feat(wip): udp rework client side --- core/client/client.go | 282 ++++-------------- core/client/reconnect.go | 68 ----- core/client/udp.go | 177 +++++++++++ core/internal/integration_tests/masq_test.go | 21 -- core/internal/integration_tests/utils_test.go | 2 +- core/server/udp.go | 11 +- core/server/udp_test.go | 221 +++++++++----- 7 files changed, 391 insertions(+), 391 deletions(-) delete mode 100644 core/client/reconnect.go create mode 100644 core/client/udp.go diff --git a/core/client/client.go b/core/client/client.go index 3fe154d..2f35a07 100644 --- a/core/client/client.go +++ b/core/client/client.go @@ -3,18 +3,13 @@ package client import ( "context" "crypto/tls" - "errors" - "io" - "math/rand" "net" "net/http" "net/url" - "sync" "time" coreErrs "github.com/apernet/hysteria/core/errors" "github.com/apernet/hysteria/core/internal/congestion" - "github.com/apernet/hysteria/core/internal/frag" "github.com/apernet/hysteria/core/internal/protocol" "github.com/apernet/hysteria/core/internal/utils" @@ -23,8 +18,6 @@ import ( ) const ( - udpMessageChanSize = 1024 - closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError ) @@ -48,94 +41,25 @@ func NewClient(config *Config) (Client, error) { c := &clientImpl{ config: config, } - c.conn = &autoReconnectConn{ - Connect: c.connect, + if err := c.connect(); err != nil { + return nil, err } return c, nil } type clientImpl struct { config *Config - conn *autoReconnectConn - udpSM udpSessionManager + pktConn net.PacketConn + conn quic.Connection + + udpSM *udpSessionManager } -type udpSessionEntry struct { - Ch chan *protocol.UDPMessage - D *frag.Defragger - Closed bool -} - -type udpSessionManager struct { - mutex sync.RWMutex - m map[uint32]*udpSessionEntry -} - -func (m *udpSessionManager) Init() { - m.mutex.Lock() - defer m.mutex.Unlock() - m.m = make(map[uint32]*udpSessionEntry) -} - -// Add returns both a channel for receiving messages and a function to close the channel & delete the session. -func (m *udpSessionManager) Add(id uint32) (<-chan *protocol.UDPMessage, func()) { - m.mutex.Lock() - defer m.mutex.Unlock() - - // Important: make sure we add and delete the channel in the same map, - // as the map may be replaced by Init() at any time. - currentM := m.m - - entry := &udpSessionEntry{ - Ch: make(chan *protocol.UDPMessage, udpMessageChanSize), - D: &frag.Defragger{}, - Closed: false, - } - currentM[id] = entry - - return entry.Ch, func() { - m.mutex.Lock() - defer m.mutex.Unlock() - if entry.Closed { - // Double close a channel will panic, - // so we need a flag to make sure we only close it once. - return - } - entry.Closed = true - close(entry.Ch) - delete(currentM, id) - } -} - -func (m *udpSessionManager) Feed(msg *protocol.UDPMessage) { - m.mutex.RLock() - defer m.mutex.RUnlock() - - entry, ok := m.m[msg.SessionID] - if !ok { - // No such session, drop the message - return - } - dfMsg := entry.D.Feed(msg) - if dfMsg == nil { - // Not a complete message yet - return - } - select { - case entry.Ch <- dfMsg: - // OK - default: - // Channel is full, drop the message - } -} - -func (c *clientImpl) connect() (quic.Connection, func(), error) { - // Use a new packet conn for each connection, - // remember to close it after the QUIC connection is closed. +func (c *clientImpl) connect() error { pktConn, err := c.config.ConnFactory.New(c.config.ServerAddr) if err != nil { - return nil, nil, err + return err } // Convert config to TLS config & QUIC config tlsConfig := &tls.Config{ @@ -185,15 +109,15 @@ func (c *clientImpl) connect() (quic.Connection, func(), error) { _ = conn.CloseWithError(closeErrCodeProtocolError, "") } _ = pktConn.Close() - return nil, nil, &coreErrs.ConnectError{Err: err} + return &coreErrs.ConnectError{Err: err} } if resp.StatusCode != protocol.StatusAuthOK { _ = conn.CloseWithError(closeErrCodeProtocolError, "") _ = pktConn.Close() - return nil, nil, &coreErrs.AuthError{StatusCode: resp.StatusCode} + return &coreErrs.AuthError{StatusCode: resp.StatusCode} } // Auth OK - serverRx := protocol.AuthResponseDataFromHeader(resp.Header) + udpEnabled, serverRx := protocol.AuthResponseDataFromHeader(resp.Header) // actualTx = min(serverRx, clientTx) actualTx := serverRx if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx { @@ -205,46 +129,20 @@ func (c *clientImpl) connect() (quic.Connection, func(), error) { } _ = resp.Body.Close() - c.udpSM.Init() - go c.udpLoop(conn) - - return conn, func() { - _ = conn.CloseWithError(closeErrCodeOK, "") - _ = pktConn.Close() - }, nil -} - -func (c *clientImpl) udpLoop(conn quic.Connection) { - for { - msg, err := conn.ReceiveMessage() - if err != nil { - return - } - c.handleUDPMessage(msg) + c.pktConn = pktConn + c.conn = conn + if udpEnabled { + c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn}) + go func() { + c.udpSM.Run() + // TODO: Mark connection as closed + }() } -} - -// client <- remote direction -func (c *clientImpl) handleUDPMessage(msg []byte) { - udpMsg, err := protocol.ParseUDPMessage(msg) - if err != nil { - return - } - c.udpSM.Feed(udpMsg) -} - -// openStream wraps the stream with QStream, which handles Close() properly -func (c *clientImpl) openStream() (quic.Connection, quic.Stream, error) { - qc, stream, err := c.conn.OpenStream() - if err != nil { - return nil, nil, err - } - - return qc, &utils.QStream{Stream: stream}, nil + return nil } func (c *clientImpl) DialTCP(addr string) (net.Conn, error) { - qc, stream, err := c.openStream() + stream, err := c.openStream() if err != nil { return nil, err } @@ -260,8 +158,8 @@ func (c *clientImpl) DialTCP(addr string) (net.Conn, error) { // to the first Read() call. return &tcpConn{ Orig: stream, - PseudoLocalAddr: qc.LocalAddr(), - PseudoRemoteAddr: qc.RemoteAddr(), + PseudoLocalAddr: c.conn.LocalAddr(), + PseudoRemoteAddr: c.conn.RemoteAddr(), Established: false, }, nil } @@ -277,49 +175,23 @@ func (c *clientImpl) DialTCP(addr string) (net.Conn, error) { } return &tcpConn{ Orig: stream, - PseudoLocalAddr: qc.LocalAddr(), - PseudoRemoteAddr: qc.RemoteAddr(), + PseudoLocalAddr: c.conn.LocalAddr(), + PseudoRemoteAddr: c.conn.RemoteAddr(), Established: true, }, nil } func (c *clientImpl) ListenUDP() (HyUDPConn, error) { - qc, stream, err := c.openStream() - if err != nil { - return nil, err + if c.udpSM == nil { + return nil, coreErrs.DialError{Message: "UDP not enabled"} } - // Send request - err = protocol.WriteUDPRequest(stream) - if err != nil { - _ = stream.Close() - return nil, err - } - // Read response - ok, sessionID, msg, err := protocol.ReadUDPResponse(stream) - if err != nil { - _ = stream.Close() - return nil, err - } - if !ok { - _ = stream.Close() - return nil, coreErrs.DialError{Message: msg} - } - - ch, closeFunc := c.udpSM.Add(sessionID) - uc := &udpConn{ - QC: qc, - Stream: stream, - SessionID: sessionID, - Ch: ch, - CloseFunc: closeFunc, - SendBuf: make([]byte, protocol.MaxUDPSize), - } - go uc.Hold() - return uc, nil + return c.udpSM.NewUDP() } func (c *clientImpl) Close() error { - return c.conn.Close() + _ = c.conn.CloseWithError(closeErrCodeOK, "") + _ = c.pktConn.Close() + return nil } type tcpConn struct { @@ -372,72 +244,40 @@ func (c *tcpConn) SetWriteDeadline(t time.Time) error { return c.Orig.SetWriteDeadline(t) } -type udpConn struct { - QC quic.Connection - Stream quic.Stream - SessionID uint32 - Ch <-chan *protocol.UDPMessage - CloseFunc func() - SendBuf []byte +type udpIOImpl struct { + Conn quic.Connection } -func (c *udpConn) Hold() { - // Hold (drain) the stream until someone closes it. - // Closing the stream is the signal to stop the UDP session. - _, _ = io.Copy(io.Discard, c.Stream) - _ = c.Close() -} - -func (c *udpConn) Receive() ([]byte, string, error) { - msg := <-c.Ch - if msg == nil { - // Closed - return nil, "", io.EOF - } - return msg.Data, msg.Addr, nil -} - -// Send is not thread-safe as it uses a shared send buffer for now. -func (c *udpConn) Send(data []byte, addr string) error { - // Try no frag first - msg := &protocol.UDPMessage{ - SessionID: c.SessionID, - PacketID: 0, - FragID: 0, - FragCount: 1, - Addr: addr, - Data: data, - } - n := msg.Serialize(c.SendBuf) - if n < 0 { - // Message even larger than MaxUDPSize, drop it - // Maybe we should return an error in the future? - return nil - } - sendErr := c.QC.SendMessage(c.SendBuf[:n]) - if sendErr == nil { - // All good - return nil - } - var errTooLarge quic.ErrMessageTooLarge - if errors.As(sendErr, &errTooLarge) { - // Message too large, try fragmentation - msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1 - fMsgs := frag.FragUDPMessage(msg, int(errTooLarge)) - for _, fMsg := range fMsgs { - n = fMsg.Serialize(c.SendBuf) - err := c.QC.SendMessage(c.SendBuf[:n]) - if err != nil { - return err - } +func (io *udpIOImpl) ReceiveMessage() (*protocol.UDPMessage, error) { + for { + msg, err := io.Conn.ReceiveMessage() + if err != nil { + // Connection error, this will stop the session manager + return nil, err } - return nil + udpMsg, err := protocol.ParseUDPMessage(msg) + if err != nil { + // Invalid message, this is fine - just wait for the next + continue + } + return udpMsg, nil } - // Other error - return sendErr } -func (c *udpConn) Close() error { - c.CloseFunc() - return c.Stream.Close() +func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error { + msgN := msg.Serialize(buf) + if msgN < 0 { + // Message larger than buffer, silent drop + return nil + } + return io.Conn.SendMessage(buf[:msgN]) +} + +// openStream wraps the stream with QStream, which handles Close() properly +func (c *clientImpl) openStream() (quic.Stream, error) { + stream, err := c.conn.OpenStream() + if err != nil { + return nil, err + } + return &utils.QStream{Stream: stream}, nil } diff --git a/core/client/reconnect.go b/core/client/reconnect.go deleted file mode 100644 index 7ea6943..0000000 --- a/core/client/reconnect.go +++ /dev/null @@ -1,68 +0,0 @@ -package client - -import ( - "net" - "sync" - - "github.com/quic-go/quic-go" -) - -// autoReconnectConn is a wrapper of quic.Connection that automatically reconnects -// when a non-temporary error (usually a timeout) occurs. -type autoReconnectConn struct { - // Connect is called whenever a new QUIC connection is needed. - // It should return a new QUIC connection, a function to close the connection - // (and potentially other underlying resources), and an error if one occurred. - Connect func() (quic.Connection, func(), error) - - conn quic.Connection - closeFunc func() - connMutex sync.RWMutex -} - -func (c *autoReconnectConn) OpenStream() (quic.Connection, quic.Stream, error) { - c.connMutex.Lock() - defer c.connMutex.Unlock() - // First time? - if c.conn == nil { - conn, closeFunc, err := c.Connect() - if err != nil { - return nil, nil, err - } - c.conn = conn - c.closeFunc = closeFunc - } - stream, err := c.conn.OpenStream() - if err == nil { - // All is good - return c.conn, stream, nil - } else if nErr, ok := err.(net.Error); ok && nErr.Temporary() { - // Temporary error, just pass the error to the caller - return nil, nil, err - } else { - // Permanent error - // Close the previous connection, - // reconnect and try again (only once) - c.closeFunc() - conn, closeFunc, err := c.Connect() - if err != nil { - return nil, nil, err - } - c.conn = conn - c.closeFunc = closeFunc - stream, err = c.conn.OpenStream() - return c.conn, stream, err - } -} - -func (c *autoReconnectConn) Close() error { - c.connMutex.Lock() - defer c.connMutex.Unlock() - if c.conn == nil { - return nil - } - c.closeFunc() - c.conn = nil - c.closeFunc = nil - return nil -} diff --git a/core/client/udp.go b/core/client/udp.go new file mode 100644 index 0000000..6cca1af --- /dev/null +++ b/core/client/udp.go @@ -0,0 +1,177 @@ +package client + +import ( + "errors" + "io" + "math/rand" + "sync" + + "github.com/apernet/hysteria/core/internal/frag" + "github.com/apernet/hysteria/core/internal/protocol" + "github.com/quic-go/quic-go" +) + +const ( + udpMessageChanSize = 1024 +) + +type udpIO interface { + ReceiveMessage() (*protocol.UDPMessage, error) + SendMessage([]byte, *protocol.UDPMessage) error +} + +type udpConn struct { + ID uint32 + D *frag.Defragger + ReceiveCh chan *protocol.UDPMessage + SendBuf []byte + SendFunc func([]byte, *protocol.UDPMessage) error + CloseFunc func() + Closed bool +} + +func (u *udpConn) Receive() ([]byte, string, error) { + for { + msg := <-u.ReceiveCh + if msg == nil { + // Closed + return nil, "", io.EOF + } + dfMsg := u.D.Feed(msg) + if dfMsg == nil { + // Incomplete message, wait for more + continue + } + return dfMsg.Data, dfMsg.Addr, nil + } +} + +// Send is not thread-safe, as it uses a shared SendBuf. +func (u *udpConn) Send(data []byte, addr string) error { + // Try no frag first + msg := &protocol.UDPMessage{ + SessionID: u.ID, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: addr, + Data: data, + } + err := u.SendFunc(u.SendBuf, msg) + var errTooLarge quic.ErrMessageTooLarge + if errors.As(err, &errTooLarge) { + // Message too large, try fragmentation + msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1 + fMsgs := frag.FragUDPMessage(msg, int(errTooLarge)) + for _, fMsg := range fMsgs { + err := u.SendFunc(u.SendBuf, &fMsg) + if err != nil { + return err + } + } + return nil + } else { + return err + } +} + +func (u *udpConn) Close() error { + u.CloseFunc() + return nil +} + +type udpSessionManager struct { + io udpIO + + mutex sync.Mutex + m map[uint32]*udpConn + nextID uint32 +} + +func newUDPSessionManager(io udpIO) *udpSessionManager { + return &udpSessionManager{ + io: io, + m: make(map[uint32]*udpConn), + nextID: 1, + } +} + +// Run runs the session manager main loop. +// Exit and returns error when the underlying io returns error (e.g. closed). +func (m *udpSessionManager) Run() error { + defer m.cleanup() + + for { + msg, err := m.io.ReceiveMessage() + if err != nil { + return err + } + m.feed(msg) + } +} + +func (m *udpSessionManager) cleanup() { + m.mutex.Lock() + defer m.mutex.Unlock() + + for _, conn := range m.m { + m.close(conn) + } +} + +func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { + m.mutex.Lock() + defer m.mutex.Unlock() + + conn, ok := m.m[msg.SessionID] + if !ok { + // Ignore message from unknown session + return + } + + select { + case conn.ReceiveCh <- msg: + // OK + default: + // Channel full, drop the message + } +} + +// NewUDP creates a new UDP session. +func (m *udpSessionManager) NewUDP() (HyUDPConn, error) { + m.mutex.Lock() + defer m.mutex.Unlock() + + id := m.nextID + m.nextID++ + + conn := &udpConn{ + ID: id, + D: &frag.Defragger{}, + ReceiveCh: make(chan *protocol.UDPMessage, udpMessageChanSize), + SendBuf: make([]byte, protocol.MaxUDPSize), + SendFunc: m.io.SendMessage, + } + conn.CloseFunc = func() { + m.mutex.Lock() + defer m.mutex.Unlock() + if !conn.Closed { + m.close(conn) + } + } + m.m[id] = conn + + return conn, nil +} + +func (m *udpSessionManager) close(conn *udpConn) { + conn.Closed = true + close(conn.ReceiveCh) + delete(m.m, conn.ID) +} + +func (m *udpSessionManager) Count() int { + m.mutex.Lock() + defer m.mutex.Unlock() + return len(m.m) +} diff --git a/core/internal/integration_tests/masq_test.go b/core/internal/integration_tests/masq_test.go index b665414..333be94 100644 --- a/core/internal/integration_tests/masq_test.go +++ b/core/internal/integration_tests/masq_test.go @@ -106,25 +106,4 @@ func TestServerMasquerade(t *testing.T) { if nErr, ok := err.(net.Error); !ok || !nErr.Timeout() { t.Fatal("expected timeout, got", err) } - - // Try UDP request - udpStream, err := conn.OpenStream() - if err != nil { - t.Fatal("error opening stream:", err) - } - defer udpStream.Close() - err = protocol.WriteUDPRequest(udpStream) - if err != nil { - t.Fatal("error sending request:", err) - } - - // We should receive nothing - _ = udpStream.SetReadDeadline(time.Now().Add(2 * time.Second)) - n, err = udpStream.Read(buf) - if n != 0 { - t.Fatal("expected no response, got", n) - } - if nErr, ok := err.(net.Error); !ok || !nErr.Timeout() { - t.Fatal("expected timeout, got", err) - } } diff --git a/core/internal/integration_tests/utils_test.go b/core/internal/integration_tests/utils_test.go index ae7e9b4..f87dc24 100644 --- a/core/internal/integration_tests/utils_test.go +++ b/core/internal/integration_tests/utils_test.go @@ -287,7 +287,7 @@ func (l *channelEventLogger) TCPError(addr net.Addr, id, reqAddr string, err err } } -func (l *channelEventLogger) UDPRequest(addr net.Addr, id string, sessionID uint32) { +func (l *channelEventLogger) UDPRequest(addr net.Addr, id string, sessionID uint32, reqAddr string) { if l.UDPRequestEventCh != nil { l.UDPRequestEventCh <- udpRequestEvent{ Addr: addr, diff --git a/core/server/udp.go b/core/server/udp.go index ae98b80..892dd99 100644 --- a/core/server/udp.go +++ b/core/server/udp.go @@ -111,9 +111,8 @@ type udpSessionManager struct { eventLogger udpEventLogger idleTimeout time.Duration - mutex sync.Mutex - m map[uint32]*udpSessionEntry - nextID uint32 + mutex sync.Mutex + m map[uint32]*udpSessionEntry } func newUDPSessionManager(io udpIO, eventLogger udpEventLogger, idleTimeout time.Duration) *udpSessionManager { @@ -212,3 +211,9 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { // as some are temporary (e.g. invalid address) _, _ = entry.Feed(msg) } + +func (m *udpSessionManager) Count() int { + m.mutex.Lock() + defer m.mutex.Unlock() + return len(m.m) +} diff --git a/core/server/udp_test.go b/core/server/udp_test.go index 9d46eba..880bbad 100644 --- a/core/server/udp_test.go +++ b/core/server/udp_test.go @@ -10,6 +10,11 @@ import ( "go.uber.org/goleak" ) +var ( + errUDPBlocked = errors.New("blocked") + errUDPClosed = errors.New("closed") +) + type echoUDPConnPkt struct { Data []byte Addr string @@ -23,7 +28,7 @@ type echoUDPConn struct { func (c *echoUDPConn) ReadFrom(b []byte) (int, string, error) { pkt := <-c.PktCh if pkt.Close { - return 0, "", errors.New("closed") + return 0, "", errUDPClosed } n := copy(b, pkt.Data) return n, pkt.Addr, nil @@ -49,12 +54,14 @@ func (c *echoUDPConn) Close() error { type udpMockIO struct { ReceiveCh <-chan *protocol.UDPMessage SendCh chan<- *protocol.UDPMessage + UDPClose bool // ReadFrom() returns error immediately + BlockUDP bool // Block UDP connection creation } func (io *udpMockIO) ReceiveMessage() (*protocol.UDPMessage, error) { m := <-io.ReceiveCh if m == nil { - return nil, errors.New("closed") + return nil, errUDPClosed } return m, nil } @@ -68,9 +75,18 @@ func (io *udpMockIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error { } func (io *udpMockIO) UDP(reqAddr string) (UDPConn, error) { - return &echoUDPConn{ + if io.BlockUDP { + return nil, errUDPBlocked + } + conn := &echoUDPConn{ PktCh: make(chan echoUDPConnPkt, 10), - }, nil + } + if io.UDPClose { + conn.PktCh <- echoUDPConnPkt{ + Close: true, + } + } + return conn, nil } type udpMockEventNew struct { @@ -112,80 +128,131 @@ func TestUDPSessionManager(t *testing.T) { sm := newUDPSessionManager(io, eventLogger, 2*time.Second) go sm.Run() - ms := []*protocol.UDPMessage{ - { - SessionID: 1234, - PacketID: 0, - FragID: 0, - FragCount: 1, - Addr: "example.com:5353", - Data: []byte("hello"), - }, - { - SessionID: 5678, - PacketID: 0, - FragID: 0, - FragCount: 1, - Addr: "example.com:9999", - Data: []byte("goodbye"), - }, - { - SessionID: 1234, - PacketID: 0, - FragID: 0, - FragCount: 1, - Addr: "example.com:5353", - Data: []byte(" world"), - }, - { - SessionID: 5678, - PacketID: 0, - FragID: 0, - FragCount: 1, - Addr: "example.com:9999", - Data: []byte(" girl"), - }, - } - for _, m := range ms { - msgReceiveCh <- m - } - // New event order should be consistent - newEvent := <-eventNewCh - if newEvent.SessionID != 1234 || newEvent.ReqAddr != "example.com:5353" { - t.Error("unexpected new event value") - } - newEvent = <-eventNewCh - if newEvent.SessionID != 5678 || newEvent.ReqAddr != "example.com:9999" { - t.Error("unexpected new event value") - } - // Message order is not guaranteed - msgMap := make(map[string]bool) - for i := 0; i < 4; i++ { - msg := <-msgSendCh - msgMap[fmt.Sprintf("%d:%s:%s", msg.SessionID, msg.Addr, string(msg.Data))] = true - } - if !(msgMap["1234:example.com:5353:hello"] && - msgMap["5678:example.com:9999:goodbye"] && - msgMap["1234:example.com:5353: world"] && - msgMap["5678:example.com:9999: girl"]) { - t.Error("unexpected message value") - } - // Timeout check - startTime := time.Now() - closeMap := make(map[uint32]bool) - for i := 0; i < 2; i++ { - closeEvent := <-eventCloseCh - closeMap[closeEvent.SessionID] = true - } - if !(closeMap[1234] && closeMap[5678]) { - t.Error("unexpected close event value") - } - if time.Since(startTime) < 2*time.Second || time.Since(startTime) > 4*time.Second { - t.Error("unexpected timeout duration") - } + t.Run("session creation & timeout", func(t *testing.T) { + ms := []*protocol.UDPMessage{ + { + SessionID: 1234, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: "example.com:5353", + Data: []byte("hello"), + }, + { + SessionID: 5678, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: "example.com:9999", + Data: []byte("goodbye"), + }, + { + SessionID: 1234, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: "example.com:5353", + Data: []byte(" world"), + }, + { + SessionID: 5678, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: "example.com:9999", + Data: []byte(" girl"), + }, + } + for _, m := range ms { + msgReceiveCh <- m + } + // New event order should be consistent + newEvent := <-eventNewCh + if newEvent.SessionID != 1234 || newEvent.ReqAddr != "example.com:5353" { + t.Error("unexpected new event value") + } + newEvent = <-eventNewCh + if newEvent.SessionID != 5678 || newEvent.ReqAddr != "example.com:9999" { + t.Error("unexpected new event value") + } + // Message order is not guaranteed + msgMap := make(map[string]bool) + for i := 0; i < 4; i++ { + msg := <-msgSendCh + msgMap[fmt.Sprintf("%d:%s:%s", msg.SessionID, msg.Addr, string(msg.Data))] = true + } + if !(msgMap["1234:example.com:5353:hello"] && + msgMap["5678:example.com:9999:goodbye"] && + msgMap["1234:example.com:5353: world"] && + msgMap["5678:example.com:9999: girl"]) { + t.Error("unexpected message value") + } + // Timeout check + startTime := time.Now() + closeMap := make(map[uint32]bool) + for i := 0; i < 2; i++ { + closeEvent := <-eventCloseCh + closeMap[closeEvent.SessionID] = true + } + if !(closeMap[1234] && closeMap[5678]) { + t.Error("unexpected close event value") + } + if time.Since(startTime) < 2*time.Second || time.Since(startTime) > 4*time.Second { + t.Error("unexpected timeout duration") + } + }) - // Goroutine leak check + t.Run("UDP connection close", func(t *testing.T) { + // Close UDP connection immediately after creation + io.UDPClose = true + + msgReceiveCh <- &protocol.UDPMessage{ + SessionID: 8888, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: "mygod.org:1514", + Data: []byte("goodnight"), + } + // Should have both new and close events immediately + newEvent := <-eventNewCh + if newEvent.SessionID != 8888 || newEvent.ReqAddr != "mygod.org:1514" { + t.Error("unexpected new event value") + } + closeEvent := <-eventCloseCh + if closeEvent.SessionID != 8888 || closeEvent.Err != errUDPClosed { + t.Error("unexpected close event value") + } + }) + + t.Run("UDP IO failure", func(t *testing.T) { + // Block UDP connection creation + io.BlockUDP = true + + msgReceiveCh <- &protocol.UDPMessage{ + SessionID: 9999, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: "xxx.net:12450", + Data: []byte("nope"), + } + // Should have both new and close events immediately + newEvent := <-eventNewCh + if newEvent.SessionID != 9999 || newEvent.ReqAddr != "xxx.net:12450" { + t.Error("unexpected new event value") + } + closeEvent := <-eventCloseCh + if closeEvent.SessionID != 9999 || closeEvent.Err != errUDPBlocked { + t.Error("unexpected close event value") + } + }) + + // Leak checks msgReceiveCh <- nil - time.Sleep(1 * time.Second) // Wait for internal routines to exit + time.Sleep(1 * time.Second) // Give some time for the goroutines to exit + if sm.Count() != 0 { + t.Error("session count should be 0") + } goleak.VerifyNone(t) }