feat: rework multiport address format to support ranges, drop server support (use iptables DNAT instead)

This commit is contained in:
Toby 2022-11-04 11:47:24 -07:00
parent 263ac8d313
commit 2e80334841
5 changed files with 205 additions and 257 deletions

View File

@ -4,8 +4,6 @@ import (
"net"
"strings"
"github.com/HyNetwork/hysteria/pkg/transport/pktconns/udphop"
"github.com/HyNetwork/hysteria/pkg/transport/pktconns/faketcp"
"github.com/HyNetwork/hysteria/pkg/transport/pktconns/obfs"
"github.com/HyNetwork/hysteria/pkg/transport/pktconns/udp"
@ -25,8 +23,8 @@ type (
func NewClientUDPConnFunc(obfsPassword string) ClientPacketConnFunc {
if obfsPassword == "" {
return func(server string) (net.PacketConn, net.Addr, error) {
if isAddrPortHopping(server) {
return udphop.NewObfsUDPHopClientPacketConn(server, nil)
if isMultiPortAddr(server) {
return udp.NewObfsUDPHopClientPacketConn(server, nil)
}
sAddr, err := net.ResolveUDPAddr("udp", server)
if err != nil {
@ -37,9 +35,9 @@ func NewClientUDPConnFunc(obfsPassword string) ClientPacketConnFunc {
}
} else {
return func(server string) (net.PacketConn, net.Addr, error) {
if isAddrPortHopping(server) {
if isMultiPortAddr(server) {
ob := obfs.NewXPlusObfuscator([]byte(obfsPassword))
return udphop.NewObfsUDPHopClientPacketConn(server, ob)
return udp.NewObfsUDPHopClientPacketConn(server, ob)
}
sAddr, err := net.ResolveUDPAddr("udp", server)
if err != nil {
@ -113,9 +111,6 @@ func NewClientFakeTCPConnFunc(obfsPassword string) ClientPacketConnFunc {
func NewServerUDPConnFunc(obfsPassword string) ServerPacketConnFunc {
if obfsPassword == "" {
return func(listen string) (net.PacketConn, error) {
if isAddrPortHopping(listen) {
return udphop.NewObfsUDPHopServerPacketConn(listen, nil)
}
laddrU, err := net.ResolveUDPAddr("udp", listen)
if err != nil {
return nil, err
@ -124,10 +119,6 @@ func NewServerUDPConnFunc(obfsPassword string) ServerPacketConnFunc {
}
} else {
return func(listen string) (net.PacketConn, error) {
if isAddrPortHopping(listen) {
ob := obfs.NewXPlusObfuscator([]byte(obfsPassword))
return udphop.NewObfsUDPHopServerPacketConn(listen, ob)
}
ob := obfs.NewXPlusObfuscator([]byte(obfsPassword))
laddrU, err := net.ResolveUDPAddr("udp", listen)
if err != nil {
@ -188,7 +179,10 @@ func NewServerFakeTCPConnFunc(obfsPassword string) ServerPacketConnFunc {
}
}
func isAddrPortHopping(addr string) bool {
func isMultiPortAddr(addr string) bool {
_, portStr, err := net.SplitHostPort(addr)
return err == nil && strings.Contains(portStr, ",")
if err == nil && (strings.Contains(portStr, ",") || strings.Contains(portStr, "-")) {
return true
}
return false
}

View File

@ -1,17 +1,20 @@
package udphop
package udp
import (
"log"
"math/rand"
"net"
"strconv"
"strings"
"sync"
"time"
"github.com/HyNetwork/hysteria/pkg/transport/pktconns/obfs"
"github.com/HyNetwork/hysteria/pkg/transport/pktconns/udp"
)
const (
packetQueueSize = 1024
portHoppingInterval = 30 * time.Second
)
@ -35,6 +38,22 @@ type ObfsUDPHopClientPacketConn struct {
bufPool sync.Pool
}
type udpHopAddr string
func (a *udpHopAddr) Network() string {
return "udp-hop"
}
func (a *udpHopAddr) String() string {
return string(*a)
}
type udpPacket struct {
buf []byte
n int
addr net.Addr
}
func NewObfsUDPHopClientPacketConn(server string, obfs obfs.Obfuscator) (*ObfsUDPHopClientPacketConn, net.Addr, error) {
host, ports, err := parseAddr(server)
if err != nil {
@ -53,8 +72,9 @@ func NewObfsUDPHopClientPacketConn(server string, obfs obfs.Obfuscator) (*ObfsUD
}
log.Printf("udphop: server address %s", serverAddrs[i])
}
hopAddr := udpHopAddr(server)
conn := &ObfsUDPHopClientPacketConn{
serverAddr: &udpHopAddr{server},
serverAddr: &hopAddr,
serverAddrs: serverAddrs,
obfs: obfs,
addrIndex: rand.Intn(len(serverAddrs)),
@ -71,7 +91,7 @@ func NewObfsUDPHopClientPacketConn(server string, obfs obfs.Obfuscator) (*ObfsUD
return nil, nil, err
}
if obfs != nil {
conn.currentConn = udp.NewObfsUDPConn(curConn, obfs)
conn.currentConn = NewObfsUDPConn(curConn, obfs)
} else {
conn.currentConn = curConn
}
@ -134,7 +154,7 @@ func (c *ObfsUDPHopClientPacketConn) hop() {
}
c.prevConn = c.currentConn
if c.obfs != nil {
c.currentConn = udp.NewObfsUDPConn(newConn, c.obfs)
c.currentConn = NewObfsUDPConn(newConn, c.obfs)
} else {
c.currentConn = newConn
}
@ -147,17 +167,25 @@ func (c *ObfsUDPHopClientPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
for {
select {
case p := <-c.recvQueue:
// Check if the packet is from one of the server addresses
for _, addr := range c.serverAddrs {
if addr.String() == p.addr.String() {
// Copy the packet to the buffer
n := copy(b, p.buf[:p.n])
c.bufPool.Put(p.buf)
return n, c.serverAddr, nil
/*
// Check if the packet is from one of the server addresses
for _, addr := range c.serverAddrs {
if addr.String() == p.addr.String() {
// Copy the packet to the buffer
n := copy(b, p.buf[:p.n])
c.bufPool.Put(p.buf)
return n, c.serverAddr, nil
}
}
}
// Drop the packet, continue
// Drop the packet, continue
c.bufPool.Put(p.buf)
*/
// The above code was causing performance issues when the range is large,
// so we skip the check for now. Should probably still check by using a map
// or something in the future.
n := copy(b, p.buf[:p.n])
c.bufPool.Put(p.buf)
return n, c.serverAddr, nil
case <-c.closeChan:
return 0, nil, net.ErrClosed
}
@ -200,17 +228,62 @@ func (c *ObfsUDPHopClientPacketConn) LocalAddr() net.Addr {
return c.currentConn.LocalAddr()
}
func (c *ObfsUDPHopClientPacketConn) SetDeadline(t time.Time) error {
// Not implemented
return nil
}
func (c *ObfsUDPHopClientPacketConn) SetReadDeadline(t time.Time) error {
// Not implemented
// Not supported
return nil
}
func (c *ObfsUDPHopClientPacketConn) SetWriteDeadline(t time.Time) error {
// Not implemented
// Not supported
return nil
}
func (c *ObfsUDPHopClientPacketConn) SetDeadline(t time.Time) error {
err := c.SetReadDeadline(t)
if err != nil {
return err
}
return c.SetWriteDeadline(t)
}
// parseAddr parses the multi-port server address and returns the host and ports.
// Supports both comma-separated single ports and dash-separated port ranges.
// Format: "host:port1,port2-port3,port4"
func parseAddr(addr string) (host string, ports []uint16, err error) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return "", nil, err
}
portStrs := strings.Split(portStr, ",")
for _, portStr := range portStrs {
if strings.Contains(portStr, "-") {
// Port range
portRange := strings.Split(portStr, "-")
if len(portRange) != 2 {
return "", nil, net.InvalidAddrError("invalid port range")
}
start, err := strconv.ParseUint(portRange[0], 10, 16)
if err != nil {
return "", nil, net.InvalidAddrError("invalid port range")
}
end, err := strconv.ParseUint(portRange[1], 10, 16)
if err != nil {
return "", nil, net.InvalidAddrError("invalid port range")
}
if start > end {
start, end = end, start
}
for i := start; i <= end; i++ {
ports = append(ports, uint16(i))
}
} else {
// Single port
port, err := strconv.ParseUint(portStr, 10, 16)
if err != nil {
return "", nil, net.InvalidAddrError("invalid port")
}
ports = append(ports, uint16(port))
}
}
return host, ports, nil
}

View File

@ -0,0 +1,102 @@
package udp
import (
"reflect"
"testing"
)
func Test_parseAddr(t *testing.T) {
tests := []struct {
name string
addr string
wantHost string
wantPorts []uint16
wantErr bool
}{
{
name: "empty",
addr: "",
wantHost: "",
wantPorts: nil,
wantErr: true,
},
{
name: "host only",
addr: "example.com",
wantHost: "",
wantPorts: nil,
wantErr: true,
},
{
name: "single port",
addr: "example.com:1234",
wantHost: "example.com",
wantPorts: []uint16{1234},
wantErr: false,
},
{
name: "multi ports",
addr: "example.com:1234,5678,9999",
wantHost: "example.com",
wantPorts: []uint16{1234, 5678, 9999},
wantErr: false,
},
{
name: "multi ports with range",
addr: "example.com:1234,5678-5685,9999",
wantHost: "example.com",
wantPorts: []uint16{1234, 5678, 5679, 5680, 5681, 5682, 5683, 5684, 5685, 9999},
wantErr: false,
},
{
name: "range single port",
addr: "example.com:1234-1234",
wantHost: "example.com",
wantPorts: []uint16{1234},
wantErr: false,
},
{
name: "range reversed",
addr: "example.com:8003-8000",
wantHost: "example.com",
wantPorts: []uint16{8000, 8001, 8002, 8003},
wantErr: false,
},
{
name: "invalid port",
addr: "example.com:1234,5678,9999,invalid",
wantHost: "",
wantPorts: nil,
wantErr: true,
},
{
name: "invalid port range",
addr: "example.com:1234,5678,9999,8000-8002-8004",
wantHost: "",
wantPorts: nil,
wantErr: true,
},
{
name: "invalid port range 2",
addr: "example.com:1234,5678,9999,8000-woot",
wantHost: "",
wantPorts: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotHost, gotPorts, err := parseAddr(tt.addr)
if (err != nil) != tt.wantErr {
t.Errorf("parseAddr() error = %v, wantErr %v", err, tt.wantErr)
return
}
if gotHost != tt.wantHost {
t.Errorf("parseAddr() gotHost = %v, want %v", gotHost, tt.wantHost)
}
if !reflect.DeepEqual(gotPorts, tt.wantPorts) {
t.Errorf("parseAddr() gotPorts = %v, want %v", gotPorts, tt.wantPorts)
}
})
}
}

View File

@ -1,52 +0,0 @@
package udphop
import (
"net"
"strconv"
"strings"
)
const (
packetQueueSize = 1024
udpBufferSize = 4096
)
// parseAddr parses the listen address and returns the host and ports.
// Format: "host:port1,port2,port3,..."
func parseAddr(addr string) (host string, ports []uint16, err error) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return
}
portsStr := strings.Split(portStr, ",")
if len(portsStr) < 2 {
return "", nil, net.InvalidAddrError("at least two ports required")
}
ports = make([]uint16, len(portsStr))
for i, p := range portsStr {
port, err := strconv.ParseUint(p, 10, 16)
if err != nil {
return "", nil, net.InvalidAddrError("invalid port: " + p)
}
ports[i] = uint16(port)
}
return
}
type udpHopAddr struct {
listen string
}
func (a *udpHopAddr) Network() string {
return "udp-hop"
}
func (a *udpHopAddr) String() string {
return a.listen
}
type udpPacket struct {
buf []byte
n int
addr net.Addr
}

View File

@ -1,169 +0,0 @@
package udphop
import (
"log"
"net"
"strconv"
"sync"
"time"
"github.com/HyNetwork/hysteria/pkg/transport/pktconns/obfs"
"github.com/HyNetwork/hysteria/pkg/transport/pktconns/udp"
)
const (
addrMapEntryTTL = time.Minute
)
// ObfsUDPHopServerPacketConn is the UDP port-hopping packet connection for server side.
// It listens on multiple UDP ports and replies to a client using the port it received packet from.
type ObfsUDPHopServerPacketConn struct {
localAddr net.Addr
conns []net.PacketConn
recvQueue chan *udpPacket
closeChan chan struct{}
addrMapMutex sync.RWMutex
addrMap map[string]addrMapEntry
bufPool sync.Pool
}
type addrMapEntry struct {
index int
last time.Time
}
func NewObfsUDPHopServerPacketConn(listen string, obfs obfs.Obfuscator) (*ObfsUDPHopServerPacketConn, error) {
host, ports, err := parseAddr(listen)
if err != nil {
return nil, err
}
conns := make([]net.PacketConn, len(ports))
for i, port := range ports {
addr := net.JoinHostPort(host, strconv.FormatUint(uint64(port), 10))
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
conn, err := net.ListenUDP("udp", udpAddr)
if err != nil {
return nil, err
}
if obfs != nil {
conns[i] = udp.NewObfsUDPConn(conn, obfs)
} else {
conns[i] = conn
}
}
c := &ObfsUDPHopServerPacketConn{
localAddr: &udpHopAddr{listen},
conns: conns,
recvQueue: make(chan *udpPacket, packetQueueSize),
closeChan: make(chan struct{}),
addrMap: make(map[string]addrMapEntry),
bufPool: sync.Pool{
New: func() interface{} {
return make([]byte, udpBufferSize)
},
},
}
c.startRecvRoutines()
go c.addrMapCleanupRoutine()
return c, nil
}
func (c *ObfsUDPHopServerPacketConn) startRecvRoutines() {
for i, conn := range c.conns {
go c.recvRoutine(i, conn)
}
}
func (c *ObfsUDPHopServerPacketConn) recvRoutine(i int, conn net.PacketConn) {
log.Printf("udphop: receiving on %s", conn.LocalAddr())
for {
buf := c.bufPool.Get().([]byte)
n, addr, err := conn.ReadFrom(buf)
if err != nil {
log.Printf("udphop: routine %d read error: %v", i, err)
return
}
// Update addrMap
c.addrMapMutex.Lock()
c.addrMap[addr.String()] = addrMapEntry{i, time.Now()}
c.addrMapMutex.Unlock()
select {
case c.recvQueue <- &udpPacket{buf, n, addr}:
// Packet sent to queue
default:
log.Printf("udphop: recv queue full, dropping packet from %s", addr)
c.bufPool.Put(buf)
}
}
}
func (c *ObfsUDPHopServerPacketConn) addrMapCleanupRoutine() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-ticker.C:
c.addrMapMutex.Lock()
for addr, entry := range c.addrMap {
if time.Since(entry.last) > addrMapEntryTTL {
delete(c.addrMap, addr)
}
}
c.addrMapMutex.Unlock()
case <-c.closeChan:
return
}
}
}
func (c *ObfsUDPHopServerPacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
select {
case p := <-c.recvQueue:
n := copy(b, p.buf[:p.n])
c.bufPool.Put(p.buf)
return n, p.addr, nil
case <-c.closeChan:
return 0, nil, net.ErrClosed
}
}
func (c *ObfsUDPHopServerPacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
// Find index from addrMap
c.addrMapMutex.RLock()
entry := c.addrMap[addr.String()]
c.addrMapMutex.RUnlock()
return c.conns[entry.index].WriteTo(b, addr)
}
func (c *ObfsUDPHopServerPacketConn) Close() error {
for _, conn := range c.conns {
_ = conn.Close() // recvRoutines will exit on error
}
close(c.closeChan)
return nil
}
func (c *ObfsUDPHopServerPacketConn) LocalAddr() net.Addr {
return c.localAddr
}
func (c *ObfsUDPHopServerPacketConn) SetDeadline(t time.Time) error {
// Not implemented
return nil
}
func (c *ObfsUDPHopServerPacketConn) SetReadDeadline(t time.Time) error {
// Not implemented
return nil
}
func (c *ObfsUDPHopServerPacketConn) SetWriteDeadline(t time.Time) error {
// Not implemented
return nil
}