mirror of
https://github.com/cmz0228/hysteria-dev.git
synced 2025-06-19 10:49:59 +00:00
Implemented UDP for both server & client
This commit is contained in:
parent
01c7d18211
commit
4bb5982960
@ -50,6 +50,7 @@ func client(config *clientConfig) {
|
||||
MaxStreamReceiveWindow: config.ReceiveWindowConn,
|
||||
MaxConnectionReceiveWindow: config.ReceiveWindow,
|
||||
KeepAlive: true,
|
||||
EnableDatagrams: true,
|
||||
}
|
||||
if quicConfig.MaxStreamReceiveWindow == 0 {
|
||||
quicConfig.MaxStreamReceiveWindow = DefaultMaxReceiveStreamFlowControlWindow
|
||||
|
@ -36,6 +36,7 @@ func server(config *serverConfig) {
|
||||
MaxConnectionReceiveWindow: config.ReceiveWindowClient,
|
||||
MaxIncomingStreams: int64(config.MaxConnClient),
|
||||
KeepAlive: true,
|
||||
EnableDatagrams: true,
|
||||
}
|
||||
if quicConfig.MaxStreamReceiveWindow == 0 {
|
||||
quicConfig.MaxStreamReceiveWindow = DefaultMaxReceiveStreamFlowControlWindow
|
||||
@ -96,7 +97,8 @@ func server(config *serverConfig) {
|
||||
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
|
||||
func(refBPS uint64) congestion.CongestionControl {
|
||||
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
|
||||
}, aclEngine, obfuscator, authFunc, tcpRequestFunc, tcpErrorFunc)
|
||||
}, config.DisableUDP, aclEngine, obfuscator, authFunc,
|
||||
tcpRequestFunc, tcpErrorFunc, udpRequestFunc, udpErrorFunc)
|
||||
if err != nil {
|
||||
logrus.WithField("error", err).Fatal("Failed to initialize server")
|
||||
}
|
||||
@ -130,6 +132,28 @@ func tcpErrorFunc(addr net.Addr, auth []byte, reqAddr string, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func udpRequestFunc(addr net.Addr, auth []byte, sessionID uint32) {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"src": addr.String(),
|
||||
"session": sessionID,
|
||||
}).Debug("UDP request")
|
||||
}
|
||||
|
||||
func udpErrorFunc(addr net.Addr, auth []byte, sessionID uint32, err error) {
|
||||
if err != io.EOF {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"src": addr.String(),
|
||||
"session": sessionID,
|
||||
"error": err,
|
||||
}).Info("UDP error")
|
||||
} else {
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"src": addr.String(),
|
||||
"session": sessionID,
|
||||
}).Debug("UDP EOF")
|
||||
}
|
||||
}
|
||||
|
||||
func actionToString(action acl.Action, arg string) string {
|
||||
switch action {
|
||||
case acl.ActionDirect:
|
||||
|
@ -1,6 +1,7 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
@ -14,7 +15,7 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
ErrClosed = errors.New("client closed")
|
||||
ErrClosed = errors.New("closed")
|
||||
)
|
||||
|
||||
type CongestionFactory func(refBPS uint64) congestion.CongestionControl
|
||||
@ -32,6 +33,9 @@ type Client struct {
|
||||
quicSession quic.Session
|
||||
reconnectMutex sync.Mutex
|
||||
closed bool
|
||||
|
||||
udpSessionMutex sync.RWMutex
|
||||
udpSessionMap map[uint32]chan *udpMessage
|
||||
}
|
||||
|
||||
func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
|
||||
@ -90,6 +94,8 @@ func (c *Client) connectToServer() error {
|
||||
return fmt.Errorf("auth error: %s", msg)
|
||||
}
|
||||
// All good
|
||||
c.udpSessionMap = make(map[uint32]chan *udpMessage)
|
||||
go c.handleMessage(qs)
|
||||
c.quicSession = qs
|
||||
return nil
|
||||
}
|
||||
@ -119,34 +125,59 @@ func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (bool,
|
||||
return true, sh.Message, nil
|
||||
}
|
||||
|
||||
func (c *Client) openStreamWithReconnect() (quic.Stream, net.Addr, net.Addr, error) {
|
||||
func (c *Client) handleMessage(qs quic.Session) {
|
||||
for {
|
||||
msg, err := qs.ReceiveMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
var udpMsg udpMessage
|
||||
err = struc.Unpack(bytes.NewBuffer(msg), &udpMsg)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
c.udpSessionMutex.RLock()
|
||||
ch, ok := c.udpSessionMap[udpMsg.SessionID]
|
||||
if ok {
|
||||
select {
|
||||
case ch <- &udpMsg:
|
||||
// OK
|
||||
default:
|
||||
// Silently drop the message when the channel is full
|
||||
}
|
||||
}
|
||||
c.udpSessionMutex.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) openStreamWithReconnect() (quic.Session, quic.Stream, error) {
|
||||
c.reconnectMutex.Lock()
|
||||
defer c.reconnectMutex.Unlock()
|
||||
if c.closed {
|
||||
return nil, nil, nil, ErrClosed
|
||||
return nil, nil, ErrClosed
|
||||
}
|
||||
stream, err := c.quicSession.OpenStream()
|
||||
if err == nil {
|
||||
// All good
|
||||
return stream, c.quicSession.LocalAddr(), c.quicSession.RemoteAddr(), nil
|
||||
return c.quicSession, stream, nil
|
||||
}
|
||||
// Something is wrong
|
||||
if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
|
||||
// Temporary error, just return
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
// Permanent error, need to reconnect
|
||||
if err := c.connectToServer(); err != nil {
|
||||
// Still error, oops
|
||||
return nil, nil, nil, err
|
||||
return nil, nil, err
|
||||
}
|
||||
// We are not going to try again even if it still fails the second time
|
||||
stream, err = c.quicSession.OpenStream()
|
||||
return stream, c.quicSession.LocalAddr(), c.quicSession.RemoteAddr(), err
|
||||
return c.quicSession, stream, nil
|
||||
}
|
||||
|
||||
func (c *Client) DialTCP(addr string) (net.Conn, error) {
|
||||
stream, localAddr, remoteAddr, err := c.openStreamWithReconnect()
|
||||
session, stream, err := c.openStreamWithReconnect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -172,11 +203,64 @@ func (c *Client) DialTCP(addr string) (net.Conn, error) {
|
||||
}
|
||||
return &quicConn{
|
||||
Orig: stream,
|
||||
PseudoLocalAddr: localAddr,
|
||||
PseudoRemoteAddr: remoteAddr,
|
||||
PseudoLocalAddr: session.LocalAddr(),
|
||||
PseudoRemoteAddr: session.RemoteAddr(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (c *Client) DialUDP() (UDPConn, error) {
|
||||
session, stream, err := c.openStreamWithReconnect()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Send request
|
||||
err = struc.Pack(stream, &clientRequest{
|
||||
UDP: true,
|
||||
})
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
// Read response
|
||||
var sr serverResponse
|
||||
err = struc.Unpack(stream, &sr)
|
||||
if err != nil {
|
||||
_ = stream.Close()
|
||||
return nil, err
|
||||
}
|
||||
if !sr.OK {
|
||||
_ = stream.Close()
|
||||
return nil, fmt.Errorf("connection rejected: %s", sr.Message)
|
||||
}
|
||||
|
||||
// Create a session in the map
|
||||
c.udpSessionMutex.Lock()
|
||||
nCh := make(chan *udpMessage, 1024)
|
||||
// Store the current session map for CloseFunc below
|
||||
// to ensures that we are adding and removing sessions on the same map,
|
||||
// as reconnecting will reassign the map
|
||||
sessionMap := c.udpSessionMap
|
||||
sessionMap[sr.UDPSessionID] = nCh
|
||||
c.udpSessionMutex.Unlock()
|
||||
|
||||
pktConn := &quicPktConn{
|
||||
Session: session,
|
||||
Stream: stream,
|
||||
CloseFunc: func() {
|
||||
c.udpSessionMutex.Lock()
|
||||
if ch, ok := sessionMap[sr.UDPSessionID]; ok {
|
||||
close(ch)
|
||||
delete(sessionMap, sr.UDPSessionID)
|
||||
}
|
||||
c.udpSessionMutex.Unlock()
|
||||
},
|
||||
UDPSessionID: sr.UDPSessionID,
|
||||
MsgCh: nCh,
|
||||
}
|
||||
go pktConn.Hold()
|
||||
return pktConn, nil
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
c.reconnectMutex.Lock()
|
||||
defer c.reconnectMutex.Unlock()
|
||||
@ -222,3 +306,53 @@ func (w *quicConn) SetReadDeadline(t time.Time) error {
|
||||
func (w *quicConn) SetWriteDeadline(t time.Time) error {
|
||||
return w.Orig.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
type UDPConn interface {
|
||||
ReadFrom() ([]byte, string, error)
|
||||
WriteTo([]byte, string) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
type quicPktConn struct {
|
||||
Session quic.Session
|
||||
Stream quic.Stream
|
||||
CloseFunc func()
|
||||
UDPSessionID uint32
|
||||
MsgCh <-chan *udpMessage
|
||||
}
|
||||
|
||||
func (c *quicPktConn) Hold() {
|
||||
// Hold the stream until it's closed
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
_, err := c.Stream.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
_ = c.Close()
|
||||
}
|
||||
|
||||
func (c *quicPktConn) ReadFrom() ([]byte, string, error) {
|
||||
msg := <-c.MsgCh
|
||||
if msg == nil {
|
||||
// Closed
|
||||
return nil, "", ErrClosed
|
||||
}
|
||||
return msg.Data, msg.Address, nil
|
||||
}
|
||||
|
||||
func (c *quicPktConn) WriteTo(p []byte, addr string) error {
|
||||
var msgBuf bytes.Buffer
|
||||
_ = struc.Pack(&msgBuf, &udpMessage{
|
||||
SessionID: c.UDPSessionID,
|
||||
Address: addr,
|
||||
Data: p,
|
||||
})
|
||||
return c.Session.SendMessage(msgBuf.Bytes())
|
||||
}
|
||||
|
||||
func (c *quicPktConn) Close() error {
|
||||
c.CloseFunc()
|
||||
return c.Stream.Close()
|
||||
}
|
||||
|
58
pkg/core/client_udp_test.go
Normal file
58
pkg/core/client_udp_test.go
Normal file
@ -0,0 +1,58 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClientUDP(t *testing.T) {
|
||||
client, err := NewClient("toby.moe:36713", nil, &tls.Config{
|
||||
NextProtos: []string{"hysteria"},
|
||||
MinVersion: tls.VersionTLS13,
|
||||
}, &quic.Config{
|
||||
EnableDatagrams: true,
|
||||
}, 125000, 125000, nil, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
conn, err := client.DialUDP()
|
||||
if err != nil {
|
||||
t.Fatal("conn DialUDP", err)
|
||||
}
|
||||
t.Run("8.8.8.8", func(t *testing.T) {
|
||||
dnsReq := []byte{0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x62, 0x61, 0x69, 0x64, 0x75, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01}
|
||||
err := conn.WriteTo(dnsReq, "8.8.8.8:53")
|
||||
if err != nil {
|
||||
t.Error("WriteTo", err)
|
||||
}
|
||||
buf, _, err := conn.ReadFrom()
|
||||
if err != nil {
|
||||
t.Error("ReadFrom", err)
|
||||
}
|
||||
if buf[0] != dnsReq[0] || buf[1] != dnsReq[1] {
|
||||
t.Error("invalid response")
|
||||
}
|
||||
})
|
||||
t.Run("1.1.1.1", func(t *testing.T) {
|
||||
dnsReq := []byte{0x66, 0x77, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x62, 0x61, 0x69, 0x64, 0x75, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01}
|
||||
err := conn.WriteTo(dnsReq, "1.1.1.1:53")
|
||||
if err != nil {
|
||||
t.Error("WriteTo", err)
|
||||
}
|
||||
buf, _, err := conn.ReadFrom()
|
||||
if err != nil {
|
||||
t.Error("ReadFrom", err)
|
||||
}
|
||||
if buf[0] != dnsReq[0] || buf[1] != dnsReq[1] {
|
||||
t.Error("invalid response")
|
||||
}
|
||||
})
|
||||
t.Run("Close", func(t *testing.T) {
|
||||
_ = conn.Close()
|
||||
_, _, err := conn.ReadFrom()
|
||||
if err != ErrClosed {
|
||||
t.Error("closed conn not returning the correct error")
|
||||
}
|
||||
})
|
||||
}
|
@ -43,3 +43,11 @@ type serverResponse struct {
|
||||
MessageLen uint16 `struc:"sizeof=Message"`
|
||||
Message string
|
||||
}
|
||||
|
||||
type udpMessage struct {
|
||||
SessionID uint32
|
||||
AddressLen uint16 `struc:"sizeof=Address"`
|
||||
Address string
|
||||
DataLen uint16 `struc:"sizeof=Data"`
|
||||
Data []byte
|
||||
}
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lunixbochs/struc"
|
||||
"github.com/tobyxdd/hysteria/pkg/acl"
|
||||
"github.com/tobyxdd/hysteria/pkg/utils"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
@ -17,22 +16,28 @@ const dialTimeout = 10 * time.Second
|
||||
type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string)
|
||||
type TCPRequestFunc func(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string)
|
||||
type TCPErrorFunc func(addr net.Addr, auth []byte, reqAddr string, err error)
|
||||
type UDPRequestFunc func(addr net.Addr, auth []byte, sessionID uint32)
|
||||
type UDPErrorFunc func(addr net.Addr, auth []byte, sessionID uint32, err error)
|
||||
|
||||
type Server struct {
|
||||
sendBPS, recvBPS uint64
|
||||
congestionFactory CongestionFactory
|
||||
disableUDP bool
|
||||
aclEngine *acl.Engine
|
||||
|
||||
authFunc AuthFunc
|
||||
tcpRequestFunc TCPRequestFunc
|
||||
tcpErrorFunc TCPErrorFunc
|
||||
udpRequestFunc UDPRequestFunc
|
||||
udpErrorFunc UDPErrorFunc
|
||||
|
||||
listener quic.Listener
|
||||
}
|
||||
|
||||
func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
|
||||
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, aclEngine *acl.Engine,
|
||||
obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc) (*Server, error) {
|
||||
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine,
|
||||
obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc,
|
||||
udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc) (*Server, error) {
|
||||
packetConn, err := net.ListenPacket("udp", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -53,10 +58,13 @@ func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
|
||||
sendBPS: sendBPS,
|
||||
recvBPS: recvBPS,
|
||||
congestionFactory: congestionFactory,
|
||||
disableUDP: disableUDP,
|
||||
aclEngine: aclEngine,
|
||||
authFunc: authFunc,
|
||||
tcpRequestFunc: tcpRequestFunc,
|
||||
tcpErrorFunc: tcpErrorFunc,
|
||||
udpRequestFunc: udpRequestFunc,
|
||||
udpErrorFunc: udpErrorFunc,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
@ -94,14 +102,10 @@ func (s *Server) handleClient(cs quic.Session) {
|
||||
_ = cs.CloseWithError(closeErrorCodeAuth, "auth error")
|
||||
return
|
||||
}
|
||||
// Start accepting streams
|
||||
for {
|
||||
stream, err := cs.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
go s.handleStream(cs.RemoteAddr(), auth, stream)
|
||||
}
|
||||
// Start accepting streams and messages
|
||||
sc := newServerClient(cs, auth, s.disableUDP, s.aclEngine,
|
||||
s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc)
|
||||
sc.Run()
|
||||
_ = cs.CloseWithError(closeErrorCodeGeneric, "")
|
||||
}
|
||||
|
||||
@ -143,89 +147,3 @@ func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byt
|
||||
}
|
||||
return ch.Auth, ok, nil
|
||||
}
|
||||
|
||||
func (s *Server) handleStream(remoteAddr net.Addr, auth []byte, stream quic.Stream) {
|
||||
defer stream.Close()
|
||||
// Read request
|
||||
var req clientRequest
|
||||
err := struc.Unpack(stream, &req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !req.UDP {
|
||||
// TCP connection
|
||||
s.handleTCP(remoteAddr, auth, stream, req.Address)
|
||||
} else {
|
||||
// UDP connection
|
||||
// TODO
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleTCP(remoteAddr net.Addr, auth []byte, stream quic.Stream, reqAddr string) {
|
||||
host, port, err := net.SplitHostPort(reqAddr)
|
||||
if err != nil {
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: "invalid address",
|
||||
})
|
||||
s.tcpErrorFunc(remoteAddr, auth, reqAddr, err)
|
||||
return
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil {
|
||||
// IP request, clear host for ACL engine
|
||||
host = ""
|
||||
}
|
||||
action, arg := acl.ActionDirect, ""
|
||||
if s.aclEngine != nil {
|
||||
action, arg = s.aclEngine.Lookup(host, ip)
|
||||
}
|
||||
s.tcpRequestFunc(remoteAddr, auth, reqAddr, action, arg)
|
||||
|
||||
var conn net.Conn // Connection to be piped
|
||||
switch action {
|
||||
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
|
||||
conn, err = net.DialTimeout("tcp", reqAddr, dialTimeout)
|
||||
if err != nil {
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: err.Error(),
|
||||
})
|
||||
s.tcpErrorFunc(remoteAddr, auth, reqAddr, err)
|
||||
return
|
||||
}
|
||||
case acl.ActionBlock:
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: "blocked by ACL",
|
||||
})
|
||||
return
|
||||
case acl.ActionHijack:
|
||||
hijackAddr := net.JoinHostPort(arg, port)
|
||||
conn, err = net.DialTimeout("tcp", hijackAddr, dialTimeout)
|
||||
if err != nil {
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: err.Error(),
|
||||
})
|
||||
s.tcpErrorFunc(remoteAddr, auth, reqAddr, err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: "ACL error",
|
||||
})
|
||||
return
|
||||
}
|
||||
// So far so good if we reach here
|
||||
defer conn.Close()
|
||||
err = struc.Pack(stream, &serverResponse{
|
||||
OK: true,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = utils.Pipe2Way(stream, conn)
|
||||
s.tcpErrorFunc(remoteAddr, auth, reqAddr, err)
|
||||
}
|
||||
|
269
pkg/core/server_client.go
Normal file
269
pkg/core/server_client.go
Normal file
@ -0,0 +1,269 @@
|
||||
package core
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"github.com/lucas-clemente/quic-go"
|
||||
"github.com/lunixbochs/struc"
|
||||
"github.com/tobyxdd/hysteria/pkg/acl"
|
||||
"github.com/tobyxdd/hysteria/pkg/utils"
|
||||
"net"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const udpBufferSize = 65535
|
||||
|
||||
type serverClient struct {
|
||||
CS quic.Session
|
||||
Auth []byte
|
||||
ClientAddr net.Addr
|
||||
DisableUDP bool
|
||||
ACLEngine *acl.Engine
|
||||
CTCPRequestFunc TCPRequestFunc
|
||||
CTCPErrorFunc TCPErrorFunc
|
||||
CUDPRequestFunc UDPRequestFunc
|
||||
CUDPErrorFunc UDPErrorFunc
|
||||
|
||||
udpSessionMutex sync.RWMutex
|
||||
udpSessionMap map[uint32]*net.UDPConn
|
||||
nextUDPSessionID uint32
|
||||
}
|
||||
|
||||
func newServerClient(cs quic.Session, auth []byte, disableUDP bool, ACLEngine *acl.Engine,
|
||||
CTCPRequestFunc TCPRequestFunc, CTCPErrorFunc TCPErrorFunc,
|
||||
CUDPRequestFunc UDPRequestFunc, CUDPErrorFunc UDPErrorFunc) *serverClient {
|
||||
return &serverClient{
|
||||
CS: cs,
|
||||
Auth: auth,
|
||||
ClientAddr: cs.RemoteAddr(),
|
||||
DisableUDP: disableUDP,
|
||||
ACLEngine: ACLEngine,
|
||||
CTCPRequestFunc: CTCPRequestFunc,
|
||||
CTCPErrorFunc: CTCPErrorFunc,
|
||||
CUDPRequestFunc: CUDPRequestFunc,
|
||||
CUDPErrorFunc: CUDPErrorFunc,
|
||||
udpSessionMap: make(map[uint32]*net.UDPConn),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverClient) Run() {
|
||||
if !c.DisableUDP {
|
||||
go func() {
|
||||
for {
|
||||
msg, err := c.CS.ReceiveMessage()
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
c.handleMessage(msg)
|
||||
}
|
||||
}()
|
||||
}
|
||||
for {
|
||||
stream, err := c.CS.AcceptStream(context.Background())
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
go c.handleStream(stream)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverClient) handleStream(stream quic.Stream) {
|
||||
defer stream.Close()
|
||||
// Read request
|
||||
var req clientRequest
|
||||
err := struc.Unpack(stream, &req)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if !req.UDP {
|
||||
// TCP connection
|
||||
c.handleTCP(stream, req.Address)
|
||||
} else if !c.DisableUDP {
|
||||
// UDP connection
|
||||
c.handleUDP(stream)
|
||||
} else {
|
||||
// UDP disabled
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: "UDP disabled",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverClient) handleMessage(msg []byte) {
|
||||
var udpMsg udpMessage
|
||||
err := struc.Unpack(bytes.NewBuffer(msg), &udpMsg)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.udpSessionMutex.RLock()
|
||||
conn, ok := c.udpSessionMap[udpMsg.SessionID]
|
||||
c.udpSessionMutex.RUnlock()
|
||||
if ok {
|
||||
// Session found, send the message
|
||||
host, port, err := net.SplitHostPort(udpMsg.Address)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
action, arg := acl.ActionDirect, ""
|
||||
if c.ACLEngine != nil {
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil {
|
||||
// IP request, clear host for ACL engine
|
||||
host = ""
|
||||
}
|
||||
action, arg = c.ACLEngine.Lookup(host, ip)
|
||||
}
|
||||
switch action {
|
||||
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
|
||||
addr, err := net.ResolveUDPAddr("udp", udpMsg.Address)
|
||||
if err == nil {
|
||||
_, _ = conn.WriteToUDP(udpMsg.Data, addr)
|
||||
}
|
||||
case acl.ActionBlock:
|
||||
// Do nothing
|
||||
case acl.ActionHijack:
|
||||
hijackAddr := net.JoinHostPort(arg, port)
|
||||
addr, err := net.ResolveUDPAddr("udp", hijackAddr)
|
||||
if err == nil {
|
||||
_, _ = conn.WriteToUDP(udpMsg.Data, addr)
|
||||
}
|
||||
default:
|
||||
// Do nothing
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *serverClient) handleTCP(stream quic.Stream, reqAddr string) {
|
||||
host, port, err := net.SplitHostPort(reqAddr)
|
||||
if err != nil {
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: "invalid address",
|
||||
})
|
||||
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
|
||||
return
|
||||
}
|
||||
action, arg := acl.ActionDirect, ""
|
||||
if c.ACLEngine != nil {
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil {
|
||||
// IP request, clear host for ACL engine
|
||||
host = ""
|
||||
}
|
||||
action, arg = c.ACLEngine.Lookup(host, ip)
|
||||
}
|
||||
c.CTCPRequestFunc(c.ClientAddr, c.Auth, reqAddr, action, arg)
|
||||
|
||||
var conn net.Conn // Connection to be piped
|
||||
switch action {
|
||||
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
|
||||
conn, err = net.DialTimeout("tcp", reqAddr, dialTimeout)
|
||||
if err != nil {
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: err.Error(),
|
||||
})
|
||||
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
|
||||
return
|
||||
}
|
||||
case acl.ActionBlock:
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: "blocked by ACL",
|
||||
})
|
||||
return
|
||||
case acl.ActionHijack:
|
||||
hijackAddr := net.JoinHostPort(arg, port)
|
||||
conn, err = net.DialTimeout("tcp", hijackAddr, dialTimeout)
|
||||
if err != nil {
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: err.Error(),
|
||||
})
|
||||
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
|
||||
return
|
||||
}
|
||||
default:
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: "ACL error",
|
||||
})
|
||||
return
|
||||
}
|
||||
// So far so good if we reach here
|
||||
defer conn.Close()
|
||||
err = struc.Pack(stream, &serverResponse{
|
||||
OK: true,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = utils.Pipe2Way(stream, conn)
|
||||
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
|
||||
}
|
||||
|
||||
func (c *serverClient) handleUDP(stream quic.Stream) {
|
||||
// Like in SOCKS5, the stream here is only used to maintain the UDP session. No need to read anything from it
|
||||
conn, err := net.ListenUDP("udp", nil)
|
||||
if err != nil {
|
||||
_ = struc.Pack(stream, &serverResponse{
|
||||
OK: false,
|
||||
Message: "UDP initialization failed",
|
||||
})
|
||||
c.CUDPErrorFunc(c.ClientAddr, c.Auth, 0, err)
|
||||
return
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
var id uint32
|
||||
c.udpSessionMutex.Lock()
|
||||
id = c.nextUDPSessionID
|
||||
c.udpSessionMap[id] = conn
|
||||
c.nextUDPSessionID += 1
|
||||
c.udpSessionMutex.Unlock()
|
||||
|
||||
err = struc.Pack(stream, &serverResponse{
|
||||
OK: true,
|
||||
UDPSessionID: id,
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
c.CUDPRequestFunc(c.ClientAddr, c.Auth, id)
|
||||
|
||||
// Receive UDP packets, send them to the client
|
||||
go func() {
|
||||
buf := make([]byte, udpBufferSize)
|
||||
for {
|
||||
n, rAddr, err := conn.ReadFromUDP(buf)
|
||||
if n > 0 {
|
||||
var msgBuf bytes.Buffer
|
||||
_ = struc.Pack(&msgBuf, &udpMessage{
|
||||
SessionID: id,
|
||||
Address: rAddr.String(),
|
||||
Data: buf[:n],
|
||||
})
|
||||
_ = c.CS.SendMessage(msgBuf.Bytes())
|
||||
}
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Hold the stream until it's closed by the client
|
||||
buf := make([]byte, 1024)
|
||||
for {
|
||||
_, err = stream.Read(buf)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
c.CUDPErrorFunc(c.ClientAddr, c.Auth, id, err)
|
||||
|
||||
// Remove the session
|
||||
c.udpSessionMutex.Lock()
|
||||
delete(c.udpSessionMap, id)
|
||||
c.udpSessionMutex.Unlock()
|
||||
}
|
@ -27,13 +27,13 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil {
|
||||
host = ""
|
||||
}
|
||||
// ACL
|
||||
action, arg := acl.ActionProxy, ""
|
||||
if aclEngine != nil {
|
||||
ip := net.ParseIP(host)
|
||||
if ip != nil {
|
||||
host = ""
|
||||
}
|
||||
action, arg = aclEngine.Lookup(host, ip)
|
||||
}
|
||||
newDialFunc(addr, action, arg)
|
||||
|
Loading…
x
Reference in New Issue
Block a user