diff --git a/go.mod b/go.mod index 1aad654..45371d7 100644 --- a/go.mod +++ b/go.mod @@ -74,4 +74,4 @@ require ( gopkg.in/yaml.v2 v2.4.0 // indirect ) -replace github.com/lucas-clemente/quic-go => github.com/tobyxdd/quic-go v0.25.1-0.20220220071449-45d21d89d5d4 +replace github.com/lucas-clemente/quic-go => github.com/tobyxdd/quic-go v0.25.1-0.20220224051149-310bd1bfaf1f diff --git a/go.sum b/go.sum index 6b6ac3c..311a20c 100644 --- a/go.sum +++ b/go.sum @@ -502,8 +502,8 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/subosito/gotenv v1.2.0 h1:Slr1R9HxAlEKefgq5jn9U+DnETlIUa6HfgEzj0g5d7s= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tarm/serial v0.0.0-20180830185346-98f6abe2eb07/go.mod h1:kDXzergiv9cbyO7IOYJZWg1U88JhDg3PB6klq9Hg2pA= -github.com/tobyxdd/quic-go v0.25.1-0.20220220071449-45d21d89d5d4 h1:2oHsHe9vfX0djOpa6y0OahD/wmtdbrLN6o/idWT7hWo= -github.com/tobyxdd/quic-go v0.25.1-0.20220220071449-45d21d89d5d4/go.mod h1:YtzP8bxRVCBlO77yRanE264+fY/T2U9ZlW1AaHOsMOg= +github.com/tobyxdd/quic-go v0.25.1-0.20220224051149-310bd1bfaf1f h1:phddE/foYEnuZOgKfmfYVcnXBvV5cjhGrqBtSETqvQ0= +github.com/tobyxdd/quic-go v0.25.1-0.20220224051149-310bd1bfaf1f/go.mod h1:YtzP8bxRVCBlO77yRanE264+fY/T2U9ZlW1AaHOsMOg= github.com/tv42/httpunix v0.0.0-20150427012821-b75d8614f926/go.mod h1:9ESjWnEqriFuLhtthL60Sar/7RFoluCcXsuvEwTV5KM= github.com/txthinking/runnergroup v0.0.0-20210326110939-37fc67d0da7c h1:6WIrmZPMl2Q61vozy5MfJNfD6CAgivGFgqvXsrho8GM= github.com/txthinking/runnergroup v0.0.0-20210326110939-37fc67d0da7c/go.mod h1:CLUSJbazqETbaR+i0YAhXBICV9TrKH93pziccMhmhpM= diff --git a/pkg/core/client.go b/pkg/core/client.go index de22aa8..bc3b753 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -13,6 +13,7 @@ import ( "github.com/tobyxdd/hysteria/pkg/pmtud_fix" "github.com/tobyxdd/hysteria/pkg/transport" "github.com/tobyxdd/hysteria/pkg/utils" + "math/rand" "net" "strconv" "sync" @@ -43,6 +44,7 @@ type Client struct { udpSessionMutex sync.RWMutex udpSessionMap map[uint32]chan *udpMessage + udpDefragger defragger } func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, @@ -137,11 +139,15 @@ func (c *Client) handleMessage(qs quic.Session) { if err != nil { continue } + dfMsg := c.udpDefragger.Feed(udpMsg) + if dfMsg == nil { + continue + } c.udpSessionMutex.RLock() - ch, ok := c.udpSessionMap[udpMsg.SessionID] + ch, ok := c.udpSessionMap[dfMsg.SessionID] if ok { select { - case ch <- &udpMsg: + case ch <- dfMsg: // OK default: // Silently drop the message when the channel is full @@ -353,14 +359,38 @@ func (c *quicPktConn) WriteTo(p []byte, addr string) error { if err != nil { return err } - var msgBuf bytes.Buffer - _ = struc.Pack(&msgBuf, &udpMessage{ + msg := udpMessage{ SessionID: c.UDPSessionID, Host: host, Port: port, + FragCount: 1, Data: p, - }) - return c.Session.SendMessage(msgBuf.Bytes()) + } + // try no frag first + var msgBuf bytes.Buffer + _ = struc.Pack(&msgBuf, &msg) + err = c.Session.SendMessage(msgBuf.Bytes()) + if err != nil { + if errSize, ok := err.(quic.ErrMessageToLarge); ok { + // need to frag + msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 + fragMsgs := fragUDPMessage(msg, int(errSize)) + for _, fragMsg := range fragMsgs { + msgBuf.Reset() + _ = struc.Pack(&msgBuf, &fragMsg) + err = c.Session.SendMessage(msgBuf.Bytes()) + if err != nil { + return err + } + } + return nil + } else { + // some other error + return err + } + } else { + return nil + } } func (c *quicPktConn) Close() error { diff --git a/pkg/core/frag.go b/pkg/core/frag.go new file mode 100644 index 0000000..7a38774 --- /dev/null +++ b/pkg/core/frag.go @@ -0,0 +1,67 @@ +package core + +func fragUDPMessage(m udpMessage, maxSize int) []udpMessage { + if m.Size() <= maxSize { + return []udpMessage{m} + } + fullPayload := m.Data + maxPayloadSize := maxSize - m.HeaderSize() + off := 0 + fragID := uint8(0) + fragCount := uint8((len(fullPayload) + maxPayloadSize - 1) / maxPayloadSize) // round up + var frags []udpMessage + for off < len(fullPayload) { + payloadSize := len(fullPayload) - off + if payloadSize > maxPayloadSize { + payloadSize = maxPayloadSize + } + frag := m + frag.FragID = fragID + frag.FragCount = fragCount + frag.DataLen = uint16(payloadSize) + frag.Data = fullPayload[off : off+payloadSize] + frags = append(frags, frag) + off += payloadSize + fragID++ + } + return frags +} + +type defragger struct { + msgID uint16 + frags []*udpMessage + count uint8 +} + +func (d *defragger) Feed(m udpMessage) *udpMessage { + if m.FragCount <= 1 { + return &m + } + if m.FragID >= m.FragCount { + // wtf is this? + return nil + } + if m.MsgID != d.msgID { + // new message, clear previous state + d.msgID = m.MsgID + d.frags = make([]*udpMessage, m.FragCount) + d.count = 1 + d.frags[m.FragID] = &m + } else if d.frags[m.FragID] == nil { + d.frags[m.FragID] = &m + d.count++ + if int(d.count) == len(d.frags) { + // all fragments received, assemble + var data []byte + for _, frag := range d.frags { + data = append(data, frag.Data...) + } + m.DataLen = uint16(len(data)) + m.Data = data + m.FragID = 0 + m.FragCount = 1 + return &m + } + } + return nil +} diff --git a/pkg/core/frag_test.go b/pkg/core/frag_test.go new file mode 100644 index 0000000..f2f2462 --- /dev/null +++ b/pkg/core/frag_test.go @@ -0,0 +1,346 @@ +package core + +import ( + "reflect" + "testing" +) + +func Test_fragUDPMessage(t *testing.T) { + type args struct { + m udpMessage + maxSize int + } + tests := []struct { + name string + args args + want []udpMessage + }{ + { + "no frag", + args{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 123, + FragID: 0, + FragCount: 1, + DataLen: 5, + Data: []byte("hello"), + }, + 100, + }, + []udpMessage{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 123, + FragID: 0, + FragCount: 1, + DataLen: 5, + Data: []byte("hello"), + }, + }, + }, + { + "2 frags", + args{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 123, + FragID: 0, + FragCount: 1, + DataLen: 5, + Data: []byte("hello"), + }, + 22, + }, + []udpMessage{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 123, + FragID: 0, + FragCount: 2, + DataLen: 4, + Data: []byte("hell"), + }, + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 123, + FragID: 1, + FragCount: 2, + DataLen: 1, + Data: []byte("o"), + }, + }, + }, + { + "4 frags", + args{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 123, + FragID: 0, + FragCount: 1, + DataLen: 20, + Data: []byte("wow wow wow lol lmao"), + }, + 23, + }, + []udpMessage{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 123, + FragID: 0, + FragCount: 4, + DataLen: 5, + Data: []byte("wow w"), + }, + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 123, + FragID: 1, + FragCount: 4, + DataLen: 5, + Data: []byte("ow wo"), + }, + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 123, + FragID: 2, + FragCount: 4, + DataLen: 5, + Data: []byte("w lol"), + }, + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 123, + FragID: 3, + FragCount: 4, + DataLen: 5, + Data: []byte(" lmao"), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := fragUDPMessage(tt.args.m, tt.args.maxSize); !reflect.DeepEqual(got, tt.want) { + t.Errorf("fragUDPMessage() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_defragger_Feed(t *testing.T) { + d := &defragger{} + type args struct { + m udpMessage + } + tests := []struct { + name string + args args + want *udpMessage + }{ + { + "no frag", + args{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 123, + FragID: 0, + FragCount: 1, + DataLen: 5, + Data: []byte("hello"), + }, + }, + &udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 123, + FragID: 0, + FragCount: 1, + DataLen: 5, + Data: []byte("hello"), + }, + }, + { + "frag 1 - 1/3", + args{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 666, + FragID: 0, + FragCount: 3, + DataLen: 5, + Data: []byte("hello"), + }, + }, + nil, + }, + { + "frag 1 - 2/3", + args{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 666, + FragID: 1, + FragCount: 3, + DataLen: 8, + Data: []byte(" shitty "), + }, + }, + nil, + }, + { + "frag 1 - 3/3", + args{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 666, + FragID: 2, + FragCount: 3, + DataLen: 7, + Data: []byte("world!!"), + }, + }, + &udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 666, + FragID: 0, + FragCount: 1, + DataLen: 20, + Data: []byte("hello shitty world!!"), + }, + }, + { + "frag 2 - 1/2", + args{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 777, + FragID: 0, + FragCount: 2, + DataLen: 5, + Data: []byte("hello"), + }, + }, + nil, + }, + { + "frag 3 - 2/2", + args{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 778, + FragID: 1, + FragCount: 2, + DataLen: 5, + Data: []byte(" moto"), + }, + }, + nil, + }, + { + "frag 2 - 2/2", + args{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 777, + FragID: 1, + FragCount: 2, + DataLen: 5, + Data: []byte(" moto"), + }, + }, + nil, + }, + { + "frag 2 - 1/2 re", + args{ + udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 777, + FragID: 0, + FragCount: 2, + DataLen: 5, + Data: []byte("hello"), + }, + }, + &udpMessage{ + SessionID: 123, + HostLen: 4, + Host: "test", + Port: 123, + MsgID: 777, + FragID: 0, + FragCount: 1, + DataLen: 10, + Data: []byte("hello moto"), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := d.Feed(tt.args.m); !reflect.DeepEqual(got, tt.want) { + t.Errorf("Feed() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/pkg/core/protocol.go b/pkg/core/protocol.go index d7ccc6b..a7ab386 100644 --- a/pkg/core/protocol.go +++ b/pkg/core/protocol.go @@ -5,7 +5,7 @@ import ( ) const ( - protocolVersion = uint8(2) + protocolVersion = uint8(3) protocolTimeout = 10 * time.Second closeErrorCodeGeneric = 0 @@ -50,6 +50,17 @@ type udpMessage struct { HostLen uint16 `struc:"sizeof=Host"` Host string Port uint16 + MsgID uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented + FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented + FragCount uint8 // must be 1 when not fragmented DataLen uint16 `struc:"sizeof=Data"` Data []byte } + +func (m udpMessage) HeaderSize() int { + return 4 + 2 + len(m.Host) + 2 + 2 + 1 + 1 + 2 +} + +func (m udpMessage) Size() int { + return m.HeaderSize() + len(m.Data) +} diff --git a/pkg/core/server_client.go b/pkg/core/server_client.go index 9a96046..79b4521 100644 --- a/pkg/core/server_client.go +++ b/pkg/core/server_client.go @@ -10,6 +10,7 @@ import ( "github.com/tobyxdd/hysteria/pkg/acl" "github.com/tobyxdd/hysteria/pkg/transport" "github.com/tobyxdd/hysteria/pkg/utils" + "math/rand" "net" "strconv" "sync" @@ -35,6 +36,7 @@ type serverClient struct { udpSessionMutex sync.RWMutex udpSessionMap map[uint32]*net.UDPConn nextUDPSessionID uint32 + udpDefragger defragger } func newServerClient(cs quic.Session, transport *transport.ServerTransport, auth []byte, disableUDP bool, ACLEngine *acl.Engine, @@ -123,43 +125,47 @@ func (c *serverClient) handleMessage(msg []byte) { if err != nil { return } + dfMsg := c.udpDefragger.Feed(udpMsg) + if dfMsg == nil { + return + } c.udpSessionMutex.RLock() - conn, ok := c.udpSessionMap[udpMsg.SessionID] + conn, ok := c.udpSessionMap[dfMsg.SessionID] c.udpSessionMutex.RUnlock() if ok { // Session found, send the message action, arg := acl.ActionDirect, "" var ipAddr *net.IPAddr if c.ACLEngine != nil { - action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(udpMsg.Host) + action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host) } else { - ipAddr, err = c.Transport.ResolveIPAddr(udpMsg.Host) + ipAddr, err = c.Transport.ResolveIPAddr(dfMsg.Host) } if err != nil { return } switch action { case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side - _, _ = conn.WriteToUDP(udpMsg.Data, &net.UDPAddr{ + _, _ = conn.WriteToUDP(dfMsg.Data, &net.UDPAddr{ IP: ipAddr.IP, - Port: int(udpMsg.Port), + Port: int(dfMsg.Port), Zone: ipAddr.Zone, }) if c.UpCounter != nil { - c.UpCounter.Add(float64(len(udpMsg.Data))) + c.UpCounter.Add(float64(len(dfMsg.Data))) } case acl.ActionBlock: // Do nothing case acl.ActionHijack: hijackIPAddr, err := c.Transport.ResolveIPAddr(arg) if err == nil { - _, _ = conn.WriteToUDP(udpMsg.Data, &net.UDPAddr{ + _, _ = conn.WriteToUDP(dfMsg.Data, &net.UDPAddr{ IP: hijackIPAddr.IP, - Port: int(udpMsg.Port), + Port: int(dfMsg.Port), Zone: hijackIPAddr.Zone, }) if c.UpCounter != nil { - c.UpCounter.Add(float64(len(udpMsg.Data))) + c.UpCounter.Add(float64(len(dfMsg.Data))) } } default: @@ -297,14 +303,29 @@ func (c *serverClient) handleUDP(stream quic.Stream) { for { n, rAddr, err := conn.ReadFromUDP(buf) if n > 0 { - var msgBuf bytes.Buffer - _ = struc.Pack(&msgBuf, &udpMessage{ + msg := udpMessage{ SessionID: id, Host: rAddr.IP.String(), Port: uint16(rAddr.Port), + FragCount: 1, Data: buf[:n], - }) - _ = c.CS.SendMessage(msgBuf.Bytes()) + } + // try no frag first + var msgBuf bytes.Buffer + _ = struc.Pack(&msgBuf, &msg) + err = c.CS.SendMessage(msgBuf.Bytes()) + if err != nil { + if errSize, ok := err.(quic.ErrMessageToLarge); ok { + // need to frag + msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 + fragMsgs := fragUDPMessage(msg, int(errSize)) + for _, fragMsg := range fragMsgs { + msgBuf.Reset() + _ = struc.Pack(&msgBuf, &fragMsg) + _ = c.CS.SendMessage(msgBuf.Bytes()) + } + } + } if c.DownCounter != nil { c.DownCounter.Add(float64(n)) }