From 7b841aa203ce036ce24b3a62f7a5e0a9bd7dd991 Mon Sep 17 00:00:00 2001 From: Toby Date: Sun, 18 Apr 2021 18:07:01 -0700 Subject: [PATCH] Protocol version check --- pkg/core/client.go | 7 ++++++- pkg/core/protocol.go | 2 +- pkg/core/server.go | 16 ++++++++++++---- 3 files changed, 19 insertions(+), 6 deletions(-) diff --git a/pkg/core/client.go b/pkg/core/client.go index 822d3f2..b222dda 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -106,8 +106,13 @@ func (c *Client) connectToServer() error { } func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (bool, string, error) { + // Send protocol version + _, err := stream.Write([]byte{protocolVersion}) + if err != nil { + return false, "", err + } // Send client hello - err := struc.Pack(stream, &clientHello{ + err = struc.Pack(stream, &clientHello{ Rate: transmissionRate{ SendBPS: c.sendBPS, RecvBPS: c.recvBPS, diff --git a/pkg/core/protocol.go b/pkg/core/protocol.go index 9299044..0a71ba3 100644 --- a/pkg/core/protocol.go +++ b/pkg/core/protocol.go @@ -5,7 +5,7 @@ import ( ) const ( - protocolVersion = uint8(1) + protocolVersion = uint8(2) protocolTimeout = 10 * time.Second closeErrorCodeGeneric = 0 diff --git a/pkg/core/server.go b/pkg/core/server.go index afd5358..57475c6 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -4,16 +4,14 @@ import ( "context" "crypto/tls" "errors" + "fmt" "github.com/lucas-clemente/quic-go" "github.com/lunixbochs/struc" "github.com/prometheus/client_golang/prometheus" "github.com/tobyxdd/hysteria/pkg/acl" "net" - "time" ) -const dialTimeout = 10 * time.Second - type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) type TCPRequestFunc func(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string) type TCPErrorFunc func(addr net.Addr, auth []byte, reqAddr string, err error) @@ -132,8 +130,18 @@ func (s *Server) handleClient(cs quic.Session) { // Auth & negotiate speed func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byte, bool, error) { + // Check version + vb := make([]byte, 1) + _, err := stream.Read(vb) + if err != nil { + return nil, false, err + } + if vb[0] != protocolVersion { + return nil, false, fmt.Errorf("unsupported protocol version %d, expecting %d", vb[0], protocolVersion) + } + // Parse client hello var ch clientHello - err := struc.Unpack(stream, &ch) + err = struc.Unpack(stream, &ch) if err != nil { return nil, false, err }