feat(wip): udp rework client side

This commit is contained in:
Toby
2023-07-24 16:32:25 -07:00
parent f142a24047
commit cbedb27f0f
7 changed files with 391 additions and 391 deletions

View File

@@ -3,18 +3,13 @@ package client
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"io"
"math/rand"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"sync"
"time" "time"
coreErrs "github.com/apernet/hysteria/core/errors" coreErrs "github.com/apernet/hysteria/core/errors"
"github.com/apernet/hysteria/core/internal/congestion" "github.com/apernet/hysteria/core/internal/congestion"
"github.com/apernet/hysteria/core/internal/frag"
"github.com/apernet/hysteria/core/internal/protocol" "github.com/apernet/hysteria/core/internal/protocol"
"github.com/apernet/hysteria/core/internal/utils" "github.com/apernet/hysteria/core/internal/utils"
@@ -23,8 +18,6 @@ import (
) )
const ( const (
udpMessageChanSize = 1024
closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError
closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError
) )
@@ -48,94 +41,25 @@ func NewClient(config *Config) (Client, error) {
c := &clientImpl{ c := &clientImpl{
config: config, config: config,
} }
c.conn = &autoReconnectConn{ if err := c.connect(); err != nil {
Connect: c.connect, return nil, err
} }
return c, nil return c, nil
} }
type clientImpl struct { type clientImpl struct {
config *Config config *Config
conn *autoReconnectConn
udpSM udpSessionManager pktConn net.PacketConn
conn quic.Connection
udpSM *udpSessionManager
} }
type udpSessionEntry struct { func (c *clientImpl) connect() error {
Ch chan *protocol.UDPMessage
D *frag.Defragger
Closed bool
}
type udpSessionManager struct {
mutex sync.RWMutex
m map[uint32]*udpSessionEntry
}
func (m *udpSessionManager) Init() {
m.mutex.Lock()
defer m.mutex.Unlock()
m.m = make(map[uint32]*udpSessionEntry)
}
// Add returns both a channel for receiving messages and a function to close the channel & delete the session.
func (m *udpSessionManager) Add(id uint32) (<-chan *protocol.UDPMessage, func()) {
m.mutex.Lock()
defer m.mutex.Unlock()
// Important: make sure we add and delete the channel in the same map,
// as the map may be replaced by Init() at any time.
currentM := m.m
entry := &udpSessionEntry{
Ch: make(chan *protocol.UDPMessage, udpMessageChanSize),
D: &frag.Defragger{},
Closed: false,
}
currentM[id] = entry
return entry.Ch, func() {
m.mutex.Lock()
defer m.mutex.Unlock()
if entry.Closed {
// Double close a channel will panic,
// so we need a flag to make sure we only close it once.
return
}
entry.Closed = true
close(entry.Ch)
delete(currentM, id)
}
}
func (m *udpSessionManager) Feed(msg *protocol.UDPMessage) {
m.mutex.RLock()
defer m.mutex.RUnlock()
entry, ok := m.m[msg.SessionID]
if !ok {
// No such session, drop the message
return
}
dfMsg := entry.D.Feed(msg)
if dfMsg == nil {
// Not a complete message yet
return
}
select {
case entry.Ch <- dfMsg:
// OK
default:
// Channel is full, drop the message
}
}
func (c *clientImpl) connect() (quic.Connection, func(), error) {
// Use a new packet conn for each connection,
// remember to close it after the QUIC connection is closed.
pktConn, err := c.config.ConnFactory.New(c.config.ServerAddr) pktConn, err := c.config.ConnFactory.New(c.config.ServerAddr)
if err != nil { if err != nil {
return nil, nil, err return err
} }
// Convert config to TLS config & QUIC config // Convert config to TLS config & QUIC config
tlsConfig := &tls.Config{ tlsConfig := &tls.Config{
@@ -185,15 +109,15 @@ func (c *clientImpl) connect() (quic.Connection, func(), error) {
_ = conn.CloseWithError(closeErrCodeProtocolError, "") _ = conn.CloseWithError(closeErrCodeProtocolError, "")
} }
_ = pktConn.Close() _ = pktConn.Close()
return nil, nil, &coreErrs.ConnectError{Err: err} return &coreErrs.ConnectError{Err: err}
} }
if resp.StatusCode != protocol.StatusAuthOK { if resp.StatusCode != protocol.StatusAuthOK {
_ = conn.CloseWithError(closeErrCodeProtocolError, "") _ = conn.CloseWithError(closeErrCodeProtocolError, "")
_ = pktConn.Close() _ = pktConn.Close()
return nil, nil, &coreErrs.AuthError{StatusCode: resp.StatusCode} return &coreErrs.AuthError{StatusCode: resp.StatusCode}
} }
// Auth OK // Auth OK
serverRx := protocol.AuthResponseDataFromHeader(resp.Header) udpEnabled, serverRx := protocol.AuthResponseDataFromHeader(resp.Header)
// actualTx = min(serverRx, clientTx) // actualTx = min(serverRx, clientTx)
actualTx := serverRx actualTx := serverRx
if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx { if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx {
@@ -205,46 +129,20 @@ func (c *clientImpl) connect() (quic.Connection, func(), error) {
} }
_ = resp.Body.Close() _ = resp.Body.Close()
c.udpSM.Init() c.pktConn = pktConn
go c.udpLoop(conn) c.conn = conn
if udpEnabled {
return conn, func() { c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn})
_ = conn.CloseWithError(closeErrCodeOK, "") go func() {
_ = pktConn.Close() c.udpSM.Run()
}, nil // TODO: Mark connection as closed
} }()
func (c *clientImpl) udpLoop(conn quic.Connection) {
for {
msg, err := conn.ReceiveMessage()
if err != nil {
return
}
c.handleUDPMessage(msg)
} }
} return nil
// client <- remote direction
func (c *clientImpl) handleUDPMessage(msg []byte) {
udpMsg, err := protocol.ParseUDPMessage(msg)
if err != nil {
return
}
c.udpSM.Feed(udpMsg)
}
// openStream wraps the stream with QStream, which handles Close() properly
func (c *clientImpl) openStream() (quic.Connection, quic.Stream, error) {
qc, stream, err := c.conn.OpenStream()
if err != nil {
return nil, nil, err
}
return qc, &utils.QStream{Stream: stream}, nil
} }
func (c *clientImpl) DialTCP(addr string) (net.Conn, error) { func (c *clientImpl) DialTCP(addr string) (net.Conn, error) {
qc, stream, err := c.openStream() stream, err := c.openStream()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -260,8 +158,8 @@ func (c *clientImpl) DialTCP(addr string) (net.Conn, error) {
// to the first Read() call. // to the first Read() call.
return &tcpConn{ return &tcpConn{
Orig: stream, Orig: stream,
PseudoLocalAddr: qc.LocalAddr(), PseudoLocalAddr: c.conn.LocalAddr(),
PseudoRemoteAddr: qc.RemoteAddr(), PseudoRemoteAddr: c.conn.RemoteAddr(),
Established: false, Established: false,
}, nil }, nil
} }
@@ -277,49 +175,23 @@ func (c *clientImpl) DialTCP(addr string) (net.Conn, error) {
} }
return &tcpConn{ return &tcpConn{
Orig: stream, Orig: stream,
PseudoLocalAddr: qc.LocalAddr(), PseudoLocalAddr: c.conn.LocalAddr(),
PseudoRemoteAddr: qc.RemoteAddr(), PseudoRemoteAddr: c.conn.RemoteAddr(),
Established: true, Established: true,
}, nil }, nil
} }
func (c *clientImpl) ListenUDP() (HyUDPConn, error) { func (c *clientImpl) ListenUDP() (HyUDPConn, error) {
qc, stream, err := c.openStream() if c.udpSM == nil {
if err != nil { return nil, coreErrs.DialError{Message: "UDP not enabled"}
return nil, err
} }
// Send request return c.udpSM.NewUDP()
err = protocol.WriteUDPRequest(stream)
if err != nil {
_ = stream.Close()
return nil, err
}
// Read response
ok, sessionID, msg, err := protocol.ReadUDPResponse(stream)
if err != nil {
_ = stream.Close()
return nil, err
}
if !ok {
_ = stream.Close()
return nil, coreErrs.DialError{Message: msg}
}
ch, closeFunc := c.udpSM.Add(sessionID)
uc := &udpConn{
QC: qc,
Stream: stream,
SessionID: sessionID,
Ch: ch,
CloseFunc: closeFunc,
SendBuf: make([]byte, protocol.MaxUDPSize),
}
go uc.Hold()
return uc, nil
} }
func (c *clientImpl) Close() error { func (c *clientImpl) Close() error {
return c.conn.Close() _ = c.conn.CloseWithError(closeErrCodeOK, "")
_ = c.pktConn.Close()
return nil
} }
type tcpConn struct { type tcpConn struct {
@@ -372,72 +244,40 @@ func (c *tcpConn) SetWriteDeadline(t time.Time) error {
return c.Orig.SetWriteDeadline(t) return c.Orig.SetWriteDeadline(t)
} }
type udpConn struct { type udpIOImpl struct {
QC quic.Connection Conn quic.Connection
Stream quic.Stream
SessionID uint32
Ch <-chan *protocol.UDPMessage
CloseFunc func()
SendBuf []byte
} }
func (c *udpConn) Hold() { func (io *udpIOImpl) ReceiveMessage() (*protocol.UDPMessage, error) {
// Hold (drain) the stream until someone closes it. for {
// Closing the stream is the signal to stop the UDP session. msg, err := io.Conn.ReceiveMessage()
_, _ = io.Copy(io.Discard, c.Stream) if err != nil {
_ = c.Close() // Connection error, this will stop the session manager
} return nil, err
func (c *udpConn) Receive() ([]byte, string, error) {
msg := <-c.Ch
if msg == nil {
// Closed
return nil, "", io.EOF
}
return msg.Data, msg.Addr, nil
}
// Send is not thread-safe as it uses a shared send buffer for now.
func (c *udpConn) Send(data []byte, addr string) error {
// Try no frag first
msg := &protocol.UDPMessage{
SessionID: c.SessionID,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: addr,
Data: data,
}
n := msg.Serialize(c.SendBuf)
if n < 0 {
// Message even larger than MaxUDPSize, drop it
// Maybe we should return an error in the future?
return nil
}
sendErr := c.QC.SendMessage(c.SendBuf[:n])
if sendErr == nil {
// All good
return nil
}
var errTooLarge quic.ErrMessageTooLarge
if errors.As(sendErr, &errTooLarge) {
// Message too large, try fragmentation
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
for _, fMsg := range fMsgs {
n = fMsg.Serialize(c.SendBuf)
err := c.QC.SendMessage(c.SendBuf[:n])
if err != nil {
return err
}
} }
return nil udpMsg, err := protocol.ParseUDPMessage(msg)
if err != nil {
// Invalid message, this is fine - just wait for the next
continue
}
return udpMsg, nil
} }
// Other error
return sendErr
} }
func (c *udpConn) Close() error { func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
c.CloseFunc() msgN := msg.Serialize(buf)
return c.Stream.Close() if msgN < 0 {
// Message larger than buffer, silent drop
return nil
}
return io.Conn.SendMessage(buf[:msgN])
}
// openStream wraps the stream with QStream, which handles Close() properly
func (c *clientImpl) openStream() (quic.Stream, error) {
stream, err := c.conn.OpenStream()
if err != nil {
return nil, err
}
return &utils.QStream{Stream: stream}, nil
} }

