From d6b549cea402aa66d954b40adfd452b096931b77 Mon Sep 17 00:00:00 2001 From: Toby Date: Sat, 19 Feb 2022 17:42:02 -0800 Subject: [PATCH] fix: incorrect TProxy UDP implementation (228) --- cmd/client.go | 11 +++- pkg/tproxy/udp_linux.go | 140 +++++++++++++++++----------------------- 2 files changed, 68 insertions(+), 83 deletions(-) diff --git a/cmd/client.go b/cmd/client.go index 72d40ba..83e3728 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -3,12 +3,14 @@ package main import ( "crypto/tls" "crypto/x509" + "errors" "github.com/oschwald/geoip2-golang" "github.com/yosuke-furukawa/json5/encoding/json5" "io" "io/ioutil" "net" "net/http" + "os" "strings" "time" @@ -388,20 +390,23 @@ func client(config *clientConfig) { go func() { rl, err := tproxy.NewUDPTProxy(client, config.UDPTProxy.Listen, time.Duration(config.UDPTProxy.Timeout)*time.Second, - func(addr net.Addr) { + func(addr, reqAddr net.Addr) { logrus.WithFields(logrus.Fields{ "src": addr.String(), + "dst": reqAddr.String(), }).Debug("UDP TProxy request") }, - func(addr net.Addr, err error) { - if err != tproxy.ErrTimeout { + func(addr, reqAddr net.Addr, err error) { + if !errors.Is(err, os.ErrDeadlineExceeded) { logrus.WithFields(logrus.Fields{ "error": err, "src": addr.String(), + "dst": reqAddr.String(), }).Info("UDP TProxy error") } else { logrus.WithFields(logrus.Fields{ "src": addr.String(), + "dst": reqAddr.String(), }).Debug("UDP TProxy session closed") } }) diff --git a/pkg/tproxy/udp_linux.go b/pkg/tproxy/udp_linux.go index d61da27..f28b95b 100644 --- a/pkg/tproxy/udp_linux.go +++ b/pkg/tproxy/udp_linux.go @@ -1,30 +1,26 @@ package tproxy import ( - "errors" "github.com/LiamHaworth/go-tproxy" "github.com/tobyxdd/hysteria/pkg/core" "net" - "sync" - "sync/atomic" "time" ) const udpBufferSize = 65535 -var ErrTimeout = errors.New("inactivity timeout") - type UDPTProxy struct { HyClient *core.Client ListenAddr *net.UDPAddr Timeout time.Duration - ConnFunc func(addr net.Addr) - ErrorFunc func(addr net.Addr, err error) + ConnFunc func(addr, reqAddr net.Addr) + ErrorFunc func(addr, reqAddr net.Addr, err error) } func NewUDPTProxy(hyClient *core.Client, listen string, timeout time.Duration, - connFunc func(addr net.Addr), errorFunc func(addr net.Addr, err error)) (*UDPTProxy, error) { + connFunc func(addr, reqAddr net.Addr), + errorFunc func(addr, reqAddr net.Addr, err error)) (*UDPTProxy, error) { uAddr, err := net.ResolveUDPAddr("udp", listen) if err != nil { return nil, err @@ -42,89 +38,73 @@ func NewUDPTProxy(hyClient *core.Client, listen string, timeout time.Duration, return r, nil } -type connEntry struct { - LocalConn *net.UDPConn - HyConn core.UDPConn - Deadline atomic.Value -} - func (r *UDPTProxy) ListenAndServe() error { conn, err := tproxy.ListenUDP("udp", r.ListenAddr) if err != nil { return err } defer conn.Close() - // src <-> HyClient UDPConn - connMap := make(map[string]*connEntry) - var connMapMutex sync.RWMutex // Read loop buf := make([]byte, udpBufferSize) for { - n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(conn, buf) + n, srcAddr, dstAddr, err := tproxy.ReadFromUDP(conn, buf) // Huge Caveat!! This essentially works as TCP's Accept here - won't repeat for the same srcAddr/dstAddr pair - because and only because we have tproxy.DialUDP("udp", dstAddr, srcAddr) to take over the connection below if n > 0 { - connMapMutex.RLock() - entry := connMap[srcAddr.String()] - connMapMutex.RUnlock() - if entry != nil { - // Existing conn - entry.Deadline.Store(time.Now().Add(r.Timeout)) - _ = entry.HyConn.WriteTo(buf[:n], dstAddr.String()) - } else { - // New - r.ConnFunc(srcAddr) - // TODO: Change fixed dstAddr - localConn, err := tproxy.DialUDP("udp", dstAddr, srcAddr) - if err != nil { - r.ErrorFunc(srcAddr, err) - continue - } - hyConn, err := r.HyClient.DialUDP() - if err != nil { - r.ErrorFunc(srcAddr, err) - _ = localConn.Close() - continue - } - // Send - entry := &connEntry{ - LocalConn: localConn, - HyConn: hyConn, - } - entry.Deadline.Store(time.Now().Add(r.Timeout)) - // Add it to the map - connMapMutex.Lock() - connMap[srcAddr.String()] = entry - connMapMutex.Unlock() - // Start remote to local - go func() { - for { - bs, _, err := hyConn.ReadFrom() - if err != nil { - break - } - entry.Deadline.Store(time.Now().Add(r.Timeout)) - _, _ = localConn.Write(bs) - } - }() - // Timeout cleanup routine - go func() { - for { - ttl := entry.Deadline.Load().(time.Time).Sub(time.Now()) - if ttl <= 0 { - // Time to die - connMapMutex.Lock() - _ = localConn.Close() - _ = hyConn.Close() - delete(connMap, srcAddr.String()) - connMapMutex.Unlock() - r.ErrorFunc(srcAddr, ErrTimeout) - return - } else { - time.Sleep(ttl) - } - } - }() - _ = hyConn.WriteTo(buf[:n], dstAddr.String()) + r.ConnFunc(srcAddr, dstAddr) + localConn, err := tproxy.DialUDP("udp", dstAddr, srcAddr) + if err != nil { + r.ErrorFunc(srcAddr, dstAddr, err) + continue } + hyConn, err := r.HyClient.DialUDP() + if err != nil { + r.ErrorFunc(srcAddr, dstAddr, err) + _ = localConn.Close() + continue + } + _ = hyConn.WriteTo(buf[:n], dstAddr.String()) + + errChan := make(chan error, 2) + // Start remote to local + go func() { + for { + bs, _, err := hyConn.ReadFrom() + if err != nil { + errChan <- err + return + } + _, err = localConn.Write(bs) + if err != nil { + errChan <- err + return + } + _ = localConn.SetDeadline(time.Now().Add(r.Timeout)) + } + }() + // Start local to remote + go func() { + for { + _ = localConn.SetDeadline(time.Now().Add(r.Timeout)) + n, err := localConn.Read(buf) + if n > 0 { + err := hyConn.WriteTo(buf[:n], dstAddr.String()) + if err != nil { + errChan <- err + return + } + } + if err != nil { + errChan <- err + return + } + } + }() + // Error cleanup routine + go func() { + err := <-errChan + _ = localConn.Close() + _ = hyConn.Close() + r.ErrorFunc(srcAddr, dstAddr, err) + }() } if err != nil { return err