From fd2d20a46a801e18289a99fe02625f6835c42969 Mon Sep 17 00:00:00 2001 From: Haruue Date: Sat, 24 Aug 2024 00:27:57 +0800 Subject: [PATCH 1/7] feat: local cert loader & sni guard --- app/cmd/server.go | 48 +++++--- app/internal/utils/certloader.go | 189 +++++++++++++++++++++++++++++++ 2 files changed, 221 insertions(+), 16 deletions(-) create mode 100644 app/internal/utils/certloader.go diff --git a/app/cmd/server.go b/app/cmd/server.go index b45fb15..3d37c09 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -83,8 +83,9 @@ type serverConfigObfs struct { } type serverConfigTLS struct { - Cert string `mapstructure:"cert"` - Key string `mapstructure:"key"` + Cert string `mapstructure:"cert"` + Key string `mapstructure:"key"` + SNIGuard string `mapstructure:"sniGuard"` // "disable", "dns-san", "strict" } type serverConfigACME struct { @@ -290,31 +291,46 @@ func (c *serverConfig) fillTLSConfig(hyConfig *server.Config) error { if c.TLS != nil && c.ACME != nil { return configError{Field: "tls", Err: errors.New("cannot set both tls and acme")} } + // SNI guard + var sniGuard utils.SNIGuardFunc + switch strings.ToLower(c.TLS.SNIGuard) { + case "", "dns-san": + sniGuard = utils.SNIGuardDNSSAN + case "strict": + sniGuard = utils.SNIGuardStrict + case "disable": + sniGuard = nil + default: + return configError{Field: "tls.sniGuard", Err: errors.New("unsupported SNI guard")} + } if c.TLS != nil { // Local TLS cert if c.TLS.Cert == "" || c.TLS.Key == "" { return configError{Field: "tls", Err: errors.New("empty cert or key path")} } + certLoader := &utils.LocalCertificateLoader{ + CertFile: c.TLS.Cert, + KeyFile: c.TLS.Key, + SNIGuard: sniGuard, + } // Try loading the cert-key pair here to catch errors early // (e.g. invalid files or insufficient permissions) - certPEMBlock, err := os.ReadFile(c.TLS.Cert) + err := certLoader.InitializeCache() if err != nil { - return configError{Field: "tls.cert", Err: err} - } - keyPEMBlock, err := os.ReadFile(c.TLS.Key) - if err != nil { - return configError{Field: "tls.key", Err: err} - } - _, err = tls.X509KeyPair(certPEMBlock, keyPEMBlock) - if err != nil { - return configError{Field: "tls", Err: fmt.Errorf("invalid cert-key pair: %w", err)} + var pathErr *os.PathError + if errors.As(err, &pathErr) { + if pathErr.Path == c.TLS.Cert { + return configError{Field: "tls.cert", Err: pathErr} + } + if pathErr.Path == c.TLS.Key { + return configError{Field: "tls.key", Err: pathErr} + } + } + return configError{Field: "tls", Err: err} } // Use GetCertificate instead of Certificates so that // users can update the cert without restarting the server. - hyConfig.TLSConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key) - return &cert, err - } + hyConfig.TLSConfig.GetCertificate = certLoader.GetCertificate } else { // ACME dataDir := c.ACME.Dir diff --git a/app/internal/utils/certloader.go b/app/internal/utils/certloader.go new file mode 100644 index 0000000..390548e --- /dev/null +++ b/app/internal/utils/certloader.go @@ -0,0 +1,189 @@ +package utils + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "os" + "strings" + "sync" + "time" +) + +type LocalCertificateLoader struct { + CertFile string + KeyFile string + SNIGuard SNIGuardFunc + + lock sync.RWMutex + cache *localCertificateCache +} + +type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error + +// localCertificateCache holds the certificate and its mod times. +// this struct is designed to be read-only. +// +// to update the cache, use LocalCertificateLoader.makeCache and +// update the LocalCertificateLoader.cache field. +type localCertificateCache struct { + certificate *tls.Certificate + certModTime time.Time + keyModTime time.Time +} + +func (l *LocalCertificateLoader) InitializeCache() error { + cache, err := l.makeCache() + if err != nil { + return err + } + + l.lock.Lock() + defer l.lock.Unlock() + l.cache = cache + return nil +} + +func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := l.getCertificateWithCache() + if err != nil { + return nil, err + } + + if l.SNIGuard == nil { + return cert, nil + } + err = l.SNIGuard(info, cert) + if err != nil { + return nil, err + } + + return cert, nil +} + +func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) { + if fi, ferr := os.Stat(l.CertFile); ferr != nil { + err = fmt.Errorf("failed to stat certificate file: %w", ferr) + return + } else { + certModTime = fi.ModTime() + } + if fi, ferr := os.Stat(l.KeyFile); ferr != nil { + err = fmt.Errorf("failed to stat key file: %w", ferr) + return + } else { + keyModTime = fi.ModTime() + } + return +} + +func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err error) { + c := &localCertificateCache{} + + c.certModTime, c.keyModTime, err = l.checkModTime() + if err != nil { + return + } + + cert, err := tls.LoadX509KeyPair(l.CertFile, l.KeyFile) + if err != nil { + return + } + c.certificate = &cert + c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return + } + + cache = c + return +} + +func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) { + l.lock.RLock() + cache := l.cache + l.lock.RUnlock() + + certModTime, keyModTime, terr := l.checkModTime() + if terr != nil { + if cache != nil { + // use cache when file is temporarily unavailable + return cache.certificate, nil + } + return nil, terr + } + + if cache != nil && cache.certModTime.Equal(certModTime) && cache.keyModTime.Equal(keyModTime) { + // cache is up-to-date + return cache.certificate, nil + } + + if cache != nil { + if !l.lock.TryLock() { + // another goroutine is updating the cache + return cache.certificate, nil + } + } else { + l.lock.Lock() + } + defer l.lock.Unlock() + + newCache, err := l.makeCache() + if err != nil { + if cache != nil { + // use cache when loading failed + return cache.certificate, nil + } + return nil, err + } + + l.cache = newCache + return newCache.certificate, nil +} + +// getNameFromClientHello returns a normalized form of hello.ServerName. +// If hello.ServerName is empty (i.e. client did not use SNI), then the +// associated connection's local address is used to extract an IP address. +// +// ref: https://github.com/caddyserver/certmagic/blob/3bad5b6bb595b09c14bd86ff0b365d302faaf5e2/handshake.go#L838 +func getNameFromClientHello(hello *tls.ClientHelloInfo) string { + normalizedName := func(serverName string) string { + return strings.ToLower(strings.TrimSpace(serverName)) + } + localIPFromConn := func(c net.Conn) string { + if c == nil { + return "" + } + localAddr := c.LocalAddr().String() + ip, _, err := net.SplitHostPort(localAddr) + if err != nil { + ip = localAddr + } + if scopeIDStart := strings.Index(ip, "%"); scopeIDStart > -1 { + ip = ip[:scopeIDStart] + } + return ip + } + + if name := normalizedName(hello.ServerName); name != "" { + return name + } + return localIPFromConn(hello.Conn) +} + +func SNIGuardDNSSAN(info *tls.ClientHelloInfo, cert *tls.Certificate) error { + if len(cert.Leaf.DNSNames) == 0 { + return nil + } + return SNIGuardStrict(info, cert) +} + +func SNIGuardStrict(info *tls.ClientHelloInfo, cert *tls.Certificate) error { + hostname := getNameFromClientHello(info) + err := cert.Leaf.VerifyHostname(hostname) + if err != nil { + return fmt.Errorf("sni guard: %w", err) + } + return nil +} From 57a48a674be1e731dcafb6580da1bb406106c270 Mon Sep 17 00:00:00 2001 From: Haruue Date: Sat, 24 Aug 2024 10:37:08 +0800 Subject: [PATCH 2/7] chore: replace rwlock with atomic pointer --- app/internal/utils/certloader.go | 40 ++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/app/internal/utils/certloader.go b/app/internal/utils/certloader.go index 390548e..6e4a7be 100644 --- a/app/internal/utils/certloader.go +++ b/app/internal/utils/certloader.go @@ -8,6 +8,7 @@ import ( "os" "strings" "sync" + "sync/atomic" "time" ) @@ -16,8 +17,8 @@ type LocalCertificateLoader struct { KeyFile string SNIGuard SNIGuardFunc - lock sync.RWMutex - cache *localCertificateCache + lock sync.Mutex + cache atomic.Pointer[localCertificateCache] } type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error @@ -34,14 +35,15 @@ type localCertificateCache struct { } func (l *LocalCertificateLoader) InitializeCache() error { + l.lock.Lock() + defer l.lock.Unlock() + cache, err := l.makeCache() if err != nil { return err } - l.lock.Lock() - defer l.lock.Unlock() - l.cache = cache + l.cache.Store(cache) return nil } @@ -63,18 +65,19 @@ func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls } func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) { - if fi, ferr := os.Stat(l.CertFile); ferr != nil { - err = fmt.Errorf("failed to stat certificate file: %w", ferr) + fi, err := os.Stat(l.CertFile) + if err != nil { + err = fmt.Errorf("failed to stat certificate file: %w", err) return - } else { - certModTime = fi.ModTime() } - if fi, ferr := os.Stat(l.KeyFile); ferr != nil { - err = fmt.Errorf("failed to stat key file: %w", ferr) + certModTime = fi.ModTime() + + fi, err = os.Stat(l.KeyFile) + if err != nil { + err = fmt.Errorf("failed to stat key file: %w", err) return - } else { - keyModTime = fi.ModTime() } + keyModTime = fi.ModTime() return } @@ -101,9 +104,7 @@ func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err } func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) { - l.lock.RLock() - cache := l.cache - l.lock.RUnlock() + cache := l.cache.Load() certModTime, keyModTime, terr := l.checkModTime() if terr != nil { @@ -129,6 +130,11 @@ func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, er } defer l.lock.Unlock() + if l.cache.Load() != cache { + // another goroutine updated the cache + return l.cache.Load().certificate, nil + } + newCache, err := l.makeCache() if err != nil { if cache != nil { @@ -138,7 +144,7 @@ func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, er return nil, err } - l.cache = newCache + l.cache.Store(newCache) return newCache.certificate, nil } From 45893b5d1e5c69b1519051b24a46425b519f31a0 Mon Sep 17 00:00:00 2001 From: Haruue Date: Sat, 24 Aug 2024 13:40:42 +0800 Subject: [PATCH 3/7] test: update server_test for sniGuard --- app/cmd/server_test.go | 5 +++-- app/cmd/server_test.yaml | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/app/cmd/server_test.go b/app/cmd/server_test.go index bb2d12a..f35edfb 100644 --- a/app/cmd/server_test.go +++ b/app/cmd/server_test.go @@ -26,8 +26,9 @@ func TestServerConfig(t *testing.T) { }, }, TLS: &serverConfigTLS{ - Cert: "some.crt", - Key: "some.key", + Cert: "some.crt", + Key: "some.key", + SNIGuard: "strict", }, ACME: &serverConfigACME{ Domains: []string{ diff --git a/app/cmd/server_test.yaml b/app/cmd/server_test.yaml index ff0bf52..b7d1a3e 100644 --- a/app/cmd/server_test.yaml +++ b/app/cmd/server_test.yaml @@ -8,6 +8,7 @@ obfs: tls: cert: some.crt key: some.key + sniGuard: strict acme: domains: From bcf830c29aa2b73a04f0108cd22ae420e095de4c Mon Sep 17 00:00:00 2001 From: Haruue Date: Sat, 24 Aug 2024 13:41:46 +0800 Subject: [PATCH 4/7] chore: only init cert.Leaf when not populated since Go 1.23, cert.Leaf will be populated after loaded. see doc of tls.LoadX509KeyPair for details --- app/internal/utils/certloader.go | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/app/internal/utils/certloader.go b/app/internal/utils/certloader.go index 6e4a7be..fb41a3c 100644 --- a/app/internal/utils/certloader.go +++ b/app/internal/utils/certloader.go @@ -94,9 +94,12 @@ func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err return } c.certificate = &cert - c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) - if err != nil { - return + if c.certificate.Leaf == nil { + // certificate.Leaf was left nil by tls.LoadX509KeyPair before Go 1.23 + c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return + } } cache = c From 667b08ec3e49e1a19a30046293bdf576768e8979 Mon Sep 17 00:00:00 2001 From: Haruue Date: Sat, 24 Aug 2024 17:25:31 +0800 Subject: [PATCH 5/7] test: add tests for certloader --- app/internal/utils/certloader_test.go | 139 ++++++++++++++++++ app/internal/utils/certloader_test_gencert.py | 134 +++++++++++++++++ .../utils/certloader_test_tlsclient.py | 60 ++++++++ app/internal/utils/testcerts/.gitignore | 3 + 4 files changed, 336 insertions(+) create mode 100644 app/internal/utils/certloader_test.go create mode 100644 app/internal/utils/certloader_test_gencert.py create mode 100644 app/internal/utils/certloader_test_tlsclient.py create mode 100644 app/internal/utils/testcerts/.gitignore diff --git a/app/internal/utils/certloader_test.go b/app/internal/utils/certloader_test.go new file mode 100644 index 0000000..7c5875c --- /dev/null +++ b/app/internal/utils/certloader_test.go @@ -0,0 +1,139 @@ +package utils + +import ( + "crypto/tls" + "log" + "net/http" + "os" + "os/exec" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + testListen = "127.82.39.147:12947" + testCAFile = "./testcerts/ca" + testCertFile = "./testcerts/cert" + testKeyFile = "./testcerts/key" +) + +func TestCertificateLoaderPathError(t *testing.T) { + assert.NoError(t, os.RemoveAll(testCertFile)) + assert.NoError(t, os.RemoveAll(testKeyFile)) + loader := LocalCertificateLoader{ + CertFile: testCertFile, + KeyFile: testKeyFile, + SNIGuard: SNIGuardStrict, + } + err := loader.InitializeCache() + var pathErr *os.PathError + assert.ErrorAs(t, err, &pathErr) +} + +func TestCertificateLoaderFullChain(t *testing.T) { + assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain")) + + loader := LocalCertificateLoader{ + CertFile: testCertFile, + KeyFile: testKeyFile, + SNIGuard: SNIGuardStrict, + } + assert.NoError(t, loader.InitializeCache()) + + lis, err := tls.Listen("tcp", testListen, &tls.Config{ + GetCertificate: loader.GetCertificate, + }) + assert.NoError(t, err) + defer lis.Close() + go http.Serve(lis, nil) + + assert.Error(t, runTestTLSClient("unmatched-sni.example.com")) + assert.Error(t, runTestTLSClient("")) + assert.NoError(t, runTestTLSClient("example.com")) +} + +func TestCertificateLoaderNoSAN(t *testing.T) { + assert.NoError(t, generateTestCertificate(nil, "selfsign")) + + loader := LocalCertificateLoader{ + CertFile: testCertFile, + KeyFile: testKeyFile, + SNIGuard: SNIGuardDNSSAN, + } + assert.NoError(t, loader.InitializeCache()) + + lis, err := tls.Listen("tcp", testListen, &tls.Config{ + GetCertificate: loader.GetCertificate, + }) + assert.NoError(t, err) + defer lis.Close() + go http.Serve(lis, nil) + + assert.NoError(t, runTestTLSClient("")) +} + +func TestCertificateLoaderReplaceCertificate(t *testing.T) { + assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain")) + + loader := LocalCertificateLoader{ + CertFile: testCertFile, + KeyFile: testKeyFile, + SNIGuard: SNIGuardStrict, + } + assert.NoError(t, loader.InitializeCache()) + + lis, err := tls.Listen("tcp", testListen, &tls.Config{ + GetCertificate: loader.GetCertificate, + }) + assert.NoError(t, err) + defer lis.Close() + go http.Serve(lis, nil) + + assert.NoError(t, runTestTLSClient("example.com")) + assert.Error(t, runTestTLSClient("2.example.com")) + + assert.NoError(t, generateTestCertificate([]string{"2.example.com"}, "fullchain")) + + assert.Error(t, runTestTLSClient("example.com")) + assert.NoError(t, runTestTLSClient("2.example.com")) +} + +func generateTestCertificate(dnssan []string, certType string) error { + args := []string{ + "certloader_test_gencert.py", + "--ca", testCAFile, + "--cert", testCertFile, + "--key", testKeyFile, + "--type", certType, + } + if len(dnssan) > 0 { + args = append(args, "--dnssan", strings.Join(dnssan, ",")) + } + cmd := exec.Command("python3", args...) + out, err := cmd.CombinedOutput() + if err != nil { + log.Printf("Failed to generate test certificate: %s", out) + return err + } + return nil +} + +func runTestTLSClient(sni string) error { + args := []string{ + "certloader_test_tlsclient.py", + "--server", testListen, + "--ca", testCAFile, + } + if sni != "" { + args = append(args, "--sni", sni) + } + cmd := exec.Command("python3", args...) + out, err := cmd.CombinedOutput() + if err != nil { + log.Printf("Failed to run test TLS client: %s", out) + return err + } + return nil +} diff --git a/app/internal/utils/certloader_test_gencert.py b/app/internal/utils/certloader_test_gencert.py new file mode 100644 index 0000000..d4d5695 --- /dev/null +++ b/app/internal/utils/certloader_test_gencert.py @@ -0,0 +1,134 @@ +import argparse +import datetime +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption + + +def create_key(): + return ec.generate_private_key(ec.SECP256R1()) + + +def create_certificate(cert_type, subject, issuer, private_key, public_key, dns_san=None): + serial_number = x509.random_serial_number() + not_valid_before = datetime.datetime.now(datetime.UTC) + not_valid_after = not_valid_before + datetime.timedelta(days=365) + + subject_name = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, subject.get('C', 'ZZ')), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, subject.get('O', 'No Organization')), + x509.NameAttribute(NameOID.COMMON_NAME, subject.get('CN', 'No CommonName')), + ]) + issuer_name = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, issuer.get('C', 'ZZ')), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, issuer.get('O', 'No Organization')), + x509.NameAttribute(NameOID.COMMON_NAME, issuer.get('CN', 'No CommonName')), + ]) + builder = x509.CertificateBuilder() + builder = builder.subject_name(subject_name) + builder = builder.issuer_name(issuer_name) + builder = builder.public_key(public_key) + builder = builder.serial_number(serial_number) + builder = builder.not_valid_before(not_valid_before) + builder = builder.not_valid_after(not_valid_after) + if cert_type == 'root': + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=None), critical=True + ) + elif cert_type == 'intermediate': + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=0), critical=True + ) + elif cert_type == 'leaf': + builder = builder.add_extension( + x509.BasicConstraints(ca=False, path_length=None), critical=True + ) + else: + raise ValueError(f'Invalid cert_type: {cert_type}') + if dns_san: + builder = builder.add_extension( + x509.SubjectAlternativeName([x509.DNSName(d) for d in dns_san.split(',')]), + critical=False + ) + return builder.sign(private_key=private_key, algorithm=hashes.SHA256()) + + +def main(): + parser = argparse.ArgumentParser(description='Generate HTTPS server certificate.') + parser.add_argument('--ca', required=True, + help='Path to write the X509 CA certificate in PEM format') + parser.add_argument('--cert', required=True, + help='Path to write the X509 certificate in PEM format') + parser.add_argument('--key', required=True, + help='Path to write the private key in PEM format') + parser.add_argument('--dnssan', required=False, default=None, + help='Comma-separated list of DNS SANs') + parser.add_argument('--type', required=True, choices=['selfsign', 'fullchain'], + help='Type of certificate to generate') + + args = parser.parse_args() + + key = create_key() + public_key = key.public_key() + + if args.type == 'selfsign': + subject = {"C": "ZZ", "O": "Certificate", "CN": "Certificate"} + cert = create_certificate( + cert_type='root', + subject=subject, + issuer=subject, + private_key=key, + public_key=public_key, + dns_san=args.dnssan) + with open(args.ca, 'wb') as f: + f.write(cert.public_bytes(Encoding.PEM)) + with open(args.cert, 'wb') as f: + f.write(cert.public_bytes(Encoding.PEM)) + with open(args.key, 'wb') as f: + f.write( + key.private_bytes(Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption())) + + elif args.type == 'fullchain': + ca_key = create_key() + ca_public_key = ca_key.public_key() + ca_subject = {"C": "ZZ", "O": "Root CA", "CN": "Root CA"} + ca_cert = create_certificate( + cert_type='root', + subject=ca_subject, + issuer=ca_subject, + private_key=ca_key, + public_key=ca_public_key) + + intermediate_key = create_key() + intermediate_public_key = intermediate_key.public_key() + intermediate_subject = {"C": "ZZ", "O": "Intermediate CA", "CN": "Intermediate CA"} + intermediate_cert = create_certificate( + cert_type='intermediate', + subject=intermediate_subject, + issuer=ca_subject, + private_key=ca_key, + public_key=intermediate_public_key) + + leaf_subject = {"C": "ZZ", "O": "Leaf Certificate", "CN": "Leaf Certificate"} + cert = create_certificate( + cert_type='leaf', + subject=leaf_subject, + issuer=intermediate_subject, + private_key=intermediate_key, + public_key=public_key, + dns_san=args.dnssan) + + with open(args.ca, 'wb') as f: + f.write(ca_cert.public_bytes(Encoding.PEM)) + with open(args.cert, 'wb') as f: + f.write(cert.public_bytes(Encoding.PEM)) + f.write(intermediate_cert.public_bytes(Encoding.PEM)) + with open(args.key, 'wb') as f: + f.write( + key.private_bytes(Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption())) + + +if __name__ == "__main__": + main() diff --git a/app/internal/utils/certloader_test_tlsclient.py b/app/internal/utils/certloader_test_tlsclient.py new file mode 100644 index 0000000..3b7efd6 --- /dev/null +++ b/app/internal/utils/certloader_test_tlsclient.py @@ -0,0 +1,60 @@ +import argparse +import ssl +import socket +import sys + + +def check_tls(server, ca_cert, sni, alpn): + try: + host, port = server.split(":") + port = int(port) + + if ca_cert: + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=ca_cert) + context.check_hostname = sni is not None + context.verify_mode = ssl.CERT_REQUIRED + else: + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + if alpn: + context.set_alpn_protocols([p for p in alpn.split(",")]) + + with socket.create_connection((host, port)) as sock: + with context.wrap_socket(sock, server_hostname=sni) as ssock: + # Verify handshake and certificate + print(f'Connected to {ssock.version()} using {ssock.cipher()}') + print(f'Server certificate validated and details: {ssock.getpeercert()}') + print("OK") + return 0 + except Exception as e: + print(f"Error: {e}") + return 1 + + +def main(): + parser = argparse.ArgumentParser(description="Test TLS Server") + parser.add_argument("--server", required=True, + help="Server address to test (e.g., 127.1.2.3:8443)") + parser.add_argument("--ca", required=False, default=None, + help="CA certificate file used to validate the server certificate" + "Omit to use insecure connection") + parser.add_argument("--sni", required=False, default=None, + help="SNI to send in ClientHello") + parser.add_argument("--alpn", required=False, default='h2', + help="ALPN to send in ClientHello") + + args = parser.parse_args() + + exit_status = check_tls( + server=args.server, + ca_cert=args.ca, + sni=args.sni, + alpn=args.alpn) + + sys.exit(exit_status) + + +if __name__ == "__main__": + main() diff --git a/app/internal/utils/testcerts/.gitignore b/app/internal/utils/testcerts/.gitignore new file mode 100644 index 0000000..082821a --- /dev/null +++ b/app/internal/utils/testcerts/.gitignore @@ -0,0 +1,3 @@ +# This directory is used for certificate generation in certloader_test.go +/* +!/.gitignore From 4ed3f21d7293dd6ac3bf3a9dc4c1fa0e28247379 Mon Sep 17 00:00:00 2001 From: Toby Date: Sat, 24 Aug 2024 17:07:45 -0700 Subject: [PATCH 6/7] fix: crash when the tls option is not used & change from python3 to python --- app/cmd/server.go | 24 ++++++++++++------------ app/internal/utils/certloader_test.go | 4 ++-- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/app/cmd/server.go b/app/cmd/server.go index 3d37c09..3da748d 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -291,19 +291,19 @@ func (c *serverConfig) fillTLSConfig(hyConfig *server.Config) error { if c.TLS != nil && c.ACME != nil { return configError{Field: "tls", Err: errors.New("cannot set both tls and acme")} } - // SNI guard - var sniGuard utils.SNIGuardFunc - switch strings.ToLower(c.TLS.SNIGuard) { - case "", "dns-san": - sniGuard = utils.SNIGuardDNSSAN - case "strict": - sniGuard = utils.SNIGuardStrict - case "disable": - sniGuard = nil - default: - return configError{Field: "tls.sniGuard", Err: errors.New("unsupported SNI guard")} - } if c.TLS != nil { + // SNI guard + var sniGuard utils.SNIGuardFunc + switch strings.ToLower(c.TLS.SNIGuard) { + case "", "dns-san": + sniGuard = utils.SNIGuardDNSSAN + case "strict": + sniGuard = utils.SNIGuardStrict + case "disable": + sniGuard = nil + default: + return configError{Field: "tls.sniGuard", Err: errors.New("unsupported SNI guard")} + } // Local TLS cert if c.TLS.Cert == "" || c.TLS.Key == "" { return configError{Field: "tls", Err: errors.New("empty cert or key path")} diff --git a/app/internal/utils/certloader_test.go b/app/internal/utils/certloader_test.go index 7c5875c..3a8e26b 100644 --- a/app/internal/utils/certloader_test.go +++ b/app/internal/utils/certloader_test.go @@ -111,7 +111,7 @@ func generateTestCertificate(dnssan []string, certType string) error { if len(dnssan) > 0 { args = append(args, "--dnssan", strings.Join(dnssan, ",")) } - cmd := exec.Command("python3", args...) + cmd := exec.Command("python", args...) out, err := cmd.CombinedOutput() if err != nil { log.Printf("Failed to generate test certificate: %s", out) @@ -129,7 +129,7 @@ func runTestTLSClient(sni string) error { if sni != "" { args = append(args, "--sni", sni) } - cmd := exec.Command("python3", args...) + cmd := exec.Command("python", args...) out, err := cmd.CombinedOutput() if err != nil { log.Printf("Failed to run test TLS client: %s", out) From d4b9c5a822d17d3e2ae59a7696e4ffbab0c67dbc Mon Sep 17 00:00:00 2001 From: Haruue Date: Sun, 25 Aug 2024 13:36:45 +0800 Subject: [PATCH 7/7] test: add requirements.txt for ut scripts --- requirements.txt | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 requirements.txt diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..44ee651 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +blinker==1.8.2 +cffi==1.17.0 +click==8.1.7 +cryptography==43.0.0 +Flask==3.0.3 +itsdangerous==2.2.0 +Jinja2==3.1.4 +MarkupSafe==2.1.5 +pycparser==2.22 +PySocks==1.7.1 +Werkzeug==3.0.4