feat: full frag support

This commit is contained in:
Toby 2022-02-25 17:08:54 -08:00
parent 6efa976a56
commit 1789db9ade
3 changed files with 71 additions and 20 deletions

View File

@ -13,6 +13,7 @@ import (
"github.com/tobyxdd/hysteria/pkg/pmtud_fix" "github.com/tobyxdd/hysteria/pkg/pmtud_fix"
"github.com/tobyxdd/hysteria/pkg/transport" "github.com/tobyxdd/hysteria/pkg/transport"
"github.com/tobyxdd/hysteria/pkg/utils" "github.com/tobyxdd/hysteria/pkg/utils"
"math/rand"
"net" "net"
"strconv" "strconv"
"sync" "sync"
@ -43,6 +44,7 @@ type Client struct {
udpSessionMutex sync.RWMutex udpSessionMutex sync.RWMutex
udpSessionMap map[uint32]chan *udpMessage udpSessionMap map[uint32]chan *udpMessage
udpDefragger defragger
} }
func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
@ -137,11 +139,15 @@ func (c *Client) handleMessage(qs quic.Session) {
if err != nil { if err != nil {
continue continue
} }
dfMsg := c.udpDefragger.Feed(udpMsg)
if dfMsg == nil {
continue
}
c.udpSessionMutex.RLock() c.udpSessionMutex.RLock()
ch, ok := c.udpSessionMap[udpMsg.SessionID] ch, ok := c.udpSessionMap[dfMsg.SessionID]
if ok { if ok {
select { select {
case ch <- &udpMsg: case ch <- dfMsg:
// OK // OK
default: default:
// Silently drop the message when the channel is full // Silently drop the message when the channel is full
@ -353,14 +359,38 @@ func (c *quicPktConn) WriteTo(p []byte, addr string) error {
if err != nil { if err != nil {
return err return err
} }
var msgBuf bytes.Buffer msg := udpMessage{
_ = struc.Pack(&msgBuf, &udpMessage{
SessionID: c.UDPSessionID, SessionID: c.UDPSessionID,
Host: host, Host: host,
Port: port, Port: port,
FragCount: 1,
Data: p, Data: p,
}) }
return c.Session.SendMessage(msgBuf.Bytes()) // try no frag first
var msgBuf bytes.Buffer
_ = struc.Pack(&msgBuf, &msg)
err = c.Session.SendMessage(msgBuf.Bytes())
if err != nil {
if errSize, ok := err.(quic.ErrMessageToLarge); ok {
// need to frag
msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
fragMsgs := fragUDPMessage(msg, int(errSize))
for _, fragMsg := range fragMsgs {
msgBuf.Reset()
_ = struc.Pack(&msgBuf, &fragMsg)
err = c.Session.SendMessage(msgBuf.Bytes())
if err != nil {
return err
}
}
return nil
} else {
// some other error
return err
}
} else {
return nil
}
} }
func (c *quicPktConn) Close() error { func (c *quicPktConn) Close() error {

View File

@ -50,7 +50,7 @@ type udpMessage struct {
HostLen uint16 `struc:"sizeof=Host"` HostLen uint16 `struc:"sizeof=Host"`
Host string Host string
Port uint16 Port uint16
MsgID uint16 // doesn't matter when not fragmented MsgID uint16 // doesn't matter when not fragmented, but must not be 0 when fragmented
FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented FragID uint8 // doesn't matter when not fragmented, starts at 0 when fragmented
FragCount uint8 // must be 1 when not fragmented FragCount uint8 // must be 1 when not fragmented
DataLen uint16 `struc:"sizeof=Data"` DataLen uint16 `struc:"sizeof=Data"`

View File

@ -10,6 +10,7 @@ import (
"github.com/tobyxdd/hysteria/pkg/acl" "github.com/tobyxdd/hysteria/pkg/acl"
"github.com/tobyxdd/hysteria/pkg/transport" "github.com/tobyxdd/hysteria/pkg/transport"
"github.com/tobyxdd/hysteria/pkg/utils" "github.com/tobyxdd/hysteria/pkg/utils"
"math/rand"
"net" "net"
"strconv" "strconv"
"sync" "sync"
@ -35,6 +36,7 @@ type serverClient struct {
udpSessionMutex sync.RWMutex udpSessionMutex sync.RWMutex
udpSessionMap map[uint32]*net.UDPConn udpSessionMap map[uint32]*net.UDPConn
nextUDPSessionID uint32 nextUDPSessionID uint32
udpDefragger defragger
} }
func newServerClient(cs quic.Session, transport *transport.ServerTransport, auth []byte, disableUDP bool, ACLEngine *acl.Engine, func newServerClient(cs quic.Session, transport *transport.ServerTransport, auth []byte, disableUDP bool, ACLEngine *acl.Engine,
@ -123,43 +125,47 @@ func (c *serverClient) handleMessage(msg []byte) {
if err != nil { if err != nil {
return return
} }
dfMsg := c.udpDefragger.Feed(udpMsg)
if dfMsg == nil {
return
}
c.udpSessionMutex.RLock() c.udpSessionMutex.RLock()
conn, ok := c.udpSessionMap[udpMsg.SessionID] conn, ok := c.udpSessionMap[dfMsg.SessionID]
c.udpSessionMutex.RUnlock() c.udpSessionMutex.RUnlock()
if ok { if ok {
// Session found, send the message // Session found, send the message
action, arg := acl.ActionDirect, "" action, arg := acl.ActionDirect, ""
var ipAddr *net.IPAddr var ipAddr *net.IPAddr
if c.ACLEngine != nil { if c.ACLEngine != nil {
action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(udpMsg.Host) action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host)
} else { } else {
ipAddr, err = c.Transport.ResolveIPAddr(udpMsg.Host) ipAddr, err = c.Transport.ResolveIPAddr(dfMsg.Host)
} }
if err != nil { if err != nil {
return return
} }
switch action { switch action {
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
_, _ = conn.WriteToUDP(udpMsg.Data, &net.UDPAddr{ _, _ = conn.WriteToUDP(dfMsg.Data, &net.UDPAddr{
IP: ipAddr.IP, IP: ipAddr.IP,
Port: int(udpMsg.Port), Port: int(dfMsg.Port),
Zone: ipAddr.Zone, Zone: ipAddr.Zone,
}) })
if c.UpCounter != nil { if c.UpCounter != nil {
c.UpCounter.Add(float64(len(udpMsg.Data))) c.UpCounter.Add(float64(len(dfMsg.Data)))
} }
case acl.ActionBlock: case acl.ActionBlock:
// Do nothing // Do nothing
case acl.ActionHijack: case acl.ActionHijack:
hijackIPAddr, err := c.Transport.ResolveIPAddr(arg) hijackIPAddr, err := c.Transport.ResolveIPAddr(arg)
if err == nil { if err == nil {
_, _ = conn.WriteToUDP(udpMsg.Data, &net.UDPAddr{ _, _ = conn.WriteToUDP(dfMsg.Data, &net.UDPAddr{
IP: hijackIPAddr.IP, IP: hijackIPAddr.IP,
Port: int(udpMsg.Port), Port: int(dfMsg.Port),
Zone: hijackIPAddr.Zone, Zone: hijackIPAddr.Zone,
}) })
if c.UpCounter != nil { if c.UpCounter != nil {
c.UpCounter.Add(float64(len(udpMsg.Data))) c.UpCounter.Add(float64(len(dfMsg.Data)))
} }
} }
default: default:
@ -297,14 +303,29 @@ func (c *serverClient) handleUDP(stream quic.Stream) {
for { for {
n, rAddr, err := conn.ReadFromUDP(buf) n, rAddr, err := conn.ReadFromUDP(buf)
if n > 0 { if n > 0 {
var msgBuf bytes.Buffer msg := udpMessage{
_ = struc.Pack(&msgBuf, &udpMessage{
SessionID: id, SessionID: id,
Host: rAddr.IP.String(), Host: rAddr.IP.String(),
Port: uint16(rAddr.Port), Port: uint16(rAddr.Port),
FragCount: 1,
Data: buf[:n], Data: buf[:n],
}) }
_ = c.CS.SendMessage(msgBuf.Bytes()) // try no frag first
var msgBuf bytes.Buffer
_ = struc.Pack(&msgBuf, &msg)
err = c.CS.SendMessage(msgBuf.Bytes())
if err != nil {
if errSize, ok := err.(quic.ErrMessageToLarge); ok {
// need to frag
msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1
fragMsgs := fragUDPMessage(msg, int(errSize))
for _, fragMsg := range fragMsgs {
msgBuf.Reset()
_ = struc.Pack(&msgBuf, &fragMsg)
_ = c.CS.SendMessage(msgBuf.Bytes())
}
}
}
if c.DownCounter != nil { if c.DownCounter != nil {
c.DownCounter.Add(float64(n)) c.DownCounter.Add(float64(n))
} }