View File

@@ -1,68 +0,0 @@
package client
import (
"net"
"sync"
"github.com/quic-go/quic-go"
)
// autoReconnectConn is a wrapper of quic.Connection that automatically reconnects
// when a non-temporary error (usually a timeout) occurs.
type autoReconnectConn struct {
// Connect is called whenever a new QUIC connection is needed.
// It should return a new QUIC connection, a function to close the connection
// (and potentially other underlying resources), and an error if one occurred.
Connect func() (quic.Connection, func(), error)
conn quic.Connection
closeFunc func()
connMutex sync.RWMutex
}
func (c *autoReconnectConn) OpenStream() (quic.Connection, quic.Stream, error) {
c.connMutex.Lock()
defer c.connMutex.Unlock()
// First time?
if c.conn == nil {
conn, closeFunc, err := c.Connect()
if err != nil {
return nil, nil, err
}
c.conn = conn
c.closeFunc = closeFunc
}
stream, err := c.conn.OpenStream()
if err == nil {
// All is good
return c.conn, stream, nil
} else if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
// Temporary error, just pass the error to the caller
return nil, nil, err
} else {
// Permanent error
// Close the previous connection,
// reconnect and try again (only once)
c.closeFunc()
conn, closeFunc, err := c.Connect()
if err != nil {
return nil, nil, err
}
c.conn = conn
c.closeFunc = closeFunc
stream, err = c.conn.OpenStream()
return c.conn, stream, err
}
}
func (c *autoReconnectConn) Close() error {
c.connMutex.Lock()
defer c.connMutex.Unlock()
if c.conn == nil {
return nil
}
c.closeFunc()
c.conn = nil
c.closeFunc = nil
return nil
}

