diff --git a/core/client/client.go b/core/client/client.go index 2f35a07..153ffbd 100644 --- a/core/client/client.go +++ b/core/client/client.go @@ -133,14 +133,19 @@ func (c *clientImpl) connect() error { c.conn = conn if udpEnabled { c.udpSM = newUDPSessionManager(&udpIOImpl{Conn: conn}) - go func() { - c.udpSM.Run() - // TODO: Mark connection as closed - }() } return nil } +// openStream wraps the stream with QStream, which handles Close() properly +func (c *clientImpl) openStream() (quic.Stream, error) { + stream, err := c.conn.OpenStream() + if err != nil { + return nil, err + } + return &utils.QStream{Stream: stream}, nil +} + func (c *clientImpl) DialTCP(addr string) (net.Conn, error) { stream, err := c.openStream() if err != nil { @@ -272,12 +277,3 @@ func (io *udpIOImpl) SendMessage(buf []byte, msg *protocol.UDPMessage) error { } return io.Conn.SendMessage(buf[:msgN]) } - -// openStream wraps the stream with QStream, which handles Close() properly -func (c *clientImpl) openStream() (quic.Stream, error) { - stream, err := c.conn.OpenStream() - if err != nil { - return nil, err - } - return &utils.QStream{Stream: stream}, nil -} diff --git a/core/client/udp.go b/core/client/udp.go index 6cca1af..762f947 100644 --- a/core/client/udp.go +++ b/core/client/udp.go @@ -6,6 +6,7 @@ import ( "math/rand" "sync" + coreErrs "github.com/apernet/hysteria/core/errors" "github.com/apernet/hysteria/core/internal/frag" "github.com/apernet/hysteria/core/internal/protocol" "github.com/quic-go/quic-go" @@ -86,21 +87,22 @@ type udpSessionManager struct { mutex sync.Mutex m map[uint32]*udpConn nextID uint32 + + closed bool } func newUDPSessionManager(io udpIO) *udpSessionManager { - return &udpSessionManager{ + m := &udpSessionManager{ io: io, m: make(map[uint32]*udpConn), nextID: 1, } + go m.run() + return m } -// Run runs the session manager main loop. -// Exit and returns error when the underlying io returns error (e.g. closed). -func (m *udpSessionManager) Run() error { - defer m.cleanup() - +func (m *udpSessionManager) run() error { + defer m.closeCleanup() for { msg, err := m.io.ReceiveMessage() if err != nil { @@ -110,13 +112,14 @@ func (m *udpSessionManager) Run() error { } } -func (m *udpSessionManager) cleanup() { +func (m *udpSessionManager) closeCleanup() { m.mutex.Lock() defer m.mutex.Unlock() for _, conn := range m.m { m.close(conn) } + m.closed = true } func (m *udpSessionManager) feed(msg *protocol.UDPMessage) { @@ -142,6 +145,10 @@ func (m *udpSessionManager) NewUDP() (HyUDPConn, error) { m.mutex.Lock() defer m.mutex.Unlock() + if m.closed { + return nil, coreErrs.ClosedError{} + } + id := m.nextID m.nextID++ diff --git a/core/errors/errors.go b/core/errors/errors.go index 094ce1f..7643ee4 100644 --- a/core/errors/errors.go +++ b/core/errors/errors.go @@ -47,6 +47,13 @@ func (c DialError) Error() string { return "dial error: " + c.Message } +// ClosedError is returned when the client attempts to use a closed connection. +type ClosedError struct{} + +func (c ClosedError) Error() string { + return "connection closed" +} + // ProtocolError is returned when the server/client runs into an unexpected // or malformed request/response/message. type ProtocolError struct {