wip: core client & server rework

This commit is contained in:
Toby
2022-10-22 11:45:46 -07:00
parent ca3de154ba
commit e3c3088596
19 changed files with 368 additions and 304 deletions

View File

@@ -12,11 +12,11 @@ import (
"sync"
"time"
"github.com/HyNetwork/hysteria/pkg/transport/pktconns"
"github.com/HyNetwork/hysteria/pkg/congestion"
"github.com/HyNetwork/hysteria/pkg/obfs"
"github.com/HyNetwork/hysteria/pkg/pmtud_fix"
"github.com/HyNetwork/hysteria/pkg/transport"
"github.com/HyNetwork/hysteria/pkg/utils"
"github.com/lucas-clemente/quic-go"
"github.com/lunixbochs/struc"
@@ -25,18 +25,18 @@ import (
var ErrClosed = errors.New("closed")
type Client struct {
transport *transport.ClientTransport
serverAddr string
protocol string
sendBPS, recvBPS uint64
auth []byte
obfuscator obfs.Obfuscator
tlsConfig *tls.Config
quicConfig *quic.Config
quicSession quic.Connection
pktConnFunc pktconns.ClientPacketConnFunc
reconnectMutex sync.Mutex
pktConn net.PacketConn
quicConn quic.Connection
closed bool
udpSessionMutex sync.RWMutex
@@ -46,59 +46,78 @@ type Client struct {
quicReconnectFunc func(err error)
}
func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
transport *transport.ClientTransport, sendBPS uint64, recvBPS uint64, obfuscator obfs.Obfuscator,
quicReconnectFunc func(err error),
func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
pktConnFunc pktconns.ClientPacketConnFunc, sendBPS uint64, recvBPS uint64, quicReconnectFunc func(err error),
) (*Client, error) {
quicConfig.DisablePathMTUDiscovery = quicConfig.DisablePathMTUDiscovery || pmtud_fix.DisablePathMTUDiscovery
c := &Client{
transport: transport,
serverAddr: serverAddr,
protocol: protocol,
sendBPS: sendBPS,
recvBPS: recvBPS,
auth: auth,
obfuscator: obfuscator,
tlsConfig: tlsConfig,
quicConfig: quicConfig,
pktConnFunc: pktConnFunc,
quicReconnectFunc: quicReconnectFunc,
}
if err := c.connectToServer(); err != nil {
if err := c.connect(); err != nil {
return nil, err
}
return c, nil
}
func (c *Client) connectToServer() error {
qs, err := c.transport.QUICDial(c.protocol, c.serverAddr, c.tlsConfig, c.quicConfig, c.obfuscator)
func (c *Client) connect() error {
// Clear previous connection
if c.quicConn != nil {
_ = c.quicConn.CloseWithError(0, "")
}
if c.pktConn != nil {
_ = c.pktConn.Close()
}
// New connection
pktConn, err := c.pktConnFunc(c.serverAddr)
if err != nil {
return err
}
serverUDPAddr, err := net.ResolveUDPAddr("udp", c.serverAddr)
if err != nil {
_ = pktConn.Close()
return err
}
quicConn, err := quic.Dial(pktConn, serverUDPAddr, c.serverAddr, c.tlsConfig, c.quicConfig)
if err != nil {
_ = pktConn.Close()
return err
}
// Control stream
ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout)
stream, err := qs.OpenStreamSync(ctx)
stream, err := quicConn.OpenStreamSync(ctx)
ctxCancel()
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
_ = quicConn.CloseWithError(closeErrorCodeProtocol, "protocol error")
_ = pktConn.Close()
return err
}
ok, msg, err := c.handleControlStream(qs, stream)
ok, msg, err := c.handleControlStream(quicConn, stream)
if err != nil {
_ = qs.CloseWithError(closeErrorCodeProtocol, "protocol error")
_ = quicConn.CloseWithError(closeErrorCodeProtocol, "protocol error")
_ = pktConn.Close()
return err
}
if !ok {
_ = qs.CloseWithError(closeErrorCodeAuth, "auth error")
_ = quicConn.CloseWithError(closeErrorCodeAuth, "auth error")
_ = pktConn.Close()
return fmt.Errorf("auth error: %s", msg)
}
// All good
c.udpSessionMap = make(map[uint32]chan *udpMessage)
go c.handleMessage(qs)
c.quicSession = qs
go c.handleMessage(quicConn)
c.pktConn = pktConn
c.quicConn = quicConn
return nil
}
func (c *Client) handleControlStream(qs quic.Connection, stream quic.Stream) (bool, string, error) {
func (c *Client) handleControlStream(qc quic.Connection, stream quic.Stream) (bool, string, error) {
// Send protocol version
_, err := stream.Write([]byte{protocolVersion})
if err != nil {
@@ -123,14 +142,14 @@ func (c *Client) handleControlStream(qs quic.Connection, stream quic.Stream) (bo
}
// Set the congestion accordingly
if sh.OK {
qs.SetCongestionControl(congestion.NewBrutalSender(sh.Rate.RecvBPS))
qc.SetCongestionControl(congestion.NewBrutalSender(sh.Rate.RecvBPS))
}
return sh.OK, sh.Message, nil
}
func (c *Client) handleMessage(qs quic.Connection) {
func (c *Client) handleMessage(qc quic.Connection) {
for {
msg, err := qs.ReceiveMessage()
msg, err := qc.ReceiveMessage()
if err != nil {
break
}
@@ -163,10 +182,10 @@ func (c *Client) openStreamWithReconnect() (quic.Connection, quic.Stream, error)
if c.closed {
return nil, nil, ErrClosed
}
stream, err := c.quicSession.OpenStream()
stream, err := c.quicConn.OpenStream()
if err == nil {
// All good
return c.quicSession, &wrappedQUICStream{stream}, nil
return c.quicConn, &wrappedQUICStream{stream}, nil
}
// Something is wrong
if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
@@ -175,13 +194,13 @@ func (c *Client) openStreamWithReconnect() (quic.Connection, quic.Stream, error)
}
c.quicReconnectFunc(err)
// Permanent error, need to reconnect
if err := c.connectToServer(); err != nil {
if err := c.connect(); err != nil {
// Still error, oops
return nil, nil, err
}
// We are not going to try again even if it still fails the second time
stream, err = c.quicSession.OpenStream()
return c.quicSession, &wrappedQUICStream{stream}, err
stream, err = c.quicConn.OpenStream()
return c.quicConn, &wrappedQUICStream{stream}, err
}
func (c *Client) DialTCP(addr string) (net.Conn, error) {
@@ -250,7 +269,7 @@ func (c *Client) DialUDP() (UDPConn, error) {
c.udpSessionMutex.Lock()
nCh := make(chan *udpMessage, 1024)
// Store the current session map for CloseFunc below
// to ensures that we are adding and removing sessions on the same map,
// to ensure that we are adding and removing sessions on the same map,
// as reconnecting will reassign the map
sessionMap := c.udpSessionMap
sessionMap[sr.UDPSessionID] = nCh
@@ -277,7 +296,8 @@ func (c *Client) DialUDP() (UDPConn, error) {
func (c *Client) Close() error {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
err := c.quicSession.CloseWithError(closeErrorCodeGeneric, "")
err := c.quicConn.CloseWithError(closeErrorCodeGeneric, "")
_ = c.pktConn.Close()
c.closed = true
return err
}