diff --git a/cmd/client.go b/cmd/client.go index 83622fa..d7b39e9 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -181,10 +181,10 @@ func client(config *clientConfig) { }() } - if len(config.Relay.Listen) > 0 { + if len(config.TCPRelay.Listen) > 0 { go func() { - rl, err := relay.NewTCPRelay(client, config.Relay.Listen, config.Relay.Remote, - time.Duration(config.Relay.Timeout)*time.Second, + rl, err := relay.NewTCPRelay(client, config.TCPRelay.Listen, config.TCPRelay.Remote, + time.Duration(config.TCPRelay.Timeout)*time.Second, func(addr net.Addr) { logrus.WithFields(logrus.Fields{ "src": addr.String(), @@ -201,12 +201,40 @@ func client(config *clientConfig) { "src": addr.String(), }).Debug("TCP relay EOF") } - }) if err != nil { logrus.WithField("error", err).Fatal("Failed to initialize TCP relay") } - logrus.WithField("addr", config.Relay.Listen).Info("TCP relay up and running") + logrus.WithField("addr", config.TCPRelay.Listen).Info("TCP relay up and running") + errChan <- rl.ListenAndServe() + }() + } + + if len(config.UDPRelay.Listen) > 0 { + go func() { + rl, err := relay.NewUDPRelay(client, config.UDPRelay.Listen, config.UDPRelay.Remote, + time.Duration(config.UDPRelay.Timeout)*time.Second, + func(addr net.Addr) { + logrus.WithFields(logrus.Fields{ + "src": addr.String(), + }).Debug("UDP relay request") + }, + func(addr net.Addr, err error) { + if err != relay.ErrTimeout { + logrus.WithFields(logrus.Fields{ + "error": err, + "src": addr.String(), + }).Info("UDP relay error") + } else { + logrus.WithFields(logrus.Fields{ + "src": addr.String(), + }).Debug("UDP relay session closed") + } + }) + if err != nil { + logrus.WithField("error", err).Fatal("Failed to initialize UDP relay") + } + logrus.WithField("addr", config.UDPRelay.Listen).Info("UDP relay up and running") errChan <- rl.ListenAndServe() }() } diff --git a/cmd/config.go b/cmd/config.go index 1e9c543..5a76cdf 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -80,11 +80,16 @@ type clientConfig struct { Cert string `json:"cert"` Key string `json:"key"` } `json:"http"` - Relay struct { + TCPRelay struct { Listen string `json:"listen"` Remote string `json:"remote"` Timeout int `json:"timeout"` - } `json:"relay"` + } `json:"relay_tcp"` + UDPRelay struct { + Listen string `json:"listen"` + Remote string `json:"remote"` + Timeout int `json:"timeout"` + } `json:"relay_udp"` ACL string `json:"acl"` Obfs string `json:"obfs"` Auth []byte `json:"auth"` @@ -96,10 +101,11 @@ type clientConfig struct { } func (c *clientConfig) Check() error { - if len(c.SOCKS5.Listen) == 0 && len(c.HTTP.Listen) == 0 && len(c.Relay.Listen) == 0 { - return errors.New("no SOCKS5, HTTP or relay listen address") + if len(c.SOCKS5.Listen) == 0 && len(c.HTTP.Listen) == 0 && + len(c.TCPRelay.Listen) == 0 && len(c.UDPRelay.Listen) == 0 { + return errors.New("no SOCKS5, HTTP, TCP relay or UDP relay listen address") } - if len(c.Relay.Listen) > 0 && len(c.Relay.Remote) == 0 { + if len(c.TCPRelay.Listen) > 0 && len(c.TCPRelay.Remote) == 0 { return errors.New("no relay remote address") } if c.SOCKS5.Timeout != 0 && c.SOCKS5.Timeout <= 4 { @@ -108,8 +114,11 @@ func (c *clientConfig) Check() error { if c.HTTP.Timeout != 0 && c.HTTP.Timeout <= 4 { return errors.New("invalid HTTP timeout") } - if c.Relay.Timeout != 0 && c.Relay.Timeout <= 4 { - return errors.New("invalid relay timeout") + if c.TCPRelay.Timeout != 0 && c.TCPRelay.Timeout <= 4 { + return errors.New("invalid TCP relay timeout") + } + if c.UDPRelay.Timeout != 0 && c.UDPRelay.Timeout <= 4 { + return errors.New("invalid UDP relay timeout") } if len(c.Server) == 0 { return errors.New("no server address") diff --git a/pkg/relay/udp.go b/pkg/relay/udp.go new file mode 100644 index 0000000..bf01851 --- /dev/null +++ b/pkg/relay/udp.go @@ -0,0 +1,134 @@ +package relay + +import ( + "errors" + "github.com/tobyxdd/hysteria/pkg/core" + "net" + "sync" + "sync/atomic" + "time" +) + +const udpBufferSize = 65535 + +const udpMinTimeout = 4 * time.Second + +var ErrTimeout = errors.New("inactivity timeout") + +type UDPRelay struct { + HyClient *core.Client + ListenAddr *net.UDPAddr + Remote string + Timeout time.Duration + + ConnFunc func(addr net.Addr) + ErrorFunc func(addr net.Addr, err error) +} + +func NewUDPRelay(hyClient *core.Client, listen, remote string, timeout time.Duration, + connFunc func(addr net.Addr), errorFunc func(addr net.Addr, err error)) (*UDPRelay, error) { + uAddr, err := net.ResolveUDPAddr("udp", listen) + if err != nil { + return nil, err + } + r := &UDPRelay{ + HyClient: hyClient, + ListenAddr: uAddr, + Remote: remote, + Timeout: timeout, + ConnFunc: connFunc, + ErrorFunc: errorFunc, + } + if timeout == 0 { + r.Timeout = 1 * time.Minute + } else if timeout < udpMinTimeout { + r.Timeout = udpMinTimeout + } + return r, nil +} + +type cmEntry struct { + HyConn core.UDPConn + Addr *net.UDPAddr + LastActiveTime atomic.Value +} + +func (r *UDPRelay) ListenAndServe() error { + conn, err := net.ListenUDP("udp", r.ListenAddr) + if err != nil { + return err + } + defer conn.Close() + // src <-> HyClient UDPConn + connMap := make(map[string]*cmEntry) + var connMapMutex sync.RWMutex + // Timeout cleanup routine + stopChan := make(chan bool) + go func() { + ticker := time.NewTicker(udpMinTimeout) + defer ticker.Stop() + for { + select { + case <-stopChan: + return + case t := <-ticker.C: + allowedLAT := t.Add(-r.Timeout) + connMapMutex.Lock() + for k, v := range connMap { + if v.LastActiveTime.Load().(time.Time).Before(allowedLAT) { + // Timeout + r.ErrorFunc(v.Addr, ErrTimeout) + _ = v.HyConn.Close() + delete(connMap, k) + } + } + connMapMutex.Unlock() + } + } + }() + // Read loop + buf := make([]byte, udpBufferSize) + for { + n, rAddr, err := conn.ReadFromUDP(buf) + if n > 0 { + connMapMutex.RLock() + cme := connMap[rAddr.String()] + connMapMutex.RUnlock() + if cme != nil { + // Existing conn + cme.LastActiveTime.Store(time.Now()) + _ = cme.HyConn.WriteTo(buf[:n], r.Remote) + } else { + // New + r.ConnFunc(rAddr) + hyConn, err := r.HyClient.DialUDP() + if err != nil { + r.ErrorFunc(rAddr, err) + } else { + // Add it to the map + ent := &cmEntry{HyConn: hyConn, Addr: rAddr} + ent.LastActiveTime.Store(time.Now()) + connMapMutex.Lock() + connMap[rAddr.String()] = ent + connMapMutex.Unlock() + // Start remote to local + go func() { + for { + bs, _, err := hyConn.ReadFrom() + if err != nil { + break + } + ent.LastActiveTime.Store(time.Now()) + _, _ = conn.WriteToUDP(bs, rAddr) + } + }() + // Send the packet + _ = hyConn.WriteTo(buf[:n], r.Remote) + } + } + } + if err != nil { + return err + } + } +}