Implemented UDP for both server & client

This commit is contained in:
Toby 2021-03-27 16:51:15 -07:00
parent 01c7d18211
commit 4bb5982960
8 changed files with 524 additions and 112 deletions

View File

@ -50,6 +50,7 @@ func client(config *clientConfig) {
MaxStreamReceiveWindow: config.ReceiveWindowConn, MaxStreamReceiveWindow: config.ReceiveWindowConn,
MaxConnectionReceiveWindow: config.ReceiveWindow, MaxConnectionReceiveWindow: config.ReceiveWindow,
KeepAlive: true, KeepAlive: true,
EnableDatagrams: true,
} }
if quicConfig.MaxStreamReceiveWindow == 0 { if quicConfig.MaxStreamReceiveWindow == 0 {
quicConfig.MaxStreamReceiveWindow = DefaultMaxReceiveStreamFlowControlWindow quicConfig.MaxStreamReceiveWindow = DefaultMaxReceiveStreamFlowControlWindow

View File

@ -36,6 +36,7 @@ func server(config *serverConfig) {
MaxConnectionReceiveWindow: config.ReceiveWindowClient, MaxConnectionReceiveWindow: config.ReceiveWindowClient,
MaxIncomingStreams: int64(config.MaxConnClient), MaxIncomingStreams: int64(config.MaxConnClient),
KeepAlive: true, KeepAlive: true,
EnableDatagrams: true,
} }
if quicConfig.MaxStreamReceiveWindow == 0 { if quicConfig.MaxStreamReceiveWindow == 0 {
quicConfig.MaxStreamReceiveWindow = DefaultMaxReceiveStreamFlowControlWindow quicConfig.MaxStreamReceiveWindow = DefaultMaxReceiveStreamFlowControlWindow
@ -96,7 +97,8 @@ func server(config *serverConfig) {
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
func(refBPS uint64) congestion.CongestionControl { func(refBPS uint64) congestion.CongestionControl {
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
}, aclEngine, obfuscator, authFunc, tcpRequestFunc, tcpErrorFunc) }, config.DisableUDP, aclEngine, obfuscator, authFunc,
tcpRequestFunc, tcpErrorFunc, udpRequestFunc, udpErrorFunc)
if err != nil { if err != nil {
logrus.WithField("error", err).Fatal("Failed to initialize server") logrus.WithField("error", err).Fatal("Failed to initialize server")
} }
@ -130,6 +132,28 @@ func tcpErrorFunc(addr net.Addr, auth []byte, reqAddr string, err error) {
} }
} }
func udpRequestFunc(addr net.Addr, auth []byte, sessionID uint32) {
logrus.WithFields(logrus.Fields{
"src": addr.String(),
"session": sessionID,
}).Debug("UDP request")
}
func udpErrorFunc(addr net.Addr, auth []byte, sessionID uint32, err error) {
if err != io.EOF {
logrus.WithFields(logrus.Fields{
"src": addr.String(),
"session": sessionID,
"error": err,
}).Info("UDP error")
} else {
logrus.WithFields(logrus.Fields{
"src": addr.String(),
"session": sessionID,
}).Debug("UDP EOF")
}
}
func actionToString(action acl.Action, arg string) string { func actionToString(action acl.Action, arg string) string {
switch action { switch action {
case acl.ActionDirect: case acl.ActionDirect:

View File

@ -1,6 +1,7 @@
package core package core
import ( import (
"bytes"
"context" "context"
"crypto/tls" "crypto/tls"
"errors" "errors"
@ -14,7 +15,7 @@ import (
) )
var ( var (
ErrClosed = errors.New("client closed") ErrClosed = errors.New("closed")
) )
type CongestionFactory func(refBPS uint64) congestion.CongestionControl type CongestionFactory func(refBPS uint64) congestion.CongestionControl
@ -32,6 +33,9 @@ type Client struct {
quicSession quic.Session quicSession quic.Session
reconnectMutex sync.Mutex reconnectMutex sync.Mutex
closed bool closed bool
udpSessionMutex sync.RWMutex
udpSessionMap map[uint32]chan *udpMessage
} }
func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
@ -90,6 +94,8 @@ func (c *Client) connectToServer() error {
return fmt.Errorf("auth error: %s", msg) return fmt.Errorf("auth error: %s", msg)
} }
// All good // All good
c.udpSessionMap = make(map[uint32]chan *udpMessage)
go c.handleMessage(qs)
c.quicSession = qs c.quicSession = qs
return nil return nil
} }
@ -119,34 +125,59 @@ func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (bool,
return true, sh.Message, nil return true, sh.Message, nil
} }
func (c *Client) openStreamWithReconnect() (quic.Stream, net.Addr, net.Addr, error) { func (c *Client) handleMessage(qs quic.Session) {
for {
msg, err := qs.ReceiveMessage()
if err != nil {
break
}
var udpMsg udpMessage
err = struc.Unpack(bytes.NewBuffer(msg), &udpMsg)
if err != nil {
continue
}
c.udpSessionMutex.RLock()
ch, ok := c.udpSessionMap[udpMsg.SessionID]
if ok {
select {
case ch <- &udpMsg:
// OK
default:
// Silently drop the message when the channel is full
}
}
c.udpSessionMutex.RUnlock()
}
}
func (c *Client) openStreamWithReconnect() (quic.Session, quic.Stream, error) {
c.reconnectMutex.Lock() c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock() defer c.reconnectMutex.Unlock()
if c.closed { if c.closed {
return nil, nil, nil, ErrClosed return nil, nil, ErrClosed
} }
stream, err := c.quicSession.OpenStream() stream, err := c.quicSession.OpenStream()
if err == nil { if err == nil {
// All good // All good
return stream, c.quicSession.LocalAddr(), c.quicSession.RemoteAddr(), nil return c.quicSession, stream, nil
} }
// Something is wrong // Something is wrong
if nErr, ok := err.(net.Error); ok && nErr.Temporary() { if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
// Temporary error, just return // Temporary error, just return
return nil, nil, nil, err return nil, nil, err
} }
// Permanent error, need to reconnect // Permanent error, need to reconnect
if err := c.connectToServer(); err != nil { if err := c.connectToServer(); err != nil {
// Still error, oops // Still error, oops
return nil, nil, nil, err return nil, nil, err
} }
// We are not going to try again even if it still fails the second time // We are not going to try again even if it still fails the second time
stream, err = c.quicSession.OpenStream() stream, err = c.quicSession.OpenStream()
return stream, c.quicSession.LocalAddr(), c.quicSession.RemoteAddr(), err return c.quicSession, stream, nil
} }
func (c *Client) DialTCP(addr string) (net.Conn, error) { func (c *Client) DialTCP(addr string) (net.Conn, error) {
stream, localAddr, remoteAddr, err := c.openStreamWithReconnect() session, stream, err := c.openStreamWithReconnect()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -172,11 +203,64 @@ func (c *Client) DialTCP(addr string) (net.Conn, error) {
} }
return &quicConn{ return &quicConn{
Orig: stream, Orig: stream,
PseudoLocalAddr: localAddr, PseudoLocalAddr: session.LocalAddr(),
PseudoRemoteAddr: remoteAddr, PseudoRemoteAddr: session.RemoteAddr(),
}, nil }, nil
} }
func (c *Client) DialUDP() (UDPConn, error) {
session, stream, err := c.openStreamWithReconnect()
if err != nil {
return nil, err
}
// Send request
err = struc.Pack(stream, &clientRequest{
UDP: true,
})
if err != nil {
_ = stream.Close()
return nil, err
}
// Read response
var sr serverResponse
err = struc.Unpack(stream, &sr)
if err != nil {
_ = stream.Close()
return nil, err
}
if !sr.OK {
_ = stream.Close()
return nil, fmt.Errorf("connection rejected: %s", sr.Message)
}
// Create a session in the map
c.udpSessionMutex.Lock()
nCh := make(chan *udpMessage, 1024)
// Store the current session map for CloseFunc below
// to ensures that we are adding and removing sessions on the same map,
// as reconnecting will reassign the map
sessionMap := c.udpSessionMap
sessionMap[sr.UDPSessionID] = nCh
c.udpSessionMutex.Unlock()
pktConn := &quicPktConn{
Session: session,
Stream: stream,
CloseFunc: func() {
c.udpSessionMutex.Lock()
if ch, ok := sessionMap[sr.UDPSessionID]; ok {
close(ch)
delete(sessionMap, sr.UDPSessionID)
}
c.udpSessionMutex.Unlock()
},
UDPSessionID: sr.UDPSessionID,
MsgCh: nCh,
}
go pktConn.Hold()
return pktConn, nil
}
func (c *Client) Close() error { func (c *Client) Close() error {
c.reconnectMutex.Lock() c.reconnectMutex.Lock()
defer c.reconnectMutex.Unlock() defer c.reconnectMutex.Unlock()
@ -222,3 +306,53 @@ func (w *quicConn) SetReadDeadline(t time.Time) error {
func (w *quicConn) SetWriteDeadline(t time.Time) error { func (w *quicConn) SetWriteDeadline(t time.Time) error {
return w.Orig.SetWriteDeadline(t) return w.Orig.SetWriteDeadline(t)
} }
type UDPConn interface {
ReadFrom() ([]byte, string, error)
WriteTo([]byte, string) error
Close() error
}
type quicPktConn struct {
Session quic.Session
Stream quic.Stream
CloseFunc func()
UDPSessionID uint32
MsgCh <-chan *udpMessage
}
func (c *quicPktConn) Hold() {
// Hold the stream until it's closed
buf := make([]byte, 1024)
for {
_, err := c.Stream.Read(buf)
if err != nil {
break
}
}
_ = c.Close()
}
func (c *quicPktConn) ReadFrom() ([]byte, string, error) {
msg := <-c.MsgCh
if msg == nil {
// Closed
return nil, "", ErrClosed
}
return msg.Data, msg.Address, nil
}
func (c *quicPktConn) WriteTo(p []byte, addr string) error {
var msgBuf bytes.Buffer
_ = struc.Pack(&msgBuf, &udpMessage{
SessionID: c.UDPSessionID,
Address: addr,
Data: p,
})
return c.Session.SendMessage(msgBuf.Bytes())
}
func (c *quicPktConn) Close() error {
c.CloseFunc()
return c.Stream.Close()
}

View File

@ -0,0 +1,58 @@
package core
import (
"crypto/tls"
"github.com/lucas-clemente/quic-go"
"testing"
)
func TestClientUDP(t *testing.T) {
client, err := NewClient("toby.moe:36713", nil, &tls.Config{
NextProtos: []string{"hysteria"},
MinVersion: tls.VersionTLS13,
}, &quic.Config{
EnableDatagrams: true,
}, 125000, 125000, nil, nil)
if err != nil {
t.Fatal(err)
}
conn, err := client.DialUDP()
if err != nil {
t.Fatal("conn DialUDP", err)
}
t.Run("8.8.8.8", func(t *testing.T) {
dnsReq := []byte{0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x62, 0x61, 0x69, 0x64, 0x75, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01}
err := conn.WriteTo(dnsReq, "8.8.8.8:53")
if err != nil {
t.Error("WriteTo", err)
}
buf, _, err := conn.ReadFrom()
if err != nil {
t.Error("ReadFrom", err)
}
if buf[0] != dnsReq[0] || buf[1] != dnsReq[1] {
t.Error("invalid response")
}
})
t.Run("1.1.1.1", func(t *testing.T) {
dnsReq := []byte{0x66, 0x77, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x62, 0x61, 0x69, 0x64, 0x75, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01}
err := conn.WriteTo(dnsReq, "1.1.1.1:53")
if err != nil {
t.Error("WriteTo", err)
}
buf, _, err := conn.ReadFrom()
if err != nil {
t.Error("ReadFrom", err)
}
if buf[0] != dnsReq[0] || buf[1] != dnsReq[1] {
t.Error("invalid response")
}
})
t.Run("Close", func(t *testing.T) {
_ = conn.Close()
_, _, err := conn.ReadFrom()
if err != ErrClosed {
t.Error("closed conn not returning the correct error")
}
})
}

View File

@ -43,3 +43,11 @@ type serverResponse struct {
MessageLen uint16 `struc:"sizeof=Message"` MessageLen uint16 `struc:"sizeof=Message"`
Message string Message string
} }
type udpMessage struct {
SessionID uint32
AddressLen uint16 `struc:"sizeof=Address"`
Address string
DataLen uint16 `struc:"sizeof=Data"`
Data []byte
}

View File

@ -7,7 +7,6 @@ import (
"github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go"
"github.com/lunixbochs/struc" "github.com/lunixbochs/struc"
"github.com/tobyxdd/hysteria/pkg/acl" "github.com/tobyxdd/hysteria/pkg/acl"
"github.com/tobyxdd/hysteria/pkg/utils"
"net" "net"
"time" "time"
) )
@ -17,22 +16,28 @@ const dialTimeout = 10 * time.Second
type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string) type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string)
type TCPRequestFunc func(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string) type TCPRequestFunc func(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string)
type TCPErrorFunc func(addr net.Addr, auth []byte, reqAddr string, err error) type TCPErrorFunc func(addr net.Addr, auth []byte, reqAddr string, err error)
type UDPRequestFunc func(addr net.Addr, auth []byte, sessionID uint32)
type UDPErrorFunc func(addr net.Addr, auth []byte, sessionID uint32, err error)
type Server struct { type Server struct {
sendBPS, recvBPS uint64 sendBPS, recvBPS uint64
congestionFactory CongestionFactory congestionFactory CongestionFactory
disableUDP bool
aclEngine *acl.Engine aclEngine *acl.Engine
authFunc AuthFunc authFunc AuthFunc
tcpRequestFunc TCPRequestFunc tcpRequestFunc TCPRequestFunc
tcpErrorFunc TCPErrorFunc tcpErrorFunc TCPErrorFunc
udpRequestFunc UDPRequestFunc
udpErrorFunc UDPErrorFunc
listener quic.Listener listener quic.Listener
} }
func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, aclEngine *acl.Engine, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine,
obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc) (*Server, error) { obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc,
udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc) (*Server, error) {
packetConn, err := net.ListenPacket("udp", addr) packetConn, err := net.ListenPacket("udp", addr)
if err != nil { if err != nil {
return nil, err return nil, err
@ -53,10 +58,13 @@ func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
sendBPS: sendBPS, sendBPS: sendBPS,
recvBPS: recvBPS, recvBPS: recvBPS,
congestionFactory: congestionFactory, congestionFactory: congestionFactory,
disableUDP: disableUDP,
aclEngine: aclEngine, aclEngine: aclEngine,
authFunc: authFunc, authFunc: authFunc,
tcpRequestFunc: tcpRequestFunc, tcpRequestFunc: tcpRequestFunc,
tcpErrorFunc: tcpErrorFunc, tcpErrorFunc: tcpErrorFunc,
udpRequestFunc: udpRequestFunc,
udpErrorFunc: udpErrorFunc,
} }
return s, nil return s, nil
} }
@ -94,14 +102,10 @@ func (s *Server) handleClient(cs quic.Session) {
_ = cs.CloseWithError(closeErrorCodeAuth, "auth error") _ = cs.CloseWithError(closeErrorCodeAuth, "auth error")
return return
} }
// Start accepting streams // Start accepting streams and messages
for { sc := newServerClient(cs, auth, s.disableUDP, s.aclEngine,
stream, err := cs.AcceptStream(context.Background()) s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc)
if err != nil { sc.Run()
break
}
go s.handleStream(cs.RemoteAddr(), auth, stream)
}
_ = cs.CloseWithError(closeErrorCodeGeneric, "") _ = cs.CloseWithError(closeErrorCodeGeneric, "")
} }
@ -143,89 +147,3 @@ func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byt
} }
return ch.Auth, ok, nil return ch.Auth, ok, nil
} }
func (s *Server) handleStream(remoteAddr net.Addr, auth []byte, stream quic.Stream) {
defer stream.Close()
// Read request
var req clientRequest
err := struc.Unpack(stream, &req)
if err != nil {
return
}
if !req.UDP {
// TCP connection
s.handleTCP(remoteAddr, auth, stream, req.Address)
} else {
// UDP connection
// TODO
}
}
func (s *Server) handleTCP(remoteAddr net.Addr, auth []byte, stream quic.Stream, reqAddr string) {
host, port, err := net.SplitHostPort(reqAddr)
if err != nil {
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: "invalid address",
})
s.tcpErrorFunc(remoteAddr, auth, reqAddr, err)
return
}
ip := net.ParseIP(host)
if ip != nil {
// IP request, clear host for ACL engine
host = ""
}
action, arg := acl.ActionDirect, ""
if s.aclEngine != nil {
action, arg = s.aclEngine.Lookup(host, ip)
}
s.tcpRequestFunc(remoteAddr, auth, reqAddr, action, arg)
var conn net.Conn // Connection to be piped
switch action {
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
conn, err = net.DialTimeout("tcp", reqAddr, dialTimeout)
if err != nil {
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: err.Error(),
})
s.tcpErrorFunc(remoteAddr, auth, reqAddr, err)
return
}
case acl.ActionBlock:
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: "blocked by ACL",
})
return
case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, port)
conn, err = net.DialTimeout("tcp", hijackAddr, dialTimeout)
if err != nil {
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: err.Error(),
})
s.tcpErrorFunc(remoteAddr, auth, reqAddr, err)
return
}
default:
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: "ACL error",
})
return
}
// So far so good if we reach here
defer conn.Close()
err = struc.Pack(stream, &serverResponse{
OK: true,
})
if err != nil {
return
}
err = utils.Pipe2Way(stream, conn)
s.tcpErrorFunc(remoteAddr, auth, reqAddr, err)
}

