From fc28c0198013fa37e746667dcbdde2d8c89cba94 Mon Sep 17 00:00:00 2001 From: Toby Date: Wed, 30 Nov 2022 23:10:30 -0800 Subject: [PATCH] fix(RPT-03-003): add socket dial timeouts to SOCKS5Client --- app/cmd/server.go | 10 ++------ core/transport/server.go | 6 ++++- core/transport/socks5.go | 50 +++++++++++++++++++++------------------- 3 files changed, 33 insertions(+), 33 deletions(-) diff --git a/app/cmd/server.go b/app/cmd/server.go index f4c533b..c54cbc5 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -149,14 +149,8 @@ func server(config *serverConfig) { } // SOCKS5 outbound if config.SOCKS5Outbound.Server != "" { - ob, err := transport.NewSOCKS5Client(config.SOCKS5Outbound.Server, - config.SOCKS5Outbound.User, config.SOCKS5Outbound.Password, 10*time.Second) - if err != nil { - logrus.WithFields(logrus.Fields{ - "error": err, - }).Fatal("Failed to initialize SOCKS5 outbound") - } - transport.DefaultServerTransport.SOCKS5Client = ob + transport.DefaultServerTransport.SOCKS5Client = transport.NewSOCKS5Client(config.SOCKS5Outbound.Server, + config.SOCKS5Outbound.User, config.SOCKS5Outbound.Password) } // Bind outbound if config.BindOutbound.Device != "" { diff --git a/core/transport/server.go b/core/transport/server.go index dee4af2..42a5025 100644 --- a/core/transport/server.go +++ b/core/transport/server.go @@ -80,7 +80,11 @@ func (st *ServerTransport) ResolveIPAddr(address string) (*net.IPAddr, bool, err func (st *ServerTransport) DialTCP(raddr *AddrEx) (*net.TCPConn, error) { if st.SOCKS5Client != nil { - return st.SOCKS5Client.DialTCP(raddr) + conn, err := st.SOCKS5Client.DialTCP(raddr) + if err != nil { + return nil, err + } + return conn.(*net.TCPConn), nil } else { conn, err := st.Dialer.Dial("tcp", raddr.String()) if err != nil { diff --git a/core/transport/socks5.go b/core/transport/socks5.go index ca45ce1..a390b62 100644 --- a/core/transport/socks5.go +++ b/core/transport/socks5.go @@ -10,27 +10,29 @@ import ( "github.com/txthinking/socks5" ) +const ( + negTimeout = 8 * time.Second +) + type SOCKS5Client struct { - ServerTCPAddr *net.TCPAddr - Username string - Password string - NegTimeout time.Duration + Dialer *net.Dialer + ServerAddr string + Username string + Password string } -func NewSOCKS5Client(serverAddr string, username string, password string, negTimeout time.Duration) (*SOCKS5Client, error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", serverAddr) - if err != nil { - return nil, err - } +func NewSOCKS5Client(serverAddr string, username string, password string) *SOCKS5Client { return &SOCKS5Client{ - ServerTCPAddr: tcpAddr, - Username: username, - Password: password, - NegTimeout: negTimeout, - }, nil + Dialer: &net.Dialer{ + Timeout: 8 * time.Second, + }, + ServerAddr: serverAddr, + Username: username, + Password: password, + } } -func (c *SOCKS5Client) negotiate(conn *net.TCPConn) error { +func (c *SOCKS5Client) negotiate(conn net.Conn) error { m := []byte{socks5.MethodNone} if c.Username != "" && c.Password != "" { m = append(m, socks5.MethodUsernamePassword) @@ -63,7 +65,7 @@ func (c *SOCKS5Client) negotiate(conn *net.TCPConn) error { return nil } -func (c *SOCKS5Client) request(conn *net.TCPConn, r *socks5.Request) (*socks5.Reply, error) { +func (c *SOCKS5Client) request(conn net.Conn, r *socks5.Request) (*socks5.Reply, error) { if _, err := r.WriteTo(conn); err != nil { return nil, err } @@ -74,12 +76,12 @@ func (c *SOCKS5Client) request(conn *net.TCPConn, r *socks5.Request) (*socks5.Re return reply, nil } -func (c *SOCKS5Client) DialTCP(raddr *AddrEx) (*net.TCPConn, error) { - conn, err := net.DialTCP("tcp", nil, c.ServerTCPAddr) +func (c *SOCKS5Client) DialTCP(raddr *AddrEx) (net.Conn, error) { + conn, err := c.Dialer.Dial("tcp", c.ServerAddr) if err != nil { return nil, err } - if err := conn.SetDeadline(time.Now().Add(c.NegTimeout)); err != nil { + if err := conn.SetDeadline(time.Now().Add(negTimeout)); err != nil { _ = conn.Close() return nil, err } @@ -112,11 +114,11 @@ func (c *SOCKS5Client) DialTCP(raddr *AddrEx) (*net.TCPConn, error) { } func (c *SOCKS5Client) ListenUDP() (STPacketConn, error) { - conn, err := net.DialTCP("tcp", nil, c.ServerTCPAddr) + conn, err := c.Dialer.Dial("tcp", c.ServerAddr) if err != nil { return nil, err } - if err := conn.SetDeadline(time.Now().Add(c.NegTimeout)); err != nil { + if err := conn.SetDeadline(time.Now().Add(negTimeout)); err != nil { _ = conn.Close() return nil, err } @@ -145,7 +147,7 @@ func (c *SOCKS5Client) ListenUDP() (STPacketConn, error) { _ = conn.Close() return nil, err } - udpConn, err := net.DialUDP("udp", nil, udpRelayAddr) + udpConn, err := c.Dialer.Dial("udp", udpRelayAddr.String()) if err != nil { _ = conn.Close() return nil, err @@ -159,8 +161,8 @@ func (c *SOCKS5Client) ListenUDP() (STPacketConn, error) { } type socks5UDPConn struct { - tcpConn *net.TCPConn - udpConn *net.UDPConn + tcpConn net.Conn + udpConn net.Conn } func (c *socks5UDPConn) hold() {