diff --git a/README.md b/README.md index 552e6c0..4bdbacd 100644 --- a/README.md +++ b/README.md @@ -155,6 +155,7 @@ encryption. If you need a proxy, just use our proxy modes. ```json5 { "listen": ":36712", // Listen address + "protocol": "faketcp", // Blank or "udp" for UDP mode, "faketcp" for TCP "masquerade", see below for details "acme": { "domains": [ "your.domain.com", @@ -185,7 +186,8 @@ encryption. If you need a proxy, just use our proxy modes. "recv_window_client": 67108864, // QUIC connection receive window "max_conn_client": 4096, // Max concurrent connections per client "disable_mtu_discovery": false, // Disable Path MTU Discovery (RFC 8899) - "ipv6_only": false // Only resolve domains to IPv6 address + "ipv6_only": false, // Only resolve domains to IPv6 address + "resolver": "1.1.1.1" // DNS resolver address } ``` @@ -255,6 +257,7 @@ hysteria_traffic_uplink_bytes_total{auth="aGFja2VyISE="} 37452 ```json5 { "server": "example.com:36712", // Server address + "protocol": "faketcp", // Blank or "udp" for UDP mode, "faketcp" for TCP "masquerade", see below for details "up_mbps": 10, // Max upload Mbps "down_mbps": 50, // Max download Mbps "socks5": { @@ -323,10 +326,22 @@ hysteria_traffic_uplink_bytes_total{auth="aGFja2VyISE="} 37452 "ca": "my.ca", // Custom CA file "recv_window_conn": 15728640, // QUIC stream receive window "recv_window": 67108864, // QUIC connection receive window - "disable_mtu_discovery": false // Disable Path MTU Discovery (RFC 8899) + "disable_mtu_discovery": false, // Disable Path MTU Discovery (RFC 8899) + "resolver": "1.1.1.1" // DNS resolver address } ``` +#### Fake TCP / TCP masquerade + +Certain networks may impose various restrictions on UDP traffic or block it altogether. Hysteria offers a "faketcp" mode +that allows servers and clients to communicate using a protocol that looks like TCP but does not actually go through the +system TCP stack. This tricks whatever middleboxes into thinking it's actually TCP traffic, rendering UDP-specific +restrictions useless. + +This mode is currently only supported on Linux (both client and server) and requires root privileges. + +If your server is behind a firewall, open the corresponding TCP port instead of UDP. + #### Transparent proxy TPROXY modes (`tproxy_tcp` & `tproxy_udp`) are only available on Linux. diff --git a/README.zh.md b/README.zh.md index 89706c7..5c6896e 100644 --- a/README.zh.md +++ b/README.zh.md @@ -141,6 +141,7 @@ Hysteria 是一个功能丰富的,专为恶劣网络环境进行优化的网 ```json5 { "listen": ":36712", // 监听地址 + "protocol": "faketcp", // 留空或 "udp" 为 UDP 模式,"faketcp" 为伪装 TCP 模式,详情见下 "acme": { "domains": [ "your.domain.com", @@ -171,7 +172,8 @@ Hysteria 是一个功能丰富的,专为恶劣网络环境进行优化的网 "recv_window_client": 67108864, // QUIC connection receive window "max_conn_client": 4096, // 单客户端最大活跃连接数 "disable_mtu_discovery": false, // 禁用 MTU 探测 (RFC 8899) - "ipv6_only": false // 强制把域名解析成 IPv6 地址 + "ipv6_only": false, // 强制把域名解析成 IPv6 地址 + "resolver": "1.1.1.1" // DNS 地址 } ``` @@ -240,6 +242,7 @@ hysteria_traffic_uplink_bytes_total{auth="aGFja2VyISE="} 37452 ```json5 { "server": "example.com:36712", // 服务器地址 + "protocol": "faketcp", // 留空或 "udp" 为 UDP 模式,"faketcp" 为伪装 TCP 模式,详情见下 "up_mbps": 10, // 最大上传速度 "down_mbps": 50, // 最大下载速度 "socks5": { @@ -308,10 +311,20 @@ hysteria_traffic_uplink_bytes_total{auth="aGFja2VyISE="} 37452 "ca": "my.ca", // 自定义 CA "recv_window_conn": 15728640, // QUIC stream receive window "recv_window": 67108864, // QUIC connection receive window - "disable_mtu_discovery": false // 禁用 MTU 探测 (RFC 8899) + "disable_mtu_discovery": false, // 禁用 MTU 探测 (RFC 8899) + "resolver": "1.1.1.1" // DNS 地址 } ``` +#### 伪装 TCP (faketcp 模式) + +某些网络可能对 UDP 流量施加各种限制,或者完全屏蔽。Hysteria 提供了一个 "faketcp" 模式,让服务端与客户端之间用看起来是 TCP 但实际不走 +系统 TCP 栈的方式通信。通过这种方式可以让防火墙、QoS 设备认为这是真的 TCP 连接,绕过对 UDP 的限制。 + +目前只在 Linux 上支持(客户端和服务器都是),并且需要 root 权限。 + +如果你的服务器有防火墙,请放行相应的 TCP 端口而不是 UDP。 + #### 透明代理 TPROXY 模式 (`tproxy_tcp` 和 `tproxy_udp`) 只在 Linux 下可用。 diff --git a/cmd/client.go b/cmd/client.go index dceef3d..abf5d14 100644 --- a/cmd/client.go +++ b/cmd/client.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/tls" "crypto/x509" "io" @@ -27,6 +28,16 @@ import ( func client(config *clientConfig) { logrus.WithField("config", config.String()).Info("Client configuration loaded") + // Resolver + if len(config.Resolver) > 0 { + net.DefaultResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, "udp", config.Resolver) + }, + } + } // TLS tlsConfig := &tls.Config{ ServerName: config.ServerName, @@ -98,8 +109,8 @@ func client(config *clientConfig) { } } // Client - client, err := core.NewClient(config.Server, auth, tlsConfig, quicConfig, transport.DefaultTransport, - uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, + client, err := core.NewClient(config.Server, config.Protocol, auth, tlsConfig, quicConfig, + transport.DefaultTransport, uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, func(refBPS uint64) congestion.CongestionControl { return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) }, obfuscator) diff --git a/cmd/config.go b/cmd/config.go index 4b4902f..24b538b 100644 --- a/cmd/config.go +++ b/cmd/config.go @@ -19,8 +19,9 @@ const ( ) type serverConfig struct { - Listen string `json:"listen"` - ACME struct { + Listen string `json:"listen"` + Protocol string `json:"protocol"` + ACME struct { Domains []string `json:"domains"` Email string `json:"email"` DisableHTTPChallenge bool `json:"disable_http"` @@ -47,6 +48,7 @@ type serverConfig struct { MaxConnClient int `json:"max_conn_client"` DisableMTUDiscovery bool `json:"disable_mtu_discovery"` IPv6Only bool `json:"ipv6_only"` + Resolver string `json:"resolver"` } func (c *serverConfig) Check() error { @@ -94,6 +96,7 @@ func (r *Relay) Check() error { type clientConfig struct { Server string `json:"server"` + Protocol string `json:"protocol"` UpMbps int `json:"up_mbps"` DownMbps int `json:"down_mbps"` // Optional below @@ -144,6 +147,7 @@ type clientConfig struct { ReceiveWindowConn uint64 `json:"recv_window_conn"` ReceiveWindow uint64 `json:"recv_window"` DisableMTUDiscovery bool `json:"disable_mtu_discovery"` + Resolver string `json:"resolver"` } func (c *clientConfig) Check() error { diff --git a/cmd/server.go b/cmd/server.go index f5d7a44..a441f97 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -1,6 +1,7 @@ package main import ( + "context" "crypto/tls" "github.com/lucas-clemente/quic-go" "github.com/lucas-clemente/quic-go/congestion" @@ -22,6 +23,16 @@ import ( func server(config *serverConfig) { logrus.WithField("config", config.String()).Info("Server configuration loaded") + // Resolver + if len(config.Resolver) > 0 { + net.DefaultResolver = &net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + d := net.Dialer{} + return d.DialContext(ctx, "udp", config.Resolver) + }, + } + } // Load TLS config var tlsConfig *tls.Config if len(config.ACME.Domains) > 0 { @@ -154,7 +165,7 @@ func server(config *serverConfig) { logrus.WithField("error", err).Fatal("Prometheus HTTP server error") }() } - server, err := core.NewServer(config.Listen, tlsConfig, quicConfig, transport.DefaultTransport, + server, err := core.NewServer(config.Listen, config.Protocol, tlsConfig, quicConfig, transport.DefaultTransport, uint64(config.UpMbps)*mbpsToBps, uint64(config.DownMbps)*mbpsToBps, func(refBPS uint64) congestion.CongestionControl { return hyCongestion.NewBrutalSender(congestion.ByteCount(refBPS)) diff --git a/go.mod b/go.mod index 6d08392..f25c7de 100644 --- a/go.mod +++ b/go.mod @@ -6,10 +6,12 @@ require ( github.com/LiamHaworth/go-tproxy v0.0.0-20190726054950-ef7efd7f24ed github.com/antonfisher/nested-logrus-formatter v1.3.1 github.com/caddyserver/certmagic v0.15.2 + github.com/coreos/go-iptables v0.6.0 github.com/cpuguy83/go-md2man/v2 v2.0.0 // indirect github.com/elazarl/goproxy v0.0.0-20210110162100-a92cc753f88e github.com/elazarl/goproxy/ext v0.0.0-20210110162100-a92cc753f88e github.com/eycorsican/go-tun2socks v1.16.11 + github.com/google/gopacket v1.1.19 github.com/hashicorp/golang-lru v0.5.4 github.com/lucas-clemente/quic-go v0.22.0 github.com/lunixbochs/struc v0.0.0-20200707160740-784aaebc1d40 diff --git a/go.sum b/go.sum index 7b677a3..483a4b8 100644 --- a/go.sum +++ b/go.sum @@ -31,6 +31,8 @@ github.com/cespare/xxhash/v2 v2.1.1/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XL github.com/cheekybits/genny v1.0.0 h1:uGGa4nei+j20rOSeDeP5Of12XVm7TGUd4dJA9RDitfE= github.com/cheekybits/genny v1.0.0/go.mod h1:+tQajlRqAUrPI7DOSpB0XAqZYtQakVtB7wXkRAgjxjQ= github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= +github.com/coreos/go-iptables v0.6.0 h1:is9qnZMPYjLd8LYqmm/qlE+wwEgJIkTYdhV3rfZo4jk= +github.com/coreos/go-iptables v0.6.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= github.com/coreos/go-systemd v0.0.0-20181012123002-c6f51f82210d/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/cpuguy83/go-md2man/v2 v2.0.0 h1:EoUDS0afbrsXAZ9YQ9jdu/mZ2sXgT1/2yyNng4PGlyM= @@ -95,6 +97,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.0.0/go.mod h1:odCYkC5MyYFN7vkCjXpyrEuKhc/BUO6wN/zVPAxq5ck= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= +github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= @@ -291,7 +295,9 @@ golang.org/x/lint v0.0.0-20180702182130-06c8688daad7/go.mod h1:UVdnD1Gm6xHRNCYTk golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/lint v0.0.0-20200302205851-738671d3881b/go.mod h1:3xt1FjdF8hUf6vQPIChWIBhFzV8gjjsPE/fR3IyQdNY= golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -381,6 +387,7 @@ golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgw golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200130002326-2f3ba24bd6e7/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/tools v0.0.0-20201224043029-2b0845dc783e/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.1.1 h1:wGiQel/hW0NnEkJUk8lbzkX2gFJU6PFxf1v5OlCfuOs= golang.org/x/tools v0.1.1/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= diff --git a/pkg/core/client.go b/pkg/core/client.go index 6e235b4..70ae934 100644 --- a/pkg/core/client.go +++ b/pkg/core/client.go @@ -26,6 +26,7 @@ type CongestionFactory func(refBPS uint64) congestion.CongestionControl type Client struct { transport transport2.Transport serverAddr string + protocol string sendBPS, recvBPS uint64 auth []byte congestionFactory CongestionFactory @@ -42,11 +43,13 @@ type Client struct { udpSessionMap map[uint32]chan *udpMessage } -func NewClient(serverAddr string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, transport transport2.Transport, - sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, obfuscator Obfuscator) (*Client, error) { +func NewClient(serverAddr string, protocol string, auth []byte, tlsConfig *tls.Config, quicConfig *quic.Config, + transport transport2.Transport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, + obfuscator Obfuscator) (*Client, error) { c := &Client{ transport: transport, serverAddr: serverAddr, + protocol: protocol, sendBPS: sendBPS, recvBPS: recvBPS, auth: auth, @@ -66,27 +69,40 @@ func (c *Client) connectToServer() error { if err != nil { return err } - udpConn, err := c.transport.QUICListenUDP(nil) - if err != nil { - return err - } - var qs quic.Session - if c.obfuscator != nil { - // Wrap PacketConn with obfuscator - qs, err = quic.Dial(&obfsUDPConn{ - Orig: udpConn, - Obfuscator: c.obfuscator, - }, serverUDPAddr, c.serverAddr, c.tlsConfig, c.quicConfig) + var pktConn net.PacketConn + if len(c.protocol) == 0 || c.protocol == "udp" { + udpConn, err := c.transport.QUICListenUDP(nil) if err != nil { - _ = udpConn.Close() return err } + if c.obfuscator != nil { + pktConn = &obfsUDPConn{ + Orig: udpConn, + Obfuscator: c.obfuscator, + } + } else { + pktConn = udpConn + } + } else if c.protocol == "faketcp" { + ftcpConn, err := c.transport.QUICDialFakeTCP(c.serverAddr) + if err != nil { + return err + } + if c.obfuscator != nil { + pktConn = &obfsPacketConn{ + Orig: ftcpConn, + Obfuscator: c.obfuscator, + } + } else { + pktConn = ftcpConn + } } else { - qs, err = quic.Dial(udpConn, serverUDPAddr, c.serverAddr, c.tlsConfig, c.quicConfig) - if err != nil { - _ = udpConn.Close() - return err - } + return fmt.Errorf("unsupported protocol: %s", c.protocol) + } + qs, err := quic.Dial(pktConn, serverUDPAddr, c.serverAddr, c.tlsConfig, c.quicConfig) + if err != nil { + _ = pktConn.Close() + return err } // Control stream ctx, ctxCancel := context.WithTimeout(context.Background(), protocolTimeout) diff --git a/pkg/core/obfs.go b/pkg/core/obfs.go index 45523c3..419a31a 100644 --- a/pkg/core/obfs.go +++ b/pkg/core/obfs.go @@ -17,10 +17,6 @@ type obfsUDPConn struct { Obfuscator Obfuscator } -func (c *obfsUDPConn) SyscallConn() (syscall.RawConn, error) { - return c.Orig.SyscallConn() -} - func (c *obfsUDPConn) ReadFrom(p []byte) (int, net.Addr, error) { buf := make([]byte, udpBufferSize) for { @@ -77,6 +73,63 @@ func (c *obfsUDPConn) SetWriteBuffer(bytes int) error { return c.Orig.SetWriteBuffer(bytes) } +func (c *obfsUDPConn) SyscallConn() (syscall.RawConn, error) { + return c.Orig.SyscallConn() +} + func (c *obfsUDPConn) File() (f *os.File, err error) { return c.Orig.File() } + +type obfsPacketConn struct { + Orig net.PacketConn + Obfuscator Obfuscator +} + +func (c *obfsPacketConn) ReadFrom(p []byte) (int, net.Addr, error) { + buf := make([]byte, udpBufferSize) + for { + n, addr, err := c.Orig.ReadFrom(buf) + if n <= 0 { + return 0, addr, err + } + newN := c.Obfuscator.Deobfuscate(buf[:n], p) + if newN > 0 { + // Valid packet + return newN, addr, err + } else if err != nil { + // Not valid and Orig.ReadFrom had some error + return 0, addr, err + } + } +} + +func (c *obfsPacketConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + np := c.Obfuscator.Obfuscate(p) + _, err = c.Orig.WriteTo(np, addr) + if err != nil { + return 0, err + } else { + return len(p), nil + } +} + +func (c *obfsPacketConn) Close() error { + return c.Orig.Close() +} + +func (c *obfsPacketConn) LocalAddr() net.Addr { + return c.Orig.LocalAddr() +} + +func (c *obfsPacketConn) SetDeadline(t time.Time) error { + return c.Orig.SetDeadline(t) +} + +func (c *obfsPacketConn) SetReadDeadline(t time.Time) error { + return c.Orig.SetReadDeadline(t) +} + +func (c *obfsPacketConn) SetWriteDeadline(t time.Time) error { + return c.Orig.SetWriteDeadline(t) +} diff --git a/pkg/core/server.go b/pkg/core/server.go index 089edda..79a1cb8 100644 --- a/pkg/core/server.go +++ b/pkg/core/server.go @@ -38,35 +38,48 @@ type Server struct { listener quic.Listener } -func NewServer(addr string, tlsConfig *tls.Config, quicConfig *quic.Config, transport transport2.Transport, +func NewServer(addr string, protocol string, tlsConfig *tls.Config, quicConfig *quic.Config, transport transport2.Transport, sendBPS uint64, recvBPS uint64, congestionFactory CongestionFactory, disableUDP bool, aclEngine *acl.Engine, obfuscator Obfuscator, authFunc AuthFunc, tcpRequestFunc TCPRequestFunc, tcpErrorFunc TCPErrorFunc, udpRequestFunc UDPRequestFunc, udpErrorFunc UDPErrorFunc, promRegistry *prometheus.Registry) (*Server, error) { - udpAddr, err := transport.QUICResolveUDPAddr(addr) - if err != nil { - return nil, err - } - udpConn, err := transport.QUICListenUDP(udpAddr) - if err != nil { - return nil, err - } - var listener quic.Listener - if obfuscator != nil { - // Wrap PacketConn with obfuscator - listener, err = quic.Listen(&obfsUDPConn{ - Orig: udpConn, - Obfuscator: obfuscator, - }, tlsConfig, quicConfig) + var pktConn net.PacketConn + if len(protocol) == 0 || protocol == "udp" { + udpAddr, err := transport.QUICResolveUDPAddr(addr) if err != nil { - _ = udpConn.Close() return nil, err } + udpConn, err := transport.QUICListenUDP(udpAddr) + if err != nil { + return nil, err + } + if obfuscator != nil { + pktConn = &obfsUDPConn{ + Orig: udpConn, + Obfuscator: obfuscator, + } + } else { + pktConn = udpConn + } + } else if protocol == "faketcp" { + ftcpConn, err := transport.QUICListenFakeTCP(addr) + if err != nil { + return nil, err + } + if obfuscator != nil { + pktConn = &obfsPacketConn{ + Orig: ftcpConn, + Obfuscator: obfuscator, + } + } else { + pktConn = ftcpConn + } } else { - listener, err = quic.Listen(udpConn, tlsConfig, quicConfig) - if err != nil { - _ = udpConn.Close() - return nil, err - } + return nil, fmt.Errorf("unsupported protocol: %s", protocol) + } + listener, err := quic.Listen(pktConn, tlsConfig, quicConfig) + if err != nil { + _ = pktConn.Close() + return nil, err } s := &Server{ listener: listener, diff --git a/pkg/faketcp/LICENSE b/pkg/faketcp/LICENSE new file mode 100644 index 0000000..79fbecb --- /dev/null +++ b/pkg/faketcp/LICENSE @@ -0,0 +1 @@ +Grabbed from https://github.com/xtaci/tcpraw with modifications \ No newline at end of file diff --git a/pkg/faketcp/tcp_linux.go b/pkg/faketcp/tcp_linux.go new file mode 100644 index 0000000..e16b478 --- /dev/null +++ b/pkg/faketcp/tcp_linux.go @@ -0,0 +1,608 @@ +//go:build linux +// +build linux + +package faketcp + +import ( + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "sync" + "sync/atomic" + "syscall" + "time" + + "github.com/coreos/go-iptables/iptables" + "github.com/google/gopacket" + "github.com/google/gopacket/layers" +) + +var ( + errOpNotImplemented = errors.New("operation not implemented") + errTimeout = errors.New("timeout") + expire = time.Minute +) + +// a message from NIC +type message struct { + bts []byte + addr net.Addr +} + +// a tcp flow information of a connection pair +type tcpFlow struct { + conn *net.TCPConn // the related system TCP connection of this flow + handle *net.IPConn // the handle to send packets + seq uint32 // TCP sequence number + ack uint32 // TCP acknowledge number + networkLayer gopacket.SerializableLayer // network layer header for tx + ts time.Time // last packet incoming time + buf gopacket.SerializeBuffer // a buffer for write + tcpHeader layers.TCP +} + +// TCPConn defines a TCP-packet oriented connection +type TCPConn struct { + die chan struct{} + dieOnce sync.Once + + // the main golang sockets + tcpconn *net.TCPConn // from net.Dial + listener *net.TCPListener // from net.Listen + + // handles + handles []*net.IPConn + + // packets captured from all related NICs will be delivered to this channel + chMessage chan message + + // all TCP flows + flowTable map[string]*tcpFlow + flowsLock sync.Mutex + + // iptables + iptables *iptables.IPTables + iprule []string + + ip6tables *iptables.IPTables + ip6rule []string + + // deadlines + readDeadline atomic.Value + writeDeadline atomic.Value + + // serialization + opts gopacket.SerializeOptions +} + +// lockflow locks the flow table and apply function `f` to the entry, and create one if not exist +func (conn *TCPConn) lockflow(addr net.Addr, f func(e *tcpFlow)) { + key := addr.String() + conn.flowsLock.Lock() + e := conn.flowTable[key] + if e == nil { // entry first visit + e = new(tcpFlow) + e.ts = time.Now() + e.buf = gopacket.NewSerializeBuffer() + } + f(e) + conn.flowTable[key] = e + conn.flowsLock.Unlock() +} + +// clean expired flows +func (conn *TCPConn) cleaner() { + ticker := time.NewTicker(time.Minute) + select { + case <-conn.die: + return + case <-ticker.C: + conn.flowsLock.Lock() + for k, v := range conn.flowTable { + if time.Now().Sub(v.ts) > expire { + if v.conn != nil { + setTTL(v.conn, 64) + v.conn.Close() + } + delete(conn.flowTable, k) + } + } + conn.flowsLock.Unlock() + } +} + +// captureFlow capture every inbound packets based on rules of BPF +func (conn *TCPConn) captureFlow(handle *net.IPConn, port int) { + buf := make([]byte, 2048) + opt := gopacket.DecodeOptions{NoCopy: true, Lazy: true} + for { + n, addr, err := handle.ReadFromIP(buf) + if err != nil { + return + } + + // try decoding TCP frame from buf[:n] + packet := gopacket.NewPacket(buf[:n], layers.LayerTypeTCP, opt) + transport := packet.TransportLayer() + tcp, ok := transport.(*layers.TCP) + if !ok { + continue + } + + // port filtering + if int(tcp.DstPort) != port { + continue + } + + // address building + var src net.TCPAddr + src.IP = addr.IP + src.Port = int(tcp.SrcPort) + + var orphan bool + // flow maintaince + conn.lockflow(&src, func(e *tcpFlow) { + if e.conn == nil { // make sure it's related to net.TCPConn + orphan = true // mark as orphan if it's not related net.TCPConn + } + + // to keep track of TCP header related to this source + e.ts = time.Now() + if tcp.ACK { + e.seq = tcp.Ack + } + if tcp.SYN { + e.ack = tcp.Seq + 1 + } + if tcp.PSH { + if e.ack == tcp.Seq { + e.ack = tcp.Seq + uint32(len(tcp.Payload)) + } + } + e.handle = handle + }) + + // push data if it's not orphan + if !orphan && tcp.PSH { + payload := make([]byte, len(tcp.Payload)) + copy(payload, tcp.Payload) + select { + case conn.chMessage <- message{payload, &src}: + case <-conn.die: + return + } + } + } +} + +// ReadFrom implements the PacketConn ReadFrom method. +func (conn *TCPConn) ReadFrom(p []byte) (n int, addr net.Addr, err error) { + var timer *time.Timer + var deadline <-chan time.Time + if d, ok := conn.readDeadline.Load().(time.Time); ok && !d.IsZero() { + timer = time.NewTimer(time.Until(d)) + defer timer.Stop() + deadline = timer.C + } + + select { + case <-deadline: + return 0, nil, errTimeout + case <-conn.die: + return 0, nil, io.EOF + case packet := <-conn.chMessage: + n = copy(p, packet.bts) + return n, packet.addr, nil + } +} + +// WriteTo implements the PacketConn WriteTo method. +func (conn *TCPConn) WriteTo(p []byte, addr net.Addr) (n int, err error) { + var deadline <-chan time.Time + if d, ok := conn.writeDeadline.Load().(time.Time); ok && !d.IsZero() { + timer := time.NewTimer(time.Until(d)) + defer timer.Stop() + deadline = timer.C + } + + select { + case <-deadline: + return 0, errTimeout + case <-conn.die: + return 0, io.EOF + default: + raddr, err := net.ResolveTCPAddr("tcp", addr.String()) + if err != nil { + return 0, err + } + + var lport int + if conn.tcpconn != nil { + lport = conn.tcpconn.LocalAddr().(*net.TCPAddr).Port + } else { + lport = conn.listener.Addr().(*net.TCPAddr).Port + } + + conn.lockflow(addr, func(e *tcpFlow) { + // if the flow doesn't have handle , assume this packet has lost, without notification + if e.handle == nil { + n = len(p) + return + } + + // build tcp header with local and remote port + e.tcpHeader.SrcPort = layers.TCPPort(lport) + e.tcpHeader.DstPort = layers.TCPPort(raddr.Port) + binary.Read(rand.Reader, binary.LittleEndian, &e.tcpHeader.Window) + e.tcpHeader.Window |= 0x8000 // make sure it's larger than 32768 + e.tcpHeader.Ack = e.ack + e.tcpHeader.Seq = e.seq + e.tcpHeader.PSH = true + e.tcpHeader.ACK = true + + // build IP header with src & dst ip for TCP checksum + if raddr.IP.To4() != nil { + ip := &layers.IPv4{ + Protocol: layers.IPProtocolTCP, + SrcIP: e.handle.LocalAddr().(*net.IPAddr).IP.To4(), + DstIP: raddr.IP.To4(), + } + e.tcpHeader.SetNetworkLayerForChecksum(ip) + } else { + ip := &layers.IPv6{ + NextHeader: layers.IPProtocolTCP, + SrcIP: e.handle.LocalAddr().(*net.IPAddr).IP.To16(), + DstIP: raddr.IP.To16(), + } + e.tcpHeader.SetNetworkLayerForChecksum(ip) + } + + e.buf.Clear() + gopacket.SerializeLayers(e.buf, conn.opts, &e.tcpHeader, gopacket.Payload(p)) + if conn.tcpconn != nil { + _, err = e.handle.Write(e.buf.Bytes()) + } else { + _, err = e.handle.WriteToIP(e.buf.Bytes(), &net.IPAddr{IP: raddr.IP}) + } + // increase seq in flow + e.seq += uint32(len(p)) + n = len(p) + }) + } + return +} + +// Close closes the connection. +func (conn *TCPConn) Close() error { + var err error + conn.dieOnce.Do(func() { + // signal closing + close(conn.die) + + // close all established tcp connections + if conn.tcpconn != nil { // client + setTTL(conn.tcpconn, 64) + err = conn.tcpconn.Close() + } else if conn.listener != nil { + err = conn.listener.Close() // server + conn.flowsLock.Lock() + for k, v := range conn.flowTable { + if v.conn != nil { + setTTL(v.conn, 64) + v.conn.Close() + } + delete(conn.flowTable, k) + } + conn.flowsLock.Unlock() + } + + // close handles + for k := range conn.handles { + conn.handles[k].Close() + } + + // delete iptable + if conn.iptables != nil { + conn.iptables.Delete("filter", "OUTPUT", conn.iprule...) + } + if conn.ip6tables != nil { + conn.ip6tables.Delete("filter", "OUTPUT", conn.ip6rule...) + } + }) + return err +} + +// LocalAddr returns the local network address. +func (conn *TCPConn) LocalAddr() net.Addr { + if conn.tcpconn != nil { + return conn.tcpconn.LocalAddr() + } else if conn.listener != nil { + return conn.listener.Addr() + } + return nil +} + +// SetDeadline implements the Conn SetDeadline method. +func (conn *TCPConn) SetDeadline(t time.Time) error { + if err := conn.SetReadDeadline(t); err != nil { + return err + } + if err := conn.SetWriteDeadline(t); err != nil { + return err + } + return nil +} + +// SetReadDeadline implements the Conn SetReadDeadline method. +func (conn *TCPConn) SetReadDeadline(t time.Time) error { + conn.readDeadline.Store(t) + return nil +} + +// SetWriteDeadline implements the Conn SetWriteDeadline method. +func (conn *TCPConn) SetWriteDeadline(t time.Time) error { + conn.writeDeadline.Store(t) + return nil +} + +// SetDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header. +func (conn *TCPConn) SetDSCP(dscp int) error { + for k := range conn.handles { + if err := setDSCP(conn.handles[k], dscp); err != nil { + return err + } + } + return nil +} + +// SetReadBuffer sets the size of the operating system's receive buffer associated with the connection. +func (conn *TCPConn) SetReadBuffer(bytes int) error { + var err error + for k := range conn.handles { + if err := conn.handles[k].SetReadBuffer(bytes); err != nil { + return err + } + } + return err +} + +// SetWriteBuffer sets the size of the operating system's transmit buffer associated with the connection. +func (conn *TCPConn) SetWriteBuffer(bytes int) error { + var err error + for k := range conn.handles { + if err := conn.handles[k].SetWriteBuffer(bytes); err != nil { + return err + } + } + return err +} + +// Dial connects to the remote TCP port, +// and returns a single packet-oriented connection +func Dial(network, address string) (*TCPConn, error) { + // remote address resolve + raddr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + + // AF_INET + handle, err := net.DialIP("ip:tcp", nil, &net.IPAddr{IP: raddr.IP}) + if err != nil { + return nil, err + } + + // create an established tcp connection + // will hack this tcp connection for packet transmission + tcpconn, err := net.DialTCP(network, nil, raddr) + if err != nil { + return nil, err + } + + // fields + conn := new(TCPConn) + conn.die = make(chan struct{}) + conn.flowTable = make(map[string]*tcpFlow) + conn.tcpconn = tcpconn + conn.chMessage = make(chan message) + conn.lockflow(tcpconn.RemoteAddr(), func(e *tcpFlow) { e.conn = tcpconn }) + conn.handles = append(conn.handles, handle) + conn.opts = gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + go conn.captureFlow(handle, tcpconn.LocalAddr().(*net.TCPAddr).Port) + go conn.cleaner() + + // iptables + err = setTTL(tcpconn, 1) + if err != nil { + return nil, err + } + + if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4); err == nil { + rule := []string{"-m", "ttl", "--ttl-eq", "1", "-p", "tcp", "-d", raddr.IP.String(), "--dport", fmt.Sprint(raddr.Port), "-j", "DROP"} + if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil { + if !exists { + if err = ipt.Append("filter", "OUTPUT", rule...); err == nil { + conn.iprule = rule + conn.iptables = ipt + } + } + } + } + if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv6); err == nil { + rule := []string{"-m", "hl", "--hl-eq", "1", "-p", "tcp", "-d", raddr.IP.String(), "--dport", fmt.Sprint(raddr.Port), "-j", "DROP"} + if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil { + if !exists { + if err = ipt.Append("filter", "OUTPUT", rule...); err == nil { + conn.ip6rule = rule + conn.ip6tables = ipt + } + } + } + } + + // discard everything + go io.Copy(ioutil.Discard, tcpconn) + + return conn, nil +} + +// Listen acts like net.ListenTCP, +// and returns a single packet-oriented connection +func Listen(network, address string) (*TCPConn, error) { + // fields + conn := new(TCPConn) + conn.flowTable = make(map[string]*tcpFlow) + conn.die = make(chan struct{}) + conn.chMessage = make(chan message) + conn.opts = gopacket.SerializeOptions{ + FixLengths: true, + ComputeChecksums: true, + } + + // resolve address + laddr, err := net.ResolveTCPAddr(network, address) + if err != nil { + return nil, err + } + + // AF_INET + ifaces, err := net.Interfaces() + if err != nil { + return nil, err + } + + if laddr.IP == nil || laddr.IP.IsUnspecified() { // if address is not specified, capture on all ifaces + var lasterr error + for _, iface := range ifaces { + if addrs, err := iface.Addrs(); err == nil { + for _, addr := range addrs { + if ipaddr, ok := addr.(*net.IPNet); ok { + if handle, err := net.ListenIP("ip:tcp", &net.IPAddr{IP: ipaddr.IP}); err == nil { + conn.handles = append(conn.handles, handle) + go conn.captureFlow(handle, laddr.Port) + } else { + lasterr = err + } + } + } + } + } + if len(conn.handles) == 0 { + return nil, lasterr + } + } else { + if handle, err := net.ListenIP("ip:tcp", &net.IPAddr{IP: laddr.IP}); err == nil { + conn.handles = append(conn.handles, handle) + go conn.captureFlow(handle, laddr.Port) + } else { + return nil, err + } + } + + // start listening + l, err := net.ListenTCP(network, laddr) + if err != nil { + return nil, err + } + + conn.listener = l + + // start cleaner + go conn.cleaner() + + // iptables drop packets marked with TTL = 1 + // TODO: what if iptables is not available, the next hop will send back ICMP Time Exceeded, + // is this still an acceptable behavior? + if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv4); err == nil { + rule := []string{"-m", "ttl", "--ttl-eq", "1", "-p", "tcp", "--sport", fmt.Sprint(laddr.Port), "-j", "DROP"} + if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil { + if !exists { + if err = ipt.Append("filter", "OUTPUT", rule...); err == nil { + conn.iprule = rule + conn.iptables = ipt + } + } + } + } + if ipt, err := iptables.NewWithProtocol(iptables.ProtocolIPv6); err == nil { + rule := []string{"-m", "hl", "--hl-eq", "1", "-p", "tcp", "--sport", fmt.Sprint(laddr.Port), "-j", "DROP"} + if exists, err := ipt.Exists("filter", "OUTPUT", rule...); err == nil { + if !exists { + if err = ipt.Append("filter", "OUTPUT", rule...); err == nil { + conn.ip6rule = rule + conn.ip6tables = ipt + } + } + } + } + + // discard everything in original connection + go func() { + for { + tcpconn, err := l.AcceptTCP() + if err != nil { + return + } + + // if we cannot set TTL = 1, the only thing reasonable is panic + if err := setTTL(tcpconn, 1); err != nil { + panic(err) + } + + // record net.Conn + conn.lockflow(tcpconn.RemoteAddr(), func(e *tcpFlow) { e.conn = tcpconn }) + + // discard everything + go io.Copy(ioutil.Discard, tcpconn) + } + }() + + return conn, nil +} + +// setTTL sets the Time-To-Live field on a given connection +func setTTL(c *net.TCPConn, ttl int) error { + raw, err := c.SyscallConn() + if err != nil { + return err + } + addr := c.LocalAddr().(*net.TCPAddr) + + if addr.IP.To4() == nil { + raw.Control(func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_UNICAST_HOPS, ttl) + }) + } else { + raw.Control(func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TTL, ttl) + }) + } + return err +} + +// setDSCP sets the 6bit DSCP field in IPv4 header, or 8bit Traffic Class in IPv6 header. +func setDSCP(c *net.IPConn, dscp int) error { + raw, err := c.SyscallConn() + if err != nil { + return err + } + addr := c.LocalAddr().(*net.IPAddr) + + if addr.IP.To4() == nil { + raw.Control(func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IPV6, syscall.IPV6_TCLASS, dscp) + }) + } else { + raw.Control(func(fd uintptr) { + err = syscall.SetsockoptInt(int(fd), syscall.IPPROTO_IP, syscall.IP_TOS, dscp<<2) + }) + } + return err +} diff --git a/pkg/faketcp/tcp_stub.go b/pkg/faketcp/tcp_stub.go new file mode 100644 index 0000000..9bc5507 --- /dev/null +++ b/pkg/faketcp/tcp_stub.go @@ -0,0 +1,21 @@ +//go:build !linux +// +build !linux + +package faketcp + +import ( + "errors" + "net" +) + +type TCPConn struct{ *net.UDPConn } + +// Dial connects to the remote TCP port, +// and returns a single packet-oriented connection +func Dial(network, address string) (*TCPConn, error) { + return nil, errors.New("faketcp is not supported on this platform") +} + +func Listen(network, address string) (*TCPConn, error) { + return nil, errors.New("faketcp is not supported on this platform") +} diff --git a/pkg/faketcp/tcp_test.go b/pkg/faketcp/tcp_test.go new file mode 100644 index 0000000..ea26c68 --- /dev/null +++ b/pkg/faketcp/tcp_test.go @@ -0,0 +1,196 @@ +//go:build linux +// +build linux + +package faketcp + +import ( + "log" + "net" + "net/http" + _ "net/http/pprof" + "testing" +) + +//const testPortStream = "127.0.0.1:3456" +//const testPortPacket = "127.0.0.1:3457" + +const testPortStream = "127.0.0.1:3456" +const portServerPacket = "[::]:3457" +const portRemotePacket = "127.0.0.1:3457" + +func init() { + startTCPServer() + startTCPRawServer() + go func() { + log.Println(http.ListenAndServe("0.0.0.0:6060", nil)) + }() +} + +func startTCPServer() net.Listener { + l, err := net.Listen("tcp", testPortStream) + if err != nil { + log.Panicln(err) + } + + go func() { + defer l.Close() + for { + conn, err := l.Accept() + if err != nil { + log.Println(err) + return + } + + go handleRequest(conn) + } + }() + return l +} + +func startTCPRawServer() *TCPConn { + conn, err := Listen("tcp", portServerPacket) + if err != nil { + log.Panicln(err) + } + err = conn.SetReadBuffer(1024 * 1024) + if err != nil { + log.Println(err) + } + err = conn.SetWriteBuffer(1024 * 1024) + if err != nil { + log.Println(err) + } + + go func() { + defer conn.Close() + buf := make([]byte, 1024) + for { + n, addr, err := conn.ReadFrom(buf) + if err != nil { + log.Println("server readfrom:", err) + return + } + //echo + n, err = conn.WriteTo(buf[:n], addr) + if err != nil { + log.Println("server writeTo:", err) + return + } + } + }() + return conn +} + +func handleRequest(conn net.Conn) { + defer conn.Close() + + for { + buf := make([]byte, 1024) + size, err := conn.Read(buf) + if err != nil { + log.Println("handleRequest:", err) + return + } + data := buf[:size] + conn.Write(data) + } +} + +func TestDialTCPStream(t *testing.T) { + conn, err := Dial("tcp", testPortStream) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + addr, err := net.ResolveTCPAddr("tcp", testPortStream) + if err != nil { + t.Fatal(err) + } + + n, err := conn.WriteTo([]byte("abc"), addr) + if err != nil { + t.Fatal(n, err) + } + + buf := make([]byte, 1024) + if n, addr, err := conn.ReadFrom(buf); err != nil { + t.Fatal(n, addr, err) + } else { + log.Println(string(buf[:n]), "from:", addr) + } +} + +func TestDialToTCPPacket(t *testing.T) { + conn, err := Dial("tcp", portRemotePacket) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + + addr, err := net.ResolveTCPAddr("tcp", portRemotePacket) + if err != nil { + t.Fatal(err) + } + + n, err := conn.WriteTo([]byte("abc"), addr) + if err != nil { + t.Fatal(n, err) + } + log.Println("written") + + buf := make([]byte, 1024) + log.Println("readfrom buf") + if n, addr, err := conn.ReadFrom(buf); err != nil { + log.Println(err) + t.Fatal(n, addr, err) + } else { + log.Println(string(buf[:n]), "from:", addr) + } + + log.Println("complete") +} + +func TestSettings(t *testing.T) { + conn, err := Dial("tcp", portRemotePacket) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + if err := conn.SetDSCP(46); err != nil { + log.Fatal("SetDSCP:", err) + } + if err := conn.SetReadBuffer(4096); err != nil { + log.Fatal("SetReaderBuffer:", err) + } + if err := conn.SetWriteBuffer(4096); err != nil { + log.Fatal("SetWriteBuffer:", err) + } +} + +func BenchmarkEcho(b *testing.B) { + conn, err := Dial("tcp", portRemotePacket) + if err != nil { + b.Fatal(err) + } + defer conn.Close() + + addr, err := net.ResolveTCPAddr("tcp", portRemotePacket) + if err != nil { + b.Fatal(err) + } + + buf := make([]byte, 1024) + b.ReportAllocs() + b.SetBytes(int64(len(buf))) + for i := 0; i < b.N; i++ { + n, err := conn.WriteTo(buf, addr) + if err != nil { + b.Fatal(n, err) + } + + if n, addr, err := conn.ReadFrom(buf); err != nil { + b.Fatal(n, addr, err) + } + } +} diff --git a/pkg/transport/transport.go b/pkg/transport/transport.go index fc48a6f..aba39f6 100644 --- a/pkg/transport/transport.go +++ b/pkg/transport/transport.go @@ -1,6 +1,7 @@ package transport import ( + "github.com/tobyxdd/hysteria/pkg/faketcp" "net" "time" ) @@ -8,6 +9,8 @@ import ( type Transport interface { QUICResolveUDPAddr(address string) (*net.UDPAddr, error) QUICListenUDP(laddr *net.UDPAddr) (*net.UDPConn, error) + QUICListenFakeTCP(address string) (*faketcp.TCPConn, error) + QUICDialFakeTCP(address string) (*faketcp.TCPConn, error) LocalResolveIPAddr(address string) (*net.IPAddr, error) LocalResolveTCPAddr(address string) (*net.TCPAddr, error) @@ -40,6 +43,14 @@ func (t *defaultTransport) QUICListenUDP(laddr *net.UDPAddr) (*net.UDPConn, erro return net.ListenUDP("udp", laddr) } +func (t *defaultTransport) QUICListenFakeTCP(address string) (*faketcp.TCPConn, error) { + return faketcp.Listen("tcp", address) +} + +func (t *defaultTransport) QUICDialFakeTCP(address string) (*faketcp.TCPConn, error) { + return faketcp.Dial("tcp", address) +} + func (t *defaultTransport) LocalResolveIPAddr(address string) (*net.IPAddr, error) { return net.ResolveIPAddr("ip", address) }