From d4f5a048651adc426534976bd0c92a5062ad9609 Mon Sep 17 00:00:00 2001 From: Toby Date: Thu, 25 Nov 2021 14:53:54 -0800 Subject: [PATCH] feat: reload server keypair every 10 minutes --- cmd/kploader.go | 57 +++++++++++++++++++++++++++++++++++++++++++++++++ cmd/server.go | 6 +++--- 2 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 cmd/kploader.go diff --git a/cmd/kploader.go b/cmd/kploader.go new file mode 100644 index 0000000..97b006b --- /dev/null +++ b/cmd/kploader.go @@ -0,0 +1,57 @@ +package main + +import ( + "crypto/tls" + "github.com/sirupsen/logrus" + "sync" + "time" +) + +const ( + keypairReloadInterval = 10 * time.Minute +) + +type keypairLoader struct { + certMu sync.RWMutex + cert *tls.Certificate + certPath string + keyPath string +} + +func newKeypairLoader(certPath, keyPath string) (*keypairLoader, error) { + result := &keypairLoader{ + certPath: certPath, + keyPath: keyPath, + } + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, err + } + result.cert = &cert + go func() { + for { + time.Sleep(keypairReloadInterval) + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + logrus.WithFields(logrus.Fields{ + "error": err, + "cert": certPath, + "key": keyPath, + }).Warning("Failed to reload keypair") + continue + } + result.certMu.Lock() + result.cert = &cert + result.certMu.Unlock() + } + }() + return result, nil +} + +func (kpr *keypairLoader) GetCertificateFunc() func(*tls.ClientHelloInfo) (*tls.Certificate, error) { + return func(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) { + kpr.certMu.RLock() + defer kpr.certMu.RUnlock() + return kpr.cert, nil + } +} diff --git a/cmd/server.go b/cmd/server.go index 5bee50c..828fd83 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -38,7 +38,7 @@ func server(config *serverConfig) { tlsConfig = tc } else { // Local cert mode - cert, err := tls.LoadX509KeyPair(config.CertFile, config.KeyFile) + kpl, err := newKeypairLoader(config.CertFile, config.KeyFile) if err != nil { logrus.WithFields(logrus.Fields{ "error": err, @@ -47,8 +47,8 @@ func server(config *serverConfig) { }).Fatal("Failed to load the certificate") } tlsConfig = &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS13, + GetCertificate: kpl.GetCertificateFunc(), + MinVersion: tls.VersionTLS13, } } if config.ALPN != "" {