fix(RPT-03-003): add socket dial timeouts to SOCKS5Client

This commit is contained in:
Toby 2022-11-30 23:10:30 -08:00
parent fcc2f06bc1
commit fc28c01980
3 changed files with 33 additions and 33 deletions

View File

@ -149,14 +149,8 @@ func server(config *serverConfig) {
} }
// SOCKS5 outbound // SOCKS5 outbound
if config.SOCKS5Outbound.Server != "" { if config.SOCKS5Outbound.Server != "" {
ob, err := transport.NewSOCKS5Client(config.SOCKS5Outbound.Server, transport.DefaultServerTransport.SOCKS5Client = transport.NewSOCKS5Client(config.SOCKS5Outbound.Server,
config.SOCKS5Outbound.User, config.SOCKS5Outbound.Password, 10*time.Second) config.SOCKS5Outbound.User, config.SOCKS5Outbound.Password)
if err != nil {
logrus.WithFields(logrus.Fields{
"error": err,
}).Fatal("Failed to initialize SOCKS5 outbound")
}
transport.DefaultServerTransport.SOCKS5Client = ob
} }
// Bind outbound // Bind outbound
if config.BindOutbound.Device != "" { if config.BindOutbound.Device != "" {

View File

@ -80,7 +80,11 @@ func (st *ServerTransport) ResolveIPAddr(address string) (*net.IPAddr, bool, err
func (st *ServerTransport) DialTCP(raddr *AddrEx) (*net.TCPConn, error) { func (st *ServerTransport) DialTCP(raddr *AddrEx) (*net.TCPConn, error) {
if st.SOCKS5Client != nil { if st.SOCKS5Client != nil {
return st.SOCKS5Client.DialTCP(raddr) conn, err := st.SOCKS5Client.DialTCP(raddr)
if err != nil {
return nil, err
}
return conn.(*net.TCPConn), nil
} else { } else {
conn, err := st.Dialer.Dial("tcp", raddr.String()) conn, err := st.Dialer.Dial("tcp", raddr.String())
if err != nil { if err != nil {

View File

@ -10,27 +10,29 @@ import (
"github.com/txthinking/socks5" "github.com/txthinking/socks5"
) )
const (
negTimeout = 8 * time.Second
)
type SOCKS5Client struct { type SOCKS5Client struct {
ServerTCPAddr *net.TCPAddr Dialer *net.Dialer
ServerAddr string
Username string Username string
Password string Password string
NegTimeout time.Duration
} }
func NewSOCKS5Client(serverAddr string, username string, password string, negTimeout time.Duration) (*SOCKS5Client, error) { func NewSOCKS5Client(serverAddr string, username string, password string) *SOCKS5Client {
tcpAddr, err := net.ResolveTCPAddr("tcp", serverAddr)
if err != nil {
return nil, err
}
return &SOCKS5Client{ return &SOCKS5Client{
ServerTCPAddr: tcpAddr, Dialer: &net.Dialer{
Timeout: 8 * time.Second,
},
ServerAddr: serverAddr,
Username: username, Username: username,
Password: password, Password: password,
NegTimeout: negTimeout, }
}, nil
} }
func (c *SOCKS5Client) negotiate(conn *net.TCPConn) error { func (c *SOCKS5Client) negotiate(conn net.Conn) error {
m := []byte{socks5.MethodNone} m := []byte{socks5.MethodNone}
if c.Username != "" && c.Password != "" { if c.Username != "" && c.Password != "" {
m = append(m, socks5.MethodUsernamePassword) m = append(m, socks5.MethodUsernamePassword)
@ -63,7 +65,7 @@ func (c *SOCKS5Client) negotiate(conn *net.TCPConn) error {
return nil return nil
} }
func (c *SOCKS5Client) request(conn *net.TCPConn, r *socks5.Request) (*socks5.Reply, error) { func (c *SOCKS5Client) request(conn net.Conn, r *socks5.Request) (*socks5.Reply, error) {
if _, err := r.WriteTo(conn); err != nil { if _, err := r.WriteTo(conn); err != nil {
return nil, err return nil, err
} }
@ -74,12 +76,12 @@ func (c *SOCKS5Client) request(conn *net.TCPConn, r *socks5.Request) (*socks5.Re
return reply, nil return reply, nil
} }
func (c *SOCKS5Client) DialTCP(raddr *AddrEx) (*net.TCPConn, error) { func (c *SOCKS5Client) DialTCP(raddr *AddrEx) (net.Conn, error) {
conn, err := net.DialTCP("tcp", nil, c.ServerTCPAddr) conn, err := c.Dialer.Dial("tcp", c.ServerAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := conn.SetDeadline(time.Now().Add(c.NegTimeout)); err != nil { if err := conn.SetDeadline(time.Now().Add(negTimeout)); err != nil {
_ = conn.Close() _ = conn.Close()
return nil, err return nil, err
} }
@ -112,11 +114,11 @@ func (c *SOCKS5Client) DialTCP(raddr *AddrEx) (*net.TCPConn, error) {
} }
func (c *SOCKS5Client) ListenUDP() (STPacketConn, error) { func (c *SOCKS5Client) ListenUDP() (STPacketConn, error) {
conn, err := net.DialTCP("tcp", nil, c.ServerTCPAddr) conn, err := c.Dialer.Dial("tcp", c.ServerAddr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if err := conn.SetDeadline(time.Now().Add(c.NegTimeout)); err != nil { if err := conn.SetDeadline(time.Now().Add(negTimeout)); err != nil {
_ = conn.Close() _ = conn.Close()
return nil, err return nil, err
} }
@ -145,7 +147,7 @@ func (c *SOCKS5Client) ListenUDP() (STPacketConn, error) {
_ = conn.Close() _ = conn.Close()
return nil, err return nil, err
} }
udpConn, err := net.DialUDP("udp", nil, udpRelayAddr) udpConn, err := c.Dialer.Dial("udp", udpRelayAddr.String())
if err != nil { if err != nil {
_ = conn.Close() _ = conn.Close()
return nil, err return nil, err
@ -159,8 +161,8 @@ func (c *SOCKS5Client) ListenUDP() (STPacketConn, error) {
} }
type socks5UDPConn struct { type socks5UDPConn struct {
tcpConn *net.TCPConn tcpConn net.Conn
udpConn *net.UDPConn udpConn net.Conn
} }
func (c *socks5UDPConn) hold() { func (c *socks5UDPConn) hold() {