diff --git a/pkg/core/protocol.go b/pkg/core/protocol.go index 7fd64d7..a7ab386 100644 --- a/pkg/core/protocol.go +++ b/pkg/core/protocol.go @@ -5,9 +5,8 @@ import ( ) const ( - protocolVersion = uint8(3) - protocolVersionV2 = uint8(2) - protocolTimeout = 10 * time.Second + protocolVersion = uint8(3) + protocolTimeout = 10 * time.Second closeErrorCodeGeneric = 0 closeErrorCodeProtocol = 1 @@ -65,12 +64,3 @@ func (m udpMessage) HeaderSize() int { func (m udpMessage) Size() int { return m.HeaderSize() + len(m.Data) } - -type udpMessageV2 struct { - SessionID uint32 - HostLen uint16 `struc:"sizeof=Host"` - Host string - Port uint16 - DataLen uint16 `struc:"sizeof=Data"` - Data []byte -} diff --git a/pkg/core/server.go b/pkg/core/server.go index 8fb9930..5419879 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -121,7 +121,7 @@ func (s *Server) handleClient(cs quic.Connection) { return } // Handle the control stream - auth, ok, v2, err := s.handleControlStream(cs, stream) + auth, ok, err := s.handleControlStream(cs, stream) if err != nil { _ = cs.CloseWithError(closeErrorCodeProtocol, "protocol error") return @@ -131,7 +131,7 @@ func (s *Server) handleClient(cs quic.Connection) { return } // Start accepting streams and messages - sc := newServerClient(v2, cs, s.transport, auth, s.disableUDP, s.aclEngine, + sc := newServerClient(cs, s.transport, auth, s.disableUDP, s.aclEngine, s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc, s.upCounterVec, s.downCounterVec, s.connGaugeVec) err = sc.Run() @@ -140,26 +140,25 @@ func (s *Server) handleClient(cs quic.Connection) { } // Auth & negotiate speed -func (s *Server) handleControlStream(cs quic.Connection, stream quic.Stream) ([]byte, bool, bool, error) { +func (s *Server) handleControlStream(cs quic.Connection, stream quic.Stream) ([]byte, bool, error) { // Check version vb := make([]byte, 1) _, err := stream.Read(vb) if err != nil { - return nil, false, false, err + return nil, false, err } - if vb[0] != protocolVersion && vb[0] != protocolVersionV2 { - return nil, false, false, fmt.Errorf("unsupported protocol version %d, expecting %d/%d", - vb[0], protocolVersionV2, protocolVersion) + if vb[0] != protocolVersion { + return nil, false, fmt.Errorf("unsupported protocol version %d, expecting %d", vb[0], protocolVersion) } // Parse client hello var ch clientHello err = struc.Unpack(stream, &ch) if err != nil { - return nil, false, false, err + return nil, false, err } // Speed if ch.Rate.SendBPS == 0 || ch.Rate.RecvBPS == 0 { - return nil, false, false, errors.New("invalid rate from client") + return nil, false, errors.New("invalid rate from client") } serverSendBPS, serverRecvBPS := ch.Rate.RecvBPS, ch.Rate.SendBPS if s.sendBPS > 0 && serverSendBPS > s.sendBPS { @@ -180,11 +179,11 @@ func (s *Server) handleControlStream(cs quic.Connection, stream quic.Stream) ([] Message: msg, }) if err != nil { - return nil, false, false, err + return nil, false, err } // Set the congestion accordingly if ok { cs.SetCongestionControl(congestion.NewBrutalSender(serverSendBPS)) } - return ch.Auth, ok, vb[0] == protocolVersionV2, nil + return ch.Auth, ok, nil } diff --git a/pkg/core/server_client.go b/pkg/core/server_client.go index e157ea5..2dd76dc 100644 --- a/pkg/core/server_client.go +++ b/pkg/core/server_client.go @@ -20,7 +20,6 @@ import ( const udpBufferSize = 65535 type serverClient struct { - V2 bool CS quic.Connection Transport *transport.ServerTransport Auth []byte @@ -41,14 +40,13 @@ type serverClient struct { udpDefragger defragger } -func newServerClient(v2 bool, cs quic.Connection, tr *transport.ServerTransport, auth []byte, disableUDP bool, ACLEngine *acl.Engine, +func newServerClient(cs quic.Connection, tr *transport.ServerTransport, auth []byte, disableUDP bool, ACLEngine *acl.Engine, CTCPRequestFunc TCPRequestFunc, CTCPErrorFunc TCPErrorFunc, CUDPRequestFunc UDPRequestFunc, CUDPErrorFunc UDPErrorFunc, UpCounterVec, DownCounterVec *prometheus.CounterVec, ConnGaugeVec *prometheus.GaugeVec, ) *serverClient { sc := &serverClient{ - V2: v2, CS: cs, Transport: tr, Auth: auth, @@ -125,26 +123,9 @@ func (c *serverClient) handleStream(stream quic.Stream) { func (c *serverClient) handleMessage(msg []byte) { var udpMsg udpMessage - if c.V2 { - var udpMsgV2 udpMessageV2 - err := struc.Unpack(bytes.NewBuffer(msg), &udpMsgV2) - if err != nil { - return - } - udpMsg = udpMessage{ - SessionID: udpMsgV2.SessionID, - HostLen: udpMsgV2.HostLen, - Host: udpMsgV2.Host, - Port: udpMsgV2.Port, - FragCount: 1, - DataLen: udpMsgV2.DataLen, - Data: udpMsgV2.Data, - } - } else { - err := struc.Unpack(bytes.NewBuffer(msg), &udpMsg) - if err != nil { - return - } + err := struc.Unpack(bytes.NewBuffer(msg), &udpMsg) + if err != nil { + return } dfMsg := c.udpDefragger.Feed(udpMsg) if dfMsg == nil { @@ -340,36 +321,25 @@ func (c *serverClient) handleUDP(stream quic.Stream) { n, rAddr, err := conn.ReadFromUDP(buf) if n > 0 { var msgBuf bytes.Buffer - if c.V2 { - msg := udpMessageV2{ - SessionID: id, - Host: rAddr.IP.String(), - Port: uint16(rAddr.Port), - Data: buf[:n], - } - _ = struc.Pack(&msgBuf, &msg) - _ = c.CS.SendMessage(msgBuf.Bytes()) - } else { - msg := udpMessage{ - SessionID: id, - Host: rAddr.IP.String(), - Port: uint16(rAddr.Port), - FragCount: 1, - Data: buf[:n], - } - // try no frag first - _ = struc.Pack(&msgBuf, &msg) - sendErr := c.CS.SendMessage(msgBuf.Bytes()) - if sendErr != nil { - if errSize, ok := sendErr.(quic.ErrMessageToLarge); ok { - // need to frag - msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 - fragMsgs := fragUDPMessage(msg, int(errSize)) - for _, fragMsg := range fragMsgs { - msgBuf.Reset() - _ = struc.Pack(&msgBuf, &fragMsg) - _ = c.CS.SendMessage(msgBuf.Bytes()) - } + msg := udpMessage{ + SessionID: id, + Host: rAddr.IP.String(), + Port: uint16(rAddr.Port), + FragCount: 1, + Data: buf[:n], + } + // try no frag first + _ = struc.Pack(&msgBuf, &msg) + sendErr := c.CS.SendMessage(msgBuf.Bytes()) + if sendErr != nil { + if errSize, ok := sendErr.(quic.ErrMessageToLarge); ok { + // need to frag + msg.MsgID = uint16(rand.Intn(0xFFFF)) + 1 // msgID must be > 0 when fragCount > 1 + fragMsgs := fragUDPMessage(msg, int(errSize)) + for _, fragMsg := range fragMsgs { + msgBuf.Reset() + _ = struc.Pack(&msgBuf, &fragMsg) + _ = c.CS.SendMessage(msgBuf.Bytes()) } } }