From 244d0d43a95c0643665e19d617aa200015361145 Mon Sep 17 00:00:00 2001 From: Toby Date: Tue, 2 Mar 2021 16:25:36 -0800 Subject: [PATCH] PipePairWithTimeout --- pkg/relay/relay.go | 50 +--------------------------------------- pkg/socks5/server.go | 54 +++----------------------------------------- pkg/utils/pipe.go | 49 ++++++++++++++++++++++++++++++++++++++++ 3 files changed, 53 insertions(+), 100 deletions(-) diff --git a/pkg/relay/relay.go b/pkg/relay/relay.go index ca89c28..ded126c 100644 --- a/pkg/relay/relay.go +++ b/pkg/relay/relay.go @@ -3,7 +3,6 @@ package relay import ( "github.com/tobyxdd/hysteria/pkg/core" "github.com/tobyxdd/hysteria/pkg/utils" - "io" "net" "time" ) @@ -58,55 +57,8 @@ func (r *Relay) ListenAndServe() error { return } defer rc.Close() - err = pipePair(c, rc, r.Timeout) + err = utils.PipePairWithTimeout(c, rc, r.Timeout) r.ErrorFunc(c.RemoteAddr(), err) }(c) } } - -func pipePair(conn *net.TCPConn, stream io.ReadWriteCloser, timeout time.Duration) error { - errChan := make(chan error, 2) - // TCP to stream - go func() { - buf := make([]byte, utils.PipeBufferSize) - for { - if timeout != 0 { - _ = conn.SetDeadline(time.Now().Add(timeout)) - } - rn, err := conn.Read(buf) - if rn > 0 { - _, err := stream.Write(buf[:rn]) - if err != nil { - errChan <- err - return - } - } - if err != nil { - errChan <- err - return - } - } - }() - // Stream to TCP - go func() { - buf := make([]byte, utils.PipeBufferSize) - for { - rn, err := stream.Read(buf) - if rn > 0 { - _, err := conn.Write(buf[:rn]) - if err != nil { - errChan <- err - return - } - if timeout != 0 { - _ = conn.SetDeadline(time.Now().Add(timeout)) - } - } - if err != nil { - errChan <- err - return - } - } - }() - return <-errChan -} diff --git a/pkg/socks5/server.go b/pkg/socks5/server.go index 094e766..cadff7d 100644 --- a/pkg/socks5/server.go +++ b/pkg/socks5/server.go @@ -7,7 +7,6 @@ import ( "github.com/tobyxdd/hysteria/pkg/acl" "github.com/tobyxdd/hysteria/pkg/core" "github.com/tobyxdd/hysteria/pkg/utils" - "io" "strconv" ) @@ -173,7 +172,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error { } defer rc.Close() _ = sendReply(c, socks5.RepSuccess) - closeErr = pipePair(c, rc, s.TCPTimeout) + closeErr = utils.PipePairWithTimeout(c, rc, s.TCPTimeout) return nil case acl.ActionProxy: rc, err := s.HyClient.DialTCP(addr) @@ -184,7 +183,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error { } defer rc.Close() _ = sendReply(c, socks5.RepSuccess) - closeErr = pipePair(c, rc, s.TCPTimeout) + closeErr = utils.PipePairWithTimeout(c, rc, s.TCPTimeout) return nil case acl.ActionBlock: _ = sendReply(c, socks5.RepHostUnreachable) @@ -199,7 +198,7 @@ func (s *Server) handleTCP(c *net.TCPConn, r *socks5.Request) error { } defer rc.Close() _ = sendReply(c, socks5.RepSuccess) - closeErr = pipePair(c, rc, s.TCPTimeout) + closeErr = utils.PipePairWithTimeout(c, rc, s.TCPTimeout) return nil default: _ = sendReply(c, socks5.RepServerFailure) @@ -223,50 +222,3 @@ func parseRequestAddress(r *socks5.Request) (domain string, ip net.IP, port stri return "", r.DstAddr, p, net.JoinHostPort(net.IP(r.DstAddr).String(), p) } } - -func pipePair(conn *net.TCPConn, stream io.ReadWriteCloser, timeout time.Duration) error { - errChan := make(chan error, 2) - // TCP to stream - go func() { - buf := make([]byte, utils.PipeBufferSize) - for { - if timeout != 0 { - _ = conn.SetDeadline(time.Now().Add(timeout)) - } - rn, err := conn.Read(buf) - if rn > 0 { - _, err := stream.Write(buf[:rn]) - if err != nil { - errChan <- err - return - } - } - if err != nil { - errChan <- err - return - } - } - }() - // Stream to TCP - go func() { - buf := make([]byte, utils.PipeBufferSize) - for { - rn, err := stream.Read(buf) - if rn > 0 { - _, err := conn.Write(buf[:rn]) - if err != nil { - errChan <- err - return - } - if timeout != 0 { - _ = conn.SetDeadline(time.Now().Add(timeout)) - } - } - if err != nil { - errChan <- err - return - } - } - }() - return <-errChan -} diff --git a/pkg/utils/pipe.go b/pkg/utils/pipe.go index 47318f9..1232167 100644 --- a/pkg/utils/pipe.go +++ b/pkg/utils/pipe.go @@ -2,6 +2,8 @@ package utils import ( "io" + "net" + "time" ) const PipeBufferSize = 65536 @@ -33,3 +35,50 @@ func Pipe2Way(rw1, rw2 io.ReadWriter) error { // We only need the first error return <-errChan } + +func PipePairWithTimeout(conn *net.TCPConn, stream io.ReadWriteCloser, timeout time.Duration) error { + errChan := make(chan error, 2) + // TCP to stream + go func() { + buf := make([]byte, PipeBufferSize) + for { + if timeout != 0 { + _ = conn.SetDeadline(time.Now().Add(timeout)) + } + rn, err := conn.Read(buf) + if rn > 0 { + _, err := stream.Write(buf[:rn]) + if err != nil { + errChan <- err + return + } + } + if err != nil { + errChan <- err + return + } + } + }() + // Stream to TCP + go func() { + buf := make([]byte, PipeBufferSize) + for { + rn, err := stream.Read(buf) + if rn > 0 { + _, err := conn.Write(buf[:rn]) + if err != nil { + errChan <- err + return + } + if timeout != 0 { + _ = conn.SetDeadline(time.Now().Add(timeout)) + } + } + if err != nil { + errChan <- err + return + } + } + }() + return <-errChan +}