269
pkg/core/server_client.go Normal file
View File

@ -0,0 +1,269 @@
package core
import (
"bytes"
"context"
"github.com/lucas-clemente/quic-go"
"github.com/lunixbochs/struc"
"github.com/tobyxdd/hysteria/pkg/acl"
"github.com/tobyxdd/hysteria/pkg/utils"
"net"
"sync"
)
const udpBufferSize = 65535
type serverClient struct {
CS quic.Session
Auth []byte
ClientAddr net.Addr
DisableUDP bool
ACLEngine *acl.Engine
CTCPRequestFunc TCPRequestFunc
CTCPErrorFunc TCPErrorFunc
CUDPRequestFunc UDPRequestFunc
CUDPErrorFunc UDPErrorFunc
udpSessionMutex sync.RWMutex
udpSessionMap map[uint32]*net.UDPConn
nextUDPSessionID uint32
}
func newServerClient(cs quic.Session, auth []byte, disableUDP bool, ACLEngine *acl.Engine,
CTCPRequestFunc TCPRequestFunc, CTCPErrorFunc TCPErrorFunc,
CUDPRequestFunc UDPRequestFunc, CUDPErrorFunc UDPErrorFunc) *serverClient {
return &serverClient{
CS: cs,
Auth: auth,
ClientAddr: cs.RemoteAddr(),
DisableUDP: disableUDP,
ACLEngine: ACLEngine,
CTCPRequestFunc: CTCPRequestFunc,
CTCPErrorFunc: CTCPErrorFunc,
CUDPRequestFunc: CUDPRequestFunc,
CUDPErrorFunc: CUDPErrorFunc,
udpSessionMap: make(map[uint32]*net.UDPConn),
}
}
func (c *serverClient) Run() {
if !c.DisableUDP {
go func() {
for {
msg, err := c.CS.ReceiveMessage()
if err != nil {
break
}
c.handleMessage(msg)
}
}()
}
for {
stream, err := c.CS.AcceptStream(context.Background())
if err != nil {
break
}
go c.handleStream(stream)
}
}
func (c *serverClient) handleStream(stream quic.Stream) {
defer stream.Close()
// Read request
var req clientRequest
err := struc.Unpack(stream, &req)
if err != nil {
return
}
if !req.UDP {
// TCP connection
c.handleTCP(stream, req.Address)
} else if !c.DisableUDP {
// UDP connection
c.handleUDP(stream)
} else {
// UDP disabled
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: "UDP disabled",
})
}
}
func (c *serverClient) handleMessage(msg []byte) {
var udpMsg udpMessage
err := struc.Unpack(bytes.NewBuffer(msg), &udpMsg)
if err != nil {
return
}
c.udpSessionMutex.RLock()
conn, ok := c.udpSessionMap[udpMsg.SessionID]
c.udpSessionMutex.RUnlock()
if ok {
// Session found, send the message
host, port, err := net.SplitHostPort(udpMsg.Address)
if err != nil {
return
}
action, arg := acl.ActionDirect, ""
if c.ACLEngine != nil {
ip := net.ParseIP(host)
if ip != nil {
// IP request, clear host for ACL engine
host = ""
}
action, arg = c.ACLEngine.Lookup(host, ip)
}
switch action {
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
addr, err := net.ResolveUDPAddr("udp", udpMsg.Address)
if err == nil {
_, _ = conn.WriteToUDP(udpMsg.Data, addr)
}
case acl.ActionBlock:
// Do nothing
case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, port)
addr, err := net.ResolveUDPAddr("udp", hijackAddr)
if err == nil {
_, _ = conn.WriteToUDP(udpMsg.Data, addr)
}
default:
// Do nothing
}
}
}
func (c *serverClient) handleTCP(stream quic.Stream, reqAddr string) {
host, port, err := net.SplitHostPort(reqAddr)
if err != nil {
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: "invalid address",
})
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
return
}
action, arg := acl.ActionDirect, ""
if c.ACLEngine != nil {
ip := net.ParseIP(host)
if ip != nil {
// IP request, clear host for ACL engine
host = ""
}
action, arg = c.ACLEngine.Lookup(host, ip)
}
c.CTCPRequestFunc(c.ClientAddr, c.Auth, reqAddr, action, arg)
var conn net.Conn // Connection to be piped
switch action {
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
conn, err = net.DialTimeout("tcp", reqAddr, dialTimeout)
if err != nil {
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: err.Error(),
})
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
return
}
case acl.ActionBlock:
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: "blocked by ACL",
})
return
case acl.ActionHijack:
hijackAddr := net.JoinHostPort(arg, port)
conn, err = net.DialTimeout("tcp", hijackAddr, dialTimeout)
if err != nil {
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: err.Error(),
})
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
return
}
default:
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: "ACL error",
})
return
}
// So far so good if we reach here
defer conn.Close()
err = struc.Pack(stream, &serverResponse{
OK: true,
})
if err != nil {
return
}
err = utils.Pipe2Way(stream, conn)
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
}
func (c *serverClient) handleUDP(stream quic.Stream) {
// Like in SOCKS5, the stream here is only used to maintain the UDP session. No need to read anything from it
conn, err := net.ListenUDP("udp", nil)
if err != nil {
_ = struc.Pack(stream, &serverResponse{
OK: false,
Message: "UDP initialization failed",
})
c.CUDPErrorFunc(c.ClientAddr, c.Auth, 0, err)
return
}
defer conn.Close()
var id uint32
c.udpSessionMutex.Lock()
id = c.nextUDPSessionID
c.udpSessionMap[id] = conn
c.nextUDPSessionID += 1
c.udpSessionMutex.Unlock()
err = struc.Pack(stream, &serverResponse{
OK: true,
UDPSessionID: id,
})
if err != nil {
return
}
c.CUDPRequestFunc(c.ClientAddr, c.Auth, id)
// Receive UDP packets, send them to the client
go func() {
buf := make([]byte, udpBufferSize)
for {
n, rAddr, err := conn.ReadFromUDP(buf)
if n > 0 {
var msgBuf bytes.Buffer
_ = struc.Pack(&msgBuf, &udpMessage{
SessionID: id,
Address: rAddr.String(),
Data: buf[:n],
})
_ = c.CS.SendMessage(msgBuf.Bytes())
}
if err != nil {
break
}
}
}()
// Hold the stream until it's closed by the client
buf := make([]byte, 1024)
for {
_, err = stream.Read(buf)
if err != nil {
break
}
}
c.CUDPErrorFunc(c.ClientAddr, c.Auth, id, err)
// Remove the session
c.udpSessionMutex.Lock()
delete(c.udpSessionMap, id)
c.udpSessionMutex.Unlock()
}

View File

@ -27,13 +27,13 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng
if err != nil { if err != nil {
return nil, err return nil, err
} }
ip := net.ParseIP(host)
if ip != nil {
host = ""
}
// ACL // ACL
action, arg := acl.ActionProxy, "" action, arg := acl.ActionProxy, ""
if aclEngine != nil { if aclEngine != nil {
ip := net.ParseIP(host)
if ip != nil {
host = ""
}
action, arg = aclEngine.Lookup(host, ip) action, arg = aclEngine.Lookup(host, ip)
} }
newDialFunc(addr, action, arg) newDialFunc(addr, action, arg)