From 6245f83262d73eb207afe5cdf63267e53dc690e3 Mon Sep 17 00:00:00 2001 From: tobyxdd Date: Fri, 21 Jul 2023 17:28:39 -0700 Subject: [PATCH] feat(wip): DirectOutbound bind to device --- extras/outbounds/{direct.go => ob_direct.go} | 50 ++++++++++++------ extras/outbounds/ob_direct_linux.go | 55 ++++++++++++++++++++ extras/outbounds/ob_direct_others.go | 19 +++++++ 3 files changed, 109 insertions(+), 15 deletions(-) rename extras/outbounds/{direct.go => ob_direct.go} (84%) create mode 100644 extras/outbounds/ob_direct_linux.go create mode 100644 extras/outbounds/ob_direct_others.go diff --git a/extras/outbounds/direct.go b/extras/outbounds/ob_direct.go similarity index 84% rename from extras/outbounds/direct.go rename to extras/outbounds/ob_direct.go index cad25b0..09b4cf2 100644 --- a/extras/outbounds/direct.go +++ b/extras/outbounds/ob_direct.go @@ -4,6 +4,7 @@ import ( "errors" "net" "strconv" + "time" ) type DirectOutboundMode int @@ -14,22 +15,34 @@ const ( DirectOutboundMode46 // Use IPv4 address when available, otherwise IPv6 DirectOutboundMode6 // Use IPv6 only, fail if not available DirectOutboundMode4 // Use IPv4 only, fail if not available + + defaultDialerTimeout = 10 * time.Second ) -var _ PluggableOutbound = (*DirectOutbound)(nil) - -// DirectOutbound is a PluggableOutbound that connects directly to the target +// directOutbound is a PluggableOutbound that connects directly to the target // using the local network (as opposed to using a proxy, for example). // It prefers to use ResolveInfo in AddrEx if available. But if it's nil, // it will fall back to resolving Host using Go's built-in DNS resolver. -type DirectOutbound struct { - Mode DirectOutboundMode - Dialer *net.Dialer +type directOutbound struct { + Mode DirectOutboundMode + Dialer *net.Dialer + DeviceName string // For UDP binding +} + +// NewDirectOutboundSimple creates a new directOutbound with the given mode, +// without binding to a specific device. Works on all platforms. +func NewDirectOutboundSimple(mode DirectOutboundMode) PluggableOutbound { + return &directOutbound{ + Mode: mode, + Dialer: &net.Dialer{ + Timeout: defaultDialerTimeout, + }, + } } // resolve is our built-in DNS resolver for handling the case when // AddrEx.ResolveInfo is nil. -func (d *DirectOutbound) resolve(reqAddr *AddrEx) { +func (d *directOutbound) resolve(reqAddr *AddrEx) { ips, err := net.LookupIP(reqAddr.Host) if err != nil { reqAddr.ResolveInfo = &ResolveInfo{Err: err} @@ -52,7 +65,7 @@ func (d *DirectOutbound) resolve(reqAddr *AddrEx) { reqAddr.ResolveInfo = r } -func (d *DirectOutbound) DialTCP(reqAddr *AddrEx) (net.Conn, error) { +func (d *directOutbound) DialTCP(reqAddr *AddrEx) (net.Conn, error) { if reqAddr.ResolveInfo == nil { // AddrEx.ResolveInfo is nil (no resolver in the pipeline), // we need to resolve the address ourselves. @@ -103,7 +116,7 @@ func (d *DirectOutbound) DialTCP(reqAddr *AddrEx) (net.Conn, error) { } } -func (d *DirectOutbound) dialTCP(ip net.IP, port uint16) (net.Conn, error) { +func (d *directOutbound) dialTCP(ip net.IP, port uint16) (net.Conn, error) { return d.Dialer.Dial("tcp", net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))) } @@ -115,7 +128,7 @@ type dialResult struct { // dualStackDialTCP dials the target using both IPv4 and IPv6 addresses simultaneously. // It returns the first successful connection and drops the other one. // If both connections fail, it returns the last error. -func (d *DirectOutbound) dualStackDialTCP(ipv4, ipv6 net.IP, port uint16) (net.Conn, error) { +func (d *directOutbound) dualStackDialTCP(ipv4, ipv6 net.IP, port uint16) (net.Conn, error) { ch := make(chan dialResult, 2) go func() { conn, err := d.dialTCP(ipv4, port) @@ -143,7 +156,7 @@ func (d *DirectOutbound) dualStackDialTCP(ipv4, ipv6 net.IP, port uint16) (net.C } type directOutboundUDPConn struct { - *DirectOutbound + *directOutbound *net.UDPConn } @@ -163,13 +176,13 @@ func (u *directOutboundUDPConn) WriteTo(b []byte, addr *AddrEx) (int, error) { if addr.ResolveInfo == nil { // Although practically rare, it is possible to send // UDP packets to a hostname (instead of an IP address). - u.DirectOutbound.resolve(addr) + u.directOutbound.resolve(addr) } r := addr.ResolveInfo if r.IPv4 == nil && r.IPv6 == nil { return 0, r.Err } - switch u.DirectOutbound.Mode { + switch u.directOutbound.Mode { case DirectOutboundModeAuto: // This is a special case. // It's not possible to do a "dual stack race dial" for UDP, @@ -237,13 +250,20 @@ func (u *directOutboundUDPConn) Close() error { return u.UDPConn.Close() } -func (d *DirectOutbound) ListenUDP() (UDPConn, error) { +func (d *directOutbound) ListenUDP() (UDPConn, error) { c, err := net.ListenUDP("udp", nil) if err != nil { return nil, err } + if d.DeviceName != "" { + if err := udpConnBindToDevice(c, d.DeviceName); err != nil { + // Don't forget to close the UDPConn if binding fails + _ = c.Close() + return nil, err + } + } return &directOutboundUDPConn{ - DirectOutbound: d, + directOutbound: d, UDPConn: c, }, nil } diff --git a/extras/outbounds/ob_direct_linux.go b/extras/outbounds/ob_direct_linux.go new file mode 100644 index 0000000..85fd455 --- /dev/null +++ b/extras/outbounds/ob_direct_linux.go @@ -0,0 +1,55 @@ +package outbounds + +import ( + "errors" + "net" + "syscall" +) + +// NewDirectOutboundBindToDevice creates a new directOutbound with the given mode, +// and binds to the given device. Only works on Linux. +func NewDirectOutboundBindToDevice(mode DirectOutboundMode, deviceName string) (PluggableOutbound, error) { + if err := verifyDeviceName(deviceName); err != nil { + return nil, err + } + return &directOutbound{ + Mode: mode, + Dialer: &net.Dialer{ + Timeout: defaultDialerTimeout, + Control: func(network, address string, c syscall.RawConn) error { + var errBind error + err := c.Control(func(fd uintptr) { + errBind = syscall.BindToDevice(int(fd), deviceName) + }) + if err != nil { + return err + } + return errBind + }, + }, + DeviceName: deviceName, + }, nil +} + +func verifyDeviceName(deviceName string) error { + if deviceName == "" { + return errors.New("device name cannot be empty") + } + _, err := net.InterfaceByName(deviceName) + return err +} + +func udpConnBindToDevice(conn *net.UDPConn, deviceName string) error { + sc, err := conn.SyscallConn() + if err != nil { + return err + } + var errBind error + err = sc.Control(func(fd uintptr) { + errBind = syscall.BindToDevice(int(fd), deviceName) + }) + if err != nil { + return err + } + return errBind +} diff --git a/extras/outbounds/ob_direct_others.go b/extras/outbounds/ob_direct_others.go new file mode 100644 index 0000000..b416c30 --- /dev/null +++ b/extras/outbounds/ob_direct_others.go @@ -0,0 +1,19 @@ +//go:build !linux + +package outbounds + +import ( + "errors" + "net" +) + +// NewDirectOutboundBindToDevice creates a new directOutbound with the given mode, +// and binds to the given device. This doesn't work on non-Linux platforms, so this +// is just a stub function that always returns an error. +func NewDirectOutboundBindToDevice(mode DirectOutboundMode, deviceName string) (PluggableOutbound, error) { + return nil, errors.New("binding to device is not supported on this platform") +} + +func udpConnBindToDevice(conn *net.UDPConn, deviceName string) error { + return errors.New("binding to device is not supported on this platform") +}