From 8b0a157e0be06b056537a9f7e4ce359b839ee938 Mon Sep 17 00:00:00 2001 From: Toby Date: Mon, 24 Oct 2022 22:47:12 -0700 Subject: [PATCH] chore: hy client should not force UDP addr for quic Dial --- pkg/core/client.go | 24 ++++++++---- pkg/transport/pktconns/funcs.go | 68 +++++++++++++++++++++++---------- 2 files changed, 63 insertions(+), 29 deletions(-) diff --git a/pkg/core/client.go b/pkg/core/client.go index d94f83a..397bd7c 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -25,7 +25,9 @@ import ( var ErrClosed = errors.New("closed") type Client struct { - serverAddr string + serverAddr string + serverName string // QUIC SNI + sendBPS, recvBPS uint64 auth []byte @@ -50,8 +52,18 @@ func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig pktConnFunc pktconns.ClientPacketConnFunc, sendBPS uint64, recvBPS uint64, quicReconnectFunc func(err error), ) (*Client, error) { quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud.DisablePathMTUDiscovery + // QUIC wants server name, but our serverAddr is usually host:port, + // so we try to extract it from serverAddr. + serverName, _, err := net.SplitHostPort(serverAddr) + if err != nil { + // It's possible that we have some weird serverAddr combined with weird PacketConn implementation, + // that doesn't follow the standard host:port format. So it's ok if we run into error here. + // Server name should be set in tlsConfig in that case. + serverName = "" + } c := &Client{ serverAddr: serverAddr, + serverName: serverName, sendBPS: sendBPS, recvBPS: recvBPS, auth: auth, @@ -75,16 +87,12 @@ func (c *Client) connect() error { _ = c.pktConn.Close() } // New connection - pktConn, err := c.pktConnFunc(c.serverAddr) + pktConn, sAddr, err := c.pktConnFunc(c.serverAddr) if err != nil { return err } - serverUDPAddr, err := net.ResolveUDPAddr("udp", c.serverAddr) - if err != nil { - _ = pktConn.Close() - return err - } - quicConn, err := quic.Dial(pktConn, serverUDPAddr, c.serverAddr, c.tlsConfig, c.quicConfig) + // Dial QUIC + quicConn, err := quic.Dial(pktConn, sAddr, c.serverName, c.tlsConfig, c.quicConfig) if err != nil { _ = pktConn.Close() return err diff --git a/pkg/transport/pktconns/funcs.go b/pkg/transport/pktconns/funcs.go index 4658a3c..a7b55bc 100644 --- a/pkg/transport/pktconns/funcs.go +++ b/pkg/transport/pktconns/funcs.go @@ -10,7 +10,7 @@ import ( ) type ( - ClientPacketConnFunc func(server string) (net.PacketConn, error) + ClientPacketConnFunc func(server string) (net.PacketConn, net.Addr, error) ServerPacketConnFunc func(listen string) (net.PacketConn, error) ) @@ -21,55 +21,81 @@ type ( func NewClientUDPConnFunc(obfsPassword string) ClientPacketConnFunc { if obfsPassword == "" { - return func(server string) (net.PacketConn, error) { - return net.ListenUDP("udp", nil) + return func(server string) (net.PacketConn, net.Addr, error) { + sAddr, err := net.ResolveUDPAddr("udp", server) + if err != nil { + return nil, nil, err + } + udpConn, err := net.ListenUDP("udp", nil) + return udpConn, sAddr, err } } else { - return func(server string) (net.PacketConn, error) { - ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) + return func(server string) (net.PacketConn, net.Addr, error) { + sAddr, err := net.ResolveUDPAddr("udp", server) + if err != nil { + return nil, nil, err + } udpConn, err := net.ListenUDP("udp", nil) if err != nil { - return nil, err + return nil, nil, err } - return udp.NewObfsUDPConn(udpConn, ob), nil + ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) + return udp.NewObfsUDPConn(udpConn, ob), sAddr, nil } } } func NewClientWeChatConnFunc(obfsPassword string) ClientPacketConnFunc { if obfsPassword == "" { - return func(server string) (net.PacketConn, error) { + return func(server string) (net.PacketConn, net.Addr, error) { + sAddr, err := net.ResolveUDPAddr("udp", server) + if err != nil { + return nil, nil, err + } udpConn, err := net.ListenUDP("udp", nil) if err != nil { - return nil, err + return nil, nil, err } - return wechat.NewObfsWeChatUDPConn(udpConn, nil), nil + return wechat.NewObfsWeChatUDPConn(udpConn, nil), sAddr, nil } } else { - return func(server string) (net.PacketConn, error) { - ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) + return func(server string) (net.PacketConn, net.Addr, error) { + sAddr, err := net.ResolveUDPAddr("udp", server) + if err != nil { + return nil, nil, err + } udpConn, err := net.ListenUDP("udp", nil) if err != nil { - return nil, err + return nil, nil, err } - return wechat.NewObfsWeChatUDPConn(udpConn, ob), nil + ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) + return wechat.NewObfsWeChatUDPConn(udpConn, ob), sAddr, nil } } } func NewClientFakeTCPConnFunc(obfsPassword string) ClientPacketConnFunc { if obfsPassword == "" { - return func(server string) (net.PacketConn, error) { - return faketcp.Dial("tcp", server) + return func(server string) (net.PacketConn, net.Addr, error) { + sAddr, err := net.ResolveTCPAddr("tcp", server) + if err != nil { + return nil, nil, err + } + fTCPConn, err := faketcp.Dial("tcp", server) + return fTCPConn, sAddr, err } } else { - return func(server string) (net.PacketConn, error) { - ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) - fakeTCPConn, err := faketcp.Dial("tcp", server) + return func(server string) (net.PacketConn, net.Addr, error) { + sAddr, err := net.ResolveTCPAddr("tcp", server) if err != nil { - return nil, err + return nil, nil, err } - return faketcp.NewObfsFakeTCPConn(fakeTCPConn, ob), nil + fTCPConn, err := faketcp.Dial("tcp", server) + if err != nil { + return nil, nil, err + } + ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) + return faketcp.NewObfsFakeTCPConn(fTCPConn, ob), sAddr, nil } } }