diff --git a/pkg/core/client.go b/pkg/core/client.go index de22aa8..bc3b753 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -13,6 +13,7 @@ import ( "github.com/tobyxdd/hysteria/pkg/pmtud_fix" "github.com/tobyxdd/hysteria/pkg/transport" "github.com/tobyxdd/hysteria/pkg/utils" + "math/rand" "net" "strconv" "sync" @@ -43,6 +44,7 @@ type Client struct { udpSessionMutex sync.RWMutex udpSessionMap map[uint32]chan *udpMessage + udpDefragger defragger } func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, @@ -137,11 +139,15 @@ func (c *Client) handleMessage(qs quic.Session) { if err != nil { continue } + dfMsg := c.udpDefragger.Feed(udpMsg) + if dfMsg == nil { + continue + } c.udpSessionMutex.RLock() - ch, ok := c.udpSessionMap[udpMsg.SessionID] + ch, ok := c.udpSessionMap[dfMsg.SessionID] if ok { select { - case ch <- &udpMsg: + case ch <- dfMsg: // OK default: // Silently drop the message when the channel is full @@ -353,14 +359,38 @@ func (c *quicPktConn) WriteTo(p []byte, addr string) error { if err != nil { return err } - var msgBuf bytes.Buffer - _ = struc.Pack(&msgBuf, &udpMessage{ + msg := udpMessage{ SessionID: c.UDPSessionID, Host: host, Port: port, + FragCount: 1, Data: p, - }) - return c.Session.SendMessage(msgBuf.Bytes()) + } + // try no frag first + var msgBuf bytes.Buffer + _ = struc.Pack(&msgBuf, &msg) + err = c.Session.SendMessage(msgBuf.Bytes()) + if err != nil { + if errSize, ok := err.(quic.ErrMessageToLarge); ok { + // need to frag + msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 + fragMsgs := fragUDPMessage(msg, int(errSize)) + for _, fragMsg := range fragMsgs { + msgBuf.Reset() + _ = struc.Pack(&msgBuf, &fragMsg) + err = c.Session.SendMessage(msgBuf.Bytes()) + if err != nil { + return err + } + } + return nil + } else { + // some other error + return err + } + } else { + return nil + } } func (c *quicPktConn) Close() error { diff --git a/pkg/core/protocol.go b/pkg/core/protocol.go index d85b97b..a7ab386 100644 --- a/pkg/core/protocol.go +++ b/pkg/core/protocol.go @@ -50,7 +50,7 @@ type udpMessage struct { HostLen uint16 `struc:"sizeof=Host"` Host string Port uint16 - MsgID uint16 // doesn't matter when not fragmented + MsgID uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented FragCount uint8 // must be 1 when not fragmented DataLen uint16 `struc:"sizeof=Data"` diff --git a/pkg/core/server_client.go b/pkg/core/server_client.go index 9a96046..79b4521 100644 --- a/pkg/core/server_client.go +++ b/pkg/core/server_client.go @@ -10,6 +10,7 @@ import ( "github.com/tobyxdd/hysteria/pkg/acl" "github.com/tobyxdd/hysteria/pkg/transport" "github.com/tobyxdd/hysteria/pkg/utils" + "math/rand" "net" "strconv" "sync" @@ -35,6 +36,7 @@ type serverClient struct { udpSessionMutex sync.RWMutex udpSessionMap map[uint32]*net.UDPConn nextUDPSessionID uint32 + udpDefragger defragger } func newServerClient(cs quic.Session, transport *transport.ServerTransport, auth []byte, disableUDP bool, ACLEngine *acl.Engine, @@ -123,43 +125,47 @@ func (c *serverClient) handleMessage(msg []byte) { if err != nil { return } + dfMsg := c.udpDefragger.Feed(udpMsg) + if dfMsg == nil { + return + } c.udpSessionMutex.RLock() - conn, ok := c.udpSessionMap[udpMsg.SessionID] + conn, ok := c.udpSessionMap[dfMsg.SessionID] c.udpSessionMutex.RUnlock() if ok { // Session found, send the message action, arg := acl.ActionDirect, "" var ipAddr *net.IPAddr if c.ACLEngine != nil { - action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(udpMsg.Host) + action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host) } else { - ipAddr, err = c.Transport.ResolveIPAddr(udpMsg.Host) + ipAddr, err = c.Transport.ResolveIPAddr(dfMsg.Host) } if err != nil { return } switch action { case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side - _, _ = conn.WriteToUDP(udpMsg.Data, &net.UDPAddr{ + _, _ = conn.WriteToUDP(dfMsg.Data, &net.UDPAddr{ IP: ipAddr.IP, - Port: int(udpMsg.Port), + Port: int(dfMsg.Port), Zone: ipAddr.Zone, }) if c.UpCounter != nil { - c.UpCounter.Add(float64(len(udpMsg.Data))) + c.UpCounter.Add(float64(len(dfMsg.Data))) } case acl.ActionBlock: // Do nothing case acl.ActionHijack: hijackIPAddr, err := c.Transport.ResolveIPAddr(arg) if err == nil { - _, _ = conn.WriteToUDP(udpMsg.Data, &net.UDPAddr{ + _, _ = conn.WriteToUDP(dfMsg.Data, &net.UDPAddr{ IP: hijackIPAddr.IP, - Port: int(udpMsg.Port), + Port: int(dfMsg.Port), Zone: hijackIPAddr.Zone, }) if c.UpCounter != nil { - c.UpCounter.Add(float64(len(udpMsg.Data))) + c.UpCounter.Add(float64(len(dfMsg.Data))) } } default: @@ -297,14 +303,29 @@ func (c *serverClient) handleUDP(stream quic.Stream) { for { n, rAddr, err := conn.ReadFromUDP(buf) if n > 0 { - var msgBuf bytes.Buffer - _ = struc.Pack(&msgBuf, &udpMessage{ + msg := udpMessage{ SessionID: id, Host: rAddr.IP.String(), Port: uint16(rAddr.Port), + FragCount: 1, Data: buf[:n], - }) - _ = c.CS.SendMessage(msgBuf.Bytes()) + } + // try no frag first + var msgBuf bytes.Buffer + _ = struc.Pack(&msgBuf, &msg) + err = c.CS.SendMessage(msgBuf.Bytes()) + if err != nil { + if errSize, ok := err.(quic.ErrMessageToLarge); ok { + // need to frag + msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 + fragMsgs := fragUDPMessage(msg, int(errSize)) + for _, fragMsg := range fragMsgs { + msgBuf.Reset() + _ = struc.Pack(&msgBuf, &fragMsg) + _ = c.CS.SendMessage(msgBuf.Bytes()) + } + } + } if c.DownCounter != nil { c.DownCounter.Add(float64(n)) }