From 57a48a674be1e731dcafb6580da1bb406106c270 Mon Sep 17 00:00:00 2001 From: Haruue Date: Sat, 24 Aug 2024 10:37:08 +0800 Subject: [PATCH] chore: replace rwlock with atomic pointer --- app/internal/utils/certloader.go | 40 ++++++++++++++++++-------------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/app/internal/utils/certloader.go b/app/internal/utils/certloader.go index 390548e..6e4a7be 100644 --- a/app/internal/utils/certloader.go +++ b/app/internal/utils/certloader.go @@ -8,6 +8,7 @@ import ( "os" "strings" "sync" + "sync/atomic" "time" ) @@ -16,8 +17,8 @@ type LocalCertificateLoader struct { KeyFile string SNIGuard SNIGuardFunc - lock sync.RWMutex - cache *localCertificateCache + lock sync.Mutex + cache atomic.Pointer[localCertificateCache] } type SNIGuardFunc func(info *tls.ClientHelloInfo, cert *tls.Certificate) error @@ -34,14 +35,15 @@ type localCertificateCache struct { } func (l *LocalCertificateLoader) InitializeCache() error { + l.lock.Lock() + defer l.lock.Unlock() + cache, err := l.makeCache() if err != nil { return err } - l.lock.Lock() - defer l.lock.Unlock() - l.cache = cache + l.cache.Store(cache) return nil } @@ -63,18 +65,19 @@ func (l *LocalCertificateLoader) GetCertificate(info *tls.ClientHelloInfo) (*tls } func (l *LocalCertificateLoader) checkModTime() (certModTime, keyModTime time.Time, err error) { - if fi, ferr := os.Stat(l.CertFile); ferr != nil { - err = fmt.Errorf("failed to stat certificate file: %w", ferr) + fi, err := os.Stat(l.CertFile) + if err != nil { + err = fmt.Errorf("failed to stat certificate file: %w", err) return - } else { - certModTime = fi.ModTime() } - if fi, ferr := os.Stat(l.KeyFile); ferr != nil { - err = fmt.Errorf("failed to stat key file: %w", ferr) + certModTime = fi.ModTime() + + fi, err = os.Stat(l.KeyFile) + if err != nil { + err = fmt.Errorf("failed to stat key file: %w", err) return - } else { - keyModTime = fi.ModTime() } + keyModTime = fi.ModTime() return } @@ -101,9 +104,7 @@ func (l *LocalCertificateLoader) makeCache() (cache *localCertificateCache, err } func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, error) { - l.lock.RLock() - cache := l.cache - l.lock.RUnlock() + cache := l.cache.Load() certModTime, keyModTime, terr := l.checkModTime() if terr != nil { @@ -129,6 +130,11 @@ func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, er } defer l.lock.Unlock() + if l.cache.Load() != cache { + // another goroutine updated the cache + return l.cache.Load().certificate, nil + } + newCache, err := l.makeCache() if err != nil { if cache != nil { @@ -138,7 +144,7 @@ func (l *LocalCertificateLoader) getCertificateWithCache() (*tls.Certificate, er return nil, err } - l.cache = newCache + l.cache.Store(newCache) return newCache.certificate, nil }