diff --git a/extras/outbounds/ob_direct.go b/extras/outbounds/ob_direct.go index f7622b2..ce825ed 100644 --- a/extras/outbounds/ob_direct.go +++ b/extras/outbounds/ob_direct.go @@ -9,6 +9,8 @@ import ( type DirectOutboundMode int +type udpConnState int + const ( DirectOutboundModeAuto DirectOutboundMode = iota // Dual-stack "happy eyeballs"-like mode DirectOutboundMode64 // Use IPv6 address when available, otherwise IPv4 @@ -19,25 +21,99 @@ const ( defaultDialerTimeout = 10 * time.Second ) +const ( + udpConnStateDualStack udpConnState = iota + udpConnStateIPv4 + udpConnStateIPv6 +) + // directOutbound is a PluggableOutbound that connects directly to the target // using the local network (as opposed to using a proxy, for example). // It prefers to use ResolveInfo in AddrEx if available. But if it's nil, // it will fall back to resolving Host using Go's built-in DNS resolver. type directOutbound struct { - Mode DirectOutboundMode - Dialer *net.Dialer - DeviceName string // For UDP binding + Mode DirectOutboundMode + + // Dialer4 and Dialer6 are used for IPv4 and IPv6 TCP connections respectively. + Dialer4 *net.Dialer + Dialer6 *net.Dialer + + // DeviceName & BindIPs are for UDP connections. They don't use dialers, so we + // need to bind them when creating the connection. + DeviceName string + BindIP4 net.IP + BindIP6 net.IP +} + +type noAddressError struct { + IPv4 bool + IPv6 bool +} + +func (e noAddressError) Error() string { + if e.IPv4 && e.IPv6 { + return "no IPv4 or IPv6 address available" + } else if e.IPv4 { + return "no IPv4 address available" + } else if e.IPv6 { + return "no IPv6 address available" + } else { + return "no address available" + } +} + +type invalidOutboundModeError struct{} + +func (e invalidOutboundModeError) Error() string { + return "invalid outbound mode" } // NewDirectOutboundSimple creates a new directOutbound with the given mode, // without binding to a specific device. Works on all platforms. func NewDirectOutboundSimple(mode DirectOutboundMode) PluggableOutbound { + d := &net.Dialer{ + Timeout: defaultDialerTimeout, + } return &directOutbound{ + Mode: mode, + Dialer4: d, + Dialer6: d, + } +} + +// NewDirectOutboundBindToIPs creates a new directOutbound with the given mode, +// and binds to the given IPv4 and IPv6 addresses. Either or both of the addresses +// can be nil, in which case the directOutbound will not bind to a specific address +// for that family. +func NewDirectOutboundBindToIPs(mode DirectOutboundMode, bindIP4, bindIP6 net.IP) (PluggableOutbound, error) { + if bindIP4 != nil && bindIP4.To4() == nil { + return nil, errors.New("bindIP4 must be an IPv4 address") + } + if bindIP6 != nil && bindIP6.To4() != nil { + return nil, errors.New("bindIP6 must be an IPv6 address") + } + ob := &directOutbound{ Mode: mode, - Dialer: &net.Dialer{ + Dialer4: &net.Dialer{ Timeout: defaultDialerTimeout, }, + Dialer6: &net.Dialer{ + Timeout: defaultDialerTimeout, + }, + BindIP4: bindIP4, + BindIP6: bindIP6, } + if bindIP4 != nil { + ob.Dialer4.LocalAddr = &net.TCPAddr{ + IP: bindIP4, + } + } + if bindIP6 != nil { + ob.Dialer6.LocalAddr = &net.TCPAddr{ + IP: bindIP6, + } + } + return ob, nil } // resolve is our built-in DNS resolver for handling the case when @@ -51,7 +127,7 @@ func (d *directOutbound) resolve(reqAddr *AddrEx) { r := &ResolveInfo{} r.IPv4, r.IPv6 = splitIPv4IPv6(ips) if r.IPv4 == nil && r.IPv6 == nil { - r.Err = errors.New("no IPv4 or IPv6 address available") + r.Err = noAddressError{IPv4: true, IPv6: true} } reqAddr.ResolveInfo = r } @@ -94,21 +170,25 @@ func (d *directOutbound) TCP(reqAddr *AddrEx) (net.Conn, error) { if r.IPv6 != nil { return d.dialTCP(r.IPv6, reqAddr.Port) } else { - return nil, errors.New("no IPv6 address available") + return nil, noAddressError{IPv6: true} } case DirectOutboundMode4: if r.IPv4 != nil { return d.dialTCP(r.IPv4, reqAddr.Port) } else { - return nil, errors.New("no IPv4 address available") + return nil, noAddressError{IPv4: true} } default: - return nil, errors.New("invalid DirectOutboundMode") + return nil, invalidOutboundModeError{} } } func (d *directOutbound) dialTCP(ip net.IP, port uint16) (net.Conn, error) { - return d.Dialer.Dial("tcp", net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))) + if ip.To4() != nil { + return d.Dialer4.Dial("tcp4", net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))) + } else { + return d.Dialer6.Dial("tcp6", net.JoinHostPort(ip.String(), strconv.Itoa(int(port)))) + } } type dialResult struct { @@ -149,6 +229,7 @@ func (d *directOutbound) dualStackDialTCP(ipv4, ipv6 net.IP, port uint16) (net.C type directOutboundUDPConn struct { *directOutbound *net.UDPConn + State udpConnState } func (u *directOutboundUDPConn) ReadFrom(b []byte) (int, *AddrEx, error) { @@ -165,75 +246,92 @@ func (u *directOutboundUDPConn) ReadFrom(b []byte) (int, *AddrEx, error) { func (u *directOutboundUDPConn) WriteTo(b []byte, addr *AddrEx) (int, error) { if addr.ResolveInfo == nil { - // Although practically rare, it is possible to send - // UDP packets to a hostname (instead of an IP address). u.directOutbound.resolve(addr) } r := addr.ResolveInfo if r.IPv4 == nil && r.IPv6 == nil { return 0, r.Err } - switch u.directOutbound.Mode { - case DirectOutboundModeAuto: - // This is a special case. - // It's not possible to do a "dual stack race dial" for UDP, - // since UDP is connectionless. - // For maximum compatibility, we just behave like DirectOutboundMode46. + if u.State == udpConnStateIPv4 { if r.IPv4 != nil { return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ IP: r.IPv4, Port: int(addr.Port), }) } else { - return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ - IP: r.IPv6, - Port: int(addr.Port), - }) + return 0, noAddressError{IPv4: true} } - case DirectOutboundMode64: + } else if u.State == udpConnStateIPv6 { if r.IPv6 != nil { return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ IP: r.IPv6, Port: int(addr.Port), }) } else { - return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ - IP: r.IPv4, - Port: int(addr.Port), - }) + return 0, noAddressError{IPv6: true} } - case DirectOutboundMode46: - if r.IPv4 != nil { - return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ - IP: r.IPv4, - Port: int(addr.Port), - }) - } else { - return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ - IP: r.IPv6, - Port: int(addr.Port), - }) + } else { + // Dual stack + switch u.directOutbound.Mode { + case DirectOutboundModeAuto: + // This is a special case. + // We must make a decision here, so we prefer IPv4 for maximum compatibility. + if r.IPv4 != nil { + return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ + IP: r.IPv4, + Port: int(addr.Port), + }) + } else { + return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ + IP: r.IPv6, + Port: int(addr.Port), + }) + } + case DirectOutboundMode64: + if r.IPv6 != nil { + return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ + IP: r.IPv6, + Port: int(addr.Port), + }) + } else { + return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ + IP: r.IPv4, + Port: int(addr.Port), + }) + } + case DirectOutboundMode46: + if r.IPv4 != nil { + return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ + IP: r.IPv4, + Port: int(addr.Port), + }) + } else { + return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ + IP: r.IPv6, + Port: int(addr.Port), + }) + } + case DirectOutboundMode6: + if r.IPv6 != nil { + return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ + IP: r.IPv6, + Port: int(addr.Port), + }) + } else { + return 0, noAddressError{IPv6: true} + } + case DirectOutboundMode4: + if r.IPv4 != nil { + return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ + IP: r.IPv4, + Port: int(addr.Port), + }) + } else { + return 0, noAddressError{IPv4: true} + } + default: + return 0, invalidOutboundModeError{} } - case DirectOutboundMode6: - if r.IPv6 != nil { - return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ - IP: r.IPv6, - Port: int(addr.Port), - }) - } else { - return 0, errors.New("no IPv6 address available") - } - case DirectOutboundMode4: - if r.IPv4 != nil { - return u.UDPConn.WriteToUDP(b, &net.UDPAddr{ - IP: r.IPv4, - Port: int(addr.Port), - }) - } else { - return 0, errors.New("no IPv4 address available") - } - default: - return 0, errors.New("invalid DirectOutboundMode") } } @@ -242,19 +340,105 @@ func (u *directOutboundUDPConn) Close() error { } func (d *directOutbound) UDP(reqAddr *AddrEx) (UDPConn, error) { - c, err := net.ListenUDP("udp", nil) - if err != nil { - return nil, err - } - if d.DeviceName != "" { - if err := udpConnBindToDevice(c, d.DeviceName); err != nil { - // Don't forget to close the UDPConn if binding fails - _ = c.Close() + if d.BindIP4 == nil && d.BindIP6 == nil { + // No bind address specified, use default dual stack implementation + c, err := net.ListenUDP("udp", nil) + if err != nil { return nil, err } + if d.DeviceName != "" { + if err := udpConnBindToDevice(c, d.DeviceName); err != nil { + // Don't forget to close the UDPConn if binding fails + _ = c.Close() + return nil, err + } + } + return &directOutboundUDPConn{ + directOutbound: d, + UDPConn: c, + State: udpConnStateDualStack, + }, nil + } else { + // Bind address specified, + // need to check what kind of address is in reqAddr + // to determine which address family to bind to + if reqAddr.ResolveInfo == nil { + d.resolve(reqAddr) + } + r := reqAddr.ResolveInfo + if r.IPv4 == nil && r.IPv6 == nil { + return nil, r.Err + } + var bindIP net.IP // can be nil, in which case we still lock the address family but don't bind to any address + var state udpConnState // either IPv4 or IPv6 + switch d.Mode { + case DirectOutboundModeAuto: + // This is a special case. + // We must make a decision here, so we prefer IPv4 for maximum compatibility. + if r.IPv4 != nil { + bindIP = d.BindIP4 + state = udpConnStateIPv4 + } else { + bindIP = d.BindIP6 + state = udpConnStateIPv6 + } + case DirectOutboundMode64: + if r.IPv6 != nil { + bindIP = d.BindIP6 + state = udpConnStateIPv6 + } else { + bindIP = d.BindIP4 + state = udpConnStateIPv4 + } + case DirectOutboundMode46: + if r.IPv4 != nil { + bindIP = d.BindIP4 + state = udpConnStateIPv4 + } else { + bindIP = d.BindIP6 + state = udpConnStateIPv6 + } + case DirectOutboundMode6: + if r.IPv6 != nil { + bindIP = d.BindIP6 + state = udpConnStateIPv6 + } else { + return nil, noAddressError{IPv6: true} + } + case DirectOutboundMode4: + if r.IPv4 != nil { + bindIP = d.BindIP4 + state = udpConnStateIPv4 + } else { + return nil, noAddressError{IPv4: true} + } + default: + return nil, invalidOutboundModeError{} + } + var network string + var c *net.UDPConn + var err error + if state == udpConnStateIPv4 { + network = "udp4" + } else { + network = "udp6" + } + if bindIP != nil { + c, err = net.ListenUDP(network, &net.UDPAddr{ + IP: bindIP, + }) + } else { + c, err = net.ListenUDP(network, nil) + } + if err != nil { + return nil, err + } + // We don't support binding to both device & address at the same time, + // so d.DeviceName is ignored in this case. + return &directOutboundUDPConn{ + directOutbound: d, + UDPConn: c, + State: state, + }, nil } - return &directOutboundUDPConn{ - directOutbound: d, - UDPConn: c, - }, nil } diff --git a/extras/outbounds/ob_direct_linux.go b/extras/outbounds/ob_direct_linux.go index 85fd455..33b7d09 100644 --- a/extras/outbounds/ob_direct_linux.go +++ b/extras/outbounds/ob_direct_linux.go @@ -12,21 +12,23 @@ func NewDirectOutboundBindToDevice(mode DirectOutboundMode, deviceName string) ( if err := verifyDeviceName(deviceName); err != nil { return nil, err } - return &directOutbound{ - Mode: mode, - Dialer: &net.Dialer{ - Timeout: defaultDialerTimeout, - Control: func(network, address string, c syscall.RawConn) error { - var errBind error - err := c.Control(func(fd uintptr) { - errBind = syscall.BindToDevice(int(fd), deviceName) - }) - if err != nil { - return err - } - return errBind - }, + d := &net.Dialer{ + Timeout: defaultDialerTimeout, + Control: func(network, address string, c syscall.RawConn) error { + var errBind error + err := c.Control(func(fd uintptr) { + errBind = syscall.BindToDevice(int(fd), deviceName) + }) + if err != nil { + return err + } + return errBind }, + } + return &directOutbound{ + Mode: mode, + Dialer4: d, + Dialer6: d, DeviceName: deviceName, }, nil }