diff --git a/app/cmd/client.go b/app/cmd/client.go index e65bf38..2f127bb 100644 --- a/app/cmd/client.go +++ b/app/cmd/client.go @@ -43,36 +43,46 @@ func initClientFlags() { } type clientConfig struct { - Server string `mapstructure:"server"` - Auth string `mapstructure:"auth"` - Obfs struct { - Type string `mapstructure:"type"` - Salamander struct { - Password string `mapstructure:"password"` - } `mapstructure:"salamander"` - } `mapstructure:"obfs"` - TLS struct { - SNI string `mapstructure:"sni"` - Insecure bool `mapstructure:"insecure"` - CA string `mapstructure:"ca"` - } `mapstructure:"tls"` - QUIC struct { - InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"` - MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"` - InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"` - MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"` - MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"` - KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"` - DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"` - } `mapstructure:"quic"` - Bandwidth struct { - Up string `mapstructure:"up"` - Down string `mapstructure:"down"` - } `mapstructure:"bandwidth"` - FastOpen bool `mapstructure:"fastOpen"` - SOCKS5 *socks5Config `mapstructure:"socks5"` - HTTP *httpConfig `mapstructure:"http"` - Forwarding []forwardingEntry `mapstructure:"forwarding"` + Server string `mapstructure:"server"` + Auth string `mapstructure:"auth"` + Obfs clientConfigObfs `mapstructure:"obfs"` + TLS clientConfigTLS `mapstructure:"tls"` + QUIC clientConfigQUIC `mapstructure:"quic"` + Bandwidth clientConfigBandwidth `mapstructure:"bandwidth"` + FastOpen bool `mapstructure:"fastOpen"` + SOCKS5 *socks5Config `mapstructure:"socks5"` + HTTP *httpConfig `mapstructure:"http"` + Forwarding []forwardingEntry `mapstructure:"forwarding"` +} + +type clientConfigObfsSalamander struct { + Password string `mapstructure:"password"` +} + +type clientConfigObfs struct { + Type string `mapstructure:"type"` + Salamander clientConfigObfsSalamander `mapstructure:"salamander"` +} + +type clientConfigTLS struct { + SNI string `mapstructure:"sni"` + Insecure bool `mapstructure:"insecure"` + CA string `mapstructure:"ca"` +} + +type clientConfigQUIC struct { + InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"` + MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"` + InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"` + MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"` + MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"` + KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"` + DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"` +} + +type clientConfigBandwidth struct { + Up string `mapstructure:"up"` + Down string `mapstructure:"down"` } type socks5Config struct { diff --git a/app/cmd/client_test.go b/app/cmd/client_test.go index f5ad852..5f111fa 100644 --- a/app/cmd/client_test.go +++ b/app/cmd/client_test.go @@ -1,10 +1,11 @@ package cmd import ( - "reflect" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/spf13/viper" ) @@ -12,47 +13,25 @@ import ( func TestClientConfig(t *testing.T) { viper.SetConfigFile("client_test.yaml") err := viper.ReadInConfig() - if err != nil { - t.Fatal("failed to read client config", err) - } + assert.NoError(t, err) var config clientConfig - if err := viper.Unmarshal(&config); err != nil { - t.Fatal("failed to parse client config", err) - } - if !reflect.DeepEqual(config, clientConfig{ + err = viper.Unmarshal(&config) + assert.NoError(t, err) + assert.Equal(t, config, clientConfig{ Server: "example.com", Auth: "weak_ahh_password", - Obfs: struct { - Type string `mapstructure:"type"` - Salamander struct { - Password string `mapstructure:"password"` - } `mapstructure:"salamander"` - }{ + Obfs: clientConfigObfs{ Type: "salamander", - Salamander: struct { - Password string `mapstructure:"password"` - }{ + Salamander: clientConfigObfsSalamander{ Password: "cry_me_a_r1ver", }, }, - TLS: struct { - SNI string `mapstructure:"sni"` - Insecure bool `mapstructure:"insecure"` - CA string `mapstructure:"ca"` - }{ + TLS: clientConfigTLS{ SNI: "another.example.com", Insecure: true, CA: "custom_ca.crt", }, - QUIC: struct { - InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"` - MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"` - InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"` - MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"` - MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"` - KeepAlivePeriod time.Duration `mapstructure:"keepAlivePeriod"` - DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"` - }{ + QUIC: clientConfigQUIC{ InitStreamReceiveWindow: 1145141, MaxStreamReceiveWindow: 1145142, InitConnectionReceiveWindow: 1145143, @@ -61,10 +40,7 @@ func TestClientConfig(t *testing.T) { KeepAlivePeriod: 4 * time.Second, DisablePathMTUDiscovery: true, }, - Bandwidth: struct { - Up string `mapstructure:"up"` - Down string `mapstructure:"down"` - }{ + Bandwidth: clientConfigBandwidth{ Up: "200 mbps", Down: "1 gbps", }, @@ -94,9 +70,7 @@ func TestClientConfig(t *testing.T) { UDPTimeout: 50 * time.Second, }, }, - }) { - t.Fatal("parsed client config is not equal to expected") - } + }) } // TestClientConfigURI tests URI-related functions of clientConfig @@ -120,24 +94,13 @@ func TestClientConfigURI(t *testing.T) { config: &clientConfig{ Server: "noauth.com", Auth: "", - Obfs: struct { - Type string `mapstructure:"type"` - Salamander struct { - Password string `mapstructure:"password"` - } `mapstructure:"salamander"` - }{ + Obfs: clientConfigObfs{ Type: "salamander", - Salamander: struct { - Password string `mapstructure:"password"` - }{ + Salamander: clientConfigObfsSalamander{ Password: "66ccff", }, }, - TLS: struct { - SNI string `mapstructure:"sni"` - Insecure bool `mapstructure:"insecure"` - CA string `mapstructure:"ca"` - }{ + TLS: clientConfigTLS{ SNI: "crap.cc", Insecure: true, }, @@ -158,19 +121,13 @@ func TestClientConfigURI(t *testing.T) { t.Run(test.uri, func(t *testing.T) { // Test parseURI nc := &clientConfig{Server: test.uri} - if ok := nc.parseURI(); ok != test.uriOK { - t.Fatal("unexpected parseURI ok result") - } - if test.uriOK && !reflect.DeepEqual(nc, test.config) { - t.Fatal("unexpected parsed client config from URI") + assert.Equal(t, nc.parseURI(), test.uriOK) + if test.uriOK { + assert.Equal(t, nc, test.config) } // Test URI generation - if test.config == nil { - // config is nil if parseURI is expected to fail - return - } - if uri := test.config.URI(); uri != test.uri { - t.Fatalf("generated URI mismatch: %s != %s", uri, test.uri) + if test.config != nil { + assert.Equal(t, test.config.URI(), test.uri) } }) } diff --git a/app/cmd/server.go b/app/cmd/server.go index b7ddff4..6617d26 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -32,43 +32,24 @@ func init() { } type serverConfig struct { - Listen string `mapstructure:"listen"` - Obfs struct { - Type string `mapstructure:"type"` - Salamander struct { - Password string `mapstructure:"password"` - } `mapstructure:"salamander"` - } `mapstructure:"obfs"` - TLS *serverConfigTLS `mapstructure:"tls"` - ACME *serverConfigACME `mapstructure:"acme"` - QUIC struct { - InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"` - MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"` - InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"` - MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"` - MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"` - MaxIncomingStreams int64 `mapstructure:"maxIncomingStreams"` - DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"` - } `mapstructure:"quic"` - Bandwidth struct { - Up string `mapstructure:"up"` - Down string `mapstructure:"down"` - } `mapstructure:"bandwidth"` - DisableUDP bool `mapstructure:"disableUDP"` - Auth struct { - Type string `mapstructure:"type"` - Password string `mapstructure:"password"` - } `mapstructure:"auth"` - Masquerade struct { - Type string `mapstructure:"type"` - File struct { - Dir string `mapstructure:"dir"` - } `mapstructure:"file"` - Proxy struct { - URL string `mapstructure:"url"` - RewriteHost bool `mapstructure:"rewriteHost"` - } `mapstructure:"proxy"` - } `mapstructure:"masquerade"` + Listen string `mapstructure:"listen"` + Obfs serverConfigObfs `mapstructure:"obfs"` + TLS *serverConfigTLS `mapstructure:"tls"` + ACME *serverConfigACME `mapstructure:"acme"` + QUIC serverConfigQUIC `mapstructure:"quic"` + Bandwidth serverConfigBandwidth `mapstructure:"bandwidth"` + DisableUDP bool `mapstructure:"disableUDP"` + Auth serverConfigAuth `mapstructure:"auth"` + Masquerade serverConfigMasquerade `mapstructure:"masquerade"` +} + +type serverConfigObfsSalamander struct { + Password string `mapstructure:"password"` +} + +type serverConfigObfs struct { + Type string `mapstructure:"type"` + Salamander serverConfigObfsSalamander `mapstructure:"salamander"` } type serverConfigTLS struct { @@ -87,6 +68,41 @@ type serverConfigACME struct { Dir string `mapstructure:"dir"` } +type serverConfigQUIC struct { + InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"` + MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"` + InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"` + MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"` + MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"` + MaxIncomingStreams int64 `mapstructure:"maxIncomingStreams"` + DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"` +} + +type serverConfigBandwidth struct { + Up string `mapstructure:"up"` + Down string `mapstructure:"down"` +} + +type serverConfigAuth struct { + Type string `mapstructure:"type"` + Password string `mapstructure:"password"` +} + +type serverConfigMasqueradeFile struct { + Dir string `mapstructure:"dir"` +} + +type serverConfigMasqueradeProxy struct { + URL string `mapstructure:"url"` + RewriteHost bool `mapstructure:"rewriteHost"` +} + +type serverConfigMasquerade struct { + Type string `mapstructure:"type"` + File serverConfigMasqueradeFile `mapstructure:"file"` + Proxy serverConfigMasqueradeProxy `mapstructure:"proxy"` +} + func (c *serverConfig) fillConn(hyConfig *server.Config) error { listenAddr := c.Listen if listenAddr == "" { @@ -348,8 +364,8 @@ func (l *serverLogger) TCPError(addr net.Addr, id, reqAddr string, err error) { } } -func (l *serverLogger) UDPRequest(addr net.Addr, id string, sessionID uint32) { - logger.Debug("UDP request", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID)) +func (l *serverLogger) UDPRequest(addr net.Addr, id string, sessionID uint32, reqAddr string) { + logger.Debug("UDP request", zap.String("addr", addr.String()), zap.String("id", id), zap.Uint32("sessionID", sessionID), zap.String("reqAddr", reqAddr)) } func (l *serverLogger) UDPError(addr net.Addr, id string, sessionID uint32, err error) { diff --git a/app/cmd/server_test.go b/app/cmd/server_test.go index f7f1791..72700f6 100644 --- a/app/cmd/server_test.go +++ b/app/cmd/server_test.go @@ -1,10 +1,11 @@ package cmd import ( - "reflect" "testing" "time" + "github.com/stretchr/testify/assert" + "github.com/spf13/viper" ) @@ -12,25 +13,15 @@ import ( func TestServerConfig(t *testing.T) { viper.SetConfigFile("server_test.yaml") err := viper.ReadInConfig() - if err != nil { - t.Fatal("failed to read server config", err) - } + assert.NoError(t, err) var config serverConfig - if err := viper.Unmarshal(&config); err != nil { - t.Fatal("failed to parse server config", err) - } - if !reflect.DeepEqual(config, serverConfig{ + err = viper.Unmarshal(&config) + assert.NoError(t, err) + assert.Equal(t, config, serverConfig{ Listen: ":8443", - Obfs: struct { - Type string `mapstructure:"type"` - Salamander struct { - Password string `mapstructure:"password"` - } `mapstructure:"salamander"` - }{ + Obfs: serverConfigObfs{ Type: "salamander", - Salamander: struct { - Password string `mapstructure:"password"` - }{ + Salamander: serverConfigObfsSalamander{ Password: "cry_me_a_r1ver", }, }, @@ -51,15 +42,7 @@ func TestServerConfig(t *testing.T) { AltTLSALPNPort: 9443, Dir: "random_dir", }, - QUIC: struct { - InitStreamReceiveWindow uint64 `mapstructure:"initStreamReceiveWindow"` - MaxStreamReceiveWindow uint64 `mapstructure:"maxStreamReceiveWindow"` - InitConnectionReceiveWindow uint64 `mapstructure:"initConnReceiveWindow"` - MaxConnectionReceiveWindow uint64 `mapstructure:"maxConnReceiveWindow"` - MaxIdleTimeout time.Duration `mapstructure:"maxIdleTimeout"` - MaxIncomingStreams int64 `mapstructure:"maxIncomingStreams"` - DisablePathMTUDiscovery bool `mapstructure:"disablePathMTUDiscovery"` - }{ + QUIC: serverConfigQUIC{ InitStreamReceiveWindow: 77881, MaxStreamReceiveWindow: 77882, InitConnectionReceiveWindow: 77883, @@ -68,46 +51,24 @@ func TestServerConfig(t *testing.T) { MaxIncomingStreams: 256, DisablePathMTUDiscovery: true, }, - Bandwidth: struct { - Up string `mapstructure:"up"` - Down string `mapstructure:"down"` - }{ + Bandwidth: serverConfigBandwidth{ Up: "500 mbps", Down: "100 mbps", }, DisableUDP: true, - Auth: struct { - Type string `mapstructure:"type"` - Password string `mapstructure:"password"` - }{ + Auth: serverConfigAuth{ Type: "password", Password: "goofy_ahh_password", }, - Masquerade: struct { - Type string `mapstructure:"type"` - File struct { - Dir string `mapstructure:"dir"` - } `mapstructure:"file"` - Proxy struct { - URL string `mapstructure:"url"` - RewriteHost bool `mapstructure:"rewriteHost"` - } `mapstructure:"proxy"` - }{ + Masquerade: serverConfigMasquerade{ Type: "proxy", - File: struct { - Dir string `mapstructure:"dir"` - }{ + File: serverConfigMasqueradeFile{ Dir: "/www/masq", }, - Proxy: struct { - URL string `mapstructure:"url"` - RewriteHost bool `mapstructure:"rewriteHost"` - }{ + Proxy: serverConfigMasqueradeProxy{ URL: "https://some.site.net", RewriteHost: true, }, }, - }) { - t.Fatal("parsed server config is not equal to expected") - } + }) } diff --git a/app/internal/forwarding/tcp_test.go b/app/internal/forwarding/tcp_test.go index 42710e3..91ab0a1 100644 --- a/app/internal/forwarding/tcp_test.go +++ b/app/internal/forwarding/tcp_test.go @@ -1,49 +1,39 @@ package forwarding import ( - "bytes" "crypto/rand" "net" "testing" + "github.com/stretchr/testify/assert" + "github.com/apernet/hysteria/app/internal/utils_test" ) func TestTCPTunnel(t *testing.T) { // Start the tunnel + l, err := net.Listen("tcp", "127.0.0.1:34567") + assert.NoError(t, err) + defer l.Close() tunnel := &TCPTunnel{ HyClient: &utils_test.MockEchoHyClient{}, - Remote: "whatever", } - l, err := net.Listen("tcp", "127.0.0.1:34567") - if err != nil { - t.Fatal(err) - } - defer l.Close() go tunnel.Serve(l) for i := 0; i < 10; i++ { conn, err := net.Dial("tcp", "127.0.0.1:34567") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) data := make([]byte, 1024) _, _ = rand.Read(data) _, err = conn.Write(data) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + recv := make([]byte, 1024) _, err = conn.Read(recv) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(data, recv) { - t.Fatalf("connection %d: data mismatch", i) - } + assert.NoError(t, err) + assert.Equal(t, data, recv) _ = conn.Close() } } diff --git a/app/internal/forwarding/udp_test.go b/app/internal/forwarding/udp_test.go index 006dab8..feb0c20 100644 --- a/app/internal/forwarding/udp_test.go +++ b/app/internal/forwarding/udp_test.go @@ -1,49 +1,39 @@ package forwarding import ( - "bytes" "crypto/rand" "net" "testing" + "github.com/stretchr/testify/assert" + "github.com/apernet/hysteria/app/internal/utils_test" ) func TestUDPTunnel(t *testing.T) { // Start the tunnel + l, err := net.ListenPacket("udp", "127.0.0.1:34567") + assert.NoError(t, err) + defer l.Close() tunnel := &UDPTunnel{ HyClient: &utils_test.MockEchoHyClient{}, - Remote: "whatever", } - l, err := net.ListenPacket("udp", "127.0.0.1:34567") - if err != nil { - t.Fatal(err) - } - defer l.Close() go tunnel.Serve(l) for i := 0; i < 10; i++ { conn, err := net.Dial("udp", "127.0.0.1:34567") - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) data := make([]byte, 1024) _, _ = rand.Read(data) _, err = conn.Write(data) - if err != nil { - t.Fatal(err) - } + assert.NoError(t, err) + recv := make([]byte, 1024) _, err = conn.Read(recv) - if err != nil { - t.Fatal(err) - } - - if !bytes.Equal(data, recv) { - t.Fatalf("connection %d: data mismatch", i) - } + assert.NoError(t, err) + assert.Equal(t, data, recv) _ = conn.Close() } } diff --git a/app/internal/http/server_test.go b/app/internal/http/server_test.go index fbd24c5..960bef7 100644 --- a/app/internal/http/server_test.go +++ b/app/internal/http/server_test.go @@ -5,8 +5,11 @@ import ( "net" "net/http" "os/exec" + "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/apernet/hysteria/core/client" ) @@ -22,7 +25,6 @@ func (c *mockHyClient) TCP(addr string) (net.Conn, error) { } func (c *mockHyClient) UDP() (client.HyUDPConn, error) { - // Not implemented return nil, errors.New("not implemented") } @@ -32,14 +34,12 @@ func (c *mockHyClient) Close() error { func TestServer(t *testing.T) { // Start the server + l, err := net.Listen("tcp", "127.0.0.1:18080") + assert.NoError(t, err) + defer l.Close() s := &Server{ HyClient: &mockHyClient{}, } - l, err := net.Listen("tcp", "127.0.0.1:18080") - if err != nil { - t.Fatal(err) - } - defer l.Close() go s.Serve(l) // Start a test HTTP & HTTPS server @@ -51,8 +51,9 @@ func TestServer(t *testing.T) { // Run the Python test script cmd := exec.Command("python", "server_test.py") + // Suppress HTTPS warning text from Python + cmd.Env = append(cmd.Env, "PYTHONWARNINGS=ignore:Unverified HTTPS request") out, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("Failed to run test script: %v\n%s", err, out) - } + assert.NoError(t, err) + assert.Equal(t, "OK", strings.TrimSpace(string(out))) } diff --git a/app/internal/http/server_test.py b/app/internal/http/server_test.py index 4b9dee5..46c638f 100644 --- a/app/internal/http/server_test.py +++ b/app/internal/http/server_test.py @@ -1,24 +1,24 @@ import requests proxies = { - 'http': 'http://127.0.0.1:18080', - 'https': 'http://127.0.0.1:18080', + "http": "http://127.0.0.1:18080", + "https": "http://127.0.0.1:18080", } def test_http(it): for i in range(it): - r = requests.get('http://127.0.0.1:18081', proxies=proxies) - assert r.status_code == 200 and r.text == 'control is an illusion' + r = requests.get("http://127.0.0.1:18081", proxies=proxies) + assert r.status_code == 200 and r.text == "control is an illusion" def test_https(it): for i in range(it): - r = requests.get('https://127.0.0.1:18082', - proxies=proxies, verify=False) - assert r.status_code == 200 and r.text == 'control is an illusion' + r = requests.get("https://127.0.0.1:18082", proxies=proxies, verify=False) + assert r.status_code == 200 and r.text == "control is an illusion" -if __name__ == '__main__': +if __name__ == "__main__": test_http(10) test_https(10) + print("OK") diff --git a/app/internal/socks5/server_test.go b/app/internal/socks5/server_test.go index de85204..7aec82a 100644 --- a/app/internal/socks5/server_test.go +++ b/app/internal/socks5/server_test.go @@ -3,27 +3,27 @@ package socks5 import ( "net" "os/exec" + "strings" "testing" + "github.com/stretchr/testify/assert" + "github.com/apernet/hysteria/app/internal/utils_test" ) func TestServer(t *testing.T) { // Start the server + l, err := net.Listen("tcp", "127.0.0.1:11080") + assert.NoError(t, err) + defer l.Close() s := &Server{ HyClient: &utils_test.MockEchoHyClient{}, } - l, err := net.Listen("tcp", "127.0.0.1:11080") - if err != nil { - t.Fatal(err) - } - defer l.Close() go s.Serve(l) // Run the Python test script cmd := exec.Command("python", "server_test.py") out, err := cmd.CombinedOutput() - if err != nil { - t.Fatalf("Failed to run test script: %v\n%s", err, out) - } + assert.NoError(t, err) + assert.Equal(t, "OK", strings.TrimSpace(string(out))) } diff --git a/app/internal/socks5/server_test.py b/app/internal/socks5/server_test.py index 5dae4d2..39f98bc 100644 --- a/app/internal/socks5/server_test.py +++ b/app/internal/socks5/server_test.py @@ -54,3 +54,4 @@ if __name__ == "__main__": test_tcp(1024, 1024, 10, domain=True) test_udp(1024, 1024, 10, domain=False) test_udp(1024, 1024, 10, domain=True) + print("OK")