mirror of
https://github.com/cmz0228/hysteria-dev.git
synced 2025-06-08 13:29:53 +00:00
feat: local cert loader & sni guard
This commit is contained in:
parent
903666f18e
commit
fd2d20a46a
@ -83,8 +83,9 @@ type serverConfigObfs struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type serverConfigTLS struct {
|
type serverConfigTLS struct {
|
||||||
Cert string `mapstructure:"cert"`
|
Cert string `mapstructure:"cert"`
|
||||||
Key string `mapstructure:"key"`
|
Key string `mapstructure:"key"`
|
||||||
|
SNIGuard string `mapstructure:"sniGuard"` // "disable", "dns-san", "strict"
|
||||||
}
|
}
|
||||||
|
|
||||||
type serverConfigACME struct {
|
type serverConfigACME struct {
|
||||||
@ -290,31 +291,46 @@ func (c *serverConfig) fillTLSConfig(hyConfig *server.Config) error {
|
|||||||
if c.TLS != nil && c.ACME != nil {
|
if c.TLS != nil && c.ACME != nil {
|
||||||
return configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
|
return configError{Field: "tls", Err: errors.New("cannot set both tls and acme")}
|
||||||
}
|
}
|
||||||
|
// SNI guard
|
||||||
|
var sniGuard utils.SNIGuardFunc
|
||||||
|
switch strings.ToLower(c.TLS.SNIGuard) {
|
||||||
|
case "", "dns-san":
|
||||||
|
sniGuard = utils.SNIGuardDNSSAN
|
||||||
|
case "strict":
|
||||||
|
sniGuard = utils.SNIGuardStrict
|
||||||
|
case "disable":
|
||||||
|
sniGuard = nil
|
||||||
|
default:
|
||||||
|
return configError{Field: "tls.sniGuard", Err: errors.New("unsupported SNI guard")}
|
||||||
|
}
|
||||||
if c.TLS != nil {
|
if c.TLS != nil {
|
||||||
// Local TLS cert
|
// Local TLS cert
|
||||||
if c.TLS.Cert == "" || c.TLS.Key == "" {
|
if c.TLS.Cert == "" || c.TLS.Key == "" {
|
||||||
return configError{Field: "tls", Err: errors.New("empty cert or key path")}
|
return configError{Field: "tls", Err: errors.New("empty cert or key path")}
|
||||||
}
|
}
|
||||||
|
certLoader := &utils.LocalCertificateLoader{
|
||||||
|
CertFile: c.TLS.Cert,
|
||||||
|
KeyFile: c.TLS.Key,
|
||||||
|
SNIGuard: sniGuard,
|
||||||
|
}
|
||||||
// Try loading the cert-key pair here to catch errors early
|
// Try loading the cert-key pair here to catch errors early
|
||||||
// (e.g. invalid files or insufficient permissions)
|
// (e.g. invalid files or insufficient permissions)
|
||||||
certPEMBlock, err := os.ReadFile(c.TLS.Cert)
|
err := certLoader.InitializeCache()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return configError{Field: "tls.cert", Err: err}
|
var pathErr *os.PathError
|
||||||
}
|
if errors.As(err, &pathErr) {
|
||||||
keyPEMBlock, err := os.ReadFile(c.TLS.Key)
|
if pathErr.Path == c.TLS.Cert {
|
||||||
if err != nil {
|
return configError{Field: "tls.cert", Err: pathErr}
|
||||||
return configError{Field: "tls.key", Err: err}
|
}
|
||||||
}
|
if pathErr.Path == c.TLS.Key {
|
||||||
_, err = tls.X509KeyPair(certPEMBlock, keyPEMBlock)
|
return configError{Field: "tls.key", Err: pathErr}
|
||||||
if err != nil {
|
}
|
||||||
return configError{Field: "tls", Err: fmt.Errorf("invalid cert-key pair: %w", err)}
|
}
|
||||||
|
return configError{Field: "tls", Err: err}
|
||||||
}
|
}
|
||||||
// Use GetCertificate instead of Certificates so that
|
// Use GetCertificate instead of Certificates so that
|
||||||
// users can update the cert without restarting the server.
|
// users can update the cert without restarting the server.
|
||||||
hyConfig.TLSConfig.GetCertificate = func(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
hyConfig.TLSConfig.GetCertificate = certLoader.GetCertificate
|
||||||
cert, err := tls.LoadX509KeyPair(c.TLS.Cert, c.TLS.Key)
|
|
||||||
return &cert, err
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// ACME
|
// ACME
|
||||||
dataDir := c.ACME.Dir
|
dataDir := c.ACME.Dir
|
||||||
|
189
app/internal/utils/certloader.go
Normal file
189
app/internal/utils/certloader.go
Normal file
@ -0,0 +1,189 @@
|
|||||||
|
package utils
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type LocalCertificateLoader struct {
|
||||||
|
CertFile string
|
||||||
|
KeyFile string
|
||||||
|
SNIGuard SNIGuardFunc
|
||||||
|
|
||||||
|
lock sync.RWMutex
|
||||||
|
cache *localCertificateCache
|
||||||
|
}
|
||||||
|
|
||||||
|
type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error
|
||||||
|
|
||||||
|
// localCertificateCache holds the certificate and its mod times.
|
||||||
|
// this struct is designed to be read-only.
|
||||||
|
//
|
||||||
|
// to update the cache, use LocalCertificateLoader.makeCache and
|
||||||
|
// update the LocalCertificateLoader.cache field.
|
||||||
|
type localCertificateCache struct {
|
||||||
|
certificate *tls.Certificate
|
||||||
|
certModTime time.Time
|
||||||
|
keyModTime time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalCertificateLoader) InitializeCache() error {
|
||||||
|
cache, err := l.makeCache()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
l.lock.Lock()
|
||||||
|
defer l.lock.Unlock()
|
||||||
|
l.cache = cache
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
cert, err := l.getCertificateWithCache()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if l.SNIGuard == nil {
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
err = l.SNIGuard(info, cert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return cert, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) {
|
||||||
|
if fi, ferr := os.Stat(l.CertFile); ferr != nil {
|
||||||
|
err = fmt.Errorf("failed to stat certificate file: %w", ferr)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
certModTime = fi.ModTime()
|
||||||
|
}
|
||||||
|
if fi, ferr := os.Stat(l.KeyFile); ferr != nil {
|
||||||
|
err = fmt.Errorf("failed to stat key file: %w", ferr)
|
||||||
|
return
|
||||||
|
} else {
|
||||||
|
keyModTime = fi.ModTime()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err error) {
|
||||||
|
c := &localCertificateCache{}
|
||||||
|
|
||||||
|
c.certModTime, c.keyModTime, err = l.checkModTime()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := tls.LoadX509KeyPair(l.CertFile, l.KeyFile)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.certificate = &cert
|
||||||
|
c.certificate.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cache = c
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) {
|
||||||
|
l.lock.RLock()
|
||||||
|
cache := l.cache
|
||||||
|
l.lock.RUnlock()
|
||||||
|
|
||||||
|
certModTime, keyModTime, terr := l.checkModTime()
|
||||||
|
if terr != nil {
|
||||||
|
if cache != nil {
|
||||||
|
// use cache when file is temporarily unavailable
|
||||||
|
return cache.certificate, nil
|
||||||
|
}
|
||||||
|
return nil, terr
|
||||||
|
}
|
||||||
|
|
||||||
|
if cache != nil && cache.certModTime.Equal(certModTime) && cache.keyModTime.Equal(keyModTime) {
|
||||||
|
// cache is up-to-date
|
||||||
|
return cache.certificate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cache != nil {
|
||||||
|
if !l.lock.TryLock() {
|
||||||
|
// another goroutine is updating the cache
|
||||||
|
return cache.certificate, nil
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
l.lock.Lock()
|
||||||
|
}
|
||||||
|
defer l.lock.Unlock()
|
||||||
|
|
||||||
|
newCache, err := l.makeCache()
|
||||||
|
if err != nil {
|
||||||
|
if cache != nil {
|
||||||
|
// use cache when loading failed
|
||||||
|
return cache.certificate, nil
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
l.cache = newCache
|
||||||
|
return newCache.certificate, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getNameFromClientHello returns a normalized form of hello.ServerName.
|
||||||
|
// If hello.ServerName is empty (i.e. client did not use SNI), then the
|
||||||
|
// associated connection's local address is used to extract an IP address.
|
||||||
|
//
|
||||||
|
// ref: https://github.com/caddyserver/certmagic/blob/3bad5b6bb595b09c14bd86ff0b365d302faaf5e2/handshake.go#L838
|
||||||
|
func getNameFromClientHello(hello *tls.ClientHelloInfo) string {
|
||||||
|
normalizedName := func(serverName string) string {
|
||||||
|
return strings.ToLower(strings.TrimSpace(serverName))
|
||||||
|
}
|
||||||
|
localIPFromConn := func(c net.Conn) string {
|
||||||
|
if c == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
localAddr := c.LocalAddr().String()
|
||||||
|
ip, _, err := net.SplitHostPort(localAddr)
|
||||||
|
if err != nil {
|
||||||
|
ip = localAddr
|
||||||
|
}
|
||||||
|
if scopeIDStart := strings.Index(ip, "%"); scopeIDStart > -1 {
|
||||||
|
ip = ip[:scopeIDStart]
|
||||||
|
}
|
||||||
|
return ip
|
||||||
|
}
|
||||||
|
|
||||||
|
if name := normalizedName(hello.ServerName); name != "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
return localIPFromConn(hello.Conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SNIGuardDNSSAN(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
|
||||||
|
if len(cert.Leaf.DNSNames) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return SNIGuardStrict(info, cert)
|
||||||
|
}
|
||||||
|
|
||||||
|
func SNIGuardStrict(info *tls.ClientHelloInfo, cert *tls.Certificate) error {
|
||||||
|
hostname := getNameFromClientHello(info)
|
||||||
|
err := cert.Leaf.VerifyHostname(hostname)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("sni guard: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user