mirror of
https://github.com/cmz0228/hysteria-dev.git
synced 2025-06-19 10:49:59 +00:00
Implemented UDP for both server & client
This commit is contained in:
parent
01c7d18211
commit
4bb5982960
@ -50,6 +50,7 @@ func client(config *clientConfig) {
|
|||||||
MaxStreamReceiveWindow: config.ReceiveWindowConn,
|
MaxStreamReceiveWindow: config.ReceiveWindowConn,
|
||||||
MaxConnectionReceiveWindow: config.ReceiveWindow,
|
MaxConnectionReceiveWindow: config.ReceiveWindow,
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
|
EnableDatagrams: true,
|
||||||
}
|
}
|
||||||
if quicConfig.MaxStreamReceiveWindow == 0 {
|
if quicConfig.MaxStreamReceiveWindow == 0 {
|
||||||
quicConfig.MaxStreamReceiveWindow = DefaultMaxReceiveStreamFlowControlWindow
|
quicConfig.MaxStreamReceiveWindow = DefaultMaxReceiveStreamFlowControlWindow
|
||||||
|
@ -36,6 +36,7 @@ func server(config *serverConfig) {
|
|||||||
MaxConnectionReceiveWindow: config.ReceiveWindowClient,
|
MaxConnectionReceiveWindow: config.ReceiveWindowClient,
|
||||||
MaxIncomingStreams: int64(config.MaxConnClient),
|
MaxIncomingStreams: int64(config.MaxConnClient),
|
||||||
KeepAlive: true,
|
KeepAlive: true,
|
||||||
|
EnableDatagrams: true,
|
||||||
}
|
}
|
||||||
if quicConfig.MaxStreamReceiveWindow == 0 {
|
if quicConfig.MaxStreamReceiveWindow == 0 {
|
||||||
quicConfig.MaxStreamReceiveWindow = DefaultMaxReceiveStreamFlowControlWindow
|
quicConfig.MaxStreamReceiveWindow = DefaultMaxReceiveStreamFlowControlWindow
|
||||||
@ -96,7 +97,8 @@ func server(config *serverConfig) {
|
|||||||
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
|
uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps,
|
||||||
func(refBPS uint64) congestion.CongestionControl {
|
func(refBPS uint64) congestion.CongestionControl {
|
||||||
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
|
return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS))
|
||||||
}, aclEngine, obfuscator, authFunc, tcpRequestFunc, tcpErrorFunc)
|
}, config.DisableUDP, aclEngine, obfuscator, authFunc,
|
||||||
|
tcpRequestFunc, tcpErrorFunc, udpRequestFunc, udpErrorFunc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logrus.WithField("error", err).Fatal("Failed to initialize server")
|
logrus.WithField("error", err).Fatal("Failed to initialize server")
|
||||||
}
|
}
|
||||||
@ -130,6 +132,28 @@ func tcpErrorFunc(addr net.Addr, auth []byte, reqAddr string, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func udpRequestFunc(addr net.Addr, auth []byte, sessionID uint32) {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"src": addr.String(),
|
||||||
|
"session": sessionID,
|
||||||
|
}).Debug("UDP request")
|
||||||
|
}
|
||||||
|
|
||||||
|
func udpErrorFunc(addr net.Addr, auth []byte, sessionID uint32, err error) {
|
||||||
|
if err != io.EOF {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"src": addr.String(),
|
||||||
|
"session": sessionID,
|
||||||
|
"error": err,
|
||||||
|
}).Info("UDP error")
|
||||||
|
} else {
|
||||||
|
logrus.WithFields(logrus.Fields{
|
||||||
|
"src": addr.String(),
|
||||||
|
"session": sessionID,
|
||||||
|
}).Debug("UDP EOF")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func actionToString(action acl.Action, arg string) string {
|
func actionToString(action acl.Action, arg string) string {
|
||||||
switch action {
|
switch action {
|
||||||
case acl.ActionDirect:
|
case acl.ActionDirect:
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package core
|
package core
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
@ -14,7 +15,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ErrClosed = errors.New("client closed")
|
ErrClosed = errors.New("closed")
|
||||||
)
|
)
|
||||||
|
|
||||||
type CongestionFactory func(refBPS uint64) congestion.CongestionControl
|
type CongestionFactory func(refBPS uint64) congestion.CongestionControl
|
||||||
@ -32,6 +33,9 @@ type Client struct {
|
|||||||
quicSession quic.Session
|
quicSession quic.Session
|
||||||
reconnectMutex sync.Mutex
|
reconnectMutex sync.Mutex
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
|
udpSessionMutex sync.RWMutex
|
||||||
|
udpSessionMap map[uint32]chan *udpMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
|
func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config,
|
||||||
@ -90,6 +94,8 @@ func (c *Client) connectToServer() error {
|
|||||||
return fmt.Errorf("auth error: %s", msg)
|
return fmt.Errorf("auth error: %s", msg)
|
||||||
}
|
}
|
||||||
// All good
|
// All good
|
||||||
|
c.udpSessionMap = make(map[uint32]chan *udpMessage)
|
||||||
|
go c.handleMessage(qs)
|
||||||
c.quicSession = qs
|
c.quicSession = qs
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@ -119,34 +125,59 @@ func (c *Client) handleControlStream(qs quic.Session, stream quic.Stream) (bool,
|
|||||||
return true, sh.Message, nil
|
return true, sh.Message, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) openStreamWithReconnect() (quic.Stream, net.Addr, net.Addr, error) {
|
func (c *Client) handleMessage(qs quic.Session) {
|
||||||
|
for {
|
||||||
|
msg, err := qs.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
var udpMsg udpMessage
|
||||||
|
err = struc.Unpack(bytes.NewBuffer(msg), &udpMsg)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
c.udpSessionMutex.RLock()
|
||||||
|
ch, ok := c.udpSessionMap[udpMsg.SessionID]
|
||||||
|
if ok {
|
||||||
|
select {
|
||||||
|
case ch <- &udpMsg:
|
||||||
|
// OK
|
||||||
|
default:
|
||||||
|
// Silently drop the message when the channel is full
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.udpSessionMutex.RUnlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) openStreamWithReconnect() (quic.Session, quic.Stream, error) {
|
||||||
c.reconnectMutex.Lock()
|
c.reconnectMutex.Lock()
|
||||||
defer c.reconnectMutex.Unlock()
|
defer c.reconnectMutex.Unlock()
|
||||||
if c.closed {
|
if c.closed {
|
||||||
return nil, nil, nil, ErrClosed
|
return nil, nil, ErrClosed
|
||||||
}
|
}
|
||||||
stream, err := c.quicSession.OpenStream()
|
stream, err := c.quicSession.OpenStream()
|
||||||
if err == nil {
|
if err == nil {
|
||||||
// All good
|
// All good
|
||||||
return stream, c.quicSession.LocalAddr(), c.quicSession.RemoteAddr(), nil
|
return c.quicSession, stream, nil
|
||||||
}
|
}
|
||||||
// Something is wrong
|
// Something is wrong
|
||||||
if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
|
if nErr, ok := err.(net.Error); ok && nErr.Temporary() {
|
||||||
// Temporary error, just return
|
// Temporary error, just return
|
||||||
return nil, nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
// Permanent error, need to reconnect
|
// Permanent error, need to reconnect
|
||||||
if err := c.connectToServer(); err != nil {
|
if err := c.connectToServer(); err != nil {
|
||||||
// Still error, oops
|
// Still error, oops
|
||||||
return nil, nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
// We are not going to try again even if it still fails the second time
|
// We are not going to try again even if it still fails the second time
|
||||||
stream, err = c.quicSession.OpenStream()
|
stream, err = c.quicSession.OpenStream()
|
||||||
return stream, c.quicSession.LocalAddr(), c.quicSession.RemoteAddr(), err
|
return c.quicSession, stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) DialTCP(addr string) (net.Conn, error) {
|
func (c *Client) DialTCP(addr string) (net.Conn, error) {
|
||||||
stream, localAddr, remoteAddr, err := c.openStreamWithReconnect()
|
session, stream, err := c.openStreamWithReconnect()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -172,11 +203,64 @@ func (c *Client) DialTCP(addr string) (net.Conn, error) {
|
|||||||
}
|
}
|
||||||
return &quicConn{
|
return &quicConn{
|
||||||
Orig: stream,
|
Orig: stream,
|
||||||
PseudoLocalAddr: localAddr,
|
PseudoLocalAddr: session.LocalAddr(),
|
||||||
PseudoRemoteAddr: remoteAddr,
|
PseudoRemoteAddr: session.RemoteAddr(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) DialUDP() (UDPConn, error) {
|
||||||
|
session, stream, err := c.openStreamWithReconnect()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Send request
|
||||||
|
err = struc.Pack(stream, &clientRequest{
|
||||||
|
UDP: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
_ = stream.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
// Read response
|
||||||
|
var sr serverResponse
|
||||||
|
err = struc.Unpack(stream, &sr)
|
||||||
|
if err != nil {
|
||||||
|
_ = stream.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !sr.OK {
|
||||||
|
_ = stream.Close()
|
||||||
|
return nil, fmt.Errorf("connection rejected: %s", sr.Message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a session in the map
|
||||||
|
c.udpSessionMutex.Lock()
|
||||||
|
nCh := make(chan *udpMessage, 1024)
|
||||||
|
// Store the current session map for CloseFunc below
|
||||||
|
// to ensures that we are adding and removing sessions on the same map,
|
||||||
|
// as reconnecting will reassign the map
|
||||||
|
sessionMap := c.udpSessionMap
|
||||||
|
sessionMap[sr.UDPSessionID] = nCh
|
||||||
|
c.udpSessionMutex.Unlock()
|
||||||
|
|
||||||
|
pktConn := &quicPktConn{
|
||||||
|
Session: session,
|
||||||
|
Stream: stream,
|
||||||
|
CloseFunc: func() {
|
||||||
|
c.udpSessionMutex.Lock()
|
||||||
|
if ch, ok := sessionMap[sr.UDPSessionID]; ok {
|
||||||
|
close(ch)
|
||||||
|
delete(sessionMap, sr.UDPSessionID)
|
||||||
|
}
|
||||||
|
c.udpSessionMutex.Unlock()
|
||||||
|
},
|
||||||
|
UDPSessionID: sr.UDPSessionID,
|
||||||
|
MsgCh: nCh,
|
||||||
|
}
|
||||||
|
go pktConn.Hold()
|
||||||
|
return pktConn, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
c.reconnectMutex.Lock()
|
c.reconnectMutex.Lock()
|
||||||
defer c.reconnectMutex.Unlock()
|
defer c.reconnectMutex.Unlock()
|
||||||
@ -222,3 +306,53 @@ func (w *quicConn) SetReadDeadline(t time.Time) error {
|
|||||||
func (w *quicConn) SetWriteDeadline(t time.Time) error {
|
func (w *quicConn) SetWriteDeadline(t time.Time) error {
|
||||||
return w.Orig.SetWriteDeadline(t)
|
return w.Orig.SetWriteDeadline(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type UDPConn interface {
|
||||||
|
ReadFrom() ([]byte, string, error)
|
||||||
|
WriteTo([]byte, string) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
type quicPktConn struct {
|
||||||
|
Session quic.Session
|
||||||
|
Stream quic.Stream
|
||||||
|
CloseFunc func()
|
||||||
|
UDPSessionID uint32
|
||||||
|
MsgCh <-chan *udpMessage
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *quicPktConn) Hold() {
|
||||||
|
// Hold the stream until it's closed
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
for {
|
||||||
|
_, err := c.Stream.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ = c.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *quicPktConn) ReadFrom() ([]byte, string, error) {
|
||||||
|
msg := <-c.MsgCh
|
||||||
|
if msg == nil {
|
||||||
|
// Closed
|
||||||
|
return nil, "", ErrClosed
|
||||||
|
}
|
||||||
|
return msg.Data, msg.Address, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *quicPktConn) WriteTo(p []byte, addr string) error {
|
||||||
|
var msgBuf bytes.Buffer
|
||||||
|
_ = struc.Pack(&msgBuf, &udpMessage{
|
||||||
|
SessionID: c.UDPSessionID,
|
||||||
|
Address: addr,
|
||||||
|
Data: p,
|
||||||
|
})
|
||||||
|
return c.Session.SendMessage(msgBuf.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *quicPktConn) Close() error {
|
||||||
|
c.CloseFunc()
|
||||||
|
return c.Stream.Close()
|
||||||
|
}
|
||||||
|
58
pkg/core/client_udp_test.go
Normal file
58
pkg/core/client_udp_test.go
Normal file
@ -0,0 +1,58 @@
|
|||||||
|
package core
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestClientUDP(t *testing.T) {
|
||||||
|
client, err := NewClient("toby.moe:36713", nil, &tls.Config{
|
||||||
|
NextProtos: []string{"hysteria"},
|
||||||
|
MinVersion: tls.VersionTLS13,
|
||||||
|
}, &quic.Config{
|
||||||
|
EnableDatagrams: true,
|
||||||
|
}, 125000, 125000, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
conn, err := client.DialUDP()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal("conn DialUDP", err)
|
||||||
|
}
|
||||||
|
t.Run("8.8.8.8", func(t *testing.T) {
|
||||||
|
dnsReq := []byte{0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x62, 0x61, 0x69, 0x64, 0x75, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01}
|
||||||
|
err := conn.WriteTo(dnsReq, "8.8.8.8:53")
|
||||||
|
if err != nil {
|
||||||
|
t.Error("WriteTo", err)
|
||||||
|
}
|
||||||
|
buf, _, err := conn.ReadFrom()
|
||||||
|
if err != nil {
|
||||||
|
t.Error("ReadFrom", err)
|
||||||
|
}
|
||||||
|
if buf[0] != dnsReq[0] || buf[1] != dnsReq[1] {
|
||||||
|
t.Error("invalid response")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("1.1.1.1", func(t *testing.T) {
|
||||||
|
dnsReq := []byte{0x66, 0x77, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x05, 0x62, 0x61, 0x69, 0x64, 0x75, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01}
|
||||||
|
err := conn.WriteTo(dnsReq, "1.1.1.1:53")
|
||||||
|
if err != nil {
|
||||||
|
t.Error("WriteTo", err)
|
||||||
|
}
|
||||||
|
buf, _, err := conn.ReadFrom()
|
||||||
|
if err != nil {
|
||||||
|
t.Error("ReadFrom", err)
|
||||||
|
}
|
||||||
|
if buf[0] != dnsReq[0] || buf[1] != dnsReq[1] {
|
||||||
|
t.Error("invalid response")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
t.Run("Close", func(t *testing.T) {
|
||||||
|
_ = conn.Close()
|
||||||
|
_, _, err := conn.ReadFrom()
|
||||||
|
if err != ErrClosed {
|
||||||
|
t.Error("closed conn not returning the correct error")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
@ -43,3 +43,11 @@ type serverResponse struct {
|
|||||||
MessageLen uint16 `struc:"sizeof=Message"`
|
MessageLen uint16 `struc:"sizeof=Message"`
|
||||||
Message string
|
Message string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type udpMessage struct {
|
||||||
|
SessionID uint32
|
||||||
|
AddressLen uint16 `struc:"sizeof=Address"`
|
||||||
|
Address string
|
||||||
|
DataLen uint16 `struc:"sizeof=Data"`
|
||||||
|
Data []byte
|
||||||
|
}
|
||||||
|
@ -7,7 +7,6 @@ import (
|
|||||||
"github.com/lucas-clemente/quic-go"
|
"github.com/lucas-clemente/quic-go"
|
||||||
"github.com/lunixbochs/struc"
|
"github.com/lunixbochs/struc"
|
||||||
"github.com/tobyxdd/hysteria/pkg/acl"
|
"github.com/tobyxdd/hysteria/pkg/acl"
|
||||||
"github.com/tobyxdd/hysteria/pkg/utils"
|
|
||||||
"net"
|
"net"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
@ -17,22 +16,28 @@ const dialTimeout = 10 * time.Second
|
|||||||
type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string)
|
type AuthFunc func(addr net.Addr, auth []byte, sSend uint64, sRecv uint64) (bool, string)
|
||||||
type TCPRequestFunc func(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string)
|
type TCPRequestFunc func(addr net.Addr, auth []byte, reqAddr string, action acl.Action, arg string)
|
||||||
type TCPErrorFunc func(addr net.Addr, auth []byte, reqAddr string, err error)
|
type TCPErrorFunc func(addr net.Addr, auth []byte, reqAddr string, err error)
|
||||||
|
type UDPRequestFunc func(addr net.Addr, auth []byte, sessionID uint32)
|
||||||
|
type UDPErrorFunc func(addr net.Addr, auth []byte, sessionID uint32, err error)
|
||||||
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
sendBPS, recvBPS uint64
|
sendBPS, recvBPS uint64
|
||||||
congestionFactory CongestionFactory
|
congestionFactory CongestionFactory
|
||||||
|
disableUDP bool
|
||||||
aclEngine *acl.Engine
|
aclEngine *acl.Engine
|
||||||
|
|
||||||
authFunc AuthFunc
|
authFunc AuthFunc
|
||||||
tcpRequestFunc TCPRequestFunc
|
tcpRequestFunc TCPRequestFunc
|
||||||
tcpErrorFunc TCPErrorFunc
|
tcpErrorFunc TCPErrorFunc
|
||||||
|
udpRequestFunc UDPRequestFunc
|
||||||
|
udpErrorFunc UDPErrorFunc
|
||||||
|
|
||||||
listener quic.Listener
|
listener quic.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
|
func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
|
||||||
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, aclEngine *acl.Engine,
|
sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine,
|
||||||
obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc) (*Server, error) {
|
obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc,
|
||||||
|
udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc) (*Server, error) {
|
||||||
packetConn, err := net.ListenPacket("udp", addr)
|
packetConn, err := net.ListenPacket("udp", addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@ -53,10 +58,13 @@ func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config,
|
|||||||
sendBPS: sendBPS,
|
sendBPS: sendBPS,
|
||||||
recvBPS: recvBPS,
|
recvBPS: recvBPS,
|
||||||
congestionFactory: congestionFactory,
|
congestionFactory: congestionFactory,
|
||||||
|
disableUDP: disableUDP,
|
||||||
aclEngine: aclEngine,
|
aclEngine: aclEngine,
|
||||||
authFunc: authFunc,
|
authFunc: authFunc,
|
||||||
tcpRequestFunc: tcpRequestFunc,
|
tcpRequestFunc: tcpRequestFunc,
|
||||||
tcpErrorFunc: tcpErrorFunc,
|
tcpErrorFunc: tcpErrorFunc,
|
||||||
|
udpRequestFunc: udpRequestFunc,
|
||||||
|
udpErrorFunc: udpErrorFunc,
|
||||||
}
|
}
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
@ -94,14 +102,10 @@ func (s *Server) handleClient(cs quic.Session) {
|
|||||||
_ = cs.CloseWithError(closeErrorCodeAuth, "auth error")
|
_ = cs.CloseWithError(closeErrorCodeAuth, "auth error")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Start accepting streams
|
// Start accepting streams and messages
|
||||||
for {
|
sc := newServerClient(cs, auth, s.disableUDP, s.aclEngine,
|
||||||
stream, err := cs.AcceptStream(context.Background())
|
s.tcpRequestFunc, s.tcpErrorFunc, s.udpRequestFunc, s.udpErrorFunc)
|
||||||
if err != nil {
|
sc.Run()
|
||||||
break
|
|
||||||
}
|
|
||||||
go s.handleStream(cs.RemoteAddr(), auth, stream)
|
|
||||||
}
|
|
||||||
_ = cs.CloseWithError(closeErrorCodeGeneric, "")
|
_ = cs.CloseWithError(closeErrorCodeGeneric, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -143,89 +147,3 @@ func (s *Server) handleControlStream(cs quic.Session, stream quic.Stream) ([]byt
|
|||||||
}
|
}
|
||||||
return ch.Auth, ok, nil
|
return ch.Auth, ok, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleStream(remoteAddr net.Addr, auth []byte, stream quic.Stream) {
|
|
||||||
defer stream.Close()
|
|
||||||
// Read request
|
|
||||||
var req clientRequest
|
|
||||||
err := struc.Unpack(stream, &req)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if !req.UDP {
|
|
||||||
// TCP connection
|
|
||||||
s.handleTCP(remoteAddr, auth, stream, req.Address)
|
|
||||||
} else {
|
|
||||||
// UDP connection
|
|
||||||
// TODO
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) handleTCP(remoteAddr net.Addr, auth []byte, stream quic.Stream, reqAddr string) {
|
|
||||||
host, port, err := net.SplitHostPort(reqAddr)
|
|
||||||
if err != nil {
|
|
||||||
_ = struc.Pack(stream, &serverResponse{
|
|
||||||
OK: false,
|
|
||||||
Message: "invalid address",
|
|
||||||
})
|
|
||||||
s.tcpErrorFunc(remoteAddr, auth, reqAddr, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ip := net.ParseIP(host)
|
|
||||||
if ip != nil {
|
|
||||||
// IP request, clear host for ACL engine
|
|
||||||
host = ""
|
|
||||||
}
|
|
||||||
action, arg := acl.ActionDirect, ""
|
|
||||||
if s.aclEngine != nil {
|
|
||||||
action, arg = s.aclEngine.Lookup(host, ip)
|
|
||||||
}
|
|
||||||
s.tcpRequestFunc(remoteAddr, auth, reqAddr, action, arg)
|
|
||||||
|
|
||||||
var conn net.Conn // Connection to be piped
|
|
||||||
switch action {
|
|
||||||
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
|
|
||||||
conn, err = net.DialTimeout("tcp", reqAddr, dialTimeout)
|
|
||||||
if err != nil {
|
|
||||||
_ = struc.Pack(stream, &serverResponse{
|
|
||||||
OK: false,
|
|
||||||
Message: err.Error(),
|
|
||||||
})
|
|
||||||
s.tcpErrorFunc(remoteAddr, auth, reqAddr, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
case acl.ActionBlock:
|
|
||||||
_ = struc.Pack(stream, &serverResponse{
|
|
||||||
OK: false,
|
|
||||||
Message: "blocked by ACL",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
case acl.ActionHijack:
|
|
||||||
hijackAddr := net.JoinHostPort(arg, port)
|
|
||||||
conn, err = net.DialTimeout("tcp", hijackAddr, dialTimeout)
|
|
||||||
if err != nil {
|
|
||||||
_ = struc.Pack(stream, &serverResponse{
|
|
||||||
OK: false,
|
|
||||||
Message: err.Error(),
|
|
||||||
})
|
|
||||||
s.tcpErrorFunc(remoteAddr, auth, reqAddr, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
_ = struc.Pack(stream, &serverResponse{
|
|
||||||
OK: false,
|
|
||||||
Message: "ACL error",
|
|
||||||
})
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// So far so good if we reach here
|
|
||||||
defer conn.Close()
|
|
||||||
err = struc.Pack(stream, &serverResponse{
|
|
||||||
OK: true,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = utils.Pipe2Way(stream, conn)
|
|
||||||
s.tcpErrorFunc(remoteAddr, auth, reqAddr, err)
|
|
||||||
}
|
|
||||||
|
269
pkg/core/server_client.go
Normal file
269
pkg/core/server_client.go
Normal file
@ -0,0 +1,269 @@
|
|||||||
|
package core
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"github.com/lucas-clemente/quic-go"
|
||||||
|
"github.com/lunixbochs/struc"
|
||||||
|
"github.com/tobyxdd/hysteria/pkg/acl"
|
||||||
|
"github.com/tobyxdd/hysteria/pkg/utils"
|
||||||
|
"net"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
const udpBufferSize = 65535
|
||||||
|
|
||||||
|
type serverClient struct {
|
||||||
|
CS quic.Session
|
||||||
|
Auth []byte
|
||||||
|
ClientAddr net.Addr
|
||||||
|
DisableUDP bool
|
||||||
|
ACLEngine *acl.Engine
|
||||||
|
CTCPRequestFunc TCPRequestFunc
|
||||||
|
CTCPErrorFunc TCPErrorFunc
|
||||||
|
CUDPRequestFunc UDPRequestFunc
|
||||||
|
CUDPErrorFunc UDPErrorFunc
|
||||||
|
|
||||||
|
udpSessionMutex sync.RWMutex
|
||||||
|
udpSessionMap map[uint32]*net.UDPConn
|
||||||
|
nextUDPSessionID uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newServerClient(cs quic.Session, auth []byte, disableUDP bool, ACLEngine *acl.Engine,
|
||||||
|
CTCPRequestFunc TCPRequestFunc, CTCPErrorFunc TCPErrorFunc,
|
||||||
|
CUDPRequestFunc UDPRequestFunc, CUDPErrorFunc UDPErrorFunc) *serverClient {
|
||||||
|
return &serverClient{
|
||||||
|
CS: cs,
|
||||||
|
Auth: auth,
|
||||||
|
ClientAddr: cs.RemoteAddr(),
|
||||||
|
DisableUDP: disableUDP,
|
||||||
|
ACLEngine: ACLEngine,
|
||||||
|
CTCPRequestFunc: CTCPRequestFunc,
|
||||||
|
CTCPErrorFunc: CTCPErrorFunc,
|
||||||
|
CUDPRequestFunc: CUDPRequestFunc,
|
||||||
|
CUDPErrorFunc: CUDPErrorFunc,
|
||||||
|
udpSessionMap: make(map[uint32]*net.UDPConn),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverClient) Run() {
|
||||||
|
if !c.DisableUDP {
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
msg, err := c.CS.ReceiveMessage()
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
c.handleMessage(msg)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
stream, err := c.CS.AcceptStream(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
go c.handleStream(stream)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverClient) handleStream(stream quic.Stream) {
|
||||||
|
defer stream.Close()
|
||||||
|
// Read request
|
||||||
|
var req clientRequest
|
||||||
|
err := struc.Unpack(stream, &req)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !req.UDP {
|
||||||
|
// TCP connection
|
||||||
|
c.handleTCP(stream, req.Address)
|
||||||
|
} else if !c.DisableUDP {
|
||||||
|
// UDP connection
|
||||||
|
c.handleUDP(stream)
|
||||||
|
} else {
|
||||||
|
// UDP disabled
|
||||||
|
_ = struc.Pack(stream, &serverResponse{
|
||||||
|
OK: false,
|
||||||
|
Message: "UDP disabled",
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverClient) handleMessage(msg []byte) {
|
||||||
|
var udpMsg udpMessage
|
||||||
|
err := struc.Unpack(bytes.NewBuffer(msg), &udpMsg)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.udpSessionMutex.RLock()
|
||||||
|
conn, ok := c.udpSessionMap[udpMsg.SessionID]
|
||||||
|
c.udpSessionMutex.RUnlock()
|
||||||
|
if ok {
|
||||||
|
// Session found, send the message
|
||||||
|
host, port, err := net.SplitHostPort(udpMsg.Address)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
action, arg := acl.ActionDirect, ""
|
||||||
|
if c.ACLEngine != nil {
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
if ip != nil {
|
||||||
|
// IP request, clear host for ACL engine
|
||||||
|
host = ""
|
||||||
|
}
|
||||||
|
action, arg = c.ACLEngine.Lookup(host, ip)
|
||||||
|
}
|
||||||
|
switch action {
|
||||||
|
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", udpMsg.Address)
|
||||||
|
if err == nil {
|
||||||
|
_, _ = conn.WriteToUDP(udpMsg.Data, addr)
|
||||||
|
}
|
||||||
|
case acl.ActionBlock:
|
||||||
|
// Do nothing
|
||||||
|
case acl.ActionHijack:
|
||||||
|
hijackAddr := net.JoinHostPort(arg, port)
|
||||||
|
addr, err := net.ResolveUDPAddr("udp", hijackAddr)
|
||||||
|
if err == nil {
|
||||||
|
_, _ = conn.WriteToUDP(udpMsg.Data, addr)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// Do nothing
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverClient) handleTCP(stream quic.Stream, reqAddr string) {
|
||||||
|
host, port, err := net.SplitHostPort(reqAddr)
|
||||||
|
if err != nil {
|
||||||
|
_ = struc.Pack(stream, &serverResponse{
|
||||||
|
OK: false,
|
||||||
|
Message: "invalid address",
|
||||||
|
})
|
||||||
|
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
action, arg := acl.ActionDirect, ""
|
||||||
|
if c.ACLEngine != nil {
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
if ip != nil {
|
||||||
|
// IP request, clear host for ACL engine
|
||||||
|
host = ""
|
||||||
|
}
|
||||||
|
action, arg = c.ACLEngine.Lookup(host, ip)
|
||||||
|
}
|
||||||
|
c.CTCPRequestFunc(c.ClientAddr, c.Auth, reqAddr, action, arg)
|
||||||
|
|
||||||
|
var conn net.Conn // Connection to be piped
|
||||||
|
switch action {
|
||||||
|
case acl.ActionDirect, acl.ActionProxy: // Treat proxy as direct on server side
|
||||||
|
conn, err = net.DialTimeout("tcp", reqAddr, dialTimeout)
|
||||||
|
if err != nil {
|
||||||
|
_ = struc.Pack(stream, &serverResponse{
|
||||||
|
OK: false,
|
||||||
|
Message: err.Error(),
|
||||||
|
})
|
||||||
|
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
case acl.ActionBlock:
|
||||||
|
_ = struc.Pack(stream, &serverResponse{
|
||||||
|
OK: false,
|
||||||
|
Message: "blocked by ACL",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
case acl.ActionHijack:
|
||||||
|
hijackAddr := net.JoinHostPort(arg, port)
|
||||||
|
conn, err = net.DialTimeout("tcp", hijackAddr, dialTimeout)
|
||||||
|
if err != nil {
|
||||||
|
_ = struc.Pack(stream, &serverResponse{
|
||||||
|
OK: false,
|
||||||
|
Message: err.Error(),
|
||||||
|
})
|
||||||
|
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
_ = struc.Pack(stream, &serverResponse{
|
||||||
|
OK: false,
|
||||||
|
Message: "ACL error",
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// So far so good if we reach here
|
||||||
|
defer conn.Close()
|
||||||
|
err = struc.Pack(stream, &serverResponse{
|
||||||
|
OK: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
err = utils.Pipe2Way(stream, conn)
|
||||||
|
c.CTCPErrorFunc(c.ClientAddr, c.Auth, reqAddr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *serverClient) handleUDP(stream quic.Stream) {
|
||||||
|
// Like in SOCKS5, the stream here is only used to maintain the UDP session. No need to read anything from it
|
||||||
|
conn, err := net.ListenUDP("udp", nil)
|
||||||
|
if err != nil {
|
||||||
|
_ = struc.Pack(stream, &serverResponse{
|
||||||
|
OK: false,
|
||||||
|
Message: "UDP initialization failed",
|
||||||
|
})
|
||||||
|
c.CUDPErrorFunc(c.ClientAddr, c.Auth, 0, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
var id uint32
|
||||||
|
c.udpSessionMutex.Lock()
|
||||||
|
id = c.nextUDPSessionID
|
||||||
|
c.udpSessionMap[id] = conn
|
||||||
|
c.nextUDPSessionID += 1
|
||||||
|
c.udpSessionMutex.Unlock()
|
||||||
|
|
||||||
|
err = struc.Pack(stream, &serverResponse{
|
||||||
|
OK: true,
|
||||||
|
UDPSessionID: id,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.CUDPRequestFunc(c.ClientAddr, c.Auth, id)
|
||||||
|
|
||||||
|
// Receive UDP packets, send them to the client
|
||||||
|
go func() {
|
||||||
|
buf := make([]byte, udpBufferSize)
|
||||||
|
for {
|
||||||
|
n, rAddr, err := conn.ReadFromUDP(buf)
|
||||||
|
if n > 0 {
|
||||||
|
var msgBuf bytes.Buffer
|
||||||
|
_ = struc.Pack(&msgBuf, &udpMessage{
|
||||||
|
SessionID: id,
|
||||||
|
Address: rAddr.String(),
|
||||||
|
Data: buf[:n],
|
||||||
|
})
|
||||||
|
_ = c.CS.SendMessage(msgBuf.Bytes())
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Hold the stream until it's closed by the client
|
||||||
|
buf := make([]byte, 1024)
|
||||||
|
for {
|
||||||
|
_, err = stream.Read(buf)
|
||||||
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
c.CUDPErrorFunc(c.ClientAddr, c.Auth, id, err)
|
||||||
|
|
||||||
|
// Remove the session
|
||||||
|
c.udpSessionMutex.Lock()
|
||||||
|
delete(c.udpSessionMap, id)
|
||||||
|
c.udpSessionMutex.Unlock()
|
||||||
|
}
|
@ -27,13 +27,13 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ip := net.ParseIP(host)
|
|
||||||
if ip != nil {
|
|
||||||
host = ""
|
|
||||||
}
|
|
||||||
// ACL
|
// ACL
|
||||||
action, arg := acl.ActionProxy, ""
|
action, arg := acl.ActionProxy, ""
|
||||||
if aclEngine != nil {
|
if aclEngine != nil {
|
||||||
|
ip := net.ParseIP(host)
|
||||||
|
if ip != nil {
|
||||||
|
host = ""
|
||||||
|
}
|
||||||
action, arg = aclEngine.Lookup(host, ip)
|
action, arg = aclEngine.Lookup(host, ip)
|
||||||
}
|
}
|
||||||
newDialFunc(addr, action, arg)
|
newDialFunc(addr, action, arg)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user