diff --git a/core/server/config.go b/core/server/config.go index d7fb4bd..3eb183d 100644 --- a/core/server/config.go +++ b/core/server/config.go @@ -103,13 +103,12 @@ type QUICConfig struct { } // Outbound provides the implementation of how the server should connect to remote servers. -// Even though it's called DialUDP, outbound implementations do not necessarily have to -// return a "connected" UDP socket that can only send and receive from reqAddr. It's the -// address of the first packet to be sent. -// It's perfectly fine to have a "full-cone" implementation for UDP. +// Although UDP includes a reqAddr, the implementation does not necessarily have to use it +// to make a "connected" UDP connection that does not accept packets from other addresses. +// In fact, the default implementation simply uses net.ListenUDP for a "full-cone" behavior. type Outbound interface { - DialTCP(reqAddr string) (net.Conn, error) - DialUDP(reqAddr string) (UDPConn, error) + TCP(reqAddr string) (net.Conn, error) + UDP(reqAddr string) (UDPConn, error) } // UDPConn is like net.PacketConn, but uses string for addresses. @@ -125,11 +124,11 @@ var defaultOutboundDialer = net.Dialer{ Timeout: 10 * time.Second, } -func (o *defaultOutbound) DialTCP(reqAddr string) (net.Conn, error) { +func (o *defaultOutbound) TCP(reqAddr string) (net.Conn, error) { return defaultOutboundDialer.Dial("tcp", reqAddr) } -func (o *defaultOutbound) DialUDP(reqAddr string) (UDPConn, error) { +func (o *defaultOutbound) UDP(reqAddr string) (UDPConn, error) { conn, err := net.ListenUDP("udp", nil) if err != nil { return nil, err diff --git a/core/server/server.go b/core/server/server.go index f4dfede..f780540 100644 --- a/core/server/server.go +++ b/core/server/server.go @@ -192,7 +192,7 @@ func (h *h3sHandler) handleTCPRequest(stream quic.Stream) { h.config.EventLogger.TCPRequest(h.conn.RemoteAddr(), h.authID, reqAddr) } // Dial target - tConn, err := h.config.Outbound.DialTCP(reqAddr) + tConn, err := h.config.Outbound.TCP(reqAddr) if err != nil { _ = protocol.WriteTCPResponse(stream, false, err.Error()) _ = stream.Close() @@ -280,8 +280,8 @@ func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error { return io.Conn.SendMessage(buf[:msgN]) } -func (io *udpIOImpl) DialUDP(reqAddr string) (UDPConn, error) { - return io.Outbound.DialUDP(reqAddr) +func (io *udpIOImpl) UDP(reqAddr string) (UDPConn, error) { + return io.Outbound.UDP(reqAddr) } type udpEventLoggerImpl struct { @@ -296,7 +296,7 @@ func (l *udpEventLoggerImpl) New(sessionID uint32, reqAddr string) { } } -func (l *udpEventLoggerImpl) Closed(sessionID uint32, err error) { +func (l *udpEventLoggerImpl) Close(sessionID uint32, err error) { if l.EventLogger != nil { l.EventLogger.UDPError(l.Conn.RemoteAddr(), l.AuthID, sessionID, err) } diff --git a/core/server/udp.go b/core/server/udp.go index b0b15b5..ae98b80 100644 --- a/core/server/udp.go +++ b/core/server/udp.go @@ -20,12 +20,12 @@ const ( type udpIO interface { ReceiveMessage() (*protocol.UDPMessage, error) SendMessage([]byte, *protocol.UDPMessage) error - DialUDP(reqAddr string) (UDPConn, error) + UDP(reqAddr string) (UDPConn, error) } type udpEventLogger interface { New(sessionID uint32, reqAddr string) - Closed(sessionID uint32, err error) + Close(sessionID uint32, err error) } type udpSessionEntry struct { @@ -164,7 +164,7 @@ func (m *udpSessionManager) cleanup(idleOnly bool) { if !idleOnly || now.Sub(entry.Last.Get()) > m.idleTimeout { entry.Closed = true _ = entry.Conn.Close() - m.eventLogger.Closed(sessionID, nil) + m.eventLogger.Close(sessionID, nil) delete(m.m, sessionID) } } @@ -177,10 +177,10 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { if entry == nil { // New session m.eventLogger.New(msg.SessionID, msg.Addr) - conn, err := m.io.DialUDP(msg.Addr) + conn, err := m.io.UDP(msg.Addr) if err != nil { m.mutex.Unlock() - m.eventLogger.Closed(msg.SessionID, err) + m.eventLogger.Close(msg.SessionID, err) return } entry = &udpSessionEntry{ @@ -197,7 +197,7 @@ func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { if !entry.Closed { entry.Closed = true _ = entry.Conn.Close() - m.eventLogger.Closed(entry.ID, err) + m.eventLogger.Close(entry.ID, err) delete(m.m, entry.ID) } m.mutex.Unlock() diff --git a/core/server/udp_test.go b/core/server/udp_test.go index e7d7463..9d46eba 100644 --- a/core/server/udp_test.go +++ b/core/server/udp_test.go @@ -67,7 +67,7 @@ func (io *udpMockIO) SendMessage(buf []byte, msg *protocol.UDPMessage) error { return nil } -func (io *udpMockIO) DialUDP(reqAddr string) (UDPConn, error) { +func (io *udpMockIO) UDP(reqAddr string) (UDPConn, error) { return &echoUDPConn{ PktCh: make(chan echoUDPConnPkt, 10), }, nil @@ -78,22 +78,22 @@ type udpMockEventNew struct { ReqAddr string } -type udpMockEventClosed struct { +type udpMockEventClose struct { SessionID uint32 Err error } type udpMockEventLogger struct { - NewCh chan<- udpMockEventNew - ClosedCh chan<- udpMockEventClosed + NewCh chan<- udpMockEventNew + CloseCh chan<- udpMockEventClose } func (l *udpMockEventLogger) New(sessionID uint32, reqAddr string) { l.NewCh <- udpMockEventNew{sessionID, reqAddr} } -func (l *udpMockEventLogger) Closed(sessionID uint32, err error) { - l.ClosedCh <- udpMockEventClosed{sessionID, err} +func (l *udpMockEventLogger) Close(sessionID uint32, err error) { + l.CloseCh <- udpMockEventClose{sessionID, err} } func TestUDPSessionManager(t *testing.T) { @@ -104,10 +104,10 @@ func TestUDPSessionManager(t *testing.T) { SendCh: msgSendCh, } eventNewCh := make(chan udpMockEventNew, 10) - eventClosedCh := make(chan udpMockEventClosed, 10) + eventCloseCh := make(chan udpMockEventClose, 10) eventLogger := &udpMockEventLogger{ - NewCh: eventNewCh, - ClosedCh: eventClosedCh, + NewCh: eventNewCh, + CloseCh: eventCloseCh, } sm := newUDPSessionManager(io, eventLogger, 2*time.Second) go sm.Run() @@ -172,13 +172,13 @@ func TestUDPSessionManager(t *testing.T) { } // Timeout check startTime := time.Now() - closedMap := make(map[uint32]bool) + closeMap := make(map[uint32]bool) for i := 0; i < 2; i++ { - closedEvent := <-eventClosedCh - closedMap[closedEvent.SessionID] = true + closeEvent := <-eventCloseCh + closeMap[closeEvent.SessionID] = true } - if !(closedMap[1234] && closedMap[5678]) { - t.Error("unexpected closed event value", closedMap) + if !(closeMap[1234] && closeMap[5678]) { + t.Error("unexpected close event value") } if time.Since(startTime) < 2*time.Second || time.Since(startTime) > 4*time.Second { t.Error("unexpected timeout duration")