From 2e803348419dd9853c4fe5881c9cd668ccce5476 Mon Sep 17 00:00:00 2001 From: Toby Date: Fri, 4 Nov 2022 11:47:24 -0700 Subject: [PATCH] feat: rework multiport address format to support ranges, drop server support (use iptables DNAT instead) --- pkg/transport/pktconns/funcs.go | 24 +-- .../pktconns/{udphop/client.go => udp/hop.go} | 115 +++++++++--- pkg/transport/pktconns/udp/hop_test.go | 102 +++++++++++ pkg/transport/pktconns/udphop/common.go | 52 ------ pkg/transport/pktconns/udphop/server.go | 169 ------------------ 5 files changed, 205 insertions(+), 257 deletions(-) rename pkg/transport/pktconns/{udphop/client.go => udp/hop.go} (66%) create mode 100644 pkg/transport/pktconns/udp/hop_test.go delete mode 100644 pkg/transport/pktconns/udphop/common.go delete mode 100644 pkg/transport/pktconns/udphop/server.go diff --git a/pkg/transport/pktconns/funcs.go b/pkg/transport/pktconns/funcs.go index 30918f9..2df10aa 100644 --- a/pkg/transport/pktconns/funcs.go +++ b/pkg/transport/pktconns/funcs.go @@ -4,8 +4,6 @@ import ( "net" "strings" - "github.com/HyNetwork/hysteria/pkg/transport/pktconns/udphop" - "github.com/HyNetwork/hysteria/pkg/transport/pktconns/faketcp" "github.com/HyNetwork/hysteria/pkg/transport/pktconns/obfs" "github.com/HyNetwork/hysteria/pkg/transport/pktconns/udp" @@ -25,8 +23,8 @@ type ( func NewClientUDPConnFunc(obfsPassword string) ClientPacketConnFunc { if obfsPassword == "" { return func(server string) (net.PacketConn, net.Addr, error) { - if isAddrPortHopping(server) { - return udphop.NewObfsUDPHopClientPacketConn(server, nil) + if isMultiPortAddr(server) { + return udp.NewObfsUDPHopClientPacketConn(server, nil) } sAddr, err := net.ResolveUDPAddr("udp", server) if err != nil { @@ -37,9 +35,9 @@ func NewClientUDPConnFunc(obfsPassword string) ClientPacketConnFunc { } } else { return func(server string) (net.PacketConn, net.Addr, error) { - if isAddrPortHopping(server) { + if isMultiPortAddr(server) { ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) - return udphop.NewObfsUDPHopClientPacketConn(server, ob) + return udp.NewObfsUDPHopClientPacketConn(server, ob) } sAddr, err := net.ResolveUDPAddr("udp", server) if err != nil { @@ -113,9 +111,6 @@ func NewClientFakeTCPConnFunc(obfsPassword string) ClientPacketConnFunc { func NewServerUDPConnFunc(obfsPassword string) ServerPacketConnFunc { if obfsPassword == "" { return func(listen string) (net.PacketConn, error) { - if isAddrPortHopping(listen) { - return udphop.NewObfsUDPHopServerPacketConn(listen, nil) - } laddrU, err := net.ResolveUDPAddr("udp", listen) if err != nil { return nil, err @@ -124,10 +119,6 @@ func NewServerUDPConnFunc(obfsPassword string) ServerPacketConnFunc { } } else { return func(listen string) (net.PacketConn, error) { - if isAddrPortHopping(listen) { - ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) - return udphop.NewObfsUDPHopServerPacketConn(listen, ob) - } ob := obfs.NewXPlusObfuscator([]byte(obfsPassword)) laddrU, err := net.ResolveUDPAddr("udp", listen) if err != nil { @@ -188,7 +179,10 @@ func NewServerFakeTCPConnFunc(obfsPassword string) ServerPacketConnFunc { } } -func isAddrPortHopping(addr string) bool { +func isMultiPortAddr(addr string) bool { _, portStr, err := net.SplitHostPort(addr) - return err == nil && strings.Contains(portStr, ",") + if err == nil && (strings.Contains(portStr, ",") || strings.Contains(portStr, "-")) { + return true + } + return false } diff --git a/pkg/transport/pktconns/udphop/client.go b/pkg/transport/pktconns/udp/hop.go similarity index 66% rename from pkg/transport/pktconns/udphop/client.go rename to pkg/transport/pktconns/udp/hop.go index 8a81856..9404733 100644 --- a/pkg/transport/pktconns/udphop/client.go +++ b/pkg/transport/pktconns/udp/hop.go @@ -1,17 +1,20 @@ -package udphop +package udp import ( "log" "math/rand" "net" + "strconv" + "strings" "sync" "time" "github.com/HyNetwork/hysteria/pkg/transport/pktconns/obfs" - "github.com/HyNetwork/hysteria/pkg/transport/pktconns/udp" ) const ( + packetQueueSize = 1024 + portHoppingInterval = 30 * time.Second ) @@ -35,6 +38,22 @@ type ObfsUDPHopClientPacketConn struct { bufPool sync.Pool } +type udpHopAddr string + +func (a *udpHopAddr) Network() string { + return "udp-hop" +} + +func (a *udpHopAddr) String() string { + return string(*a) +} + +type udpPacket struct { + buf []byte + n int + addr net.Addr +} + func NewObfsUDPHopClientPacketConn(server string, obfs obfs.Obfuscator) (*ObfsUDPHopClientPacketConn, net.Addr, error) { host, ports, err := parseAddr(server) if err != nil { @@ -53,8 +72,9 @@ func NewObfsUDPHopClientPacketConn(server string, obfs obfs.Obfuscator) (*ObfsUD } log.Printf("udphop: server address %s", serverAddrs[i]) } + hopAddr := udpHopAddr(server) conn := &ObfsUDPHopClientPacketConn{ - serverAddr: &udpHopAddr{server}, + serverAddr: &hopAddr, serverAddrs: serverAddrs, obfs: obfs, addrIndex: rand.Intn(len(serverAddrs)), @@ -71,7 +91,7 @@ func NewObfsUDPHopClientPacketConn(server string, obfs obfs.Obfuscator) (*ObfsUD return nil, nil, err } if obfs != nil { - conn.currentConn = udp.NewObfsUDPConn(curConn, obfs) + conn.currentConn = NewObfsUDPConn(curConn, obfs) } else { conn.currentConn = curConn } @@ -134,7 +154,7 @@ func (c *ObfsUDPHopClientPacketConn) hop() { } c.prevConn = c.currentConn if c.obfs != nil { - c.currentConn = udp.NewObfsUDPConn(newConn, c.obfs) + c.currentConn = NewObfsUDPConn(newConn, c.obfs) } else { c.currentConn = newConn } @@ -147,17 +167,25 @@ func (c *ObfsUDPHopClientPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { for { select { case p := <-c.recvQueue: - // Check if the packet is from one of the server addresses - for _, addr := range c.serverAddrs { - if addr.String() == p.addr.String() { - // Copy the packet to the buffer - n := copy(b, p.buf[:p.n]) - c.bufPool.Put(p.buf) - return n, c.serverAddr, nil + /* + // Check if the packet is from one of the server addresses + for _, addr := range c.serverAddrs { + if addr.String() == p.addr.String() { + // Copy the packet to the buffer + n := copy(b, p.buf[:p.n]) + c.bufPool.Put(p.buf) + return n, c.serverAddr, nil + } } - } - // Drop the packet, continue + // Drop the packet, continue + c.bufPool.Put(p.buf) + */ + // The above code was causing performance issues when the range is large, + // so we skip the check for now. Should probably still check by using a map + // or something in the future. + n := copy(b, p.buf[:p.n]) c.bufPool.Put(p.buf) + return n, c.serverAddr, nil case <-c.closeChan: return 0, nil, net.ErrClosed } @@ -200,17 +228,62 @@ func (c *ObfsUDPHopClientPacketConn) LocalAddr() net.Addr { return c.currentConn.LocalAddr() } -func (c *ObfsUDPHopClientPacketConn) SetDeadline(t time.Time) error { - // Not implemented - return nil -} - func (c *ObfsUDPHopClientPacketConn) SetReadDeadline(t time.Time) error { - // Not implemented + // Not supported return nil } func (c *ObfsUDPHopClientPacketConn) SetWriteDeadline(t time.Time) error { - // Not implemented + // Not supported return nil } + +func (c *ObfsUDPHopClientPacketConn) SetDeadline(t time.Time) error { + err := c.SetReadDeadline(t) + if err != nil { + return err + } + return c.SetWriteDeadline(t) +} + +// parseAddr parses the multi-port server address and returns the host and ports. +// Supports both comma-separated single ports and dash-separated port ranges. +// Format: "host:port1,port2-port3,port4" +func parseAddr(addr string) (host string, ports []uint16, err error) { + host, portStr, err := net.SplitHostPort(addr) + if err != nil { + return "", nil, err + } + portStrs := strings.Split(portStr, ",") + for _, portStr := range portStrs { + if strings.Contains(portStr, "-") { + // Port range + portRange := strings.Split(portStr, "-") + if len(portRange) != 2 { + return "", nil, net.InvalidAddrError("invalid port range") + } + start, err := strconv.ParseUint(portRange[0], 10, 16) + if err != nil { + return "", nil, net.InvalidAddrError("invalid port range") + } + end, err := strconv.ParseUint(portRange[1], 10, 16) + if err != nil { + return "", nil, net.InvalidAddrError("invalid port range") + } + if start > end { + start, end = end, start + } + for i := start; i <= end; i++ { + ports = append(ports, uint16(i)) + } + } else { + // Single port + port, err := strconv.ParseUint(portStr, 10, 16) + if err != nil { + return "", nil, net.InvalidAddrError("invalid port") + } + ports = append(ports, uint16(port)) + } + } + return host, ports, nil +} diff --git a/pkg/transport/pktconns/udp/hop_test.go b/pkg/transport/pktconns/udp/hop_test.go new file mode 100644 index 0000000..cacf5a5 --- /dev/null +++ b/pkg/transport/pktconns/udp/hop_test.go @@ -0,0 +1,102 @@ +package udp + +import ( + "reflect" + "testing" +) + +func Test_parseAddr(t *testing.T) { + tests := []struct { + name string + addr string + wantHost string + wantPorts []uint16 + wantErr bool + }{ + { + name: "empty", + addr: "", + wantHost: "", + wantPorts: nil, + wantErr: true, + }, + { + name: "host only", + addr: "example.com", + wantHost: "", + wantPorts: nil, + wantErr: true, + }, + { + name: "single port", + addr: "example.com:1234", + wantHost: "example.com", + wantPorts: []uint16{1234}, + wantErr: false, + }, + { + name: "multi ports", + addr: "example.com:1234,5678,9999", + wantHost: "example.com", + wantPorts: []uint16{1234, 5678, 9999}, + wantErr: false, + }, + { + name: "multi ports with range", + addr: "example.com:1234,5678-5685,9999", + wantHost: "example.com", + wantPorts: []uint16{1234, 5678, 5679, 5680, 5681, 5682, 5683, 5684, 5685, 9999}, + wantErr: false, + }, + { + name: "range single port", + addr: "example.com:1234-1234", + wantHost: "example.com", + wantPorts: []uint16{1234}, + wantErr: false, + }, + { + name: "range reversed", + addr: "example.com:8003-8000", + wantHost: "example.com", + wantPorts: []uint16{8000, 8001, 8002, 8003}, + wantErr: false, + }, + { + name: "invalid port", + addr: "example.com:1234,5678,9999,invalid", + wantHost: "", + wantPorts: nil, + wantErr: true, + }, + { + name: "invalid port range", + addr: "example.com:1234,5678,9999,8000-8002-8004", + wantHost: "", + wantPorts: nil, + wantErr: true, + }, + { + name: "invalid port range 2", + addr: "example.com:1234,5678,9999,8000-woot", + wantHost: "", + wantPorts: nil, + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotHost, gotPorts, err := parseAddr(tt.addr) + if (err != nil) != tt.wantErr { + t.Errorf("parseAddr() error = %v, wantErr %v", err, tt.wantErr) + return + } + if gotHost != tt.wantHost { + t.Errorf("parseAddr() gotHost = %v, want %v", gotHost, tt.wantHost) + } + if !reflect.DeepEqual(gotPorts, tt.wantPorts) { + t.Errorf("parseAddr() gotPorts = %v, want %v", gotPorts, tt.wantPorts) + } + }) + } +} diff --git a/pkg/transport/pktconns/udphop/common.go b/pkg/transport/pktconns/udphop/common.go deleted file mode 100644 index 9205233..0000000 --- a/pkg/transport/pktconns/udphop/common.go +++ /dev/null @@ -1,52 +0,0 @@ -package udphop - -import ( - "net" - "strconv" - "strings" -) - -const ( - packetQueueSize = 1024 - udpBufferSize = 4096 -) - -// parseAddr parses the listen address and returns the host and ports. -// Format: "host:port1,port2,port3,..." -func parseAddr(addr string) (host string, ports []uint16, err error) { - host, portStr, err := net.SplitHostPort(addr) - if err != nil { - return - } - portsStr := strings.Split(portStr, ",") - if len(portsStr) < 2 { - return "", nil, net.InvalidAddrError("at least two ports required") - } - ports = make([]uint16, len(portsStr)) - for i, p := range portsStr { - port, err := strconv.ParseUint(p, 10, 16) - if err != nil { - return "", nil, net.InvalidAddrError("invalid port: " + p) - } - ports[i] = uint16(port) - } - return -} - -type udpHopAddr struct { - listen string -} - -func (a *udpHopAddr) Network() string { - return "udp-hop" -} - -func (a *udpHopAddr) String() string { - return a.listen -} - -type udpPacket struct { - buf []byte - n int - addr net.Addr -} diff --git a/pkg/transport/pktconns/udphop/server.go b/pkg/transport/pktconns/udphop/server.go deleted file mode 100644 index ecc13a2..0000000 --- a/pkg/transport/pktconns/udphop/server.go +++ /dev/null @@ -1,169 +0,0 @@ -package udphop - -import ( - "log" - "net" - "strconv" - "sync" - "time" - - "github.com/HyNetwork/hysteria/pkg/transport/pktconns/obfs" - "github.com/HyNetwork/hysteria/pkg/transport/pktconns/udp" -) - -const ( - addrMapEntryTTL = time.Minute -) - -// ObfsUDPHopServerPacketConn is the UDP port-hopping packet connection for server side. -// It listens on multiple UDP ports and replies to a client using the port it received packet from. -type ObfsUDPHopServerPacketConn struct { - localAddr net.Addr - conns []net.PacketConn - - recvQueue chan *udpPacket - closeChan chan struct{} - - addrMapMutex sync.RWMutex - addrMap map[string]addrMapEntry - - bufPool sync.Pool -} - -type addrMapEntry struct { - index int - last time.Time -} - -func NewObfsUDPHopServerPacketConn(listen string, obfs obfs.Obfuscator) (*ObfsUDPHopServerPacketConn, error) { - host, ports, err := parseAddr(listen) - if err != nil { - return nil, err - } - conns := make([]net.PacketConn, len(ports)) - for i, port := range ports { - addr := net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10)) - udpAddr, err := net.ResolveUDPAddr("udp", addr) - if err != nil { - return nil, err - } - conn, err := net.ListenUDP("udp", udpAddr) - if err != nil { - return nil, err - } - if obfs != nil { - conns[i] = udp.NewObfsUDPConn(conn, obfs) - } else { - conns[i] = conn - } - } - c := &ObfsUDPHopServerPacketConn{ - localAddr: &udpHopAddr{listen}, - conns: conns, - recvQueue: make(chan *udpPacket, packetQueueSize), - closeChan: make(chan struct{}), - addrMap: make(map[string]addrMapEntry), - bufPool: sync.Pool{ - New: func() interface{} { - return make([]byte, udpBufferSize) - }, - }, - } - c.startRecvRoutines() - go c.addrMapCleanupRoutine() - return c, nil -} - -func (c *ObfsUDPHopServerPacketConn) startRecvRoutines() { - for i, conn := range c.conns { - go c.recvRoutine(i, conn) - } -} - -func (c *ObfsUDPHopServerPacketConn) recvRoutine(i int, conn net.PacketConn) { - log.Printf("udphop: receiving on %s", conn.LocalAddr()) - for { - buf := c.bufPool.Get().([]byte) - n, addr, err := conn.ReadFrom(buf) - if err != nil { - log.Printf("udphop: routine %d read error: %v", i, err) - return - } - // Update addrMap - c.addrMapMutex.Lock() - c.addrMap[addr.String()] = addrMapEntry{i, time.Now()} - c.addrMapMutex.Unlock() - select { - case c.recvQueue <- &udpPacket{buf, n, addr}: - // Packet sent to queue - default: - log.Printf("udphop: recv queue full, dropping packet from %s", addr) - c.bufPool.Put(buf) - } - } -} - -func (c *ObfsUDPHopServerPacketConn) addrMapCleanupRoutine() { - ticker := time.NewTicker(time.Minute) - defer ticker.Stop() - for { - select { - case <-ticker.C: - c.addrMapMutex.Lock() - for addr, entry := range c.addrMap { - if time.Since(entry.last) > addrMapEntryTTL { - delete(c.addrMap, addr) - } - } - c.addrMapMutex.Unlock() - case <-c.closeChan: - return - } - } -} - -func (c *ObfsUDPHopServerPacketConn) ReadFrom(b []byte) (int, net.Addr, error) { - select { - case p := <-c.recvQueue: - n := copy(b, p.buf[:p.n]) - c.bufPool.Put(p.buf) - return n, p.addr, nil - case <-c.closeChan: - return 0, nil, net.ErrClosed - } -} - -func (c *ObfsUDPHopServerPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) { - // Find index from addrMap - c.addrMapMutex.RLock() - entry := c.addrMap[addr.String()] - c.addrMapMutex.RUnlock() - return c.conns[entry.index].WriteTo(b, addr) -} - -func (c *ObfsUDPHopServerPacketConn) Close() error { - for _, conn := range c.conns { - _ = conn.Close() // recvRoutines will exit on error - } - close(c.closeChan) - return nil -} - -func (c *ObfsUDPHopServerPacketConn) LocalAddr() net.Addr { - return c.localAddr -} - -func (c *ObfsUDPHopServerPacketConn) SetDeadline(t time.Time) error { - // Not implemented - return nil -} - -func (c *ObfsUDPHopServerPacketConn) SetReadDeadline(t time.Time) error { - // Not implemented - return nil -} - -func (c *ObfsUDPHopServerPacketConn) SetWriteDeadline(t time.Time) error { - // Not implemented - return nil -}