From c0f53ea712a66bc5465b1feb1e430fa25c707461 Mon Sep 17 00:00:00 2001 From: Toby Date: Fri, 25 Feb 2022 17:59:01 -0800 Subject: [PATCH] feat: v2 compatibility --- pkg/core/protocol.go | 14 ++++++- pkg/core/server.go | 21 ++++++----- pkg/core/server_client.go | 77 +++++++++++++++++++++++++++------------ 3 files changed, 77 insertions(+), 35 deletions(-) diff --git a/pkg/core/protocol.go b/pkg/core/protocol.go index a7ab386..7fd64d7 100644 --- a/pkg/core/protocol.go +++ b/pkg/core/protocol.go @@ -5,8 +5,9 @@ import ( ) const ( - protocolVersion = uint8(3) - protocolTimeout = 10 * time.Second + protocolVersion = uint8(3) + protocolVersionV2 = uint8(2) + protocolTimeout = 10 * time.Second closeErrorCodeGeneric = 0 closeErrorCodeProtocol = 1 @@ -64,3 +65,12 @@ func (m udpMessage) HeaderSize() int { func (m udpMessage) Size() int { return m.HeaderSize() + len(m.Data) } + +type udpMessageV2 struct { + SessionID uint32 + HostLen uint16 `struc:"sizeof=Host"` + Host string + Port uint16 + DataLen uint16 `struc:"sizeof=Data"` + Data []byte +} diff --git a/pkg/core/server.go b/pkg/core/server.go index c142875..de0afe9 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -106,7 +106,7 @@ func (s *Server) handleClient(cs quic.Session) { return } // Handle the control stream - auth, ok, err := s.handleControlStream(cs, stream) + auth, ok, v2, err := s.handleControlStream(cs, stream) if err != nil { _ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error") return @@ -116,7 +116,7 @@ func (s *Server) handleClient(cs quic.Session) { return } // Start accepting streams and messages - sc := newServerClient(cs, s.transport, auth, s.disableUDP, s.aclEngine, + sc := newServerClient(v2, cs, s.transport, auth, s.disableUDP, s.aclEngine, s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc, s.upCounterVec, s.downCounterVec, s.connGaugeVec) err = sc.Run() @@ -125,25 +125,26 @@ func (s *Server) handleClient(cs quic.Session) { } // Auth & negotiate speed -func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byte, bool, error) { +func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byte, bool, bool, error) { // Check version vb := make([]byte, 1) _, err := stream.Read(vb) if err != nil { - return nil, false, err + return nil, false, false, err } - if vb[0] != protocolVersion { - return nil, false, fmt.Errorf("unsupported protocol version %d, expecting %d", vb[0], protocolVersion) + if vb[0] != protocolVersion && vb[0] != protocolVersionV2 { + return nil, false, false, fmt.Errorf("unsupported protocol version %d, expecting %d/%d", + vb[0], protocolVersionV2, protocolVersion) } // Parse client hello var ch clientHello err = struc.Unpack(stream, &ch) if err != nil { - return nil, false, err + return nil, false, false, err } // Speed if ch.Rate.SendBPS == 0 || ch.Rate.RecvBPS == 0 { - return nil, false, errors.New("invalid rate from client") + return nil, false, false, errors.New("invalid rate from client") } serverSendBPS, serverRecvBPS := ch.Rate.RecvBPS, ch.Rate.SendBPS if s.sendBPS > 0 && serverSendBPS > s.sendBPS { @@ -164,11 +165,11 @@ func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byt Message: msg, }) if err != nil { - return nil, false, err + return nil, false, false, err } // Set the congestion accordingly if ok && s.congestionFactory != nil { cs.SetCongestionControl(s.congestionFactory(serverSendBPS)) } - return ch.Auth, ok, nil + return ch.Auth, ok, vb[0] == protocolVersionV2, nil } diff --git a/pkg/core/server_client.go b/pkg/core/server_client.go index 79b4521..5b3c26c 100644 --- a/pkg/core/server_client.go +++ b/pkg/core/server_client.go @@ -19,6 +19,7 @@ import ( const udpBufferSize = 65535 type serverClient struct { + V2 bool CS quic.Session Transport *transport.ServerTransport Auth []byte @@ -39,12 +40,13 @@ type serverClient struct { udpDefragger defragger } -func newServerClient(cs quic.Session, transport *transport.ServerTransport, auth []byte, disableUDP bool, ACLEngine *acl.Engine, +func newServerClient(v2 bool, cs quic.Session, transport *transport.ServerTransport, auth []byte, disableUDP bool, ACLEngine *acl.Engine, CTCPRequestFunc TCPRequestFunc, CTCPErrorFunc TCPErrorFunc, CUDPRequestFunc UDPRequestFunc, CUDPErrorFunc UDPErrorFunc, UpCounterVec, DownCounterVec *prometheus.CounterVec, ConnGaugeVec *prometheus.GaugeVec) *serverClient { sc := &serverClient{ + V2: v2, CS: cs, Transport: transport, Auth: auth, @@ -121,9 +123,26 @@ func (c *serverClient) handleStream(stream quic.Stream) { func (c *serverClient) handleMessage(msg []byte) { var udpMsg udpMessage - err := struc.Unpack(bytes.NewBuffer(msg), &udpMsg) - if err != nil { - return + if c.V2 { + var udpMsgV2 udpMessageV2 + err := struc.Unpack(bytes.NewBuffer(msg), &udpMsgV2) + if err != nil { + return + } + udpMsg = udpMessage{ + SessionID: udpMsgV2.SessionID, + HostLen: udpMsgV2.HostLen, + Host: udpMsgV2.Host, + Port: udpMsgV2.Port, + FragCount: 1, + DataLen: udpMsgV2.DataLen, + Data: udpMsgV2.Data, + } + } else { + err := struc.Unpack(bytes.NewBuffer(msg), &udpMsg) + if err != nil { + return + } } dfMsg := c.udpDefragger.Feed(udpMsg) if dfMsg == nil { @@ -136,6 +155,7 @@ func (c *serverClient) handleMessage(msg []byte) { // Session found, send the message action, arg := acl.ActionDirect, "" var ipAddr *net.IPAddr + var err error if c.ACLEngine != nil { action, arg, ipAddr, err = c.ACLEngine.ResolveAndMatch(dfMsg.Host) } else { @@ -303,26 +323,37 @@ func (c *serverClient) handleUDP(stream quic.Stream) { for { n, rAddr, err := conn.ReadFromUDP(buf) if n > 0 { - msg := udpMessage{ - SessionID: id, - Host: rAddr.IP.String(), - Port: uint16(rAddr.Port), - FragCount: 1, - Data: buf[:n], - } - // try no frag first var msgBuf bytes.Buffer - _ = struc.Pack(&msgBuf, &msg) - err = c.CS.SendMessage(msgBuf.Bytes()) - if err != nil { - if errSize, ok := err.(quic.ErrMessageToLarge); ok { - // need to frag - msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 - fragMsgs := fragUDPMessage(msg, int(errSize)) - for _, fragMsg := range fragMsgs { - msgBuf.Reset() - _ = struc.Pack(&msgBuf, &fragMsg) - _ = c.CS.SendMessage(msgBuf.Bytes()) + if c.V2 { + msg := udpMessageV2{ + SessionID: id, + Host: rAddr.IP.String(), + Port: uint16(rAddr.Port), + Data: buf[:n], + } + _ = struc.Pack(&msgBuf, &msg) + _ = c.CS.SendMessage(msgBuf.Bytes()) + } else { + msg := udpMessage{ + SessionID: id, + Host: rAddr.IP.String(), + Port: uint16(rAddr.Port), + FragCount: 1, + Data: buf[:n], + } + // try no frag first + _ = struc.Pack(&msgBuf, &msg) + err = c.CS.SendMessage(msgBuf.Bytes()) + if err != nil { + if errSize, ok := err.(quic.ErrMessageToLarge); ok { + // need to frag + msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 + fragMsgs := fragUDPMessage(msg, int(errSize)) + for _, fragMsg := range fragMsgs { + msgBuf.Reset() + _ = struc.Pack(&msgBuf, &fragMsg) + _ = c.CS.SendMessage(msgBuf.Bytes()) + } } } }