From 97c83dc935ca3940f62c3ebcc267b3bba71b72ed Mon Sep 17 00:00:00 2001 From: Toby Date: Mon, 2 Aug 2021 22:42:25 -0700 Subject: [PATCH] wrappedQUICStream to handle stream close properly --- pkg/core/client.go | 4 +-- pkg/core/server_client.go | 1 + pkg/core/stream.go | 54 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 57 insertions(+), 2 deletions(-) create mode 100644 pkg/core/stream.go diff --git a/pkg/core/client.go b/pkg/core/client.go index 07d0b8d..6e235b4 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -176,7 +176,7 @@ func (c *Client) openStreamWithReconnect() (quic.Session, quic.Stream, error) { stream, err := c.quicSession.OpenStream() if err == nil { // All good - return c.quicSession, stream, nil + return c.quicSession, &wrappedQUICStream{stream}, nil } // Something is wrong if nErr, ok := err.(net.Error); ok && nErr.Temporary() { @@ -190,7 +190,7 @@ func (c *Client) openStreamWithReconnect() (quic.Session, quic.Stream, error) { } // We are not going to try again even if it still fails the second time stream, err = c.quicSession.OpenStream() - return c.quicSession, stream, nil + return c.quicSession, &wrappedQUICStream{stream}, err } func (c *Client) DialTCP(addr string) (net.Conn, error) { diff --git a/pkg/core/server_client.go b/pkg/core/server_client.go index 31616cb..50f49c4 100644 --- a/pkg/core/server_client.go +++ b/pkg/core/server_client.go @@ -85,6 +85,7 @@ func (c *serverClient) Run() { c.ConnGauge.Inc() } go func() { + stream := &wrappedQUICStream{stream} c.handleStream(stream) _ = stream.Close() if c.ConnGauge != nil { diff --git a/pkg/core/stream.go b/pkg/core/stream.go new file mode 100644 index 0000000..8ace4a1 --- /dev/null +++ b/pkg/core/stream.go @@ -0,0 +1,54 @@ +package core + +import ( + "context" + "github.com/lucas-clemente/quic-go" + "time" +) + +// Handle stream close properly +// Ref: https://github.com/libp2p/go-libp2p-quic-transport/blob/master/stream.go +type wrappedQUICStream struct { + Stream quic.Stream +} + +func (s *wrappedQUICStream) StreamID() quic.StreamID { + return s.Stream.StreamID() +} + +func (s *wrappedQUICStream) Read(p []byte) (n int, err error) { + return s.Stream.Read(p) +} + +func (s *wrappedQUICStream) CancelRead(code quic.StreamErrorCode) { + s.Stream.CancelRead(code) +} + +func (s *wrappedQUICStream) SetReadDeadline(t time.Time) error { + return s.Stream.SetReadDeadline(t) +} + +func (s *wrappedQUICStream) Write(p []byte) (n int, err error) { + return s.Stream.Write(p) +} + +func (s *wrappedQUICStream) Close() error { + s.Stream.CancelRead(0) + return s.Stream.Close() +} + +func (s *wrappedQUICStream) CancelWrite(code quic.StreamErrorCode) { + s.Stream.CancelWrite(code) +} + +func (s *wrappedQUICStream) Context() context.Context { + return s.Stream.Context() +} + +func (s *wrappedQUICStream) SetWriteDeadline(t time.Time) error { + return s.Stream.SetWriteDeadline(t) +} + +func (s *wrappedQUICStream) SetDeadline(t time.Time) error { + return s.Stream.SetDeadline(t) +}