mirror of
https://github.com/cmz0228/hysteria-dev.git
synced 2025-06-08 05:19:53 +00:00
Merge pull request #1191 from apernet/wip-sni-guard
feat: local cert loader & sni guard
This commit is contained in:
commit
21ea2a024a
@ -83,8 +83,9 @@ type serverConfigObfs struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type serverConfigTLS struct {
|
type serverConfigTLS struct {
|
||||||
Cert string `mapstructure:"cert"`
|
Cert string `mapstructure:"cert"`
|
||||||
Key string `mapstructure:"key"`
|
Key string `mapstructure:"key"`
|
||||||
|
SNIGuard string `mapstructure:"sniGuard"` // "disable", "dns-san", "strict"
|
||||||
}
|
}
|
||||||
|
|
||||||
type serverConfigACME struct {
|
type serverConfigACME struct {
|
||||||
@ -291,30 +292,45 @@ func (c *serverConfig) fillTLSConfig(hyConfig *server.Config) error {
|
|||||||
return configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
|
return configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
|
||||||
}
|
}
|
||||||
if c.TLS != nil {
|
if c.TLS != nil {
|
||||||
|
// SNI guard
|
||||||
|
var sniGuard utils.SNIGuardFunc
|
||||||
|
switch strings.ToLower(c.TLS.SNIGuard) {
|
||||||
|
case "", "dns-san":
|
||||||
|
sniGuard = utils.SNIGuardDNSSAN
|
||||||
|
case "strict":
|
||||||
|
sniGuard = utils.SNIGuardStrict
|
||||||
|
case "disable":
|
||||||
|
sniGuard = nil
|
||||||
|
default:
|
||||||
|
return configError{Field: "tls.sniGuard", Err: errors.New("unsupported SNI guard")}
|
||||||
|
}
|
||||||
// Local TLS cert
|
// Local TLS cert
|
||||||
if c.TLS.Cert == "" || c.TLS.Key == "" {
|
if c.TLS.Cert == "" || c.TLS.Key == "" {
|
||||||
return configError{Field: "tls", Err: errors.New("empty cert or key path")}
|
return configError{Field: "tls", Err: errors.New("empty cert or key path")}
|
||||||
}
|
}
|
||||||
|
certLoader := &utils.LocalCertificateLoader{
|
||||||
|
CertFile: c.TLS.Cert,
|
||||||
|
KeyFile: c.TLS.Key,
|
||||||
|
SNIGuard: sniGuard,
|
||||||
|
}
|
||||||
// Try loading the cert-key pair here to catch errors early
|
// Try loading the cert-key pair here to catch errors early
|
||||||
// (e.g. invalid files or insufficient permissions)
|
// (e.g. invalid files or insufficient permissions)
|
||||||
certPEMBlock, err := os.ReadFile(c.TLS.Cert)
|
err := certLoader.InitializeCache()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return configError{Field: "tls.cert", Err: err}
|
var pathErr *os.PathError
|
||||||
}
|
if errors.As(err, &pathErr) {
|
||||||
keyPEMBlock, err := os.ReadFile(c.TLS.Key)
|
if pathErr.Path == c.TLS.Cert {
|
||||||
if err != nil {
|
return configError{Field: "tls.cert", Err: pathErr}
|
||||||
return configError{Field: "tls.key", Err: err}
|
}
|
||||||
}
|
if pathErr.Path == c.TLS.Key {
|
||||||
_, err = tls.X509KeyPair(certPEMBlock, keyPEMBlock)
|
return configError{Field: "tls.key", Err: pathErr}
|
||||||
if err != nil {
|
}
|
||||||
return configError{Field: "tls", Err: fmt.Errorf("invalid cert-key pair: %w", err)}
|
}
|
||||||
|
return configError{Field: "tls", Err: err}
|
||||||
}
|
}
|
||||||
// Use GetCertificate instead of Certificates so that
|
// Use GetCertificate instead of Certificates so that
|
||||||
// users can update the cert without restarting the server.
|
// users can update the cert without restarting the server.
|
||||||
hyConfig.TLSConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
hyConfig.TLSConfig.GetCertificate = certLoader.GetCertificate
|
||||||
cert, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key)
|
|
||||||
return &cert, err
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// ACME
|
// ACME
|
||||||
dataDir := c.ACME.Dir
|
dataDir := c.ACME.Dir
|
||||||
|
@ -26,8 +26,9 @@ func TestServerConfig(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
TLS: &serverConfigTLS{
|
TLS: &serverConfigTLS{
|
||||||
Cert: "some.crt",
|
Cert: "some.crt",
|
||||||
Key: "some.key",
|
Key: "some.key",
|
||||||
|
SNIGuard: "strict",
|
||||||
},
|
},
|
||||||
ACME: &serverConfigACME{
|
ACME: &serverConfigACME{
|
||||||
Domains: []string{
|
Domains: []string{
|
||||||
|
@ -8,6 +8,7 @@ obfs:
|
|||||||
tls:
|
tls:
|
||||||
cert: some.crt
|
cert: some.crt
|
||||||
key: some.key
|
key: some.key
|
||||||
|
sniGuard: strict
|
||||||
|
|
||||||
acme:
|
acme:
|
||||||
domains:
|
domains:
|
||||||
|
198
app/internal/utils/certloader.go
Normal file
198
app/internal/utils/certloader.go
Normal file
@ -0,0 +1,198 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LocalCertificateLoader struct {
|
||||||
|
CertFile string
|
||||||
|
KeyFile string
|
||||||
|
SNIGuard SNIGuardFunc
|
||||||
|
|
||||||
|
lock sync.Mutex
|
||||||
|
cache atomic.Pointer[localCertificateCache]
|
||||||
|
}
|
||||||
|
|
||||||
|
type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error
|
||||||
|
|
||||||
|
// localCertificateCache holds the certificate and its mod times.
|
||||||
|
// this struct is designed to be read-only.
|
||||||
|
//
|
||||||
|
// to update the cache, use LocalCertificateLoader.makeCache and
|
||||||
|
// update the LocalCertificateLoader.cache field.
|
||||||
|
type localCertificateCache struct {
|
||||||
|
certificate *tls.Certificate
|
||||||
|
certModTime time.Time
|
||||||
|
keyModTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalCertificateLoader) InitializeCache() error {
|
||||||
|
l.lock.Lock()
|
||||||
|
defer l.lock.Unlock()
|
||||||
|
|
||||||
|
cache, err := l.makeCache()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
l.cache.Store(cache)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
cert, err := l.getCertificateWithCache()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if l.SNIGuard == nil {
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
err = l.SNIGuard(info, cert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) {
|
||||||
|
fi, err := os.Stat(l.CertFile)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to stat certificate file: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
certModTime = fi.ModTime()
|
||||||
|
|
||||||
|
fi, err = os.Stat(l.KeyFile)
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("failed to stat key file: %w", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
keyModTime = fi.ModTime()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err error) {
|
||||||
|
c := &localCertificateCache{}
|
||||||
|
|
||||||
|
c.certModTime, c.keyModTime, err = l.checkModTime()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := tls.LoadX509KeyPair(l.CertFile, l.KeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.certificate = &cert
|
||||||
|
if c.certificate.Leaf == nil {
|
||||||
|
// certificate.Leaf was left nil by tls.LoadX509KeyPair before Go 1.23
|
||||||
|
c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cache = c
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) {
|
||||||
|
cache := l.cache.Load()
|
||||||
|
|
||||||
|
certModTime, keyModTime, terr := l.checkModTime()
|
||||||
|
if terr != nil {
|
||||||
|
if cache != nil {
|
||||||
|
// use cache when file is temporarily unavailable
|
||||||
|
return cache.certificate, nil
|
||||||
|
}
|
||||||
|
return nil, terr
|
||||||
|
}
|
||||||
|
|
||||||
|
if cache != nil && cache.certModTime.Equal(certModTime) && cache.keyModTime.Equal(keyModTime) {
|
||||||
|
// cache is up-to-date
|
||||||
|
return cache.certificate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cache != nil {
|
||||||
|
if !l.lock.TryLock() {
|
||||||
|
// another goroutine is updating the cache
|
||||||
|
return cache.certificate, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
l.lock.Lock()
|
||||||
|
}
|
||||||
|
defer l.lock.Unlock()
|
||||||
|
|
||||||
|
if l.cache.Load() != cache {
|
||||||
|
// another goroutine updated the cache
|
||||||
|
return l.cache.Load().certificate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
newCache, err := l.makeCache()
|
||||||
|
if err != nil {
|
||||||
|
if cache != nil {
|
||||||
|
// use cache when loading failed
|
||||||
|
return cache.certificate, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
l.cache.Store(newCache)
|
||||||
|
return newCache.certificate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNameFromClientHello returns a normalized form of hello.ServerName.
|
||||||
|
// If hello.ServerName is empty (i.e. client did not use SNI), then the
|
||||||
|
// associated connection's local address is used to extract an IP address.
|
||||||
|
//
|
||||||
|
// ref: https://github.com/caddyserver/certmagic/blob/3bad5b6bb595b09c14bd86ff0b365d302faaf5e2/handshake.go#L838
|
||||||
|
func getNameFromClientHello(hello *tls.ClientHelloInfo) string {
|
||||||
|
normalizedName := func(serverName string) string {
|
||||||
|
return strings.ToLower(strings.TrimSpace(serverName))
|
||||||
|
}
|
||||||
|
localIPFromConn := func(c net.Conn) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
localAddr := c.LocalAddr().String()
|
||||||
|
ip, _, err := net.SplitHostPort(localAddr)
|
||||||
|
if err != nil {
|
||||||
|
ip = localAddr
|
||||||
|
}
|
||||||
|
if scopeIDStart := strings.Index(ip, "%"); scopeIDStart > -1 {
|
||||||
|
ip = ip[:scopeIDStart]
|
||||||
|
}
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
if name := normalizedName(hello.ServerName); name != "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
return localIPFromConn(hello.Conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SNIGuardDNSSAN(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
|
||||||
|
if len(cert.Leaf.DNSNames) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return SNIGuardStrict(info, cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SNIGuardStrict(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
|
||||||
|
hostname := getNameFromClientHello(info)
|
||||||
|
err := cert.Leaf.VerifyHostname(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sni guard: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
139
app/internal/utils/certloader_test.go
Normal file
139
app/internal/utils/certloader_test.go
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"log"
|
||||||
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testListen = "127.82.39.147:12947"
|
||||||
|
testCAFile = "./testcerts/ca"
|
||||||
|
testCertFile = "./testcerts/cert"
|
||||||
|
testKeyFile = "./testcerts/key"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestCertificateLoaderPathError(t *testing.T) {
|
||||||
|
assert.NoError(t, os.RemoveAll(testCertFile))
|
||||||
|
assert.NoError(t, os.RemoveAll(testKeyFile))
|
||||||
|
loader := LocalCertificateLoader{
|
||||||
|
CertFile: testCertFile,
|
||||||
|
KeyFile: testKeyFile,
|
||||||
|
SNIGuard: SNIGuardStrict,
|
||||||
|
}
|
||||||
|
err := loader.InitializeCache()
|
||||||
|
var pathErr *os.PathError
|
||||||
|
assert.ErrorAs(t, err, &pathErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertificateLoaderFullChain(t *testing.T) {
|
||||||
|
assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain"))
|
||||||
|
|
||||||
|
loader := LocalCertificateLoader{
|
||||||
|
CertFile: testCertFile,
|
||||||
|
KeyFile: testKeyFile,
|
||||||
|
SNIGuard: SNIGuardStrict,
|
||||||
|
}
|
||||||
|
assert.NoError(t, loader.InitializeCache())
|
||||||
|
|
||||||
|
lis, err := tls.Listen("tcp", testListen, &tls.Config{
|
||||||
|
GetCertificate: loader.GetCertificate,
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer lis.Close()
|
||||||
|
go http.Serve(lis, nil)
|
||||||
|
|
||||||
|
assert.Error(t, runTestTLSClient("unmatched-sni.example.com"))
|
||||||
|
assert.Error(t, runTestTLSClient(""))
|
||||||
|
assert.NoError(t, runTestTLSClient("example.com"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertificateLoaderNoSAN(t *testing.T) {
|
||||||
|
assert.NoError(t, generateTestCertificate(nil, "selfsign"))
|
||||||
|
|
||||||
|
loader := LocalCertificateLoader{
|
||||||
|
CertFile: testCertFile,
|
||||||
|
KeyFile: testKeyFile,
|
||||||
|
SNIGuard: SNIGuardDNSSAN,
|
||||||
|
}
|
||||||
|
assert.NoError(t, loader.InitializeCache())
|
||||||
|
|
||||||
|
lis, err := tls.Listen("tcp", testListen, &tls.Config{
|
||||||
|
GetCertificate: loader.GetCertificate,
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer lis.Close()
|
||||||
|
go http.Serve(lis, nil)
|
||||||
|
|
||||||
|
assert.NoError(t, runTestTLSClient(""))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCertificateLoaderReplaceCertificate(t *testing.T) {
|
||||||
|
assert.NoError(t, generateTestCertificate([]string{"example.com"}, "fullchain"))
|
||||||
|
|
||||||
|
loader := LocalCertificateLoader{
|
||||||
|
CertFile: testCertFile,
|
||||||
|
KeyFile: testKeyFile,
|
||||||
|
SNIGuard: SNIGuardStrict,
|
||||||
|
}
|
||||||
|
assert.NoError(t, loader.InitializeCache())
|
||||||
|
|
||||||
|
lis, err := tls.Listen("tcp", testListen, &tls.Config{
|
||||||
|
GetCertificate: loader.GetCertificate,
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
defer lis.Close()
|
||||||
|
go http.Serve(lis, nil)
|
||||||
|
|
||||||
|
assert.NoError(t, runTestTLSClient("example.com"))
|
||||||
|
assert.Error(t, runTestTLSClient("2.example.com"))
|
||||||
|
|
||||||
|
assert.NoError(t, generateTestCertificate([]string{"2.example.com"}, "fullchain"))
|
||||||
|
|
||||||
|
assert.Error(t, runTestTLSClient("example.com"))
|
||||||
|
assert.NoError(t, runTestTLSClient("2.example.com"))
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateTestCertificate(dnssan []string, certType string) error {
|
||||||
|
args := []string{
|
||||||
|
"certloader_test_gencert.py",
|
||||||
|
"--ca", testCAFile,
|
||||||
|
"--cert", testCertFile,
|
||||||
|
"--key", testKeyFile,
|
||||||
|
"--type", certType,
|
||||||
|
}
|
||||||
|
if len(dnssan) > 0 {
|
||||||
|
args = append(args, "--dnssan", strings.Join(dnssan, ","))
|
||||||
|
}
|
||||||
|
cmd := exec.Command("python", args...)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to generate test certificate: %s", out)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func runTestTLSClient(sni string) error {
|
||||||
|
args := []string{
|
||||||
|
"certloader_test_tlsclient.py",
|
||||||
|
"--server", testListen,
|
||||||
|
"--ca", testCAFile,
|
||||||
|
}
|
||||||
|
if sni != "" {
|
||||||
|
args = append(args, "--sni", sni)
|
||||||
|
}
|
||||||
|
cmd := exec.Command("python", args...)
|
||||||
|
out, err := cmd.CombinedOutput()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to run test TLS client: %s", out)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
134
app/internal/utils/certloader_test_gencert.py
Normal file
134
app/internal/utils/certloader_test_gencert.py
Normal file
@ -0,0 +1,134 @@
|
|||||||
|
import argparse
|
||||||
|
import datetime
|
||||||
|
from cryptography import x509
|
||||||
|
from cryptography.x509.oid import NameOID
|
||||||
|
from cryptography.hazmat.primitives import hashes
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import ec
|
||||||
|
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption
|
||||||
|
|
||||||
|
|
||||||
|
def create_key():
|
||||||
|
return ec.generate_private_key(ec.SECP256R1())
|
||||||
|
|
||||||
|
|
||||||
|
def create_certificate(cert_type, subject, issuer, private_key, public_key, dns_san=None):
|
||||||
|
serial_number = x509.random_serial_number()
|
||||||
|
not_valid_before = datetime.datetime.now(datetime.UTC)
|
||||||
|
not_valid_after = not_valid_before + datetime.timedelta(days=365)
|
||||||
|
|
||||||
|
subject_name = x509.Name([
|
||||||
|
x509.NameAttribute(NameOID.COUNTRY_NAME, subject.get('C', 'ZZ')),
|
||||||
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, subject.get('O', 'No Organization')),
|
||||||
|
x509.NameAttribute(NameOID.COMMON_NAME, subject.get('CN', 'No CommonName')),
|
||||||
|
])
|
||||||
|
issuer_name = x509.Name([
|
||||||
|
x509.NameAttribute(NameOID.COUNTRY_NAME, issuer.get('C', 'ZZ')),
|
||||||
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, issuer.get('O', 'No Organization')),
|
||||||
|
x509.NameAttribute(NameOID.COMMON_NAME, issuer.get('CN', 'No CommonName')),
|
||||||
|
])
|
||||||
|
builder = x509.CertificateBuilder()
|
||||||
|
builder = builder.subject_name(subject_name)
|
||||||
|
builder = builder.issuer_name(issuer_name)
|
||||||
|
builder = builder.public_key(public_key)
|
||||||
|
builder = builder.serial_number(serial_number)
|
||||||
|
builder = builder.not_valid_before(not_valid_before)
|
||||||
|
builder = builder.not_valid_after(not_valid_after)
|
||||||
|
if cert_type == 'root':
|
||||||
|
builder = builder.add_extension(
|
||||||
|
x509.BasicConstraints(ca=True, path_length=None), critical=True
|
||||||
|
)
|
||||||
|
elif cert_type == 'intermediate':
|
||||||
|
builder = builder.add_extension(
|
||||||
|
x509.BasicConstraints(ca=True, path_length=0), critical=True
|
||||||
|
)
|
||||||
|
elif cert_type == 'leaf':
|
||||||
|
builder = builder.add_extension(
|
||||||
|
x509.BasicConstraints(ca=False, path_length=None), critical=True
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f'Invalid cert_type: {cert_type}')
|
||||||
|
if dns_san:
|
||||||
|
builder = builder.add_extension(
|
||||||
|
x509.SubjectAlternativeName([x509.DNSName(d) for d in dns_san.split(',')]),
|
||||||
|
critical=False
|
||||||
|
)
|
||||||
|
return builder.sign(private_key=private_key, algorithm=hashes.SHA256())
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description='Generate HTTPS server certificate.')
|
||||||
|
parser.add_argument('--ca', required=True,
|
||||||
|
help='Path to write the X509 CA certificate in PEM format')
|
||||||
|
parser.add_argument('--cert', required=True,
|
||||||
|
help='Path to write the X509 certificate in PEM format')
|
||||||
|
parser.add_argument('--key', required=True,
|
||||||
|
help='Path to write the private key in PEM format')
|
||||||
|
parser.add_argument('--dnssan', required=False, default=None,
|
||||||
|
help='Comma-separated list of DNS SANs')
|
||||||
|
parser.add_argument('--type', required=True, choices=['selfsign', 'fullchain'],
|
||||||
|
help='Type of certificate to generate')
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
key = create_key()
|
||||||
|
public_key = key.public_key()
|
||||||
|
|
||||||
|
if args.type == 'selfsign':
|
||||||
|
subject = {"C": "ZZ", "O": "Certificate", "CN": "Certificate"}
|
||||||
|
cert = create_certificate(
|
||||||
|
cert_type='root',
|
||||||
|
subject=subject,
|
||||||
|
issuer=subject,
|
||||||
|
private_key=key,
|
||||||
|
public_key=public_key,
|
||||||
|
dns_san=args.dnssan)
|
||||||
|
with open(args.ca, 'wb') as f:
|
||||||
|
f.write(cert.public_bytes(Encoding.PEM))
|
||||||
|
with open(args.cert, 'wb') as f:
|
||||||
|
f.write(cert.public_bytes(Encoding.PEM))
|
||||||
|
with open(args.key, 'wb') as f:
|
||||||
|
f.write(
|
||||||
|
key.private_bytes(Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption()))
|
||||||
|
|
||||||
|
elif args.type == 'fullchain':
|
||||||
|
ca_key = create_key()
|
||||||
|
ca_public_key = ca_key.public_key()
|
||||||
|
ca_subject = {"C": "ZZ", "O": "Root CA", "CN": "Root CA"}
|
||||||
|
ca_cert = create_certificate(
|
||||||
|
cert_type='root',
|
||||||
|
subject=ca_subject,
|
||||||
|
issuer=ca_subject,
|
||||||
|
private_key=ca_key,
|
||||||
|
public_key=ca_public_key)
|
||||||
|
|
||||||
|
intermediate_key = create_key()
|
||||||
|
intermediate_public_key = intermediate_key.public_key()
|
||||||
|
intermediate_subject = {"C": "ZZ", "O": "Intermediate CA", "CN": "Intermediate CA"}
|
||||||
|
intermediate_cert = create_certificate(
|
||||||
|
cert_type='intermediate',
|
||||||
|
subject=intermediate_subject,
|
||||||
|
issuer=ca_subject,
|
||||||
|
private_key=ca_key,
|
||||||
|
public_key=intermediate_public_key)
|
||||||
|
|
||||||
|
leaf_subject = {"C": "ZZ", "O": "Leaf Certificate", "CN": "Leaf Certificate"}
|
||||||
|
cert = create_certificate(
|
||||||
|
cert_type='leaf',
|
||||||
|
subject=leaf_subject,
|
||||||
|
issuer=intermediate_subject,
|
||||||
|
private_key=intermediate_key,
|
||||||
|
public_key=public_key,
|
||||||
|
dns_san=args.dnssan)
|
||||||
|
|
||||||
|
with open(args.ca, 'wb') as f:
|
||||||
|
f.write(ca_cert.public_bytes(Encoding.PEM))
|
||||||
|
with open(args.cert, 'wb') as f:
|
||||||
|
f.write(cert.public_bytes(Encoding.PEM))
|
||||||
|
f.write(intermediate_cert.public_bytes(Encoding.PEM))
|
||||||
|
with open(args.key, 'wb') as f:
|
||||||
|
f.write(
|
||||||
|
key.private_bytes(Encoding.PEM, PrivateFormat.TraditionalOpenSSL, NoEncryption()))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
60
app/internal/utils/certloader_test_tlsclient.py
Normal file
60
app/internal/utils/certloader_test_tlsclient.py
Normal file
@ -0,0 +1,60 @@
|
|||||||
|
import argparse
|
||||||
|
import ssl
|
||||||
|
import socket
|
||||||
|
import sys
|
||||||
|
|
||||||
|
|
||||||
|
def check_tls(server, ca_cert, sni, alpn):
|
||||||
|
try:
|
||||||
|
host, port = server.split(":")
|
||||||
|
port = int(port)
|
||||||
|
|
||||||
|
if ca_cert:
|
||||||
|
context = ssl.create_default_context(ssl.Purpose.SERVER_AUTH, cafile=ca_cert)
|
||||||
|
context.check_hostname = sni is not None
|
||||||
|
context.verify_mode = ssl.CERT_REQUIRED
|
||||||
|
else:
|
||||||
|
context = ssl.create_default_context()
|
||||||
|
context.check_hostname = False
|
||||||
|
context.verify_mode = ssl.CERT_NONE
|
||||||
|
|
||||||
|
if alpn:
|
||||||
|
context.set_alpn_protocols([p for p in alpn.split(",")])
|
||||||
|
|
||||||
|
with socket.create_connection((host, port)) as sock:
|
||||||
|
with context.wrap_socket(sock, server_hostname=sni) as ssock:
|
||||||
|
# Verify handshake and certificate
|
||||||
|
print(f'Connected to {ssock.version()} using {ssock.cipher()}')
|
||||||
|
print(f'Server certificate validated and details: {ssock.getpeercert()}')
|
||||||
|
print("OK")
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser(description="Test TLS Server")
|
||||||
|
parser.add_argument("--server", required=True,
|
||||||
|
help="Server address to test (e.g., 127.1.2.3:8443)")
|
||||||
|
parser.add_argument("--ca", required=False, default=None,
|
||||||
|
help="CA certificate file used to validate the server certificate"
|
||||||
|
"Omit to use insecure connection")
|
||||||
|
parser.add_argument("--sni", required=False, default=None,
|
||||||
|
help="SNI to send in ClientHello")
|
||||||
|
parser.add_argument("--alpn", required=False, default='h2',
|
||||||
|
help="ALPN to send in ClientHello")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
exit_status = check_tls(
|
||||||
|
server=args.server,
|
||||||
|
ca_cert=args.ca,
|
||||||
|
sni=args.sni,
|
||||||
|
alpn=args.alpn)
|
||||||
|
|
||||||
|
sys.exit(exit_status)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
3
app/internal/utils/testcerts/.gitignore
vendored
Normal file
3
app/internal/utils/testcerts/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
# This directory is used for certificate generation in certloader_test.go
|
||||||
|
/*
|
||||||
|
!/.gitignore
|
11
requirements.txt
Normal file
11
requirements.txt
Normal file
@ -0,0 +1,11 @@
|
|||||||
|
blinker==1.8.2
|
||||||
|
cffi==1.17.0
|
||||||
|
click==8.1.7
|
||||||
|
cryptography==43.0.0
|
||||||
|
Flask==3.0.3
|
||||||
|
itsdangerous==2.2.0
|
||||||
|
Jinja2==3.1.4
|
||||||
|
MarkupSafe==2.1.5
|
||||||
|
pycparser==2.22
|
||||||
|
PySocks==1.7.1
|
||||||
|
Werkzeug==3.0.4
|
Loading…
x
Reference in New Issue
Block a user