diff --git a/core/client/client.go b/core/client/client.go index 161265f..3fe154d 100644 --- a/core/client/client.go +++ b/core/client/client.go @@ -400,7 +400,7 @@ func (c *udpConn) Receive() ([]byte, string, error) { // Send is not thread-safe as it uses a shared send buffer for now. func (c *udpConn) Send(data []byte, addr string) error { // Try no frag first - msg := protocol.UDPMessage{ + msg := &protocol.UDPMessage{ SessionID: c.SessionID, PacketID: 0, FragID: 0, diff --git a/core/go.mod b/core/go.mod index edb0ce0..e8c6c3b 100644 --- a/core/go.mod +++ b/core/go.mod @@ -4,6 +4,7 @@ go 1.20 require ( github.com/quic-go/quic-go v0.0.0-00010101000000-000000000000 + go.uber.org/goleak v1.2.1 golang.org/x/time v0.3.0 ) diff --git a/core/go.sum b/core/go.sum index b1aa0f1..14e5737 100644 --- a/core/go.sum +++ b/core/go.sum @@ -38,6 +38,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A= +go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.4.0 h1:UVQgzMY87xqpKNgb+kDsll2Igd33HszWHFLmpaRMq/8= diff --git a/core/internal/frag/frag.go b/core/internal/frag/frag.go index 3493519..75c8dac 100644 --- a/core/internal/frag/frag.go +++ b/core/internal/frag/frag.go @@ -4,9 +4,9 @@ import ( "github.com/apernet/hysteria/core/internal/protocol" ) -func FragUDPMessage(m protocol.UDPMessage, maxSize int) []protocol.UDPMessage { +func FragUDPMessage(m *protocol.UDPMessage, maxSize int) []protocol.UDPMessage { if m.Size() <= maxSize { - return []protocol.UDPMessage{m} + return []protocol.UDPMessage{*m} } fullPayload := m.Data maxPayloadSize := maxSize - m.HeaderSize() @@ -19,7 +19,7 @@ func FragUDPMessage(m protocol.UDPMessage, maxSize int) []protocol.UDPMessage { if payloadSize > maxPayloadSize { payloadSize = maxPayloadSize } - frag := m + frag := *m frag.FragID = fragID frag.FragCount = fragCount frag.Data = fullPayload[off : off+payloadSize] diff --git a/core/internal/frag/frag_test.go b/core/internal/frag/frag_test.go index 48eb004..aeb566d 100644 --- a/core/internal/frag/frag_test.go +++ b/core/internal/frag/frag_test.go @@ -124,7 +124,7 @@ func TestFragUDPMessage(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := FragUDPMessage(tt.args.m, tt.args.maxSize); !reflect.DeepEqual(got, tt.want) { + if got := FragUDPMessage(&tt.args.m, tt.args.maxSize); !reflect.DeepEqual(got, tt.want) { t.Errorf("FragUDPMessage() = %v, want %v", got, tt.want) } }) diff --git a/core/internal/integration_tests/smoke_test.go b/core/internal/integration_tests/smoke_test.go index b5cefb1..eac2b53 100644 --- a/core/internal/integration_tests/smoke_test.go +++ b/core/internal/integration_tests/smoke_test.go @@ -36,7 +36,7 @@ func TestClientNoServer(t *testing.T) { // Try UDP _, err = c.ListenUDP() if !errors.As(err, &cErr) { - t.Fatal("expected connect error from ListenUDP") + t.Fatal("expected connect error from DialUDP") } } @@ -86,7 +86,7 @@ func TestClientServerBadAuth(t *testing.T) { // Try UDP _, err = c.ListenUDP() if !errors.As(err, &aErr) { - t.Fatal("expected auth error from ListenUDP") + t.Fatal("expected auth error from DialUDP") } } diff --git a/core/internal/protocol/http.go b/core/internal/protocol/http.go index e327258..f4ccf82 100644 --- a/core/internal/protocol/http.go +++ b/core/internal/protocol/http.go @@ -9,31 +9,34 @@ const ( URLHost = "hysteria" URLPath = "/auth" - HeaderAuth = "Hysteria-Auth" - HeaderCCRX = "Hysteria-CC-RX" - HeaderPadding = "Hysteria-Padding" + RequestHeaderAuth = "Hysteria-Auth" + ResponseHeaderUDPEnabled = "Hysteria-UDP" + CommonHeaderCCRX = "Hysteria-CC-RX" + CommonHeaderPadding = "Hysteria-Padding" StatusAuthOK = 233 ) func AuthRequestDataFromHeader(h http.Header) (auth string, rx uint64) { - auth = h.Get(HeaderAuth) - rx, _ = strconv.ParseUint(h.Get(HeaderCCRX), 10, 64) + auth = h.Get(RequestHeaderAuth) + rx, _ = strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64) return } func AuthRequestDataToHeader(h http.Header, auth string, rx uint64) { - h.Set(HeaderAuth, auth) - h.Set(HeaderCCRX, strconv.FormatUint(rx, 10)) - h.Set(HeaderPadding, authRequestPadding.String()) + h.Set(RequestHeaderAuth, auth) + h.Set(CommonHeaderCCRX, strconv.FormatUint(rx, 10)) + h.Set(CommonHeaderPadding, authRequestPadding.String()) } -func AuthResponseDataFromHeader(h http.Header) (rx uint64) { - rx, _ = strconv.ParseUint(h.Get(HeaderCCRX), 10, 64) +func AuthResponseDataFromHeader(h http.Header) (udp bool, rx uint64) { + udp, _ = strconv.ParseBool(h.Get(ResponseHeaderUDPEnabled)) + rx, _ = strconv.ParseUint(h.Get(CommonHeaderCCRX), 10, 64) return } -func AuthResponseDataToHeader(h http.Header, rx uint64) { - h.Set(HeaderCCRX, strconv.FormatUint(rx, 10)) - h.Set(HeaderPadding, authResponsePadding.String()) +func AuthResponseDataToHeader(h http.Header, udp bool, rx uint64) { + h.Set(ResponseHeaderUDPEnabled, strconv.FormatBool(udp)) + h.Set(CommonHeaderCCRX, strconv.FormatUint(rx, 10)) + h.Set(CommonHeaderPadding, authResponsePadding.String()) } diff --git a/core/internal/protocol/padding.go b/core/internal/protocol/padding.go index d8d248f..9895cdc 100644 --- a/core/internal/protocol/padding.go +++ b/core/internal/protocol/padding.go @@ -28,6 +28,4 @@ var ( authResponsePadding = padding{Min: 256, Max: 2048} tcpRequestPadding = padding{Min: 64, Max: 512} tcpResponsePadding = padding{Min: 128, Max: 1024} - udpRequestPadding = padding{Min: 64, Max: 512} - udpResponsePadding = padding{Min: 128, Max: 1024} ) diff --git a/core/internal/protocol/proxy.go b/core/internal/protocol/proxy.go index 8e2ce7c..dc36229 100644 --- a/core/internal/protocol/proxy.go +++ b/core/internal/protocol/proxy.go @@ -13,7 +13,6 @@ import ( const ( FrameTypeTCPRequest = 0x401 - FrameTypeUDPRequest = 0x402 // Max length values are for preventing DoS attacks @@ -148,113 +147,6 @@ func WriteTCPResponse(w io.Writer, ok bool, msg string) error { return err } -// UDPRequest format: -// 0x402 (QUIC varint) -// Padding length (QUIC varint) -// Padding (bytes) - -func ReadUDPRequest(r io.Reader) error { - bReader := quicvarint.NewReader(r) - paddingLen, err := quicvarint.Read(bReader) - if err != nil { - return err - } - if paddingLen > MaxPaddingLength { - return errors.ProtocolError{Message: "invalid padding length"} - } - if paddingLen > 0 { - _, err = io.CopyN(io.Discard, r, int64(paddingLen)) - if err != nil { - return err - } - } - return nil -} - -func WriteUDPRequest(w io.Writer) error { - padding := udpRequestPadding.String() - paddingLen := len(padding) - sz := int(quicvarint.Len(FrameTypeUDPRequest)) + - int(quicvarint.Len(uint64(paddingLen))) + paddingLen - buf := make([]byte, sz) - i := varintPut(buf, FrameTypeUDPRequest) - i += varintPut(buf[i:], uint64(paddingLen)) - copy(buf[i:], padding) - _, err := w.Write(buf) - return err -} - -// UDPResponse format: -// Status (byte, 0=ok, 1=error) -// Session ID (uint32 BE) -// Message length (QUIC varint) -// Message (bytes) -// Padding length (QUIC varint) -// Padding (bytes) - -func ReadUDPResponse(r io.Reader) (bool, uint32, string, error) { - var status [1]byte - if _, err := io.ReadFull(r, status[:]); err != nil { - return false, 0, "", err - } - var sessionID uint32 - if err := binary.Read(r, binary.BigEndian, &sessionID); err != nil { - return false, 0, "", err - } - bReader := quicvarint.NewReader(r) - msgLen, err := quicvarint.Read(bReader) - if err != nil { - return false, 0, "", err - } - if msgLen > MaxMessageLength { - return false, 0, "", errors.ProtocolError{Message: "invalid message length"} - } - var msgBuf []byte - // No message is fine - if msgLen > 0 { - msgBuf = make([]byte, msgLen) - _, err = io.ReadFull(r, msgBuf) - if err != nil { - return false, 0, "", err - } - } - paddingLen, err := quicvarint.Read(bReader) - if err != nil { - return false, 0, "", err - } - if paddingLen > MaxPaddingLength { - return false, 0, "", errors.ProtocolError{Message: "invalid padding length"} - } - if paddingLen > 0 { - _, err = io.CopyN(io.Discard, r, int64(paddingLen)) - if err != nil { - return false, 0, "", err - } - } - return status[0] == 0, sessionID, string(msgBuf), nil -} - -func WriteUDPResponse(w io.Writer, ok bool, sessionID uint32, msg string) error { - padding := udpResponsePadding.String() - paddingLen := len(padding) - msgLen := len(msg) - sz := 1 + 4 + int(quicvarint.Len(uint64(msgLen))) + msgLen + - int(quicvarint.Len(uint64(paddingLen))) + paddingLen - buf := make([]byte, sz) - if ok { - buf[0] = 0 - } else { - buf[0] = 1 - } - binary.BigEndian.PutUint32(buf[1:], sessionID) - i := varintPut(buf[5:], uint64(msgLen)) - i += copy(buf[5+i:], msg) - i += varintPut(buf[5+i:], uint64(paddingLen)) - copy(buf[5+i:], padding) - _, err := w.Write(buf) - return err -} - // UDPMessage format: // Session ID (uint32 BE) // Packet ID (uint16 BE) diff --git a/core/internal/protocol/proxy_test.go b/core/internal/protocol/proxy_test.go index 5c5d674..111c615 100644 --- a/core/internal/protocol/proxy_test.go +++ b/core/internal/protocol/proxy_test.go @@ -315,179 +315,3 @@ func TestWriteTCPResponse(t *testing.T) { }) } } - -func TestReadUDPRequest(t *testing.T) { - tests := []struct { - name string - data []byte - wantErr bool - }{ - { - name: "normal no padding", - data: []byte("\x00\x00"), - wantErr: false, - }, - { - name: "normal with padding", - data: []byte("\x02gg"), - wantErr: false, - }, - { - name: "incomplete 1", - data: []byte("\x0bhoho"), - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := bytes.NewReader(tt.data) - if err := ReadUDPRequest(r); (err != nil) != tt.wantErr { - t.Errorf("ReadUDPRequest() error = %v, wantErr %v", err, tt.wantErr) - } - }) - } -} - -func TestWriteUDPRequest(t *testing.T) { - tests := []struct { - name string - wantW string // Just a prefix, we don't care about the padding - wantErr bool - }{ - { - name: "normal", - wantW: "\x44\x02", - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := &bytes.Buffer{} - err := WriteUDPRequest(w) - if (err != nil) != tt.wantErr { - t.Errorf("WriteUDPRequest() error = %v, wantErr %v", err, tt.wantErr) - return - } - if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) { - t.Errorf("WriteUDPRequest() gotW = %v, want %v", gotW, tt.wantW) - } - }) - } -} - -func TestReadUDPResponse(t *testing.T) { - tests := []struct { - name string - data []byte - want bool - want1 uint32 - want2 string - wantErr bool - }{ - { - name: "normal ok no padding", - data: []byte("\x00\x00\x00\x00\x33\x0bhello world\x00"), - want: true, - want1: 51, - want2: "hello world", - wantErr: false, - }, - { - name: "normal error with padding", - data: []byte("\x01\x00\x00\x33\x33\x06stop!!\x05xxxxx"), - want: false, - want1: 13107, - want2: "stop!!", - wantErr: false, - }, - { - name: "normal ok no message with padding", - data: []byte("\x00\x00\x00\x00\x33\x00\x05xxxxx"), - want: true, - want1: 51, - want2: "", - wantErr: false, - }, - { - name: "incomplete 1", - data: []byte("\x00\x00\x06"), - want: false, - want1: 0, - want2: "", - wantErr: true, - }, - { - name: "incomplete 2", - data: []byte("\x01\x00\x01\x02\x03\x05jesus\x05x"), - want: false, - want1: 0, - want2: "", - wantErr: true, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - r := bytes.NewReader(tt.data) - got, got1, got2, err := ReadUDPResponse(r) - if (err != nil) != tt.wantErr { - t.Errorf("ReadUDPResponse() error = %v, wantErr %v", err, tt.wantErr) - return - } - if got != tt.want { - t.Errorf("ReadUDPResponse() got = %v, want %v", got, tt.want) - } - if got1 != tt.want1 { - t.Errorf("ReadUDPResponse() got1 = %v, want %v", got1, tt.want1) - } - if got2 != tt.want2 { - t.Errorf("ReadUDPResponse() got2 = %v, want %v", got2, tt.want2) - } - }) - } -} - -func TestWriteUDPResponse(t *testing.T) { - type args struct { - ok bool - sessionID uint32 - msg string - } - tests := []struct { - name string - args args - wantW string // Just a prefix, we don't care about the padding - wantErr bool - }{ - { - name: "normal ok", - args: args{ok: true, sessionID: 6, msg: "hello world"}, - wantW: "\x00\x00\x00\x00\x06\x0bhello world", - wantErr: false, - }, - { - name: "normal error", - args: args{ok: false, sessionID: 7, msg: "stop!!"}, - wantW: "\x01\x00\x00\x00\x07\x06stop!!", - wantErr: false, - }, - { - name: "empty", - args: args{ok: true, sessionID: 0, msg: ""}, - wantW: "\x00\x00\x00\x00\x00\x00", - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - w := &bytes.Buffer{} - err := WriteUDPResponse(w, tt.args.ok, tt.args.sessionID, tt.args.msg) - if (err != nil) != tt.wantErr { - t.Errorf("WriteUDPResponse() error = %v, wantErr %v", err, tt.wantErr) - return - } - if gotW := w.String(); !(strings.HasPrefix(gotW, tt.wantW) && len(gotW) > len(tt.wantW)) { - t.Errorf("WriteUDPResponse() gotW = %v, want %v", gotW, tt.wantW) - } - }) - } -} diff --git a/core/internal/utils/atomic.go b/core/internal/utils/atomic.go new file mode 100644 index 0000000..e3c3d97 --- /dev/null +++ b/core/internal/utils/atomic.go @@ -0,0 +1,24 @@ +package utils + +import ( + "sync/atomic" + "time" +) + +type AtomicTime struct { + v atomic.Value +} + +func NewAtomicTime(t time.Time) *AtomicTime { + a := &AtomicTime{} + a.Set(t) + return a +} + +func (t *AtomicTime) Set(new time.Time) { + t.v.Store(new) +} + +func (t *AtomicTime) Get() time.Time { + return t.v.Load().(time.Time) +} diff --git a/core/server/config.go b/core/server/config.go index 5152ef7..d7fb4bd 100644 --- a/core/server/config.go +++ b/core/server/config.go @@ -103,9 +103,13 @@ type QUICConfig struct { } // Outbound provides the implementation of how the server should connect to remote servers. +// Even though it's called DialUDP, outbound implementations do not necessarily have to +// return a "connected" UDP socket that can only send and receive from reqAddr. It's the +// address of the first packet to be sent. +// It's perfectly fine to have a "full-cone" implementation for UDP. type Outbound interface { DialTCP(reqAddr string) (net.Conn, error) - ListenUDP() (UDPConn, error) + DialUDP(reqAddr string) (UDPConn, error) } // UDPConn is like net.PacketConn, but uses string for addresses. @@ -125,7 +129,7 @@ func (o *defaultOutbound) DialTCP(reqAddr string) (net.Conn, error) { return defaultOutboundDialer.Dial("tcp", reqAddr) } -func (o *defaultOutbound) ListenUDP() (UDPConn, error) { +func (o *defaultOutbound) DialUDP(reqAddr string) (UDPConn, error) { conn, err := net.ListenUDP("udp", nil) if err != nil { return nil, err @@ -171,7 +175,7 @@ type EventLogger interface { Disconnect(addr net.Addr, id string, err error) TCPRequest(addr net.Addr, id, reqAddr string) TCPError(addr net.Addr, id, reqAddr string, err error) - UDPRequest(addr net.Addr, id string, sessionID uint32) + UDPRequest(addr net.Addr, id string, sessionID uint32, reqAddr string) UDPError(addr net.Addr, id string, sessionID uint32, err error) } diff --git a/core/server/server.go b/core/server/server.go index 2b0b1ca..b886e31 100644 --- a/core/server/server.go +++ b/core/server/server.go @@ -3,14 +3,11 @@ package server import ( "context" "crypto/tls" - "errors" - "io" - "math/rand" "net/http" "sync" + "time" "github.com/apernet/hysteria/core/internal/congestion" - "github.com/apernet/hysteria/core/internal/frag" "github.com/apernet/hysteria/core/internal/protocol" "github.com/apernet/hysteria/core/internal/utils" @@ -21,6 +18,8 @@ import ( const ( closeErrCodeOK = 0x100 // HTTP3 ErrCodeNoError closeErrCodeTrafficLimitReached = 0x107 // HTTP3 ErrCodeExcessiveLoad + + udpSessionIdleTimeout = 60 * time.Second ) type Server interface { @@ -101,90 +100,21 @@ type h3sHandler struct { authID string udpOnce sync.Once - udpSM udpSessionManager + udpSM *udpSessionManager // Only set after authentication } func newH3sHandler(config *Config, conn quic.Connection) *h3sHandler { return &h3sHandler{ config: config, conn: conn, - udpSM: udpSessionManager{ - listenFunc: config.Outbound.ListenUDP, - m: make(map[uint32]*udpSessionEntry), - }, } } -type udpSessionEntry struct { - Conn UDPConn - D *frag.Defragger - Closed bool -} - -type udpSessionManager struct { - listenFunc func() (UDPConn, error) - mutex sync.RWMutex - m map[uint32]*udpSessionEntry - nextID uint32 -} - -// Add returns the session ID, the UDP connection and a function to close the UDP connection & delete the session. -func (m *udpSessionManager) Add() (uint32, UDPConn, func(), error) { - conn, err := m.listenFunc() - if err != nil { - return 0, nil, nil, err - } - - m.mutex.Lock() - defer m.mutex.Unlock() - id := m.nextID - m.nextID++ - entry := &udpSessionEntry{ - Conn: conn, - D: &frag.Defragger{}, - Closed: false, - } - m.m[id] = entry - - return id, conn, func() { - m.mutex.Lock() - defer m.mutex.Unlock() - if entry.Closed { - // Already closed - return - } - entry.Closed = true - _ = conn.Close() - delete(m.m, id) - }, nil -} - -// Feed feeds a UDP message to the session manager. -// If the message itself is a complete message, or it's the last fragment of a message, -// it will be sent to the UDP connection. -// The function will then return the number of bytes sent and any error occurred. -func (m *udpSessionManager) Feed(msg *protocol.UDPMessage) (int, error) { - m.mutex.RLock() - defer m.mutex.RUnlock() - - entry, ok := m.m[msg.SessionID] - if !ok { - // No such session, drop the message - return 0, nil - } - dfMsg := entry.D.Feed(msg) - if dfMsg == nil { - // Not a complete message yet - return 0, nil - } - return entry.Conn.WriteTo(dfMsg.Data, dfMsg.Addr) -} - func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodPost && r.Host == protocol.URLHost && r.URL.Path == protocol.URLPath { if h.authenticated { // Already authenticated - protocol.AuthResponseDataToHeader(w.Header(), h.config.BandwidthConfig.MaxRx) + protocol.AuthResponseDataToHeader(w.Header(), !h.config.DisableUDP, h.config.BandwidthConfig.MaxRx) w.WriteHeader(protocol.StatusAuthOK) return } @@ -204,18 +134,23 @@ func (h *h3sHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { h.conn.SetCongestionControl(congestion.NewBrutalSender(actualTx)) } // Auth OK, send response - protocol.AuthResponseDataToHeader(w.Header(), h.config.BandwidthConfig.MaxRx) + protocol.AuthResponseDataToHeader(w.Header(), !h.config.DisableUDP, h.config.BandwidthConfig.MaxRx) w.WriteHeader(protocol.StatusAuthOK) // Call event logger if h.config.EventLogger != nil { h.config.EventLogger.Connect(h.conn.RemoteAddr(), id, actualTx) } - // Start UDP loop if UDP is not disabled + // Initialize UDP session manager (if UDP is enabled) // We use sync.Once to make sure that only one goroutine is started, // as ServeHTTP may be called by multiple goroutines simultaneously if !h.config.DisableUDP { h.udpOnce.Do(func() { - go h.udpLoop() + sm := newUDPSessionManager( + &udpsmIO{h.conn, id, h.config.TrafficLogger, h.config.Outbound}, + &udpsmEventLogger{h.conn, id, h.config.EventLogger}, + udpSessionIdleTimeout) + h.udpSM = sm + go sm.Run() }) } } else { @@ -240,9 +175,6 @@ func (h *h3sHandler) ProxyStreamHijacker(ft http3.FrameType, conn quic.Connectio case protocol.FrameTypeTCPRequest: go h.handleTCPRequest(stream) return true, nil - case protocol.FrameTypeUDPRequest: - go h.handleUDPRequest(stream) - return true, nil default: return false, nil } @@ -290,125 +222,6 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) { } } -func (h *h3sHandler) handleUDPRequest(stream quic.Stream) { - if h.config.DisableUDP { - // UDP is disabled, send error message and close the stream - _ = protocol.WriteUDPResponse(stream, false, 0, "UDP is disabled on this server") - _ = stream.Close() - return - } - // Read request - err := protocol.ReadUDPRequest(stream) - if err != nil { - _ = stream.Close() - return - } - // Add to session manager - sessionID, conn, connCloseFunc, err := h.udpSM.Add() - if err != nil { - _ = protocol.WriteUDPResponse(stream, false, 0, err.Error()) - _ = stream.Close() - return - } - // Send response - _ = protocol.WriteUDPResponse(stream, true, sessionID, "") - // Call event logger - if h.config.EventLogger != nil { - h.config.EventLogger.UDPRequest(h.conn.RemoteAddr(), h.authID, sessionID) - } - - // client <- remote direction - go func() { - udpBuf := make([]byte, protocol.MaxUDPSize) - msgBuf := make([]byte, protocol.MaxUDPSize) - for { - udpN, rAddr, err := conn.ReadFrom(udpBuf) - if err != nil { - connCloseFunc() - _ = stream.Close() - return - } - if h.config.TrafficLogger != nil { - ok := h.config.TrafficLogger.Log(h.authID, 0, uint64(udpN)) - if !ok { - // TrafficLogger requested to disconnect the client - _ = h.conn.CloseWithError(closeErrCodeTrafficLimitReached, "") - return - } - } - // Try no frag first - msg := protocol.UDPMessage{ - SessionID: sessionID, - PacketID: 0, - FragID: 0, - FragCount: 1, - Addr: rAddr, - Data: udpBuf[:udpN], - } - msgN := msg.Serialize(msgBuf) - if msgN < 0 { - // Message even larger than MaxUDPSize, drop it - continue - } - sendErr := h.conn.SendMessage(msgBuf[:msgN]) - var errTooLarge quic.ErrMessageTooLarge - if errors.As(sendErr, &errTooLarge) { - // Message too large, try fragmentation - msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1 - fMsgs := frag.FragUDPMessage(msg, int(errTooLarge)) - for _, fMsg := range fMsgs { - msgN = fMsg.Serialize(msgBuf) - _ = h.conn.SendMessage(msgBuf[:msgN]) - } - } - } - }() - - // Hold (drain) the stream until the client closes it. - // Closing the stream is the signal to stop the UDP session. - _, err = io.Copy(io.Discard, stream) - // Call event logger - if h.config.EventLogger != nil { - h.config.EventLogger.UDPError(h.conn.RemoteAddr(), h.authID, sessionID, err) - } - - // Cleanup - connCloseFunc() - _ = stream.Close() -} - -func (h *h3sHandler) udpLoop() { - for { - msg, err := h.conn.ReceiveMessage() - if err != nil { - return - } - ok := h.handleUDPMessage(msg) - if !ok { - // TrafficLogger requested to disconnect the client - _ = h.conn.CloseWithError(closeErrCodeTrafficLimitReached, "") - return - } - } -} - -// client -> remote direction -// Returns a bool indicating whether the receiving loop should continue -func (h *h3sHandler) handleUDPMessage(msg []byte) (ok bool) { - udpMsg, err := protocol.ParseUDPMessage(msg) - if err != nil { - return true - } - if h.config.TrafficLogger != nil { - ok := h.config.TrafficLogger.Log(h.authID, uint64(len(udpMsg.Data)), 0) - if !ok { - return false - } - } - _, _ = h.udpSM.Feed(udpMsg) - return true -} - func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) { if h.config.MasqHandler != nil { h.config.MasqHandler.ServeHTTP(w, r) @@ -417,3 +230,74 @@ func (h *h3sHandler) masqHandler(w http.ResponseWriter, r *http.Request) { http.NotFound(w, r) } } + +// udpsmIO is the IO implementation for udpSessionManager with TrafficLogger support +type udpsmIO struct { + Conn quic.Connection + AuthID string + TrafficLogger TrafficLogger + Outbound Outbound +} + +func (io *udpsmIO) ReceiveMessage() (*protocol.UDPMessage, error) { + for { + msg, err := io.Conn.ReceiveMessage() + if err != nil { + // Connection error, this will stop the session manager + return nil, err + } + udpMsg, err := protocol.ParseUDPMessage(msg) + if err != nil { + // Invalid message, this is fine - just wait for the next + continue + } + if io.TrafficLogger != nil { + ok := io.TrafficLogger.Log(io.AuthID, uint64(len(udpMsg.Data)), 0) + if !ok { + // TrafficLogger requested to disconnect the client + _ = io.Conn.CloseWithError(closeErrCodeTrafficLimitReached, "") + return nil, errDisconnect + } + } + return udpMsg, nil + } +} + +func (io *udpsmIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error { + if io.TrafficLogger != nil { + ok := io.TrafficLogger.Log(io.AuthID, 0, uint64(len(msg.Data))) + if !ok { + // TrafficLogger requested to disconnect the client + _ = io.Conn.CloseWithError(closeErrCodeTrafficLimitReached, "") + return errDisconnect + } + } + msgN := msg.Serialize(buf) + if msgN < 0 { + // Message larger than buffer, silent drop + return nil + } + return io.Conn.SendMessage(buf[:msgN]) +} + +func (io *udpsmIO) DialUDP(reqAddr string) (UDPConn, error) { + return io.Outbound.DialUDP(reqAddr) +} + +type udpsmEventLogger struct { + Conn quic.Connection + AuthID string + EventLogger EventLogger +} + +func (l *udpsmEventLogger) New(sessionID uint32, reqAddr string) { + if l.EventLogger != nil { + l.EventLogger.UDPRequest(l.Conn.RemoteAddr(), l.AuthID, sessionID, reqAddr) + } +} + +func (l *udpsmEventLogger) Closed(sessionID uint32, err error) { + if l.EventLogger != nil { + l.EventLogger.UDPError(l.Conn.RemoteAddr(), l.AuthID, sessionID, err) + } +} diff --git a/core/server/udp.go b/core/server/udp.go new file mode 100644 index 0000000..d60feed --- /dev/null +++ b/core/server/udp.go @@ -0,0 +1,218 @@ +package server + +import ( + "errors" + "math/rand" + "sync" + "time" + + "github.com/quic-go/quic-go" + + "github.com/apernet/hysteria/core/internal/frag" + "github.com/apernet/hysteria/core/internal/protocol" + "github.com/apernet/hysteria/core/internal/utils" +) + +const ( + idleCleanupInterval = 1 * time.Second +) + +type udpSessionManagerIO interface { + ReceiveMessage() (*protocol.UDPMessage, error) + SendMessage([]byte, *protocol.UDPMessage) error + DialUDP(reqAddr string) (UDPConn, error) +} + +type udpSessionManagerEventLogger interface { + New(sessionID uint32, reqAddr string) + Closed(sessionID uint32, err error) +} + +type udpSessionEntry struct { + ID uint32 + Conn UDPConn + D *frag.Defragger + Last *utils.AtomicTime + Closed bool +} + +// Feed feeds a UDP message to the session. +// If the message itself is a complete message, or it completes a fragmented message, +// the message is written to the session's UDP connection, and the number of bytes +// written is returned. +// Otherwise, 0 and nil are returned. +func (e *udpSessionEntry) Feed(msg *protocol.UDPMessage) (int, error) { + e.Last.Set(time.Now()) + dfMsg := e.D.Feed(msg) + if dfMsg == nil { + return 0, nil + } + return e.Conn.WriteTo(dfMsg.Data, dfMsg.Addr) +} + +// ReceiveLoop receives incoming UDP packets, packs them into UDP messages, +// and sends using the provided io. +// Exit and returns error when either the underlying UDP connection returns +// error (e.g. closed), or the provided io returns error when sending. +func (e *udpSessionEntry) ReceiveLoop(io udpSessionManagerIO) error { + udpBuf := make([]byte, protocol.MaxUDPSize) + msgBuf := make([]byte, protocol.MaxUDPSize) + for { + udpN, rAddr, err := e.Conn.ReadFrom(udpBuf) + if err != nil { + return err + } + e.Last.Set(time.Now()) + + msg := &protocol.UDPMessage{ + SessionID: e.ID, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: rAddr, + Data: udpBuf[:udpN], + } + err = sendMessageAutoFrag(io, msgBuf, msg) + if err != nil { + return err + } + } +} + +// sendMessageAutoFrag tries to send a UDP message as a whole first, +// but if it fails due to quic.ErrMessageTooLarge, it tries again by +// fragmenting the message. +func sendMessageAutoFrag(io udpSessionManagerIO, buf []byte, msg *protocol.UDPMessage) error { + err := io.SendMessage(buf, msg) + var errTooLarge quic.ErrMessageTooLarge + if errors.As(err, &errTooLarge) { + // Message too large, try fragmentation + msg.PacketID = uint16(rand.Intn(0xFFFF)) + 1 + fMsgs := frag.FragUDPMessage(msg, int(errTooLarge)) + for _, fMsg := range fMsgs { + err := io.SendMessage(buf, &fMsg) + if err != nil { + return err + } + } + return nil + } else { + return err + } +} + +// udpSessionManager manages the lifecycle of UDP sessions. +// Each UDP session is identified by a SessionID, and corresponds to a UDP connection. +// A UDP session is created when a UDP message with a new SessionID is received. +// Similar to standard NAT, a UDP session is destroyed when no UDP message is received +// for a certain period of time (specified by idleTimeout). +type udpSessionManager struct { + io udpSessionManagerIO + eventLogger udpSessionManagerEventLogger + idleTimeout time.Duration + + mutex sync.Mutex + m map[uint32]*udpSessionEntry + nextID uint32 +} + +func newUDPSessionManager( + io udpSessionManagerIO, + eventLogger udpSessionManagerEventLogger, + idleTimeout time.Duration, +) *udpSessionManager { + return &udpSessionManager{ + io: io, + eventLogger: eventLogger, + idleTimeout: idleTimeout, + m: make(map[uint32]*udpSessionEntry), + } +} + +// Run runs the session manager main loop. +// Exit and returns error when the underlying io returns error (e.g. closed). +func (m *udpSessionManager) Run() error { + stopCh := make(chan struct{}) + go m.idleCleanupLoop(stopCh) + defer close(stopCh) + defer m.cleanup(false) + + for { + msg, err := m.io.ReceiveMessage() + if err != nil { + return err + } + m.feed(msg) + } +} + +func (m *udpSessionManager) idleCleanupLoop(stopCh <-chan struct{}) { + ticker := time.NewTicker(idleCleanupInterval) + defer ticker.Stop() + for { + select { + case <-ticker.C: + m.cleanup(true) + case <-stopCh: + return + } + } +} + +func (m *udpSessionManager) cleanup(idleOnly bool) { + m.mutex.Lock() + defer m.mutex.Unlock() + + now := time.Now() + for sessionID, entry := range m.m { + if !idleOnly || now.Sub(entry.Last.Get()) > m.idleTimeout { + entry.Closed = true + _ = entry.Conn.Close() + m.eventLogger.Closed(sessionID, nil) + delete(m.m, sessionID) + } + } +} + +func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { + m.mutex.Lock() + + entry := m.m[msg.SessionID] + if entry == nil { + // New session + m.eventLogger.New(msg.SessionID, msg.Addr) + conn, err := m.io.DialUDP(msg.Addr) + if err != nil { + m.mutex.Unlock() + m.eventLogger.Closed(msg.SessionID, err) + return + } + entry = &udpSessionEntry{ + ID: msg.SessionID, + Conn: conn, + D: &frag.Defragger{}, + Last: utils.NewAtomicTime(time.Now()), + } + // Start the receive loop for this session + go func() { + err := entry.ReceiveLoop(m.io) + // Receive loop stopped, remove the session + m.mutex.Lock() + if !entry.Closed { + entry.Closed = true + _ = entry.Conn.Close() + m.eventLogger.Closed(entry.ID, err) + delete(m.m, entry.ID) + } + m.mutex.Unlock() + }() + m.m[msg.SessionID] = entry + } + + m.mutex.Unlock() + + // Feed the message to the session + // Feed (send) errors are ignored for now, + // as some are temporary (e.g. invalid address) + _, _ = entry.Feed(msg) +} diff --git a/core/server/udp_test.go b/core/server/udp_test.go new file mode 100644 index 0000000..81844b1 --- /dev/null +++ b/core/server/udp_test.go @@ -0,0 +1,191 @@ +package server + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/apernet/hysteria/core/internal/protocol" + "go.uber.org/goleak" +) + +type echoUDPConnPkt struct { + Data []byte + Addr string + Close bool +} + +type echoUDPConn struct { + PktCh chan echoUDPConnPkt +} + +func (c *echoUDPConn) ReadFrom(b []byte) (int, string, error) { + pkt := <-c.PktCh + if pkt.Close { + return 0, "", errors.New("closed") + } + n := copy(b, pkt.Data) + return n, pkt.Addr, nil +} + +func (c *echoUDPConn) WriteTo(b []byte, addr string) (int, error) { + nb := make([]byte, len(b)) + copy(nb, b) + c.PktCh <- echoUDPConnPkt{ + Data: nb, + Addr: addr, + } + return len(b), nil +} + +func (c *echoUDPConn) Close() error { + c.PktCh <- echoUDPConnPkt{ + Close: true, + } + return nil +} + +type udpsmMockIO struct { + ReceiveCh <-chan *protocol.UDPMessage + SendCh chan<- *protocol.UDPMessage +} + +func (io *udpsmMockIO) ReceiveMessage() (*protocol.UDPMessage, error) { + m := <-io.ReceiveCh + if m == nil { + return nil, errors.New("closed") + } + return m, nil +} + +func (io *udpsmMockIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error { + nMsg := *msg + nMsg.Data = make([]byte, len(msg.Data)) + copy(nMsg.Data, msg.Data) + io.SendCh <- &nMsg + return nil +} + +func (io *udpsmMockIO) DialUDP(reqAddr string) (UDPConn, error) { + return &echoUDPConn{ + PktCh: make(chan echoUDPConnPkt, 10), + }, nil +} + +type udpsmMockEventNew struct { + SessionID uint32 + ReqAddr string +} + +type udpsmMockEventClosed struct { + SessionID uint32 + Err error +} + +type udpsmMockEventLogger struct { + NewCh chan<- udpsmMockEventNew + ClosedCh chan<- udpsmMockEventClosed +} + +func (l *udpsmMockEventLogger) New(sessionID uint32, reqAddr string) { + l.NewCh <- udpsmMockEventNew{sessionID, reqAddr} +} + +func (l *udpsmMockEventLogger) Closed(sessionID uint32, err error) { + l.ClosedCh <- udpsmMockEventClosed{sessionID, err} +} + +func TestUDPSessionManager(t *testing.T) { + msgReceiveCh := make(chan *protocol.UDPMessage, 10) + msgSendCh := make(chan *protocol.UDPMessage, 10) + io := &udpsmMockIO{ + ReceiveCh: msgReceiveCh, + SendCh: msgSendCh, + } + eventNewCh := make(chan udpsmMockEventNew, 10) + eventClosedCh := make(chan udpsmMockEventClosed, 10) + eventLogger := &udpsmMockEventLogger{ + NewCh: eventNewCh, + ClosedCh: eventClosedCh, + } + sm := newUDPSessionManager(io, eventLogger, 2*time.Second) + go sm.Run() + + ms := []*protocol.UDPMessage{ + { + SessionID: 1234, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: "example.com:5353", + Data: []byte("hello"), + }, + { + SessionID: 5678, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: "example.com:9999", + Data: []byte("goodbye"), + }, + { + SessionID: 1234, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: "example.com:5353", + Data: []byte(" world"), + }, + { + SessionID: 5678, + PacketID: 0, + FragID: 0, + FragCount: 1, + Addr: "example.com:9999", + Data: []byte(" girl"), + }, + } + for _, m := range ms { + msgReceiveCh <- m + } + // New event order should be consistent + newEvent := <-eventNewCh + if newEvent.SessionID != 1234 || newEvent.ReqAddr != "example.com:5353" { + t.Error("unexpected new event value") + } + newEvent = <-eventNewCh + if newEvent.SessionID != 5678 || newEvent.ReqAddr != "example.com:9999" { + t.Error("unexpected new event value") + } + // Message order is not guaranteed + msgMap := make(map[string]bool) + for i := 0; i < 4; i++ { + msg := <-msgSendCh + msgMap[fmt.Sprintf("%d:%s:%s", msg.SessionID, msg.Addr, string(msg.Data))] = true + } + if !(msgMap["1234:example.com:5353:hello"] && + msgMap["5678:example.com:9999:goodbye"] && + msgMap["1234:example.com:5353: world"] && + msgMap["5678:example.com:9999: girl"]) { + t.Error("unexpected message value") + } + // Timeout check + startTime := time.Now() + closedMap := make(map[uint32]bool) + for i := 0; i < 2; i++ { + closedEvent := <-eventClosedCh + closedMap[closedEvent.SessionID] = true + } + if !(closedMap[1234] && closedMap[5678]) { + t.Error("unexpected closed event value", closedMap) + } + if time.Since(startTime) < 2*time.Second || time.Since(startTime) > 4*time.Second { + t.Error("unexpected timeout duration") + } + + // Goroutine leak check + msgReceiveCh <- nil + time.Sleep(1 * time.Second) // Wait for internal routines to exit + goleak.VerifyNone(t) +} diff --git a/extras/outbounds/interface.go b/extras/outbounds/interface.go index 7d34d9d..53143d6 100644 --- a/extras/outbounds/interface.go +++ b/extras/outbounds/interface.go @@ -75,7 +75,7 @@ func (a *PluggableOutboundAdapter) DialTCP(reqAddr string) (net.Conn, error) { }) } -func (a *PluggableOutboundAdapter) ListenUDP() (server.UDPConn, error) { +func (a *PluggableOutboundAdapter) DialUDP() (server.UDPConn, error) { conn, err := a.PluggableOutbound.ListenUDP() if err != nil { return nil, err diff --git a/extras/outbounds/interface_test.go b/extras/outbounds/interface_test.go index d3b4b06..aa2aa12 100644 --- a/extras/outbounds/interface_test.go +++ b/extras/outbounds/interface_test.go @@ -68,10 +68,10 @@ func TestPluggableOutboundAdapter(t *testing.T) { if err != errWrongAddr { t.Fatal("DialTCP with wrong addr should fail, got", err) } - // ListenUDP - uConn, err := adapter.ListenUDP() + // DialUDP + uConn, err := adapter.DialUDP() if err != nil { - t.Fatal("ListenUDP failed", err) + t.Fatal("DialUDP failed", err) } // ReadFrom b := make([]byte, 10)