feat: v2 compatibility

This commit is contained in:
Toby 2022-02-25 17:59:01 -08:00
parent a5e27385c8
commit c0f53ea712
3 changed files with 77 additions and 35 deletions

View File

@ -6,6 +6,7 @@ import (
const ( const (
protocolVersion = uint8(3) protocolVersion = uint8(3)
protocolVersionV2 = uint8(2)
protocolTimeout = 10 * time.Second protocolTimeout = 10 * time.Second
closeErrorCodeGeneric = 0 closeErrorCodeGeneric = 0
@ -64,3 +65,12 @@ func (m udpMessage) HeaderSize() int {
func (m udpMessage) Size() int { func (m udpMessage) Size() int {
return m.HeaderSize() + len(m.Data) return m.HeaderSize() + len(m.Data)
} }
type udpMessageV2 struct {
SessionID uint32
HostLen uint16 `struc:"sizeof=Host"`
Host string
Port uint16
DataLen uint16 `struc:"sizeof=Data"`
Data []byte
}

View File

@ -106,7 +106,7 @@ func (s *Server) handleClient(cs quic.Session) {
return return
} }
// Handle the control stream // Handle the control stream
auth, ok, err := s.handleControlStream(cs, stream) auth, ok, v2, err := s.handleControlStream(cs, stream)
if err != nil { if err != nil {
_ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error") _ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error")
return return
@ -116,7 +116,7 @@ func (s *Server) handleClient(cs quic.Session) {
return return
} }
// Start accepting streams and messages // Start accepting streams and messages
sc := newServerClient(cs, s.transport, auth, s.disableUDP, s.aclEngine, sc := newServerClient(v2, cs, s.transport, auth, s.disableUDP, s.aclEngine,
s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc, s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc,
s.upCounterVec, s.downCounterVec, s.connGaugeVec) s.upCounterVec, s.downCounterVec, s.connGaugeVec)
err = sc.Run() err = sc.Run()
@ -125,25 +125,26 @@ func (s *Server) handleClient(cs quic.Session) {
} }
// Auth & negotiate speed // Auth & negotiate speed
func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byte, bool, error) { func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byte, bool, bool, error) {
// Check version // Check version
vb := make([]byte, 1) vb := make([]byte, 1)
_, err := stream.Read(vb) _, err := stream.Read(vb)
if err != nil { if err != nil {
return nil, false, err return nil, false, false, err
} }
if vb[0] != protocolVersion { if vb[0] != protocolVersion && vb[0] != protocolVersionV2 {
return nil, false, fmt.Errorf("unsupported protocol version %d, expecting %d", vb[0], protocolVersion) return nil, false, false, fmt.Errorf("unsupported protocol version %d, expecting %d/%d",
vb[0], protocolVersionV2, protocolVersion)
} }
// Parse client hello // Parse client hello
var ch clientHello var ch clientHello
err = struc.Unpack(stream, &ch) err = struc.Unpack(stream, &ch)
if err != nil { if err != nil {
return nil, false, err return nil, false, false, err
} }
// Speed // Speed
if ch.Rate.SendBPS == 0 || ch.Rate.RecvBPS == 0 { if ch.Rate.SendBPS == 0 || ch.Rate.RecvBPS == 0 {
return nil, false, errors.New("invalid rate from client") return nil, false, false, errors.New("invalid rate from client")
} }
serverSendBPS, serverRecvBPS := ch.Rate.RecvBPS, ch.Rate.SendBPS serverSendBPS, serverRecvBPS := ch.Rate.RecvBPS, ch.Rate.SendBPS
if s.sendBPS > 0 && serverSendBPS > s.sendBPS { if s.sendBPS > 0 && serverSendBPS > s.sendBPS {
@ -164,11 +165,11 @@ func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byt
Message: msg, Message: msg,
}) })
if err != nil { if err != nil {
return nil, false, err return nil, false, false, err
} }
// Set the congestion accordingly // Set the congestion accordingly
if ok && s.congestionFactory != nil { if ok && s.congestionFactory != nil {
cs.SetCongestionControl(s.congestionFactory(serverSendBPS)) cs.SetCongestionControl(s.congestionFactory(serverSendBPS))
} }
return ch.Auth, ok, nil return ch.Auth, ok, vb[0] == protocolVersionV2, nil
} }

View File

