Protocol version check

This commit is contained in:
Toby 2021-04-18 18:07:01 -07:00
parent fc4d573f3d
commit 7b841aa203
3 changed files with 19 additions and 6 deletions

View File

@ -106,8 +106,13 @@ func (c *Client) connectToServer() error {
} }
func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (bool, string, error) { func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (bool, string, error) {
// Send protocol version
_, err := stream.Write([]byte{protocolVersion})
if err != nil {
return false, "", err
}
// Send client hello // Send client hello
err := struc.Pack(stream, &clientHello{ err = struc.Pack(stream, &clientHello{
Rate: transmissionRate{ Rate: transmissionRate{
SendBPS: c.sendBPS, SendBPS: c.sendBPS,
RecvBPS: c.recvBPS, RecvBPS: c.recvBPS,

View File

@ -5,7 +5,7 @@ import (
) )
const ( const (
protocolVersion = uint8(1) protocolVersion = uint8(2)
protocolTimeout = 10 * time.Second protocolTimeout = 10 * time.Second
closeErrorCodeGeneric = 0 closeErrorCodeGeneric = 0

View File

@ -4,16 +4,14 @@ import (
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/lunixbochs/struc" "github.com/lunixbochs/struc"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
"github.com/tobyxdd/hysteria/pkg/acl" "github.com/tobyxdd/hysteria/pkg/acl"
"net" "net"
"time"
) )
const dialTimeout = 10 * time.Second
type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string)
type TCPRequestFunc func(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string) type TCPRequestFunc func(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string)
type TCPErrorFunc func(addr net.Addr, auth []byte, reqAddr string, err error) type TCPErrorFunc func(addr net.Addr, auth []byte, reqAddr string, err error)
@ -132,8 +130,18 @@ func (s *Server) handleClient(cs quic.Session) {
// Auth & negotiate speed // Auth & negotiate speed
func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byte, bool, error) { func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byte, bool, error) {
// Check version
vb := make([]byte, 1)
_, err := stream.Read(vb)
if err != nil {
return nil, false, err
}
if vb[0] != protocolVersion {
return nil, false, fmt.Errorf("unsupported protocol version %d, expecting %d", vb[0], protocolVersion)
}
// Parse client hello
var ch clientHello var ch clientHello
err := struc.Unpack(stream, &ch) err = struc.Unpack(stream, &ch)
if err != nil { if err != nil {
return nil, false, err return nil, false, err
} }