diff --git a/app/cmd/client.go b/app/cmd/client.go index 1e9bca5..b6745ab 100644 --- a/app/cmd/client.go +++ b/app/cmd/client.go @@ -23,6 +23,7 @@ import ( "github.com/apernet/hysteria/app/internal/http" "github.com/apernet/hysteria/app/internal/proxymux" "github.com/apernet/hysteria/app/internal/redirect" + "github.com/apernet/hysteria/app/internal/sockopts" "github.com/apernet/hysteria/app/internal/socks5" "github.com/apernet/hysteria/app/internal/tproxy" "github.com/apernet/hysteria/app/internal/tun" @@ -100,13 +101,20 @@ type clientConfigTLS struct { } type clientConfigQUIC struct { - InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"` - MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"` - InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"` - MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"` - MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"` - KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"` - DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"` + InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"` + MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"` + InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"` + MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"` + MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"` + KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"` + DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"` + Sockopts clientConfigQUICSockopts `mapstructure:"sockopts"` +} + +type clientConfigQUICSockopts struct { + BindInterface *string `mapstructure:"bindInterface"` + FirewallMark *uint32 `mapstructure:"fwmark"` + FdControlUnixSocket *string `mapstructure:"fdControlUnixSocket"` } type clientConfigBandwidth struct { @@ -196,6 +204,21 @@ func (c *clientConfig) fillServerAddr(hyConfig *client.Config) error { // fillConnFactory must be called after fillServerAddr, as we have different logic // for ConnFactory depending on whether we have a port hopping address. func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error { + so := &sockopts.SocketOptions{ + BindInterface: c.QUIC.Sockopts.BindInterface, + FirewallMark: c.QUIC.Sockopts.FirewallMark, + FdControlUnixSocket: c.QUIC.Sockopts.FdControlUnixSocket, + } + if err := so.CheckSupported(); err != nil { + var unsupportedErr *sockopts.UnsupportedError + if errors.As(err, &unsupportedErr) { + return configError{ + Field: "quic.sockopts." + unsupportedErr.Field, + Err: errors.New("unsupported on this platform"), + } + } + return configError{Field: "quic.sockopts", Err: err} + } // Inner PacketConn var newFunc func(addr net.Addr) (net.PacketConn, error) switch strings.ToLower(c.Transport.Type) { @@ -203,11 +226,11 @@ func (c *clientConfig) fillConnFactory(hyConfig *client.Config) error { if hyConfig.ServerAddr.Network() == "udphop" { hopAddr := hyConfig.ServerAddr.(*udphop.UDPHopAddr) newFunc = func(addr net.Addr) (net.PacketConn, error) { - return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, nil) + return udphop.NewUDPHopPacketConn(hopAddr, c.Transport.UDP.HopInterval, so.ListenUDP) } } else { newFunc = func(addr net.Addr) (net.PacketConn, error) { - return net.ListenUDP("udp", nil) + return so.ListenUDP() } } default: diff --git a/app/cmd/client_test.go b/app/cmd/client_test.go index c586949..10b2d99 100644 --- a/app/cmd/client_test.go +++ b/app/cmd/client_test.go @@ -46,6 +46,11 @@ func TestClientConfig(t *testing.T) { MaxIdleTimeout: 10 * time.Second, KeepAlivePeriod: 4 * time.Second, DisablePathMTUDiscovery: true, + Sockopts: clientConfigQUICSockopts{ + BindInterface: stringRef("eth0"), + FirewallMark: uint32Ref(1234), + FdControlUnixSocket: stringRef("test.sock"), + }, }, Bandwidth: clientConfigBandwidth{ Up: "200 mbps", @@ -189,3 +194,11 @@ func TestClientConfigURI(t *testing.T) { }) } } + +func stringRef(s string) *string { + return &s +} + +func uint32Ref(i uint32) *uint32 { + return &i +} diff --git a/app/cmd/client_test.yaml b/app/cmd/client_test.yaml index 4f919df..e8438f6 100644 --- a/app/cmd/client_test.yaml +++ b/app/cmd/client_test.yaml @@ -26,6 +26,10 @@ quic: maxIdleTimeout: 10s keepAlivePeriod: 4s disablePathMTUDiscovery: true + sockopts: + bindInterface: eth0 + fwmark: 1234 + fdControlUnixSocket: test.sock bandwidth: up: 200 mbps @@ -75,7 +79,7 @@ tun: ipv6: 2001::ffff:ffff:ffff:fff1/126 route: strict: true - ipv4: [0.0.0.0/0] - ipv6: ["2000::/3"] - ipv4Exclude: [192.0.2.1/32] - ipv6Exclude: ["2001:db8::1/128"] + ipv4: [ 0.0.0.0/0 ] + ipv6: [ "2000::/3" ] + ipv4Exclude: [ 192.0.2.1/32 ] + ipv6Exclude: [ "2001:db8::1/128" ] diff --git a/app/go.mod b/app/go.mod index a5025c1..97c5bbb 100644 --- a/app/go.mod +++ b/app/go.mod @@ -16,6 +16,8 @@ require ( github.com/stretchr/testify v1.8.4 github.com/txthinking/socks5 v0.0.0-20230325130024-4230056ae301 go.uber.org/zap v1.24.0 + golang.org/x/exp v0.0.0-20221205204356-47842c84f3db + golang.org/x/sys v0.17.0 ) require ( @@ -54,10 +56,8 @@ require ( go.uber.org/multierr v1.11.0 // indirect go4.org/netipx v0.0.0-20231129151722-fdeea329fbba // indirect golang.org/x/crypto v0.19.0 // indirect - golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect golang.org/x/mod v0.12.0 // indirect golang.org/x/net v0.21.0 // indirect - golang.org/x/sys v0.17.0 // indirect golang.org/x/text v0.14.0 // indirect golang.org/x/tools v0.11.1 // indirect google.golang.org/protobuf v1.33.0 // indirect diff --git a/app/internal/sockopts/fd_control_unix_socket_test.py b/app/internal/sockopts/fd_control_unix_socket_test.py new file mode 100644 index 0000000..e47a6f6 --- /dev/null +++ b/app/internal/sockopts/fd_control_unix_socket_test.py @@ -0,0 +1,65 @@ +import socket +import array +import os +import struct +import sys + + +def serve(path): + try: + os.unlink(path) + except OSError: + if os.path.exists(path): + raise + + server = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + server.bind(path) + server.listen() + print(f"Listening on {path}") + + try: + while True: + connection, client_address = server.accept() + print(f"Client connected") + + try: + # Receiving fd from client + fds = array.array("i") + msg, ancdata, flags, addr = connection.recvmsg(1, socket.CMSG_LEN(struct.calcsize('i'))) + for cmsg_level, cmsg_type, cmsg_data in ancdata: + if cmsg_level == socket.SOL_SOCKET and cmsg_type == socket.SCM_RIGHTS: + fds.frombytes(cmsg_data[:len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) + + fd = fds[0] + + # We make a call to setsockopt(2) here, so client can verify we have received the fd + # In the real scenario, the server would set things like SO_MARK, + # we use SO_RCVBUF as it doesn't require any special capabilities. + nbytes = struct.pack("i", 2500) + fdsocket = fd_to_socket(fd) + fdsocket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, nbytes) + fdsocket.close() + + # The only protocol-like thing specified in the client implementation. + connection.send(b'\x01') + finally: + connection.close() + print("Connection closed") + + except KeyboardInterrupt: + print("Exit") + + finally: + server.close() + os.unlink(path) + + +def fd_to_socket(fd): + return socket.fromfd(fd, socket.AF_UNIX, socket.SOCK_STREAM) + + +if __name__ == "__main__": + if len(sys.argv) < 2: + raise ValueError("unix socket path is required") + + serve(sys.argv[1]) diff --git a/app/internal/sockopts/sockopts.go b/app/internal/sockopts/sockopts.go new file mode 100644 index 0000000..14ee0c0 --- /dev/null +++ b/app/internal/sockopts/sockopts.go @@ -0,0 +1,76 @@ +package sockopts + +import ( + "fmt" + "net" +) + +type SocketOptions struct { + BindInterface *string + FirewallMark *uint32 + FdControlUnixSocket *string +} + +// implemented in platform-specific files +var ( + bindInterfaceFunc func(c *net.UDPConn, device string) error + firewallMarkFunc func(c *net.UDPConn, fwmark uint32) error + fdControlUnixSocketFunc func(c *net.UDPConn, path string) error +) + +func (o *SocketOptions) CheckSupported() (err error) { + if o.BindInterface != nil && bindInterfaceFunc == nil { + return &UnsupportedError{"bindInterface"} + } + if o.FirewallMark != nil && firewallMarkFunc == nil { + return &UnsupportedError{"fwmark"} + } + if o.FdControlUnixSocket != nil && fdControlUnixSocketFunc == nil { + return &UnsupportedError{"fdControlUnixSocket"} + } + return nil +} + +type UnsupportedError struct { + Field string +} + +func (e *UnsupportedError) Error() string { + return fmt.Sprintf("%s is not supported on this platform", e.Field) +} + +func (o *SocketOptions) ListenUDP() (uconn net.PacketConn, err error) { + uconn, err = net.ListenUDP("udp", nil) + if err != nil { + return + } + err = o.applyToUDPConn(uconn.(*net.UDPConn)) + if err != nil { + uconn.Close() + uconn = nil + return + } + return +} + +func (o *SocketOptions) applyToUDPConn(c *net.UDPConn) error { + if o.BindInterface != nil && bindInterfaceFunc != nil { + err := bindInterfaceFunc(c, *o.BindInterface) + if err != nil { + return fmt.Errorf("failed to bind to interface: %w", err) + } + } + if o.FirewallMark != nil && firewallMarkFunc != nil { + err := firewallMarkFunc(c, *o.FirewallMark) + if err != nil { + return fmt.Errorf("failed to set fwmark: %w", err) + } + } + if o.FdControlUnixSocket != nil && fdControlUnixSocketFunc != nil { + err := fdControlUnixSocketFunc(c, *o.FdControlUnixSocket) + if err != nil { + return fmt.Errorf("failed to send fd to control unix socket: %w", err) + } + } + return nil +} diff --git a/app/internal/sockopts/sockopts_linux.go b/app/internal/sockopts/sockopts_linux.go new file mode 100644 index 0000000..d1e5d23 --- /dev/null +++ b/app/internal/sockopts/sockopts_linux.go @@ -0,0 +1,96 @@ +//go:build linux + +package sockopts + +import ( + "fmt" + "net" + "time" + + "golang.org/x/exp/constraints" + "golang.org/x/sys/unix" +) + +const ( + fdControlUnixTimeout = 3 * time.Second +) + +func init() { + bindInterfaceFunc = bindInterfaceImpl + firewallMarkFunc = firewallMarkImpl + fdControlUnixSocketFunc = fdControlUnixSocketImpl +} + +func controlUDPConn(c *net.UDPConn, cb func(fd int) error) (err error) { + rconn, err := c.SyscallConn() + if err != nil { + return + } + cerr := rconn.Control(func(fd uintptr) { + err = cb(int(fd)) + }) + if err != nil { + return + } + if cerr != nil { + err = fmt.Errorf("failed to control fd: %w", cerr) + return + } + return +} + +func bindInterfaceImpl(c *net.UDPConn, device string) error { + return controlUDPConn(c, func(fd int) error { + return unix.BindToDevice(fd, device) + }) +} + +func firewallMarkImpl(c *net.UDPConn, fwmark uint32) error { + return controlUDPConn(c, func(fd int) error { + return unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_MARK, int(fwmark)) + }) +} + +func fdControlUnixSocketImpl(c *net.UDPConn, path string) error { + return controlUDPConn(c, func(fd int) error { + socketFd, err := unix.Socket(unix.AF_UNIX, unix.SOCK_STREAM, 0) + if err != nil { + return fmt.Errorf("failed to create unix socket: %w", err) + } + defer unix.Close(socketFd) + + var timeout unix.Timeval + timeUsec := fdControlUnixTimeout.Microseconds() + castAssignInteger(timeUsec/1e6, &timeout.Sec) + // Specifying the type explicitly is not necessary here, but it makes GoLand happy. + castAssignInteger[int64](timeUsec%1e6, &timeout.Usec) + + _ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &timeout) + _ = unix.SetsockoptTimeval(socketFd, unix.SOL_SOCKET, unix.SO_SNDTIMEO, &timeout) + + err = unix.Connect(socketFd, &unix.SockaddrUnix{Name: path}) + if err != nil { + return fmt.Errorf("failed to connect: %w", err) + } + + err = unix.Sendmsg(socketFd, nil, unix.UnixRights(fd), nil, 0) + if err != nil { + return fmt.Errorf("failed to send: %w", err) + } + + dummy := []byte{1} + n, err := unix.Read(socketFd, dummy) + if err != nil { + return fmt.Errorf("failed to receive: %w", err) + } + if n != 1 { + return fmt.Errorf("socket closed unexpectedly") + } + + return nil + }) +} + +func castAssignInteger[F, T constraints.Integer](from F, to *T) { + *to = T(from) +} diff --git a/app/internal/sockopts/sockopts_linux_test.go b/app/internal/sockopts/sockopts_linux_test.go new file mode 100644 index 0000000..66614a4 --- /dev/null +++ b/app/internal/sockopts/sockopts_linux_test.go @@ -0,0 +1,53 @@ +//go:build linux + +package sockopts + +import ( + "net" + "os" + "os/exec" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.org/x/sys/unix" +) + +func Test_fdControlUnixSocketImpl(t *testing.T) { + sockPath := "./fd_control_unix_socket_test.sock" + defer os.Remove(sockPath) + + // Run test server + cmd := exec.Command("python", "fd_control_unix_socket_test.py", sockPath) + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + err := cmd.Start() + if !assert.NoError(t, err) { + return + } + defer cmd.Process.Kill() + + // Wait for the server to start + time.Sleep(1 * time.Second) + + so := SocketOptions{ + FdControlUnixSocket: &sockPath, + } + conn, err := so.ListenUDP() + if !assert.NoError(t, err) { + return + } + defer conn.Close() + + err = controlUDPConn(conn.(*net.UDPConn), func(fd int) (err error) { + rcvbuf, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF) + if err != nil { + return + } + // The test server called setsockopt(fd, SOL_SOCKET, SO_RCVBUF, 2500), + // and kernel will double this value for getsockopt(). + assert.Equal(t, 5000, rcvbuf) + return + }) + assert.NoError(t, err) +} diff --git a/extras/transport/udphop/conn.go b/extras/transport/udphop/conn.go index f20c583..32cc31c 100644 --- a/extras/transport/udphop/conn.go +++ b/extras/transport/udphop/conn.go @@ -44,7 +44,7 @@ type udpPacket struct { Err error } -type ListenUDPFunc func() (net.PacketConn, error) +type ListenUDPFunc = func() (net.PacketConn, error) func NewUDPHopPacketConn(addr *UDPHopAddr, hopInterval time.Duration, listenUDPFunc ListenUDPFunc) (net.PacketConn, error) { if hopInterval == 0 {