177
core/client/udp.go Normal file
View File

@@ -0,0 +1,177 @@
package client
import (
"errors"
"io"
"math/rand"
"sync"
"github.com/apernet/hysteria/core/internal/frag"
"github.com/apernet/hysteria/core/internal/protocol"
"github.com/quic-go/quic-go"
)
const (
udpMessageChanSize = 1024
)
type udpIO interface {
ReceiveMessage() (*protocol.UDPMessage, error)
SendMessage([]byte, *protocol.UDPMessage) error
}
type udpConn struct {
ID uint32
D *frag.Defragger
ReceiveCh chan *protocol.UDPMessage
SendBuf []byte
SendFunc func([]byte, *protocol.UDPMessage) error
CloseFunc func()
Closed bool
}
func (u *udpConn) Receive() ([]byte, string, error) {
for {
msg := <-u.ReceiveCh
if msg == nil {
// Closed
return nil, "", io.EOF
}
dfMsg := u.D.Feed(msg)
if dfMsg == nil {
// Incomplete message, wait for more
continue
}
return dfMsg.Data, dfMsg.Addr, nil
}
}
// Send is not thread-safe, as it uses a shared SendBuf.
func (u *udpConn) Send(data []byte, addr string) error {
// Try no frag first
msg := &protocol.UDPMessage{
SessionID: u.ID,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: addr,
Data: data,
}
err := u.SendFunc(u.SendBuf, msg)
var errTooLarge quic.ErrMessageTooLarge
if errors.As(err, &errTooLarge) {
// Message too large, try fragmentation
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
for _, fMsg := range fMsgs {
err := u.SendFunc(u.SendBuf, &fMsg)
if err != nil {
return err
}
}
return nil
} else {
return err
}
}
func (u *udpConn) Close() error {
u.CloseFunc()
return nil
}
type udpSessionManager struct {
io udpIO
mutex sync.Mutex
m map[uint32]*udpConn
nextID uint32
}
func newUDPSessionManager(io udpIO) *udpSessionManager {
return &udpSessionManager{
io: io,
m: make(map[uint32]*udpConn),
nextID: 1,
}
}
// Run runs the session manager main loop.
// Exit and returns error when the underlying io returns error (e.g. closed).
func (m *udpSessionManager) Run() error {
defer m.cleanup()
for {
msg, err := m.io.ReceiveMessage()
if err != nil {
return err
}
m.feed(msg)
}
}
func (m *udpSessionManager) cleanup() {
m.mutex.Lock()
defer m.mutex.Unlock()
for _, conn := range m.m {
m.close(conn)
}
}
func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
m.mutex.Lock()
defer m.mutex.Unlock()
conn, ok := m.m[msg.SessionID]
if !ok {
// Ignore message from unknown session
return
}
select {
case conn.ReceiveCh <- msg:
// OK
default:
// Channel full, drop the message
}
}
// NewUDP creates a new UDP session.
func (m *udpSessionManager) NewUDP() (HyUDPConn, error) {
m.mutex.Lock()
defer m.mutex.Unlock()
id := m.nextID
m.nextID++
conn := &udpConn{
ID: id,
D: &frag.Defragger{},
ReceiveCh: make(chan *protocol.UDPMessage, udpMessageChanSize),
SendBuf: make([]byte, protocol.MaxUDPSize),
SendFunc: m.io.SendMessage,
}
conn.CloseFunc = func() {
m.mutex.Lock()
defer m.mutex.Unlock()
if !conn.Closed {
m.close(conn)
}
}
m.m[id] = conn
return conn, nil
}
func (m *udpSessionManager) close(conn *udpConn) {
conn.Closed = true
close(conn.ReceiveCh)
delete(m.m, conn.ID)
}
func (m *udpSessionManager) Count() int {
m.mutex.Lock()
defer m.mutex.Unlock()
return len(m.m)
}

