diff --git a/app/cmd/server.go b/app/cmd/server.go index b45fb15..3da748d 100644 --- a/app/cmd/server.go +++ b/app/cmd/server.go @@ -83,8 +83,9 @@ type serverConfigObfs struct { } type serverConfigTLS struct { - Cert string `mapstructure:"cert"` - Key string `mapstructure:"key"` + Cert string `mapstructure:"cert"` + Key string `mapstructure:"key"` + SNIGuard string `mapstructure:"sniGuard"` // "disable", "dns-san", "strict" } type serverConfigACME struct { @@ -291,30 +292,45 @@ func (c *serverConfig) fillTLSConfig(hyConfig *server.Config) error { return configError{Field: "tls", Err: errors.New("cannot set both tls and acme")} } if c.TLS != nil { + // SNI guard + var sniGuard utils.SNIGuardFunc + switch strings.ToLower(c.TLS.SNIGuard) { + case "", "dns-san": + sniGuard = utils.SNIGuardDNSSAN + case "strict": + sniGuard = utils.SNIGuardStrict + case "disable": + sniGuard = nil + default: + return configError{Field: "tls.sniGuard", Err: errors.New("unsupported SNI guard")} + } // Local TLS cert if c.TLS.Cert == "" || c.TLS.Key == "" { return configError{Field: "tls", Err: errors.New("empty cert or key path")} } + certLoader := &utils.LocalCertificateLoader{ + CertFile: c.TLS.Cert, + KeyFile: c.TLS.Key, + SNIGuard: sniGuard, + } // Try loading the cert-key pair here to catch errors early // (e.g. invalid files or insufficient permissions) - certPEMBlock, err := os.ReadFile(c.TLS.Cert) + err := certLoader.InitializeCache() if err != nil { - return configError{Field: "tls.cert", Err: err} - } - keyPEMBlock, err := os.ReadFile(c.TLS.Key) - if err != nil { - return configError{Field: "tls.key", Err: err} - } - _, err = tls.X509KeyPair(certPEMBlock, keyPEMBlock) - if err != nil { - return configError{Field: "tls", Err: fmt.Errorf("invalid cert-key pair: %w", err)} + var pathErr *os.PathError + if errors.As(err, &pathErr) { + if pathErr.Path == c.TLS.Cert { + return configError{Field: "tls.cert", Err: pathErr} + } + if pathErr.Path == c.TLS.Key { + return configError{Field: "tls.key", Err: pathErr} + } + } + return configError{Field: "tls", Err: err} } // Use GetCertificate instead of Certificates so that // users can update the cert without restarting the server. - hyConfig.TLSConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - cert, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key) - return &cert, err - } + hyConfig.TLSConfig.GetCertificate = certLoader.GetCertificate } else { // ACME dataDir := c.ACME.Dir diff --git a/app/cmd/server_test.go b/app/cmd/server_test.go index bb2d12a..f35edfb 100644 --- a/app/cmd/server_test.go +++ b/app/cmd/server_test.go @@ -26,8 +26,9 @@ func TestServerConfig(t *testing.T) { }, }, TLS: &serverConfigTLS{ - Cert: "some.crt", - Key: "some.key", + Cert: "some.crt", + Key: "some.key", + SNIGuard: "strict", }, ACME: &serverConfigACME{ Domains: []string{ diff --git a/app/cmd/server_test.yaml b/app/cmd/server_test.yaml index ff0bf52..b7d1a3e 100644 --- a/app/cmd/server_test.yaml +++ b/app/cmd/server_test.yaml @@ -8,6 +8,7 @@ obfs: tls: cert: some.crt key: some.key + sniGuard: strict acme: domains: diff --git a/app/internal/utils/certloader.go b/app/internal/utils/certloader.go new file mode 100644 index 0000000..fb41a3c --- /dev/null +++ b/app/internal/utils/certloader.go @@ -0,0 +1,198 @@ +package utils + +import ( + "crypto/tls" + "crypto/x509" + "fmt" + "net" + "os" + "strings" + "sync" + "sync/atomic" + "time" +) + +type LocalCertificateLoader struct { + CertFile string + KeyFile string + SNIGuard SNIGuardFunc + + lock sync.Mutex + cache atomic.Pointer[localCertificateCache] +} + +type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error + +// localCertificateCache holds the certificate and its mod times. +// this struct is designed to be read-only. +// +// to update the cache, use LocalCertificateLoader.makeCache and +// update the LocalCertificateLoader.cache field. +type localCertificateCache struct { + certificate *tls.Certificate + certModTime time.Time + keyModTime time.Time +} + +func (l *LocalCertificateLoader) InitializeCache() error { + l.lock.Lock() + defer l.lock.Unlock() + + cache, err := l.makeCache() + if err != nil { + return err + } + + l.cache.Store(cache) + return nil +} + +func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + cert, err := l.getCertificateWithCache() + if err != nil { + return nil, err + } + + if l.SNIGuard == nil { + return cert, nil + } + err = l.SNIGuard(info, cert) + if err != nil { + return nil, err + } + + return cert, nil +} + +func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) { + fi, err := os.Stat(l.CertFile) + if err != nil { + err = fmt.Errorf("failed to stat certificate file: %w", err) + return + } + certModTime = fi.ModTime() + + fi, err = os.Stat(l.KeyFile) + if err != nil { + err = fmt.Errorf("failed to stat key file: %w", err) + return + } + keyModTime = fi.ModTime() + return +} + +func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err error) { + c := &localCertificateCache{} + + c.certModTime, c.keyModTime, err = l.checkModTime() + if err != nil { + return + } + + cert, err := tls.LoadX509KeyPair(l.CertFile, l.KeyFile) + if err != nil { + return + } + c.certificate = &cert + if c.certificate.Leaf == nil { + // certificate.Leaf was left nil by tls.LoadX509KeyPair before Go 1.23 + c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return + } + } + + cache = c + return +} + +func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) { + cache := l.cache.Load() + + certModTime, keyModTime, terr := l.checkModTime() + if terr != nil { + if cache != nil { + // use cache when file is temporarily unavailable + return cache.certificate, nil + } + return nil, terr + } + + if cache != nil && cache.certModTime.Equal(certModTime) && cache.keyModTime.Equal(keyModTime) { + // cache is up-to-date + return cache.certificate, nil + } + + if cache != nil { + if !l.lock.TryLock() { + // another goroutine is updating the cache + return cache.certificate, nil + } + } else { + l.lock.Lock() + } + defer l.lock.Unlock() + + if l.cache.Load() != cache { + // another goroutine updated the cache + return l.cache.Load().certificate, nil + } + + newCache, err := l.makeCache() + if err != nil { + if cache != nil { + // use cache when loading failed + return cache.certificate, nil + } + return nil, err + } + + l.cache.Store(newCache) + return newCache.certificate, nil +} + +// getNameFromClientHello returns a normalized form of hello.ServerName. +// If hello.ServerName is empty (i.e. client did not use SNI), then the +// associated connection's local address is used to extract an IP address. +// +// ref: https://github.com/caddyserver/certmagic/blob/3bad5b6bb595b09c14bd86ff0b365d302faaf5e2/handshake.go#L838 +func getNameFromClientHello(hello *tls.ClientHelloInfo) string { + normalizedName := func(serverName string) string { + return strings.ToLower(strings.TrimSpace(serverName)) + } + localIPFromConn := func(c net.Conn) string { + if c == nil { + return "" + } + localAddr := c.LocalAddr().String() + ip, _, err := net.SplitHostPort(localAddr) + if err != nil { + ip = localAddr + } + if scopeIDStart := strings.Index(ip, "%"); scopeIDStart > -1 { + ip = ip[:scopeIDStart] + } + return ip + } + + if name := normalizedName(hello.ServerName); name != "" { + return name + } + return localIPFromConn(hello.Conn) +} + +func SNIGuardDNSSAN(info *tls.ClientHelloInfo, cert *tls.Certificate) error { + if len(cert.Leaf.DNSNames) == 0 { + return nil + } + return SNIGuardStrict(info, cert) +} + +func SNIGuardStrict(info *tls.ClientHelloInfo, cert *tls.Certificate) error { + hostname := getNameFromClientHello(info) + err := cert.Leaf.VerifyHostname(hostname) + if err != nil { + return fmt.Errorf("sni guard: %w", err) + } + return nil +} diff --git a/app/internal/utils/certloader_test.go b/app/internal/utils/certloader_test.go new file mode 100644 index 0000000..3a8e26b --- /dev/null +++ b/app/internal/utils/certloader_test.go @@ -0,0 +1,139 @@ +package utils + +import ( + "crypto/tls" + "log" + "net/http" + "os" + "os/exec" + "strings" + "testing" + + "github.com/stretchr/testify/assert" +) + +const ( + testListen = "127.82.39.147:12947" + testCAFile = "./testcerts/ca" + testCertFile = "./testcerts/cert" + testKeyFile = "./testcerts/key" +) + +func TestCertificateLoaderPathError(t *testing.T) { + assert.NoError(t, os.RemoveAll(testCertFile)) + assert.NoError(t, os.RemoveAll(testKeyFile)) + loader := LocalCertificateLoader{ + CertFile: testCertFile, + KeyFile: testKeyFile, + SNIGuard: SNIGuardStrict, + } + err := loader.InitializeCache() + var pathErr *os.PathError + assert.ErrorAs(t, err, &pathErr) +} + +func TestCertificateLoaderFullChain(t *testing.T) { + assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain")) + + loader := LocalCertificateLoader{ + CertFile: testCertFile, + KeyFile: testKeyFile, + SNIGuard: SNIGuardStrict, + } + assert.NoError(t, loader.InitializeCache()) + + lis, err := tls.Listen("tcp", testListen, &tls.Config{ + GetCertificate: loader.GetCertificate, + }) + assert.NoError(t, err) + defer lis.Close() + go http.Serve(lis, nil) + + assert.Error(t, runTestTLSClient("unmatched-sni.example.com")) + assert.Error(t, runTestTLSClient("")) + assert.NoError(t, runTestTLSClient("example.com")) +} + +func TestCertificateLoaderNoSAN(t *testing.T) { + assert.NoError(t, generateTestCertificate(nil, "selfsign")) + + loader := LocalCertificateLoader{ + CertFile: testCertFile, + KeyFile: testKeyFile, + SNIGuard: SNIGuardDNSSAN, + } + assert.NoError(t, loader.InitializeCache()) + + lis, err := tls.Listen("tcp", testListen, &tls.Config{ + GetCertificate: loader.GetCertificate, + }) + assert.NoError(t, err) + defer lis.Close() + go http.Serve(lis, nil) + + assert.NoError(t, runTestTLSClient("")) +} + +func TestCertificateLoaderReplaceCertificate(t *testing.T) { + assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain")) + + loader := LocalCertificateLoader{ + CertFile: testCertFile, + KeyFile: testKeyFile, + SNIGuard: SNIGuardStrict, + } + assert.NoError(t, loader.InitializeCache()) + + lis, err := tls.Listen("tcp", testListen, &tls.Config{ + GetCertificate: loader.GetCertificate, + }) + assert.NoError(t, err) + defer lis.Close() + go http.Serve(lis, nil) + + assert.NoError(t, runTestTLSClient("example.com")) + assert.Error(t, runTestTLSClient("2.example.com")) + + assert.NoError(t, generateTestCertificate([]string{"2.example.com"}, "fullchain")) + + assert.Error(t, runTestTLSClient("example.com")) + assert.NoError(t, runTestTLSClient("2.example.com")) +} + +func generateTestCertificate(dnssan []string, certType string) error { + args := []string{ + "certloader_test_gencert.py", + "--ca", testCAFile, + "--cert", testCertFile, + "--key", testKeyFile, + "--type", certType, + } + if len(dnssan) > 0 { + args = append(args, "--dnssan", strings.Join(dnssan, ",")) + } + cmd := exec.Command("python", args...) + out, err := cmd.CombinedOutput() + if err != nil { + log.Printf("Failed to generate test certificate: %s", out) + return err + } + return nil +} + +func runTestTLSClient(sni string) error { + args := []string{ + "certloader_test_tlsclient.py", + "--server", testListen, + "--ca", testCAFile, + } + if sni != "" { + args = append(args, "--sni", sni) + } + cmd := exec.Command("python", args...) + out, err := cmd.CombinedOutput() + if err != nil { + log.Printf("Failed to run test TLS client: %s", out) + return err + } + return nil +} diff --git a/app/internal/utils/certloader_test_gencert.py b/app/internal/utils/certloader_test_gencert.py new file mode 100644 index 0000000..d4d5695 --- /dev/null +++ b/app/internal/utils/certloader_test_gencert.py @@ -0,0 +1,134 @@ +import argparse +import datetime +from cryptography import x509 +from cryptography.x509.oid import NameOID +from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption + + +def create_key(): + return ec.generate_private_key(ec.SECP256R1()) + + +def create_certificate(cert_type, subject, issuer, private_key, public_key, dns_san=None): + serial_number = x509.random_serial_number() + not_valid_before = datetime.datetime.now(datetime.UTC) + not_valid_after = not_valid_before + datetime.timedelta(days=365) + + subject_name = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, subject.get('C', 'ZZ')), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, subject.get('O', 'No Organization')), + x509.NameAttribute(NameOID.COMMON_NAME, subject.get('CN', 'No CommonName')), + ]) + issuer_name = x509.Name([ + x509.NameAttribute(NameOID.COUNTRY_NAME, issuer.get('C', 'ZZ')), + x509.NameAttribute(NameOID.ORGANIZATION_NAME, issuer.get('O', 'No Organization')), + x509.NameAttribute(NameOID.COMMON_NAME, issuer.get('CN', 'No CommonName')), + ]) + builder = x509.CertificateBuilder() + builder = builder.subject_name(subject_name) + builder = builder.issuer_name(issuer_name) + builder = builder.public_key(public_key) + builder = builder.serial_number(serial_number) + builder = builder.not_valid_before(not_valid_before) + builder = builder.not_valid_after(not_valid_after) + if cert_type == 'root': + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=None), critical=True + ) + elif cert_type == 'intermediate': + builder = builder.add_extension( + x509.BasicConstraints(ca=True, path_length=0), critical=True + ) + elif cert_type == 'leaf': + builder = builder.add_extension( + x509.BasicConstraints(ca=False, path_length=None), critical=True + ) + else: + raise ValueError(f'Invalid cert_type: {cert_type}') + if dns_san: + builder = builder.add_extension( + x509.SubjectAlternativeName([x509.DNSName(d) for d in dns_san.split(',')]), + critical=False + ) + return builder.sign(private_key=private_key, algorithm=hashes.SHA256()) + + +def main(): + parser = argparse.ArgumentParser(description='Generate HTTPS server certificate.') + parser.add_argument('--ca', required=True, + help='Path to write the X509 CA certificate in PEM format') + parser.add_argument('--cert', required=True, + help='Path to write the X509 certificate in PEM format') + parser.add_argument('--key', required=True, + help='Path to write the private key in PEM format') + parser.add_argument('--dnssan', required=False, default=None, + help='Comma-separated list of DNS SANs') + parser.add_argument('--type', required=True, choices=['selfsign', 'fullchain'], + help='Type of certificate to generate') + + args = parser.parse_args() + + key = create_key() + public_key = key.public_key() + + if args.type == 'selfsign': + subject = {"C": "ZZ", "O": "Certificate", "CN": "Certificate"} + cert = create_certificate( + cert_type='root', + subject=subject, + issuer=subject, + private_key=key, + public_key=public_key, + dns_san=args.dnssan) + with open(args.ca, 'wb') as f: + f.write(cert.public_bytes(Encoding.PEM)) + with open(args.cert, 'wb') as f: + f.write(cert.public_bytes(Encoding.PEM)) + with open(args.key, 'wb') as f: + f.write( + key.private_bytes(Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption())) + + elif args.type == 'fullchain': + ca_key = create_key() + ca_public_key = ca_key.public_key() + ca_subject = {"C": "ZZ", "O": "Root CA", "CN": "Root CA"} + ca_cert = create_certificate( + cert_type='root', + subject=ca_subject, + issuer=ca_subject, + private_key=ca_key, + public_key=ca_public_key) + + intermediate_key = create_key() + intermediate_public_key = intermediate_key.public_key() + intermediate_subject = {"C": "ZZ", "O": "Intermediate CA", "CN": "Intermediate CA"} + intermediate_cert = create_certificate( + cert_type='intermediate', + subject=intermediate_subject, + issuer=ca_subject, + private_key=ca_key, + public_key=intermediate_public_key) + + leaf_subject = {"C": "ZZ", "O": "Leaf Certificate", "CN": "Leaf Certificate"} + cert = create_certificate( + cert_type='leaf', + subject=leaf_subject, + issuer=intermediate_subject, + private_key=intermediate_key, + public_key=public_key, + dns_san=args.dnssan) + + with open(args.ca, 'wb') as f: + f.write(ca_cert.public_bytes(Encoding.PEM)) + with open(args.cert, 'wb') as f: + f.write(cert.public_bytes(Encoding.PEM)) + f.write(intermediate_cert.public_bytes(Encoding.PEM)) + with open(args.key, 'wb') as f: + f.write( + key.private_bytes(Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption())) + + +if __name__ == "__main__": + main() diff --git a/app/internal/utils/certloader_test_tlsclient.py b/app/internal/utils/certloader_test_tlsclient.py new file mode 100644 index 0000000..3b7efd6 --- /dev/null +++ b/app/internal/utils/certloader_test_tlsclient.py @@ -0,0 +1,60 @@ +import argparse +import ssl +import socket +import sys + + +def check_tls(server, ca_cert, sni, alpn): + try: + host, port = server.split(":") + port = int(port) + + if ca_cert: + context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=ca_cert) + context.check_hostname = sni is not None + context.verify_mode = ssl.CERT_REQUIRED + else: + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + if alpn: + context.set_alpn_protocols([p for p in alpn.split(",")]) + + with socket.create_connection((host, port)) as sock: + with context.wrap_socket(sock, server_hostname=sni) as ssock: + # Verify handshake and certificate + print(f'Connected to {ssock.version()} using {ssock.cipher()}') + print(f'Server certificate validated and details: {ssock.getpeercert()}') + print("OK") + return 0 + except Exception as e: + print(f"Error: {e}") + return 1 + + +def main(): + parser = argparse.ArgumentParser(description="Test TLS Server") + parser.add_argument("--server", required=True, + help="Server address to test (e.g., 127.1.2.3:8443)") + parser.add_argument("--ca", required=False, default=None, + help="CA certificate file used to validate the server certificate" + "Omit to use insecure connection") + parser.add_argument("--sni", required=False, default=None, + help="SNI to send in ClientHello") + parser.add_argument("--alpn", required=False, default='h2', + help="ALPN to send in ClientHello") + + args = parser.parse_args() + + exit_status = check_tls( + server=args.server, + ca_cert=args.ca, + sni=args.sni, + alpn=args.alpn) + + sys.exit(exit_status) + + +if __name__ == "__main__": + main() diff --git a/app/internal/utils/testcerts/.gitignore b/app/internal/utils/testcerts/.gitignore new file mode 100644 index 0000000..082821a --- /dev/null +++ b/app/internal/utils/testcerts/.gitignore @@ -0,0 +1,3 @@ +# This directory is used for certificate generation in certloader_test.go +/* +!/.gitignore diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..44ee651 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,11 @@ +blinker==1.8.2 +cffi==1.17.0 +click==8.1.7 +cryptography==43.0.0 +Flask==3.0.3 +itsdangerous==2.2.0 +Jinja2==3.1.4 +MarkupSafe==2.1.5 +pycparser==2.22 +PySocks==1.7.1 +Werkzeug==3.0.4