mirror of
https://github.com/cmz0228/hysteria-dev.git
synced 2025-07-29 05:34:28 +00:00
feat(wip): udp rework client side
This commit is contained in:
@@ -3,18 +3,13 @@ package client
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
|
||||||
"io"
|
|
||||||
"math/rand"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
coreErrs "github.com/apernet/hysteria/core/errors"
|
coreErrs "github.com/apernet/hysteria/core/errors"
|
||||||
"github.com/apernet/hysteria/core/internal/congestion"
|
"github.com/apernet/hysteria/core/internal/congestion"
|
||||||
"github.com/apernet/hysteria/core/internal/frag"
|
|
||||||
"github.com/apernet/hysteria/core/internal/protocol"
|
"github.com/apernet/hysteria/core/internal/protocol"
|
||||||
"github.com/apernet/hysteria/core/internal/utils"
|
"github.com/apernet/hysteria/core/internal/utils"
|
||||||
|
|
||||||
@@ -23,8 +18,6 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
udpMessageChanSize = 1024
|
|
||||||
|
|
||||||
closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError
|
closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError
|
||||||
closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError
|
closeErrCodeProtocolError = 0x101 // HTTP3 ErrCodeGeneralProtocolError
|
||||||
)
|
)
|
||||||
@@ -48,94 +41,25 @@ func NewClient(config *Config) (Client, error) {
|
|||||||
c := &clientImpl{
|
c := &clientImpl{
|
||||||
config: config,
|
config: config,
|
||||||
}
|
}
|
||||||
c.conn = &autoReconnectConn{
|
if err := c.connect(); err != nil {
|
||||||
Connect: c.connect,
|
return nil, err
|
||||||
}
|
}
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientImpl struct {
|
type clientImpl struct {
|
||||||
config *Config
|
config *Config
|
||||||
conn *autoReconnectConn
|
|
||||||
|
|
||||||
udpSM udpSessionManager
|
pktConn net.PacketConn
|
||||||
|
conn quic.Connection
|
||||||
|
|
||||||
|
udpSM *udpSessionManager
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpSessionEntry struct {
|
func (c *clientImpl) connect() error {
|
||||||
Ch chan *protocol.UDPMessage
|
|
||||||
D *frag.Defragger
|
|
||||||
Closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type udpSessionManager struct {
|
|
||||||
mutex sync.RWMutex
|
|
||||||
m map[uint32]*udpSessionEntry
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *udpSessionManager) Init() {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
m.m = make(map[uint32]*udpSessionEntry)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add returns both a channel for receiving messages and a function to close the channel & delete the session.
|
|
||||||
func (m *udpSessionManager) Add(id uint32) (<-chan *protocol.UDPMessage, func()) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
// Important: make sure we add and delete the channel in the same map,
|
|
||||||
// as the map may be replaced by Init() at any time.
|
|
||||||
currentM := m.m
|
|
||||||
|
|
||||||
entry := &udpSessionEntry{
|
|
||||||
Ch: make(chan *protocol.UDPMessage, udpMessageChanSize),
|
|
||||||
D: &frag.Defragger{},
|
|
||||||
Closed: false,
|
|
||||||
}
|
|
||||||
currentM[id] = entry
|
|
||||||
|
|
||||||
return entry.Ch, func() {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
if entry.Closed {
|
|
||||||
// Double close a channel will panic,
|
|
||||||
// so we need a flag to make sure we only close it once.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
entry.Closed = true
|
|
||||||
close(entry.Ch)
|
|
||||||
delete(currentM, id)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *udpSessionManager) Feed(msg *protocol.UDPMessage) {
|
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
entry, ok := m.m[msg.SessionID]
|
|
||||||
if !ok {
|
|
||||||
// No such session, drop the message
|
|
||||||
return
|
|
||||||
}
|
|
||||||
dfMsg := entry.D.Feed(msg)
|
|
||||||
if dfMsg == nil {
|
|
||||||
// Not a complete message yet
|
|
||||||
return
|
|
||||||
}
|
|
||||||
select {
|
|
||||||
case entry.Ch <- dfMsg:
|
|
||||||
// OK
|
|
||||||
default:
|
|
||||||
// Channel is full, drop the message
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *clientImpl) connect() (quic.Connection, func(), error) {
|
|
||||||
// Use a new packet conn for each connection,
|
|
||||||
// remember to close it after the QUIC connection is closed.
|
|
||||||
pktConn, err := c.config.ConnFactory.New(c.config.ServerAddr)
|
pktConn, err := c.config.ConnFactory.New(c.config.ServerAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return err
|
||||||
}
|
}
|
||||||
// Convert config to TLS config & QUIC config
|
// Convert config to TLS config & QUIC config
|
||||||
tlsConfig := &tls.Config{
|
tlsConfig := &tls.Config{
|
||||||
@@ -185,15 +109,15 @@ func (c *clientImpl) connect() (quic.Connection, func(), error) {
|
|||||||
_ = conn.CloseWithError(closeErrCodeProtocolError, "")
|
_ = conn.CloseWithError(closeErrCodeProtocolError, "")
|
||||||
}
|
}
|
||||||
_ = pktConn.Close()
|
_ = pktConn.Close()
|
||||||
return nil, nil, &coreErrs.ConnectError{Err: err}
|
return &coreErrs.ConnectError{Err: err}
|
||||||
}
|
}
|
||||||
if resp.StatusCode != protocol.StatusAuthOK {
|
if resp.StatusCode != protocol.StatusAuthOK {
|
||||||
_ = conn.CloseWithError(closeErrCodeProtocolError, "")
|
_ = conn.CloseWithError(closeErrCodeProtocolError, "")
|
||||||
_ = pktConn.Close()
|
_ = pktConn.Close()
|
||||||
return nil, nil, &coreErrs.AuthError{StatusCode: resp.StatusCode}
|
return &coreErrs.AuthError{StatusCode: resp.StatusCode}
|
||||||
}
|
}
|
||||||
// Auth OK
|
// Auth OK
|
||||||
serverRx := protocol.AuthResponseDataFromHeader(resp.Header)
|
udpEnabled, serverRx := protocol.AuthResponseDataFromHeader(resp.Header)
|
||||||
// actualTx = min(serverRx, clientTx)
|
// actualTx = min(serverRx, clientTx)
|
||||||
actualTx := serverRx
|
actualTx := serverRx
|
||||||
if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx {
|
if actualTx == 0 || actualTx > c.config.BandwidthConfig.MaxTx {
|
||||||
@@ -205,46 +129,20 @@ func (c *clientImpl) connect() (quic.Connection, func(), error) {
|
|||||||
}
|
}
|
||||||
_ = resp.Body.Close()
|
_ = resp.Body.Close()
|
||||||
|
|
||||||
c.udpSM.Init()
|
c.pktConn = pktConn
|
||||||
go c.udpLoop(conn)
|
c.conn = conn
|
||||||
|
if udpEnabled {
|
||||||
return conn, func() {
|
c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn})
|
||||||
_ = conn.CloseWithError(closeErrCodeOK, "")
|
go func() {
|
||||||
_ = pktConn.Close()
|
c.udpSM.Run()
|
||||||
}, nil
|
// TODO: Mark connection as closed
|
||||||
}
|
}()
|
||||||
|
|
||||||
func (c *clientImpl) udpLoop(conn quic.Connection) {
|
|
||||||
for {
|
|
||||||
msg, err := conn.ReceiveMessage()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.handleUDPMessage(msg)
|
|
||||||
}
|
}
|
||||||
}
|
return nil
|
||||||
|
|
||||||
// client <- remote direction
|
|
||||||
func (c *clientImpl) handleUDPMessage(msg []byte) {
|
|
||||||
udpMsg, err := protocol.ParseUDPMessage(msg)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
c.udpSM.Feed(udpMsg)
|
|
||||||
}
|
|
||||||
|
|
||||||
// openStream wraps the stream with QStream, which handles Close() properly
|
|
||||||
func (c *clientImpl) openStream() (quic.Connection, quic.Stream, error) {
|
|
||||||
qc, stream, err := c.conn.OpenStream()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return qc, &utils.QStream{Stream: stream}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientImpl) DialTCP(addr string) (net.Conn, error) {
|
func (c *clientImpl) DialTCP(addr string) (net.Conn, error) {
|
||||||
qc, stream, err := c.openStream()
|
stream, err := c.openStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -260,8 +158,8 @@ func (c *clientImpl) DialTCP(addr string) (net.Conn, error) {
|
|||||||
// to the first Read() call.
|
// to the first Read() call.
|
||||||
return &tcpConn{
|
return &tcpConn{
|
||||||
Orig: stream,
|
Orig: stream,
|
||||||
PseudoLocalAddr: qc.LocalAddr(),
|
PseudoLocalAddr: c.conn.LocalAddr(),
|
||||||
PseudoRemoteAddr: qc.RemoteAddr(),
|
PseudoRemoteAddr: c.conn.RemoteAddr(),
|
||||||
Established: false,
|
Established: false,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -277,49 +175,23 @@ func (c *clientImpl) DialTCP(addr string) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
return &tcpConn{
|
return &tcpConn{
|
||||||
Orig: stream,
|
Orig: stream,
|
||||||
PseudoLocalAddr: qc.LocalAddr(),
|
PseudoLocalAddr: c.conn.LocalAddr(),
|
||||||
PseudoRemoteAddr: qc.RemoteAddr(),
|
PseudoRemoteAddr: c.conn.RemoteAddr(),
|
||||||
Established: true,
|
Established: true,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientImpl) ListenUDP() (HyUDPConn, error) {
|
func (c *clientImpl) ListenUDP() (HyUDPConn, error) {
|
||||||
qc, stream, err := c.openStream()
|
if c.udpSM == nil {
|
||||||
if err != nil {
|
return nil, coreErrs.DialError{Message: "UDP not enabled"}
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
// Send request
|
return c.udpSM.NewUDP()
|
||||||
err = protocol.WriteUDPRequest(stream)
|
|
||||||
if err != nil {
|
|
||||||
_ = stream.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Read response
|
|
||||||
ok, sessionID, msg, err := protocol.ReadUDPResponse(stream)
|
|
||||||
if err != nil {
|
|
||||||
_ = stream.Close()
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if !ok {
|
|
||||||
_ = stream.Close()
|
|
||||||
return nil, coreErrs.DialError{Message: msg}
|
|
||||||
}
|
|
||||||
|
|
||||||
ch, closeFunc := c.udpSM.Add(sessionID)
|
|
||||||
uc := &udpConn{
|
|
||||||
QC: qc,
|
|
||||||
Stream: stream,
|
|
||||||
SessionID: sessionID,
|
|
||||||
Ch: ch,
|
|
||||||
CloseFunc: closeFunc,
|
|
||||||
SendBuf: make([]byte, protocol.MaxUDPSize),
|
|
||||||
}
|
|
||||||
go uc.Hold()
|
|
||||||
return uc, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *clientImpl) Close() error {
|
func (c *clientImpl) Close() error {
|
||||||
return c.conn.Close()
|
_ = c.conn.CloseWithError(closeErrCodeOK, "")
|
||||||
|
_ = c.pktConn.Close()
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type tcpConn struct {
|
type tcpConn struct {
|
||||||
@@ -372,72 +244,40 @@ func (c *tcpConn) SetWriteDeadline(t time.Time) error {
|
|||||||
return c.Orig.SetWriteDeadline(t)
|
return c.Orig.SetWriteDeadline(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpConn struct {
|
type udpIOImpl struct {
|
||||||
QC quic.Connection
|
Conn quic.Connection
|
||||||
Stream quic.Stream
|
|
||||||
SessionID uint32
|
|
||||||
Ch <-chan *protocol.UDPMessage
|
|
||||||
CloseFunc func()
|
|
||||||
SendBuf []byte
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *udpConn) Hold() {
|
func (io *udpIOImpl) ReceiveMessage() (*protocol.UDPMessage, error) {
|
||||||
// Hold (drain) the stream until someone closes it.
|
for {
|
||||||
// Closing the stream is the signal to stop the UDP session.
|
msg, err := io.Conn.ReceiveMessage()
|
||||||
_, _ = io.Copy(io.Discard, c.Stream)
|
if err != nil {
|
||||||
_ = c.Close()
|
// Connection error, this will stop the session manager
|
||||||
}
|
return nil, err
|
||||||
|
|
||||||
func (c *udpConn) Receive() ([]byte, string, error) {
|
|
||||||
msg := <-c.Ch
|
|
||||||
if msg == nil {
|
|
||||||
// Closed
|
|
||||||
return nil, "", io.EOF
|
|
||||||
}
|
|
||||||
return msg.Data, msg.Addr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send is not thread-safe as it uses a shared send buffer for now.
|
|
||||||
func (c *udpConn) Send(data []byte, addr string) error {
|
|
||||||
// Try no frag first
|
|
||||||
msg := &protocol.UDPMessage{
|
|
||||||
SessionID: c.SessionID,
|
|
||||||
PacketID: 0,
|
|
||||||
FragID: 0,
|
|
||||||
FragCount: 1,
|
|
||||||
Addr: addr,
|
|
||||||
Data: data,
|
|
||||||
}
|
|
||||||
n := msg.Serialize(c.SendBuf)
|
|
||||||
if n < 0 {
|
|
||||||
// Message even larger than MaxUDPSize, drop it
|
|
||||||
// Maybe we should return an error in the future?
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
sendErr := c.QC.SendMessage(c.SendBuf[:n])
|
|
||||||
if sendErr == nil {
|
|
||||||
// All good
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
var errTooLarge quic.ErrMessageTooLarge
|
|
||||||
if errors.As(sendErr, &errTooLarge) {
|
|
||||||
// Message too large, try fragmentation
|
|
||||||
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
|
|
||||||
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
|
|
||||||
for _, fMsg := range fMsgs {
|
|
||||||
n = fMsg.Serialize(c.SendBuf)
|
|
||||||
err := c.QC.SendMessage(c.SendBuf[:n])
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return nil
|
udpMsg, err := protocol.ParseUDPMessage(msg)
|
||||||
|
if err != nil {
|
||||||
|
// Invalid message, this is fine - just wait for the next
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return udpMsg, nil
|
||||||
}
|
}
|
||||||
// Other error
|
|
||||||
return sendErr
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *udpConn) Close() error {
|
func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
|
||||||
c.CloseFunc()
|
msgN := msg.Serialize(buf)
|
||||||
return c.Stream.Close()
|
if msgN < 0 {
|
||||||
|
// Message larger than buffer, silent drop
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return io.Conn.SendMessage(buf[:msgN])
|
||||||
|
}
|
||||||
|
|
||||||
|
// openStream wraps the stream with QStream, which handles Close() properly
|
||||||
|
func (c *clientImpl) openStream() (quic.Stream, error) {
|
||||||
|
stream, err := c.conn.OpenStream()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &utils.QStream{Stream: stream}, nil
|
||||||
}
|
}
|
||||||
|
@@ -1,68 +0,0 @@
|
|||||||
package client
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/quic-go/quic-go"
|
|
||||||
)
|
|
||||||
|
|
||||||
// autoReconnectConn is a wrapper of quic.Connection that automatically reconnects
|
|
||||||
// when a non-temporary error (usually a timeout) occurs.
|
|
||||||
type autoReconnectConn struct {
|
|
||||||
// Connect is called whenever a new QUIC connection is needed.
|
|
||||||
// It should return a new QUIC connection, a function to close the connection
|
|
||||||
// (and potentially other underlying resources), and an error if one occurred.
|
|
||||||
Connect func() (quic.Connection, func(), error)
|
|
||||||
|
|
||||||
conn quic.Connection
|
|
||||||
closeFunc func()
|
|
||||||
connMutex sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *autoReconnectConn) OpenStream() (quic.Connection, quic.Stream, error) {
|
|
||||||
c.connMutex.Lock()
|
|
||||||
defer c.connMutex.Unlock()
|
|
||||||
// First time?
|
|
||||||
if c.conn == nil {
|
|
||||||
conn, closeFunc, err := c.Connect()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
c.conn = conn
|
|
||||||
c.closeFunc = closeFunc
|
|
||||||
}
|
|
||||||
stream, err := c.conn.OpenStream()
|
|
||||||
if err == nil {
|
|
||||||
// All is good
|
|
||||||
return c.conn, stream, nil
|
|
||||||
} else if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
|
|
||||||
// Temporary error, just pass the error to the caller
|
|
||||||
return nil, nil, err
|
|
||||||
} else {
|
|
||||||
// Permanent error
|
|
||||||
// Close the previous connection,
|
|
||||||
// reconnect and try again (only once)
|
|
||||||
c.closeFunc()
|
|
||||||
conn, closeFunc, err := c.Connect()
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
c.conn = conn
|
|
||||||
c.closeFunc = closeFunc
|
|
||||||
stream, err = c.conn.OpenStream()
|
|
||||||
return c.conn, stream, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *autoReconnectConn) Close() error {
|
|
||||||
c.connMutex.Lock()
|
|
||||||
defer c.connMutex.Unlock()
|
|
||||||
if c.conn == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
c.closeFunc()
|
|
||||||
c.conn = nil
|
|
||||||
c.closeFunc = nil
|
|
||||||
return nil
|
|
||||||
}
|
|
177
core/client/udp.go
Normal file
177
core/client/udp.go
Normal file
@@ -0,0 +1,177 @@
|
|||||||
|
package client
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"math/rand"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/apernet/hysteria/core/internal/frag"
|
||||||
|
"github.com/apernet/hysteria/core/internal/protocol"
|
||||||
|
"github.com/quic-go/quic-go"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
udpMessageChanSize = 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
type udpIO interface {
|
||||||
|
ReceiveMessage() (*protocol.UDPMessage, error)
|
||||||
|
SendMessage([]byte, *protocol.UDPMessage) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpConn struct {
|
||||||
|
ID uint32
|
||||||
|
D *frag.Defragger
|
||||||
|
ReceiveCh chan *protocol.UDPMessage
|
||||||
|
SendBuf []byte
|
||||||
|
SendFunc func([]byte, *protocol.UDPMessage) error
|
||||||
|
CloseFunc func()
|
||||||
|
Closed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpConn) Receive() ([]byte, string, error) {
|
||||||
|
for {
|
||||||
|
msg := <-u.ReceiveCh
|
||||||
|
if msg == nil {
|
||||||
|
// Closed
|
||||||
|
return nil, "", io.EOF
|
||||||
|
}
|
||||||
|
dfMsg := u.D.Feed(msg)
|
||||||
|
if dfMsg == nil {
|
||||||
|
// Incomplete message, wait for more
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return dfMsg.Data, dfMsg.Addr, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send is not thread-safe, as it uses a shared SendBuf.
|
||||||
|
func (u *udpConn) Send(data []byte, addr string) error {
|
||||||
|
// Try no frag first
|
||||||
|
msg := &protocol.UDPMessage{
|
||||||
|
SessionID: u.ID,
|
||||||
|
PacketID: 0,
|
||||||
|
FragID: 0,
|
||||||
|
FragCount: 1,
|
||||||
|
Addr: addr,
|
||||||
|
Data: data,
|
||||||
|
}
|
||||||
|
err := u.SendFunc(u.SendBuf, msg)
|
||||||
|
var errTooLarge quic.ErrMessageTooLarge
|
||||||
|
if errors.As(err, &errTooLarge) {
|
||||||
|
// Message too large, try fragmentation
|
||||||
|
msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1
|
||||||
|
fMsgs := frag.FragUDPMessage(msg, int(errTooLarge))
|
||||||
|
for _, fMsg := range fMsgs {
|
||||||
|
err := u.SendFunc(u.SendBuf, &fMsg)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpConn) Close() error {
|
||||||
|
u.CloseFunc()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpSessionManager struct {
|
||||||
|
io udpIO
|
||||||
|
|
||||||
|
mutex sync.Mutex
|
||||||
|
m map[uint32]*udpConn
|
||||||
|
nextID uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPSessionManager(io udpIO) *udpSessionManager {
|
||||||
|
return &udpSessionManager{
|
||||||
|
io: io,
|
||||||
|
m: make(map[uint32]*udpConn),
|
||||||
|
nextID: 1,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run runs the session manager main loop.
|
||||||
|
// Exit and returns error when the underlying io returns error (e.g. closed).
|
||||||
|
func (m *udpSessionManager) Run() error {
|
||||||
|
defer m.cleanup()
|
||||||
|
|
||||||
|
for {
|
||||||
|
msg, err := m.io.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
m.feed(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpSessionManager) cleanup() {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
for _, conn := range m.m {
|
||||||
|
m.close(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
conn, ok := m.m[msg.SessionID]
|
||||||
|
if !ok {
|
||||||
|
// Ignore message from unknown session
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case conn.ReceiveCh <- msg:
|
||||||
|
// OK
|
||||||
|
default:
|
||||||
|
// Channel full, drop the message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUDP creates a new UDP session.
|
||||||
|
func (m *udpSessionManager) NewUDP() (HyUDPConn, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
id := m.nextID
|
||||||
|
m.nextID++
|
||||||
|
|
||||||
|
conn := &udpConn{
|
||||||
|
ID: id,
|
||||||
|
D: &frag.Defragger{},
|
||||||
|
ReceiveCh: make(chan *protocol.UDPMessage, udpMessageChanSize),
|
||||||
|
SendBuf: make([]byte, protocol.MaxUDPSize),
|
||||||
|
SendFunc: m.io.SendMessage,
|
||||||
|
}
|
||||||
|
conn.CloseFunc = func() {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
if !conn.Closed {
|
||||||
|
m.close(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.m[id] = conn
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpSessionManager) close(conn *udpConn) {
|
||||||
|
conn.Closed = true
|
||||||
|
close(conn.ReceiveCh)
|
||||||
|
delete(m.m, conn.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *udpSessionManager) Count() int {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
return len(m.m)
|
||||||
|
}
|
@@ -106,25 +106,4 @@ func TestServerMasquerade(t *testing.T) {
|
|||||||
if nErr, ok := err.(net.Error); !ok || !nErr.Timeout() {
|
if nErr, ok := err.(net.Error); !ok || !nErr.Timeout() {
|
||||||
t.Fatal("expected timeout, got", err)
|
t.Fatal("expected timeout, got", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try UDP request
|
|
||||||
udpStream, err := conn.OpenStream()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("error opening stream:", err)
|
|
||||||
}
|
|
||||||
defer udpStream.Close()
|
|
||||||
err = protocol.WriteUDPRequest(udpStream)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("error sending request:", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// We should receive nothing
|
|
||||||
_ = udpStream.SetReadDeadline(time.Now().Add(2 * time.Second))
|
|
||||||
n, err = udpStream.Read(buf)
|
|
||||||
if n != 0 {
|
|
||||||
t.Fatal("expected no response, got", n)
|
|
||||||
}
|
|
||||||
if nErr, ok := err.(net.Error); !ok || !nErr.Timeout() {
|
|
||||||
t.Fatal("expected timeout, got", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
@@ -287,7 +287,7 @@ func (l *channelEventLogger) TCPError(addr net.Addr, id, reqAddr string, err err
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *channelEventLogger) UDPRequest(addr net.Addr, id string, sessionID uint32) {
|
func (l *channelEventLogger) UDPRequest(addr net.Addr, id string, sessionID uint32, reqAddr string) {
|
||||||
if l.UDPRequestEventCh != nil {
|
if l.UDPRequestEventCh != nil {
|
||||||
l.UDPRequestEventCh <- udpRequestEvent{
|
l.UDPRequestEventCh <- udpRequestEvent{
|
||||||
Addr: addr,
|
Addr: addr,
|
||||||
|
@@ -111,9 +111,8 @@ type udpSessionManager struct {
|
|||||||
eventLogger udpEventLogger
|
eventLogger udpEventLogger
|
||||||
idleTimeout time.Duration
|
idleTimeout time.Duration
|
||||||
|
|
||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
m map[uint32]*udpSessionEntry
|
m map[uint32]*udpSessionEntry
|
||||||
nextID uint32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUDPSessionManager(io udpIO, eventLogger udpEventLogger, idleTimeout time.Duration) *udpSessionManager {
|
func newUDPSessionManager(io udpIO, eventLogger udpEventLogger, idleTimeout time.Duration) *udpSessionManager {
|
||||||
@@ -212,3 +211,9 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) {
|
|||||||
// as some are temporary (e.g. invalid address)
|
// as some are temporary (e.g. invalid address)
|
||||||
_, _ = entry.Feed(msg)
|
_, _ = entry.Feed(msg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *udpSessionManager) Count() int {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
return len(m.m)
|
||||||
|
}
|
||||||
|
@@ -10,6 +10,11 @@ import (
|
|||||||
"go.uber.org/goleak"
|
"go.uber.org/goleak"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errUDPBlocked = errors.New("blocked")
|
||||||
|
errUDPClosed = errors.New("closed")
|
||||||
|
)
|
||||||
|
|
||||||
type echoUDPConnPkt struct {
|
type echoUDPConnPkt struct {
|
||||||
Data []byte
|
Data []byte
|
||||||
Addr string
|
Addr string
|
||||||
@@ -23,7 +28,7 @@ type echoUDPConn struct {
|
|||||||
func (c *echoUDPConn) ReadFrom(b []byte) (int, string, error) {
|
func (c *echoUDPConn) ReadFrom(b []byte) (int, string, error) {
|
||||||
pkt := <-c.PktCh
|
pkt := <-c.PktCh
|
||||||
if pkt.Close {
|
if pkt.Close {
|
||||||
return 0, "", errors.New("closed")
|
return 0, "", errUDPClosed
|
||||||
}
|
}
|
||||||
n := copy(b, pkt.Data)
|
n := copy(b, pkt.Data)
|
||||||
return n, pkt.Addr, nil
|
return n, pkt.Addr, nil
|
||||||
@@ -49,12 +54,14 @@ func (c *echoUDPConn) Close() error {
|
|||||||
type udpMockIO struct {
|
type udpMockIO struct {
|
||||||
ReceiveCh <-chan *protocol.UDPMessage
|
ReceiveCh <-chan *protocol.UDPMessage
|
||||||
SendCh chan<- *protocol.UDPMessage
|
SendCh chan<- *protocol.UDPMessage
|
||||||
|
UDPClose bool // ReadFrom() returns error immediately
|
||||||
|
BlockUDP bool // Block UDP connection creation
|
||||||
}
|
}
|
||||||
|
|
||||||
func (io *udpMockIO) ReceiveMessage() (*protocol.UDPMessage, error) {
|
func (io *udpMockIO) ReceiveMessage() (*protocol.UDPMessage, error) {
|
||||||
m := <-io.ReceiveCh
|
m := <-io.ReceiveCh
|
||||||
if m == nil {
|
if m == nil {
|
||||||
return nil, errors.New("closed")
|
return nil, errUDPClosed
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
@@ -68,9 +75,18 @@ func (io *udpMockIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (io *udpMockIO) UDP(reqAddr string) (UDPConn, error) {
|
func (io *udpMockIO) UDP(reqAddr string) (UDPConn, error) {
|
||||||
return &echoUDPConn{
|
if io.BlockUDP {
|
||||||
|
return nil, errUDPBlocked
|
||||||
|
}
|
||||||
|
conn := &echoUDPConn{
|
||||||
PktCh: make(chan echoUDPConnPkt, 10),
|
PktCh: make(chan echoUDPConnPkt, 10),
|
||||||
}, nil
|
}
|
||||||
|
if io.UDPClose {
|
||||||
|
conn.PktCh <- echoUDPConnPkt{
|
||||||
|
Close: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpMockEventNew struct {
|
type udpMockEventNew struct {
|
||||||
@@ -112,80 +128,131 @@ func TestUDPSessionManager(t *testing.T) {
|
|||||||
sm := newUDPSessionManager(io, eventLogger, 2*time.Second)
|
sm := newUDPSessionManager(io, eventLogger, 2*time.Second)
|
||||||
go sm.Run()
|
go sm.Run()
|
||||||
|
|
||||||
ms := []*protocol.UDPMessage{
|
t.Run("session creation & timeout", func(t *testing.T) {
|
||||||
{
|
ms := []*protocol.UDPMessage{
|
||||||
SessionID: 1234,
|
{
|
||||||
PacketID: 0,
|
SessionID: 1234,
|
||||||
FragID: 0,
|
PacketID: 0,
|
||||||
FragCount: 1,
|
FragID: 0,
|
||||||
Addr: "example.com:5353",
|
FragCount: 1,
|
||||||
Data: []byte("hello"),
|
Addr: "example.com:5353",
|
||||||
},
|
Data: []byte("hello"),
|
||||||
{
|
},
|
||||||
SessionID: 5678,
|
{
|
||||||
PacketID: 0,
|
SessionID: 5678,
|
||||||
FragID: 0,
|
PacketID: 0,
|
||||||
FragCount: 1,
|
FragID: 0,
|
||||||
Addr: "example.com:9999",
|
FragCount: 1,
|
||||||
Data: []byte("goodbye"),
|
Addr: "example.com:9999",
|
||||||
},
|
Data: []byte("goodbye"),
|
||||||
{
|
},
|
||||||
SessionID: 1234,
|
{
|
||||||
PacketID: 0,
|
SessionID: 1234,
|
||||||
FragID: 0,
|
PacketID: 0,
|
||||||
FragCount: 1,
|
FragID: 0,
|
||||||
Addr: "example.com:5353",
|
FragCount: 1,
|
||||||
Data: []byte(" world"),
|
Addr: "example.com:5353",
|
||||||
},
|
Data: []byte(" world"),
|
||||||
{
|
},
|
||||||
SessionID: 5678,
|
{
|
||||||
PacketID: 0,
|
SessionID: 5678,
|
||||||
FragID: 0,
|
PacketID: 0,
|
||||||
FragCount: 1,
|
FragID: 0,
|
||||||
Addr: "example.com:9999",
|
FragCount: 1,
|
||||||
Data: []byte(" girl"),
|
Addr: "example.com:9999",
|
||||||
},
|
Data: []byte(" girl"),
|
||||||
}
|
},
|
||||||
for _, m := range ms {
|
}
|
||||||
msgReceiveCh <- m
|
for _, m := range ms {
|
||||||
}
|
msgReceiveCh <- m
|
||||||
// New event order should be consistent
|
}
|
||||||
newEvent := <-eventNewCh
|
// New event order should be consistent
|
||||||
if newEvent.SessionID != 1234 || newEvent.ReqAddr != "example.com:5353" {
|
newEvent := <-eventNewCh
|
||||||
t.Error("unexpected new event value")
|
if newEvent.SessionID != 1234 || newEvent.ReqAddr != "example.com:5353" {
|
||||||
}
|
t.Error("unexpected new event value")
|
||||||
newEvent = <-eventNewCh
|
}
|
||||||
if newEvent.SessionID != 5678 || newEvent.ReqAddr != "example.com:9999" {
|
newEvent = <-eventNewCh
|
||||||
t.Error("unexpected new event value")
|
if newEvent.SessionID != 5678 || newEvent.ReqAddr != "example.com:9999" {
|
||||||
}
|
t.Error("unexpected new event value")
|
||||||
// Message order is not guaranteed
|
}
|
||||||
msgMap := make(map[string]bool)
|
// Message order is not guaranteed
|
||||||
for i := 0; i < 4; i++ {
|
msgMap := make(map[string]bool)
|
||||||
msg := <-msgSendCh
|
for i := 0; i < 4; i++ {
|
||||||
msgMap[fmt.Sprintf("%d:%s:%s", msg.SessionID, msg.Addr, string(msg.Data))] = true
|
msg := <-msgSendCh
|
||||||
}
|
msgMap[fmt.Sprintf("%d:%s:%s", msg.SessionID, msg.Addr, string(msg.Data))] = true
|
||||||
if !(msgMap["1234:example.com:5353:hello"] &&
|
}
|
||||||
msgMap["5678:example.com:9999:goodbye"] &&
|
if !(msgMap["1234:example.com:5353:hello"] &&
|
||||||
msgMap["1234:example.com:5353: world"] &&
|
msgMap["5678:example.com:9999:goodbye"] &&
|
||||||
msgMap["5678:example.com:9999: girl"]) {
|
msgMap["1234:example.com:5353: world"] &&
|
||||||
t.Error("unexpected message value")
|
msgMap["5678:example.com:9999: girl"]) {
|
||||||
}
|
t.Error("unexpected message value")
|
||||||
// Timeout check
|
}
|
||||||
startTime := time.Now()
|
// Timeout check
|
||||||
closeMap := make(map[uint32]bool)
|
startTime := time.Now()
|
||||||
for i := 0; i < 2; i++ {
|
closeMap := make(map[uint32]bool)
|
||||||
closeEvent := <-eventCloseCh
|
for i := 0; i < 2; i++ {
|
||||||
closeMap[closeEvent.SessionID] = true
|
closeEvent := <-eventCloseCh
|
||||||
}
|
closeMap[closeEvent.SessionID] = true
|
||||||
if !(closeMap[1234] && closeMap[5678]) {
|
}
|
||||||
t.Error("unexpected close event value")
|
if !(closeMap[1234] && closeMap[5678]) {
|
||||||
}
|
t.Error("unexpected close event value")
|
||||||
if time.Since(startTime) < 2*time.Second || time.Since(startTime) > 4*time.Second {
|
}
|
||||||
t.Error("unexpected timeout duration")
|
if time.Since(startTime) < 2*time.Second || time.Since(startTime) > 4*time.Second {
|
||||||
}
|
t.Error("unexpected timeout duration")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
// Goroutine leak check
|
t.Run("UDP connection close", func(t *testing.T) {
|
||||||
|
// Close UDP connection immediately after creation
|
||||||
|
io.UDPClose = true
|
||||||
|
|
||||||
|
msgReceiveCh <- &protocol.UDPMessage{
|
||||||
|
SessionID: 8888,
|
||||||
|
PacketID: 0,
|
||||||
|
FragID: 0,
|
||||||
|
FragCount: 1,
|
||||||
|
Addr: "mygod.org:1514",
|
||||||
|
Data: []byte("goodnight"),
|
||||||
|
}
|
||||||
|
// Should have both new and close events immediately
|
||||||
|
newEvent := <-eventNewCh
|
||||||
|
if newEvent.SessionID != 8888 || newEvent.ReqAddr != "mygod.org:1514" {
|
||||||
|
t.Error("unexpected new event value")
|
||||||
|
}
|
||||||
|
closeEvent := <-eventCloseCh
|
||||||
|
if closeEvent.SessionID != 8888 || closeEvent.Err != errUDPClosed {
|
||||||
|
t.Error("unexpected close event value")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("UDP IO failure", func(t *testing.T) {
|
||||||
|
// Block UDP connection creation
|
||||||
|
io.BlockUDP = true
|
||||||
|
|
||||||
|
msgReceiveCh <- &protocol.UDPMessage{
|
||||||
|
SessionID: 9999,
|
||||||
|
PacketID: 0,
|
||||||
|
FragID: 0,
|
||||||
|
FragCount: 1,
|
||||||
|
Addr: "xxx.net:12450",
|
||||||
|
Data: []byte("nope"),
|
||||||
|
}
|
||||||
|
// Should have both new and close events immediately
|
||||||
|
newEvent := <-eventNewCh
|
||||||
|
if newEvent.SessionID != 9999 || newEvent.ReqAddr != "xxx.net:12450" {
|
||||||
|
t.Error("unexpected new event value")
|
||||||
|
}
|
||||||
|
closeEvent := <-eventCloseCh
|
||||||
|
if closeEvent.SessionID != 9999 || closeEvent.Err != errUDPBlocked {
|
||||||
|
t.Error("unexpected close event value")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Leak checks
|
||||||
msgReceiveCh <- nil
|
msgReceiveCh <- nil
|
||||||
time.Sleep(1 * time.Second) // Wait for internal routines to exit
|
time.Sleep(1 * time.Second) // Give some time for the goroutines to exit
|
||||||
|
if sm.Count() != 0 {
|
||||||
|
t.Error("session count should be 0")
|
||||||
|
}
|
||||||
goleak.VerifyNone(t)
|
goleak.VerifyNone(t)
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user