From 70fd2ffc0d61c5eecb2e77743d63d617268044d8 Mon Sep 17 00:00:00 2001 From: Toby Date: Sat, 24 Apr 2021 15:36:19 -0700 Subject: [PATCH] ACL for TCP TProxy --- cmd/client.go | 9 ++--- pkg/core/client.go | 16 ++------- pkg/http/server.go | 11 +++--- pkg/tproxy/tcp_linux.go | 76 ++++++++++++++++++++++++++++++++++++----- pkg/tproxy/tcp_stub.go | 6 ++-- pkg/utils/misc.go | 18 ++++++++++ 6 files changed, 100 insertions(+), 36 deletions(-) create mode 100644 pkg/utils/misc.go diff --git a/cmd/client.go b/cmd/client.go index 8284448..8591f48 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -243,11 +243,12 @@ func client(config *clientConfig) { if len(config.TCPTProxy.Listen) > 0 { go func() { rl, err := tproxy.NewTCPTProxy(client, config.TCPTProxy.Listen, - time.Duration(config.TCPTProxy.Timeout)*time.Second, - func(addr, reqAddr net.Addr) { + time.Duration(config.TCPTProxy.Timeout)*time.Second, aclEngine, + func(addr, reqAddr net.Addr, action acl.Action, arg string) { logrus.WithFields(logrus.Fields{ - "src": addr.String(), - "dst": reqAddr.String(), + "action": actionToString(action, arg), + "src": addr.String(), + "dst": reqAddr.String(), }).Debug("TCP TProxy request") }, func(addr, reqAddr net.Addr, err error) { diff --git a/pkg/core/client.go b/pkg/core/client.go index 1efaaaf..d5ae791 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -195,7 +195,7 @@ func (c *Client) openStreamWithReconnect() (quic.Session, quic.Stream, error) { } func (c *Client) DialTCP(addr string) (net.Conn, error) { - host, port, err := splitHostPort(addr) + host, port, err := utils.SplitHostPort(addr) if err != nil { return nil, err } @@ -366,7 +366,7 @@ func (c *quicPktConn) ReadFrom() ([]byte, string, error) { } func (c *quicPktConn) WriteTo(p []byte, addr string) error { - host, port, err := splitHostPort(addr) + host, port, err := utils.SplitHostPort(addr) if err != nil { return err } @@ -384,15 +384,3 @@ func (c *quicPktConn) Close() error { c.CloseFunc() return c.Stream.Close() } - -func splitHostPort(hostport string) (string, uint16, error) { - host, port, err := net.SplitHostPort(hostport) - if err != nil { - return "", 0, err - } - portUint, err := strconv.ParseUint(port, 10, 16) - if err != nil { - return "", 0, err - } - return host, uint16(portUint), err -} diff --git a/pkg/http/server.go b/pkg/http/server.go index f9c8dbe..81299ce 100644 --- a/pkg/http/server.go +++ b/pkg/http/server.go @@ -3,6 +3,7 @@ package http import ( "errors" "fmt" + "github.com/tobyxdd/hysteria/pkg/utils" "net" "net/http" "strconv" @@ -24,11 +25,7 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng proxy.Tr = &http.Transport{ Dial: func(network, addr string) (net.Conn, error) { // Parse addr string - host, port, err := net.SplitHostPort(addr) - if err != nil { - return nil, err - } - portUint, err := strconv.ParseUint(port, 10, 16) + host, port, err := utils.SplitHostPort(addr) if err != nil { return nil, err } @@ -49,7 +46,7 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng } return net.DialTCP(network, nil, &net.TCPAddr{ IP: ipAddr.IP, - Port: int(portUint), + Port: int(port), Zone: ipAddr.Zone, }) case acl.ActionProxy: @@ -57,7 +54,7 @@ func NewProxyHTTPServer(hyClient *core.Client, idleTimeout time.Duration, aclEng case acl.ActionBlock: return nil, errors.New("blocked by ACL") case acl.ActionHijack: - return net.Dial(network, net.JoinHostPort(arg, port)) + return net.Dial(network, net.JoinHostPort(arg, strconv.Itoa(int(port)))) default: return nil, fmt.Errorf("unknown action %d", action) } diff --git a/pkg/tproxy/tcp_linux.go b/pkg/tproxy/tcp_linux.go index bf9163a..b89f3de 100644 --- a/pkg/tproxy/tcp_linux.go +++ b/pkg/tproxy/tcp_linux.go @@ -1,10 +1,14 @@ package tproxy import ( + "errors" + "fmt" "github.com/LiamHaworth/go-tproxy" + "github.com/tobyxdd/hysteria/pkg/acl" "github.com/tobyxdd/hysteria/pkg/core" "github.com/tobyxdd/hysteria/pkg/utils" "net" + "strconv" "time" ) @@ -12,13 +16,15 @@ type TCPTProxy struct { HyClient *core.Client ListenAddr *net.TCPAddr Timeout time.Duration + ACLEngine *acl.Engine - ConnFunc func(addr, reqAddr net.Addr) + ConnFunc func(addr, reqAddr net.Addr, action acl.Action, arg string) ErrorFunc func(addr, reqAddr net.Addr, err error) } -func NewTCPTProxy(hyClient *core.Client, listen string, timeout time.Duration, - connFunc func(addr, reqAddr net.Addr), errorFunc func(addr, reqAddr net.Addr, err error)) (*TCPTProxy, error) { +func NewTCPTProxy(hyClient *core.Client, listen string, timeout time.Duration, aclEngine *acl.Engine, + connFunc func(addr, reqAddr net.Addr, action acl.Action, arg string), + errorFunc func(addr, reqAddr net.Addr, err error)) (*TCPTProxy, error) { tAddr, err := net.ResolveTCPAddr("tcp", listen) if err != nil { return nil, err @@ -27,6 +33,7 @@ func NewTCPTProxy(hyClient *core.Client, listen string, timeout time.Duration, HyClient: hyClient, ListenAddr: tAddr, Timeout: timeout, + ACLEngine: aclEngine, ConnFunc: connFunc, ErrorFunc: errorFunc, } @@ -49,15 +56,66 @@ func (r *TCPTProxy) ListenAndServe() error { // Under TPROXY mode, we are effectively acting as the remote server // So our LocalAddr is actually the target to which the user is trying to connect // and our RemoteAddr is the local address where the user initiates the connection - r.ConnFunc(c.RemoteAddr(), c.LocalAddr()) - rc, err := r.HyClient.DialTCP(c.LocalAddr().String()) + host, port, err := utils.SplitHostPort(c.LocalAddr().String()) if err != nil { - r.ErrorFunc(c.RemoteAddr(), c.LocalAddr(), err) return } - defer rc.Close() - err = utils.PipePairWithTimeout(c, rc, r.Timeout) - r.ErrorFunc(c.RemoteAddr(), c.LocalAddr(), err) + action, arg := acl.ActionProxy, "" + var ipAddr *net.IPAddr + var resErr error + if r.ACLEngine != nil { + action, arg, ipAddr, resErr = r.ACLEngine.ResolveAndMatch(host) + // Doesn't always matter if the resolution fails, as we may send it through HyClient + } + r.ConnFunc(c.RemoteAddr(), c.LocalAddr(), action, arg) + var closeErr error + defer func() { + r.ErrorFunc(c.RemoteAddr(), c.LocalAddr(), closeErr) + }() + // Handle according to the action + switch action { + case acl.ActionDirect: + if resErr != nil { + closeErr = resErr + return + } + rc, err := net.DialTCP("tcp", nil, &net.TCPAddr{ + IP: ipAddr.IP, + Port: int(port), + Zone: ipAddr.Zone, + }) + if err != nil { + closeErr = err + return + } + defer rc.Close() + closeErr = utils.PipePairWithTimeout(c, rc, r.Timeout) + return + case acl.ActionProxy: + rc, err := r.HyClient.DialTCP(c.LocalAddr().String()) + if err != nil { + closeErr = err + return + } + defer rc.Close() + closeErr = utils.PipePairWithTimeout(c, rc, r.Timeout) + return + case acl.ActionBlock: + closeErr = errors.New("blocked in ACL") + return + case acl.ActionHijack: + rc, err := net.Dial("tcp", net.JoinHostPort(arg, strconv.Itoa(int(port)))) + if err != nil { + closeErr = err + return + } + defer rc.Close() + closeErr = utils.PipePairWithTimeout(c, rc, r.Timeout) + return + default: + closeErr = fmt.Errorf("unknown action %d", action) + return + } }() } } diff --git a/pkg/tproxy/tcp_stub.go b/pkg/tproxy/tcp_stub.go index d89057e..d27706b 100644 --- a/pkg/tproxy/tcp_stub.go +++ b/pkg/tproxy/tcp_stub.go @@ -4,6 +4,7 @@ package tproxy import ( "errors" + "github.com/tobyxdd/hysteria/pkg/acl" "github.com/tobyxdd/hysteria/pkg/core" "net" "time" @@ -11,8 +12,9 @@ import ( type TCPTProxy struct{} -func NewTCPTProxy(hyClient *core.Client, listen string, timeout time.Duration, - connFunc func(addr, reqAddr net.Addr), errorFunc func(addr, reqAddr net.Addr, err error)) (*TCPTProxy, error) { +func NewTCPTProxy(hyClient *core.Client, listen string, timeout time.Duration, aclEngine *acl.Engine, + connFunc func(addr, reqAddr net.Addr, action acl.Action, arg string), + errorFunc func(addr, reqAddr net.Addr, err error)) (*TCPTProxy, error) { return nil, errors.New("not supported on the current system") } diff --git a/pkg/utils/misc.go b/pkg/utils/misc.go new file mode 100644 index 0000000..13db0e7 --- /dev/null +++ b/pkg/utils/misc.go @@ -0,0 +1,18 @@ +package utils + +import ( + "net" + "strconv" +) + +func SplitHostPort(hostport string) (string, uint16, error) { + host, port, err := net.SplitHostPort(hostport) + if err != nil { + return "", 0, err + } + portUint, err := strconv.ParseUint(port, 10, 16) + if err != nil { + return "", 0, err + } + return host, uint16(portUint), err +}