feat(wip): DirectOutbound bind to device

This commit is contained in:
tobyxdd 2023-07-21 17:28:39 -07:00
parent b25fb63d5b
commit 6245f83262
3 changed files with 109 additions and 15 deletions

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"net" "net"
"strconv" "strconv"
"time"
) )
type DirectOutboundMode int type DirectOutboundMode int
@ -14,22 +15,34 @@ const (
DirectOutboundMode46 // Use IPv4 address when available, otherwise IPv6 DirectOutboundMode46 // Use IPv4 address when available, otherwise IPv6
DirectOutboundMode6 // Use IPv6 only, fail if not available DirectOutboundMode6 // Use IPv6 only, fail if not available
DirectOutboundMode4 // Use IPv4 only, fail if not available DirectOutboundMode4 // Use IPv4 only, fail if not available
defaultDialerTimeout = 10 * time.Second
) )
var _ PluggableOutbound = (*DirectOutbound)(nil) // directOutbound is a PluggableOutbound that connects directly to the target
// DirectOutbound is a PluggableOutbound that connects directly to the target
// using the local network (as opposed to using a proxy, for example). // using the local network (as opposed to using a proxy, for example).
// It prefers to use ResolveInfo in AddrEx if available. But if it's nil, // It prefers to use ResolveInfo in AddrEx if available. But if it's nil,
// it will fall back to resolving Host using Go's built-in DNS resolver. // it will fall back to resolving Host using Go's built-in DNS resolver.
type DirectOutbound struct { type directOutbound struct {
Mode DirectOutboundMode Mode DirectOutboundMode
Dialer *net.Dialer Dialer *net.Dialer
DeviceName string // For UDP binding
}
// NewDirectOutboundSimple creates a new directOutbound with the given mode,
// without binding to a specific device. Works on all platforms.
func NewDirectOutboundSimple(mode DirectOutboundMode) PluggableOutbound {
return &directOutbound{
Mode: mode,
Dialer: &net.Dialer{
Timeout: defaultDialerTimeout,
},
}
} }
// resolve is our built-in DNS resolver for handling the case when // resolve is our built-in DNS resolver for handling the case when
// AddrEx.ResolveInfo is nil. // AddrEx.ResolveInfo is nil.
func (d *DirectOutbound) resolve(reqAddr *AddrEx) { func (d *directOutbound) resolve(reqAddr *AddrEx) {
ips, err := net.LookupIP(reqAddr.Host) ips, err := net.LookupIP(reqAddr.Host)
if err != nil { if err != nil {
reqAddr.ResolveInfo = &ResolveInfo{Err: err} reqAddr.ResolveInfo = &ResolveInfo{Err: err}
@ -52,7 +65,7 @@ func (d *DirectOutbound) resolve(reqAddr *AddrEx) {
reqAddr.ResolveInfo = r reqAddr.ResolveInfo = r
} }
func (d *DirectOutbound) DialTCP(reqAddr *AddrEx) (net.Conn, error) { func (d *directOutbound) DialTCP(reqAddr *AddrEx) (net.Conn, error) {
if reqAddr.ResolveInfo == nil { if reqAddr.ResolveInfo == nil {
// AddrEx.ResolveInfo is nil (no resolver in the pipeline), // AddrEx.ResolveInfo is nil (no resolver in the pipeline),
// we need to resolve the address ourselves. // we need to resolve the address ourselves.
@ -103,7 +116,7 @@ func (d *DirectOutbound) DialTCP(reqAddr *AddrEx) (net.Conn, error) {
} }
} }
func (d *DirectOutbound) dialTCP(ip net.IP, port uint16) (net.Conn, error) { func (d *directOutbound) dialTCP(ip net.IP, port uint16) (net.Conn, error) {
return d.Dialer.Dial("tcp", net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))) return d.Dialer.Dial("tcp", net.JoinHostPort(ip.String(), strconv.Itoa(int(port))))
} }
@ -115,7 +128,7 @@ type dialResult struct {
// dualStackDialTCP dials the target using both IPv4 and IPv6 addresses simultaneously. // dualStackDialTCP dials the target using both IPv4 and IPv6 addresses simultaneously.
// It returns the first successful connection and drops the other one. // It returns the first successful connection and drops the other one.
// If both connections fail, it returns the last error. // If both connections fail, it returns the last error.
func (d *DirectOutbound) dualStackDialTCP(ipv4, ipv6 net.IP, port uint16) (net.Conn, error) { func (d *directOutbound) dualStackDialTCP(ipv4, ipv6 net.IP, port uint16) (net.Conn, error) {
ch := make(chan dialResult, 2) ch := make(chan dialResult, 2)
go func() { go func() {
conn, err := d.dialTCP(ipv4, port) conn, err := d.dialTCP(ipv4, port)
@ -143,7 +156,7 @@ func (d *DirectOutbound) dualStackDialTCP(ipv4, ipv6 net.IP, port uint16) (net.C
} }
type directOutboundUDPConn struct { type directOutboundUDPConn struct {
*DirectOutbound *directOutbound
*net.UDPConn *net.UDPConn
} }
@ -163,13 +176,13 @@ func (u *directOutboundUDPConn) WriteTo(b []byte, addr *AddrEx) (int, error) {
if addr.ResolveInfo == nil { if addr.ResolveInfo == nil {
// Although practically rare, it is possible to send // Although practically rare, it is possible to send
// UDP packets to a hostname (instead of an IP address). // UDP packets to a hostname (instead of an IP address).
u.DirectOutbound.resolve(addr) u.directOutbound.resolve(addr)
} }
r := addr.ResolveInfo r := addr.ResolveInfo
if r.IPv4 == nil && r.IPv6 == nil { if r.IPv4 == nil && r.IPv6 == nil {
return 0, r.Err return 0, r.Err
} }
switch u.DirectOutbound.Mode { switch u.directOutbound.Mode {
case DirectOutboundModeAuto: case DirectOutboundModeAuto:
// This is a special case. // This is a special case.
// It's not possible to do a "dual stack race dial" for UDP, // It's not possible to do a "dual stack race dial" for UDP,
@ -237,13 +250,20 @@ func (u *directOutboundUDPConn) Close() error {
return u.UDPConn.Close() return u.UDPConn.Close()
} }
func (d *DirectOutbound) ListenUDP() (UDPConn, error) { func (d *directOutbound) ListenUDP() (UDPConn, error) {
c, err := net.ListenUDP("udp", nil) c, err := net.ListenUDP("udp", nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if d.DeviceName != "" {
if err := udpConnBindToDevice(c, d.DeviceName); err != nil {
// Don't forget to close the UDPConn if binding fails
_ = c.Close()
return nil, err
}
}
return &directOutboundUDPConn{ return &directOutboundUDPConn{
DirectOutbound: d, directOutbound: d,
UDPConn: c, UDPConn: c,
}, nil }, nil
} }

View File

@ -0,0 +1,55 @@
package outbounds
import (
"errors"
"net"
"syscall"
)
// NewDirectOutboundBindToDevice creates a new directOutbound with the given mode,
// and binds to the given device. Only works on Linux.
func NewDirectOutboundBindToDevice(mode DirectOutboundMode, deviceName string) (PluggableOutbound, error) {
if err := verifyDeviceName(deviceName); err != nil {
return nil, err
}
return &directOutbound{
Mode: mode,
Dialer: &net.Dialer{
Timeout: defaultDialerTimeout,
Control: func(network, address string, c syscall.RawConn) error {
var errBind error
err := c.Control(func(fd uintptr) {
errBind = syscall.BindToDevice(int(fd), deviceName)
})
if err != nil {
return err
}
return errBind
},
},
DeviceName: deviceName,
}, nil
}
func verifyDeviceName(deviceName string) error {
if deviceName == "" {
return errors.New("device name cannot be empty")
}
_, err := net.InterfaceByName(deviceName)
return err
}
func udpConnBindToDevice(conn *net.UDPConn, deviceName string) error {
sc, err := conn.SyscallConn()
if err != nil {
return err
}
var errBind error
err = sc.Control(func(fd uintptr) {
errBind = syscall.BindToDevice(int(fd), deviceName)
})
if err != nil {
return err
}
return errBind
}

View File

@ -0,0 +1,19 @@
//go:build !linux
package outbounds
import (
"errors"
"net"
)
// NewDirectOutboundBindToDevice creates a new directOutbound with the given mode,
// and binds to the given device. This doesn't work on non-Linux platforms, so this
// is just a stub function that always returns an error.
func NewDirectOutboundBindToDevice(mode DirectOutboundMode, deviceName string) (PluggableOutbound, error) {
return nil, errors.New("binding to device is not supported on this platform")
}
func udpConnBindToDevice(conn *net.UDPConn, deviceName string) error {
return errors.New("binding to device is not supported on this platform")
}