package forwarder import ( "context" "crypto/tls" "errors" "github.com/lucas-clemente/quic-go" "github.com/tobyxdd/hysteria/internal/utils" "net" "sync" "sync/atomic" ) type QUICClient struct { inboundBytes, outboundBytes uint64 // atomic reconnectMutex sync.Mutex quicSession quic.Session listener net.Listener remoteAddr string name string tlsConfig *tls.Config sendBPS, recvBPS uint64 recvWindowConn, recvWindow uint64 closed bool newCongestion CongestionFactory onServerConnected ServerConnectedCallback onServerError ServerErrorCallback onNewTCPConnection NewTCPConnectionCallback onTCPConnectionClosed TCPConnectionClosedCallback } func NewQUICClient(addr string, remoteAddr string, name string, tlsConfig *tls.Config, sendBPS uint64, recvBPS uint64, recvWindowConn uint64, recvWindow uint64, newCongestion CongestionFactory, onServerConnected ServerConnectedCallback, onServerError ServerErrorCallback, onNewTCPConnection NewTCPConnectionCallback, onTCPConnectionClosed TCPConnectionClosedCallback) (*QUICClient, error) { // Local TCP listener listener, err := net.Listen("tcp", addr) if err != nil { return nil, err } c := &QUICClient{ listener: listener, remoteAddr: remoteAddr, name: name, tlsConfig: tlsConfig, sendBPS: sendBPS, recvBPS: recvBPS, recvWindowConn: recvWindowConn, recvWindow: recvWindow, newCongestion: newCongestion, onServerConnected: onServerConnected, onServerError: onServerError, onNewTCPConnection: onNewTCPConnection, onTCPConnectionClosed: onTCPConnectionClosed, } if err := c.connectToServer(); err != nil { _ = c.listener.Close() return nil, err } go c.acceptLoop() return c, nil } func (c *QUICClient) Close() error { err1 := c.listener.Close() c.reconnectMutex.Lock() err2 := c.quicSession.CloseWithError(closeErrorCodeGeneric, "generic") c.closed = true c.reconnectMutex.Unlock() if err1 != nil { return err1 } return err2 } func (c *QUICClient) Stats() (string, uint64, uint64) { return c.remoteAddr, atomic.LoadUint64(&c.inboundBytes), atomic.LoadUint64(&c.outboundBytes) } func (c *QUICClient) acceptLoop() { for { conn, err := c.listener.Accept() if err != nil { break } go c.handleConn(conn) } } func (c *QUICClient) connectToServer() error { qs, err := quic.DialAddr(c.remoteAddr, c.tlsConfig, &quic.Config{ MaxReceiveStreamFlowControlWindow: c.recvWindowConn, MaxReceiveConnectionFlowControlWindow: c.recvWindow, KeepAlive: true, }) if err != nil { c.onServerError(err) return err } // Control stream ctx, ctxCancel := context.WithTimeout(context.Background(), controlStreamTimeout) ctlStream, err := qs.OpenStreamSync(ctx) ctxCancel() if err != nil { _ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream error") c.onServerError(err) return err } banner, cSendBPS, cRecvBPS, err := handleControlStream(qs, ctlStream, c.name, c.sendBPS, c.recvBPS, c.newCongestion) if err != nil { _ = qs.CloseWithError(closeErrorCodeProtocolFailure, "control stream handling error") c.onServerError(err) return err } // All good c.quicSession = qs c.onServerConnected(qs.RemoteAddr(), banner, cSendBPS, cRecvBPS) return nil } func (c *QUICClient) openStreamWithReconnect() (quic.Stream, error) { c.reconnectMutex.Lock() defer c.reconnectMutex.Unlock() if c.closed { return nil, errors.New("client closed") } stream, err := c.quicSession.OpenStream() if err == nil { // All good return stream, nil } // Something is wrong c.onServerError(err) if nErr, ok := err.(net.Error); ok && nErr.Temporary() { // Temporary error, just return return nil, err } // Permanent error, need to reconnect if err := c.connectToServer(); err != nil { // Still error, oops return nil, err } // We are not going to try again even if it still fails the second time stream, err = c.quicSession.OpenStream() if err != nil { c.onServerError(err) } return stream, err } // Negotiate speed, return banner, send & receive speed func handleControlStream(qs quic.Session, stream quic.Stream, name string, sendBPS uint64, recvBPS uint64, newCongestion CongestionFactory) (string, uint64, uint64, error) { err := writeClientSpeedRequest(stream, &ClientSpeedRequest{ Name: name, Speed: &Speed{ SendBps: sendBPS, ReceiveBps: recvBPS, }, }) if err != nil { return "", 0, 0, err } // Response resp, err := readServerSpeedResponse(stream) if err != nil { return "", 0, 0, err } // Set the congestion accordingly if newCongestion != nil { qs.SetCongestion(newCongestion(resp.Speed.ReceiveBps)) } return resp.Banner, resp.Speed.ReceiveBps, resp.Speed.SendBps, nil } func (c *QUICClient) handleConn(conn net.Conn) { c.onNewTCPConnection(conn.RemoteAddr()) defer conn.Close() stream, err := c.openStreamWithReconnect() if err != nil { c.onTCPConnectionClosed(conn.RemoteAddr(), err) return } defer stream.Close() // Pipes errChan := make(chan error, 2) go func() { // TCP to QUIC errChan <- utils.Pipe(conn, stream, &c.outboundBytes) }() go func() { // QUIC to TCP errChan <- utils.Pipe(stream, conn, &c.inboundBytes) }() // We only need the first error err = <-errChan c.onTCPConnectionClosed(conn.RemoteAddr(), err) }