View File

@@ -106,25 +106,4 @@ func TestServerMasquerade(t *testing.T) {
if nErr, ok := err.(net.Error); !ok || !nErr.Timeout() { if nErr, ok := err.(net.Error); !ok || !nErr.Timeout() {
t.Fatal("expected timeout, got", err) t.Fatal("expected timeout, got", err)
} }
// Try UDP request
udpStream, err := conn.OpenStream()
if err != nil {
t.Fatal("error opening stream:", err)
}
defer udpStream.Close()
err = protocol.WriteUDPRequest(udpStream)
if err != nil {
t.Fatal("error sending request:", err)
}
// We should receive nothing
_ = udpStream.SetReadDeadline(time.Now().Add(2 * time.Second))
n, err = udpStream.Read(buf)
if n != 0 {
t.Fatal("expected no response, got", n)
}
if nErr, ok := err.(net.Error); !ok || !nErr.Timeout() {
t.Fatal("expected timeout, got", err)
}
} }

View File

@@ -287,7 +287,7 @@ func (l *channelEventLogger) TCPError(addr net.Addr, id, reqAddr string, err err
} }
} }
func (l *channelEventLogger) UDPRequest(addr net.Addr, id string, sessionID uint32) { func (l *channelEventLogger) UDPRequest(addr net.Addr, id string, sessionID uint32, reqAddr string) {
if l.UDPRequestEventCh != nil { if l.UDPRequestEventCh != nil {
l.UDPRequestEventCh <- udpRequestEvent{ l.UDPRequestEventCh <- udpRequestEvent{
Addr: addr, Addr: addr,

View File

@@ -111,9 +111,8 @@ type udpSessionManager struct {
eventLogger udpEventLogger eventLogger udpEventLogger
idleTimeout time.Duration idleTimeout time.Duration
mutex sync.Mutex mutex sync.Mutex
m map[uint32]*udpSessionEntry m map[uint32]*udpSessionEntry
nextID uint32
} }
func newUDPSessionManager(io udpIO, eventLogger udpEventLogger, idleTimeout time.Duration) *udpSessionManager { func newUDPSessionManager(io udpIO, eventLogger udpEventLogger, idleTimeout time.Duration) *udpSessionManager {
@@ -212,3 +211,9 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
// as some are temporary (e.g. invalid address) // as some are temporary (e.g. invalid address)
_, _ = entry.Feed(msg) _, _ = entry.Feed(msg)
} }
func (m *udpSessionManager) Count() int {
m.mutex.Lock()
defer m.mutex.Unlock()
return len(m.m)
}

View File

@@ -10,6 +10,11 @@ import (
"go.uber.org/goleak" "go.uber.org/goleak"
) )
var (
errUDPBlocked = errors.New("blocked")
errUDPClosed = errors.New("closed")
)
type echoUDPConnPkt struct { type echoUDPConnPkt struct {
Data []byte Data []byte
Addr string Addr string
@@ -23,7 +28,7 @@ type echoUDPConn struct {
func (c *echoUDPConn) ReadFrom(b []byte) (int, string, error) { func (c *echoUDPConn) ReadFrom(b []byte) (int, string, error) {
pkt := <-c.PktCh pkt := <-c.PktCh
if pkt.Close { if pkt.Close {
return 0, "", errors.New("closed") return 0, "", errUDPClosed
} }
n := copy(b, pkt.Data) n := copy(b, pkt.Data)
return n, pkt.Addr, nil return n, pkt.Addr, nil
@@ -49,12 +54,14 @@ func (c *echoUDPConn) Close() error {
type udpMockIO struct { type udpMockIO struct {
ReceiveCh <-chan *protocol.UDPMessage ReceiveCh <-chan *protocol.UDPMessage
SendCh chan<- *protocol.UDPMessage SendCh chan<- *protocol.UDPMessage
UDPClose bool // ReadFrom() returns error immediately
BlockUDP bool // Block UDP connection creation
} }
func (io *udpMockIO) ReceiveMessage() (*protocol.UDPMessage, error) { func (io *udpMockIO) ReceiveMessage() (*protocol.UDPMessage, error) {
m := <-io.ReceiveCh m := <-io.ReceiveCh
if m == nil { if m == nil {
return nil, errors.New("closed") return nil, errUDPClosed
} }
return m, nil return m, nil
} }
@@ -68,9 +75,18 @@ func (io *udpMockIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
} }
func (io *udpMockIO) UDP(reqAddr string) (UDPConn, error) { func (io *udpMockIO) UDP(reqAddr string) (UDPConn, error) {
return &echoUDPConn{ if io.BlockUDP {
return nil, errUDPBlocked
}
conn := &echoUDPConn{
PktCh: make(chan echoUDPConnPkt, 10), PktCh: make(chan echoUDPConnPkt, 10),
}, nil }
if io.UDPClose {
conn.PktCh <- echoUDPConnPkt{
Close: true,
}
}
return conn, nil
} }
type udpMockEventNew struct { type udpMockEventNew struct {
@@ -112,80 +128,131 @@ func TestUDPSessionManager(t *testing.T) {
sm := newUDPSessionManager(io, eventLogger, 2*time.Second) sm := newUDPSessionManager(io, eventLogger, 2*time.Second)
go sm.Run() go sm.Run()
ms := []*protocol.UDPMessage{ t.Run("session creation & timeout", func(t *testing.T) {
{ ms := []*protocol.UDPMessage{
SessionID: 1234, {
PacketID: 0, SessionID: 1234,
FragID: 0, PacketID: 0,
FragCount: 1, FragID: 0,
Addr: "example.com:5353", FragCount: 1,
Data: []byte("hello"), Addr: "example.com:5353",
}, Data: []byte("hello"),
{ },
SessionID: 5678, {
PacketID: 0, SessionID: 5678,
FragID: 0, PacketID: 0,
FragCount: 1, FragID: 0,
Addr: "example.com:9999", FragCount: 1,
Data: []byte("goodbye"), Addr: "example.com:9999",
}, Data: []byte("goodbye"),
{ },
SessionID: 1234, {
PacketID: 0, SessionID: 1234,
FragID: 0, PacketID: 0,
FragCount: 1, FragID: 0,
Addr: "example.com:5353", FragCount: 1,
Data: []byte(" world"), Addr: "example.com:5353",
}, Data: []byte(" world"),
{ },
SessionID: 5678, {
PacketID: 0, SessionID: 5678,
FragID: 0, PacketID: 0,
FragCount: 1, FragID: 0,
Addr: "example.com:9999", FragCount: 1,
Data: []byte(" girl"), Addr: "example.com:9999",
}, Data: []byte(" girl"),
} },
for _, m := range ms { }
msgReceiveCh <- m for _, m := range ms {
} msgReceiveCh <- m
// New event order should be consistent }
newEvent := <-eventNewCh // New event order should be consistent
if newEvent.SessionID != 1234 || newEvent.ReqAddr != "example.com:5353" { newEvent := <-eventNewCh
t.Error("unexpected new event value") if newEvent.SessionID != 1234 || newEvent.ReqAddr != "example.com:5353" {
} t.Error("unexpected new event value")
newEvent = <-eventNewCh }
if newEvent.SessionID != 5678 || newEvent.ReqAddr != "example.com:9999" { newEvent = <-eventNewCh
t.Error("unexpected new event value") if newEvent.SessionID != 5678 || newEvent.ReqAddr != "example.com:9999" {
} t.Error("unexpected new event value")
// Message order is not guaranteed }
msgMap := make(map[string]bool) // Message order is not guaranteed
for i := 0; i < 4; i++ { msgMap := make(map[string]bool)
msg := <-msgSendCh for i := 0; i < 4; i++ {
msgMap[fmt.Sprintf("%d:%s:%s", msg.SessionID, msg.Addr, string(msg.Data))] = true msg := <-msgSendCh
} msgMap[fmt.Sprintf("%d:%s:%s", msg.SessionID, msg.Addr, string(msg.Data))] = true
if !(msgMap["1234:example.com:5353:hello"] && }
msgMap["5678:example.com:9999:goodbye"] && if !(msgMap["1234:example.com:5353:hello"] &&
msgMap["1234:example.com:5353: world"] && msgMap["5678:example.com:9999:goodbye"] &&
msgMap["5678:example.com:9999: girl"]) { msgMap["1234:example.com:5353: world"] &&
t.Error("unexpected message value") msgMap["5678:example.com:9999: girl"]) {
} t.Error("unexpected message value")
// Timeout check }
startTime := time.Now() // Timeout check
closeMap := make(map[uint32]bool) startTime := time.Now()
for i := 0; i < 2; i++ { closeMap := make(map[uint32]bool)
closeEvent := <-eventCloseCh for i := 0; i < 2; i++ {
closeMap[closeEvent.SessionID] = true closeEvent := <-eventCloseCh
} closeMap[closeEvent.SessionID] = true
if !(closeMap[1234] && closeMap[5678]) { }
t.Error("unexpected close event value") if !(closeMap[1234] && closeMap[5678]) {
} t.Error("unexpected close event value")
if time.Since(startTime) < 2*time.Second || time.Since(startTime) > 4*time.Second { }
t.Error("unexpected timeout duration") if time.Since(startTime) < 2*time.Second || time.Since(startTime) > 4*time.Second {
} t.Error("unexpected timeout duration")
}
})
// Goroutine leak check t.Run("UDP connection close", func(t *testing.T) {
// Close UDP connection immediately after creation
io.UDPClose = true
msgReceiveCh <- &protocol.UDPMessage{
SessionID: 8888,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: "mygod.org:1514",
Data: []byte("goodnight"),
}
// Should have both new and close events immediately
newEvent := <-eventNewCh
if newEvent.SessionID != 8888 || newEvent.ReqAddr != "mygod.org:1514" {
t.Error("unexpected new event value")
}
closeEvent := <-eventCloseCh
if closeEvent.SessionID != 8888 || closeEvent.Err != errUDPClosed {
t.Error("unexpected close event value")
}
})
t.Run("UDP IO failure", func(t *testing.T) {
// Block UDP connection creation
io.BlockUDP = true
msgReceiveCh <- &protocol.UDPMessage{
SessionID: 9999,
PacketID: 0,
FragID: 0,
FragCount: 1,
Addr: "xxx.net:12450",
Data: []byte("nope"),
}
// Should have both new and close events immediately
newEvent := <-eventNewCh
if newEvent.SessionID != 9999 || newEvent.ReqAddr != "xxx.net:12450" {
t.Error("unexpected new event value")
}
closeEvent := <-eventCloseCh
if closeEvent.SessionID != 9999 || closeEvent.Err != errUDPBlocked {
t.Error("unexpected close event value")
}
})
// Leak checks
msgReceiveCh <- nil msgReceiveCh <- nil
time.Sleep(1 * time.Second) // Wait for internal routines to exit time.Sleep(1 * time.Second) // Give some time for the goroutines to exit
if sm.Count() != 0 {
t.Error("session count should be 0")
}
goleak.VerifyNone(t) goleak.VerifyNone(t)
} }