Implemented UDP for both server & client

This commit is contained in:
Toby
2021-03-27 16:51:15 -07:00
parent 01c7d18211
commit 4bb5982960
8 changed files with 524 additions and 112 deletions

View File

@@ -1,6 +1,7 @@
package core
import (
"bytes"
"context"
"crypto/tls"
"errors"
@@ -14,7 +15,7 @@ import (
)
var (
ErrClosed = errors.New("client closed")
ErrClosed = errors.New("closed")
)
type CongestionFactory func(refBPS uint64) congestion.CongestionControl
@@ -32,6 +33,9 @@ type Client struct {
quicSession quic.Session
reconnectMutex sync.Mutex
closed bool
udpSessionMutex sync.RWMutex
udpSessionMap map[uint32]chan *udpMessage
}
func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
@@ -90,6 +94,8 @@ func (c *Client) connectToServer() error {
return fmt.Errorf("auth error: %s", msg)
}
// All good
c.udpSessionMap = make(map[uint32]chan *udpMessage)
go c.handleMessage(qs)
c.quicSession = qs
return nil
}
@@ -119,34 +125,59 @@ func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (bool,
return true, sh.Message, nil
}
func (c *Client) openStreamWithReconnect() (quic.Stream, net.Addr, net.Addr, error) {
func (c *Client) handleMessage(qs quic.Session) {
for {
msg, err := qs.ReceiveMessage()
if err != nil {
break
}
var udpMsg udpMessage
err = struc.Unpack(bytes.NewBuffer(msg), &udpMsg)
if err != nil {
continue
}
c.udpSessionMutex.RLock()
ch, ok := c.udpSessionMap[udpMsg.SessionID]
if ok {
select {
case ch <- &udpMsg:
// OK
default:
// Silently drop the message when the channel is full
}
}
c.udpSessionMutex.RUnlock()
}
}
func (c *Client) openStreamWithReconnect() (quic.Session, quic.Stream, error) {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
if c.closed {
return nil, nil, nil, ErrClosed
return nil, nil, ErrClosed
}
stream, err := c.quicSession.OpenStream()
if err == nil {
// All good
return stream, c.quicSession.LocalAddr(), c.quicSession.RemoteAddr(), nil
return c.quicSession, stream, nil
}
// Something is wrong
if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
// Temporary error, just return
return nil, nil, nil, err
return nil, nil, err
}
// Permanent error, need to reconnect
if err := c.connectToServer(); err != nil {
// Still error, oops
return nil, nil, nil, err
return nil, nil, err
}
// We are not going to try again even if it still fails the second time
stream, err = c.quicSession.OpenStream()
return stream, c.quicSession.LocalAddr(), c.quicSession.RemoteAddr(), err
return c.quicSession, stream, nil
}
func (c *Client) DialTCP(addr string) (net.Conn, error) {
stream, localAddr, remoteAddr, err := c.openStreamWithReconnect()
session, stream, err := c.openStreamWithReconnect()
if err != nil {
return nil, err
}
@@ -172,11 +203,64 @@ func (c *Client) DialTCP(addr string) (net.Conn, error) {
}
return &quicConn{
Orig: stream,
PseudoLocalAddr: localAddr,
PseudoRemoteAddr: remoteAddr,
PseudoLocalAddr: session.LocalAddr(),
PseudoRemoteAddr: session.RemoteAddr(),
}, nil
}
func (c *Client) DialUDP() (UDPConn, error) {
session, stream, err := c.openStreamWithReconnect()
if err != nil {
return nil, err
}
// Send request
err = struc.Pack(stream, &clientRequest{
UDP: true,
})
if err != nil {
_ = stream.Close()
return nil, err
}
// Read response
var sr serverResponse
err = struc.Unpack(stream, &sr)
if err != nil {
_ = stream.Close()
return nil, err
}
if !sr.OK {
_ = stream.Close()
return nil, fmt.Errorf("connection rejected: %s", sr.Message)
}
// Create a session in the map
c.udpSessionMutex.Lock()
nCh := make(chan *udpMessage, 1024)
// Store the current session map for CloseFunc below
// to ensures that we are adding and removing sessions on the same map,
// as reconnecting will reassign the map
sessionMap := c.udpSessionMap
sessionMap[sr.UDPSessionID] = nCh
c.udpSessionMutex.Unlock()
pktConn := &quicPktConn{
Session: session,
Stream: stream,
CloseFunc: func() {
c.udpSessionMutex.Lock()
if ch, ok := sessionMap[sr.UDPSessionID]; ok {
close(ch)
delete(sessionMap, sr.UDPSessionID)
}
c.udpSessionMutex.Unlock()
},
UDPSessionID: sr.UDPSessionID,
MsgCh: nCh,
}
go pktConn.Hold()
return pktConn, nil
}
func (c *Client) Close() error {
c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock()
@@ -222,3 +306,53 @@ func (w *quicConn) SetReadDeadline(t time.Time) error {
func (w *quicConn) SetWriteDeadline(t time.Time) error {
return w.Orig.SetWriteDeadline(t)
}
type UDPConn interface {
ReadFrom() ([]byte, string, error)
WriteTo([]byte, string) error
Close() error
}
type quicPktConn struct {
Session quic.Session
Stream quic.Stream
CloseFunc func()
UDPSessionID uint32
MsgCh <-chan *udpMessage
}
func (c *quicPktConn) Hold() {
// Hold the stream until it's closed
buf := make([]byte, 1024)
for {
_, err := c.Stream.Read(buf)
if err != nil {
break
}
}
_ = c.Close()
}
func (c *quicPktConn) ReadFrom() ([]byte, string, error) {
msg := <-c.MsgCh
if msg == nil {
// Closed
return nil, "", ErrClosed
}
return msg.Data, msg.Address, nil
}
func (c *quicPktConn) WriteTo(p []byte, addr string) error {
var msgBuf bytes.Buffer
_ = struc.Pack(&msgBuf, &udpMessage{
SessionID: c.UDPSessionID,
Address: addr,
Data: p,
})
return c.Session.SendMessage(msgBuf.Bytes())
}
func (c *quicPktConn) Close() error {
c.CloseFunc()
return c.Stream.Close()
}