diff --git a/core/client/client.go b/core/client/client.go index 69a9a68..11fe2f6 100644 --- a/core/client/client.go +++ b/core/client/client.go @@ -149,7 +149,7 @@ func (c *clientImpl) openStream() (quic.Stream, error) { func (c *clientImpl) TCP(addr string) (net.Conn, error) { stream, err := c.openStream() if err != nil { - if netErr, ok := err.(net.Error); ok && !netErr.Temporary() { + if isQUICClosedError(err) { // Connection is dead return nil, coreErrs.ClosedError{} } @@ -203,6 +203,17 @@ func (c *clientImpl) Close() error { return nil } +// isQUICClosedError checks if the error returned by OpenStream +// indicates that the QUIC connection is permanently closed. +func isQUICClosedError(err error) bool { + netErr, ok := err.(net.Error) + if !ok { + return true + } else { + return !netErr.Temporary() + } +} + type tcpConn struct { Orig quic.Stream PseudoLocalAddr net.Addr diff --git a/core/internal/integration_tests/close_test.go b/core/internal/integration_tests/close_test.go index 9216ede..9bca65b 100644 --- a/core/internal/integration_tests/close_test.go +++ b/core/internal/integration_tests/close_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/mock" "github.com/apernet/hysteria/core/client" + "github.com/apernet/hysteria/core/errors" "github.com/apernet/hysteria/core/internal/integration_tests/mocks" "github.com/apernet/hysteria/core/server" ) @@ -172,5 +173,74 @@ func TestClientServerUDPIdleTimeout(t *testing.T) { }) eventLogger.EXPECT().UDPError(mock.Anything, mock.Anything, uint32(1), nil).Once() time.Sleep(3 * time.Second) - mock.AssertExpectationsForObjects(t, sobConn, serverOb, eventLogger) +} + +// TestClientServerClientShutdown tests whether the server can handle the client's shutdown correctly. +func TestClientServerClientShutdown(t *testing.T) { + // Create server + udpConn, udpAddr, err := serverConn() + assert.NoError(t, err) + auth := mocks.NewMockAuthenticator(t) + auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody") + eventLogger := mocks.NewMockEventLogger(t) + eventLogger.EXPECT().Connect(mock.Anything, "nobody", mock.Anything).Once() + s, err := server.NewServer(&server.Config{ + TLSConfig: serverTLSConfig(), + Conn: udpConn, + Authenticator: auth, + EventLogger: eventLogger, + }) + assert.NoError(t, err) + defer s.Close() + go s.Serve() + + // Create client + c, err := client.NewClient(&client.Config{ + ServerAddr: udpAddr, + TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, + }) + assert.NoError(t, err) + + // Close the client - expect disconnect event on the server side. + // Since client.Close() sends HTTP3 ErrCodeNoError, the error should be nil. + eventLogger.EXPECT().Disconnect(mock.Anything, "nobody", nil).Once() + _ = c.Close() + time.Sleep(1 * time.Second) +} + +// TestClientServerServerShutdown tests whether the client can handle the server's shutdown correctly. +func TestClientServerServerShutdown(t *testing.T) { + // Create server + udpConn, udpAddr, err := serverConn() + assert.NoError(t, err) + auth := mocks.NewMockAuthenticator(t) + auth.EXPECT().Authenticate(mock.Anything, mock.Anything, mock.Anything).Return(true, "nobody") + s, err := server.NewServer(&server.Config{ + TLSConfig: serverTLSConfig(), + Conn: udpConn, + Authenticator: auth, + }) + assert.NoError(t, err) + go s.Serve() + + // Create client + c, err := client.NewClient(&client.Config{ + ServerAddr: udpAddr, + TLSConfig: client.TLSConfig{InsecureSkipVerify: true}, + }) + assert.NoError(t, err) + + // Close the server - expect the client to return ClosedError for both TCP & UDP calls. + _ = s.Close() + time.Sleep(1 * time.Second) + + _, err = c.TCP("whatever") + _, ok := err.(errors.ClosedError) + assert.True(t, ok) + + _, err = c.UDP() + _, ok = err.(errors.ClosedError) + assert.True(t, ok) + + assert.NoError(t, c.Close()) }