@ -19,6 +19,7 @@ import (
const udpBufferSize = 65535 const udpBufferSize = 65535
type serverClient struct { type serverClient struct {
V2 bool
CS quic.Session CS quic.Session
Transport *transport.ServerTransport Transport *transport.ServerTransport
Auth []byte Auth []byte
@ -39,12 +40,13 @@ type serverClient struct {
udpDefragger defragger udpDefragger defragger
} }
func newServerClient(cs quic.Session, transport *transport.ServerTransport, auth []byte, disableUDP bool, ACLEngine *acl.Engine, func newServerClient(v2 bool, cs quic.Session, transport *transport.ServerTransport, auth []byte, disableUDP bool, ACLEngine *acl.Engine,
CTCPRequestFunc TCPRequestFunc, CTCPErrorFunc TCPErrorFunc, CTCPRequestFunc TCPRequestFunc, CTCPErrorFunc TCPErrorFunc,
CUDPRequestFunc UDPRequestFunc, CUDPErrorFunc UDPErrorFunc, CUDPRequestFunc UDPRequestFunc, CUDPErrorFunc UDPErrorFunc,
UpCounterVec, DownCounterVec *prometheus.CounterVec, UpCounterVec, DownCounterVec *prometheus.CounterVec,
ConnGaugeVec *prometheus.GaugeVec) *serverClient { ConnGaugeVec *prometheus.GaugeVec) *serverClient {
sc := &serverClient{ sc := &serverClient{
V2: v2,
CS: cs, CS: cs,
Transport: transport, Transport: transport,
Auth: auth, Auth: auth,
@ -121,10 +123,27 @@ func (c *serverClient) handleStream(stream quic.Stream) {
func (c *serverClient) handleMessage(msg []byte) { func (c *serverClient) handleMessage(msg []byte) {
var udpMsg udpMessage var udpMsg udpMessage
if c.V2 {
var udpMsgV2 udpMessageV2
err := struc.Unpack(bytes.NewBuffer(msg), &udpMsgV2)
if err != nil {
return
}
udpMsg = udpMessage{
SessionID: udpMsgV2.SessionID,
HostLen: udpMsgV2.HostLen,
Host: udpMsgV2.Host,
Port: udpMsgV2.Port,
FragCount: 1,
DataLen: udpMsgV2.DataLen,
Data: udpMsgV2.Data,
}
} else {
err := struc.Unpack(bytes.NewBuffer(msg), &udpMsg) err := struc.Unpack(bytes.NewBuffer(msg), &udpMsg)
if err != nil { if err != nil {
return return
} }
}
dfMsg := c.udpDefragger.Feed(udpMsg) dfMsg := c.udpDefragger.Feed(udpMsg)
if dfMsg == nil { if dfMsg == nil {
return return
@ -136,6 +155,7 @@ func (c *serverClient) handleMessage(msg []byte) {
// Session found, send the message // Session found, send the message
action, arg := acl.ActionDirect, "" action, arg := acl.ActionDirect, ""
var ipAddr *net.IPAddr var ipAddr *net.IPAddr
var err error
if c.ACLEngine != nil { if c.ACLEngine != nil {
action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host) action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host)
} else { } else {
@ -303,6 +323,17 @@ func (c *serverClient) handleUDP(stream quic.Stream) {
for { for {
n, rAddr, err := conn.ReadFromUDP(buf) n, rAddr, err := conn.ReadFromUDP(buf)
if n > 0 { if n > 0 {
var msgBuf bytes.Buffer
if c.V2 {
msg := udpMessageV2{
SessionID: id,
Host: rAddr.IP.String(),
Port: uint16(rAddr.Port),
Data: buf[:n],
}
_ = struc.Pack(&msgBuf, &msg)
_ = c.CS.SendMessage(msgBuf.Bytes())
} else {
msg := udpMessage{ msg := udpMessage{
SessionID: id, SessionID: id,
Host: rAddr.IP.String(), Host: rAddr.IP.String(),
@ -311,7 +342,6 @@ func (c *serverClient) handleUDP(stream quic.Stream) {
Data: buf[:n], Data: buf[:n],
} }
// try no frag first // try no frag first
var msgBuf bytes.Buffer
_ = struc.Pack(&msgBuf, &msg) _ = struc.Pack(&msgBuf, &msg)
err = c.CS.SendMessage(msgBuf.Bytes()) err = c.CS.SendMessage(msgBuf.Bytes())
if err != nil { if err != nil {
@ -326,6 +356,7 @@ func (c *serverClient) handleUDP(stream quic.Stream) {
} }
} }
} }
}
if c.DownCounter != nil { if c.DownCounter != nil {
c.DownCounter.Add(float64(n)) c.DownCounter.Add(float64(n))
} }