diff --git a/internal/applicant/acme-user.go b/internal/applicant/acme-user.go index 3b74d5ca..daa7a4cf 100644 --- a/internal/applicant/acme-user.go +++ b/internal/applicant/acme-user.go @@ -9,6 +9,7 @@ import ( "github.com/go-acme/lego/v4/lego" "github.com/go-acme/lego/v4/registration" + "golang.org/x/sync/singleflight" "github.com/usual2970/certimate/internal/domain" "github.com/usual2970/certimate/internal/pkg/utils/certs" @@ -79,9 +80,21 @@ type acmeAccountRepository interface { Save(ca, email, key string, resource *registration.Resource) error } -func registerAcmeUser(client *lego.Client, sslProviderConfig *acmeSSLProviderConfig, user *acmeUser) (*registration.Resource, error) { - // TODO: fix 潜在的并发问题 +var registerGroup singleflight.Group +func registerAcmeUser(client *lego.Client, sslProviderConfig *acmeSSLProviderConfig, user *acmeUser) (*registration.Resource, error) { + resp, err, _ := registerGroup.Do(fmt.Sprintf("register_acme_user_%s_%s", sslProviderConfig.Provider, user.GetEmail()), func() (interface{}, error) { + return register(client, sslProviderConfig, user) + }) + + if err != nil { + return nil, err + } + + return resp.(*registration.Resource), nil +} + +func register(client *lego.Client, sslProviderConfig *acmeSSLProviderConfig, user *acmeUser) (*registration.Resource, error) { var reg *registration.Resource var err error switch sslProviderConfig.Provider { diff --git a/internal/applicant/applicant.go b/internal/applicant/applicant.go index a0e67fee..6be0eb12 100644 --- a/internal/applicant/applicant.go +++ b/internal/applicant/applicant.go @@ -7,32 +7,20 @@ import ( "os" "strconv" "strings" + "sync" "github.com/go-acme/lego/v4/certcrypto" "github.com/go-acme/lego/v4/certificate" "github.com/go-acme/lego/v4/challenge" "github.com/go-acme/lego/v4/challenge/dns01" "github.com/go-acme/lego/v4/lego" + "golang.org/x/time/rate" "github.com/usual2970/certimate/internal/domain" - "github.com/usual2970/certimate/internal/pkg/utils/pool" "github.com/usual2970/certimate/internal/pkg/utils/slices" "github.com/usual2970/certimate/internal/repository" ) -const defaultPoolSize = 8 - -var poolInstance *pool.Pool[proxyApplicant, applicantResult] - -type applicantResult struct { - result *ApplyCertResult - err error -} - -func init() { - poolInstance = pool.NewPool[proxyApplicant, applicantResult](defaultPoolSize) -} - type ApplyCertResult struct { CertificateFullChain string IssuerCertificate string @@ -98,17 +86,6 @@ func NewWithApplyNode(node *domain.WorkflowNode) (Applicant, error) { }, nil } -func applyAsync(applicant challenge.Provider, options *applicantOptions) <-chan applicantResult { - return poolInstance.Submit(context.Background(), func(p proxyApplicant) applicantResult { - rs := applicantResult{} - rs.result, rs.err = apply(p.applicant, p.options) - return rs - }, proxyApplicant{ - applicant: applicant, - options: options, - }) -} - func apply(challengeProvider challenge.Provider, options *applicantOptions) (*ApplyCertResult, error) { settingsRepo := repository.NewSettingsRepository() settings, _ := settingsRepo.GetByName(context.Background(), "sslProvider") @@ -209,7 +186,20 @@ type proxyApplicant struct { options *applicantOptions } -func (d *proxyApplicant) Apply() (*ApplyCertResult, error) { - rs := <-applyAsync(d.applicant, d.options) - return rs.result, rs.err +var limiters sync.Map + +const ( + limitBurst = 300 + limitRate float64 = float64(1) / float64(36) +) + +func getLimiter(key string) *rate.Limiter { + limiter, _ := limiters.LoadOrStore(key, rate.NewLimiter(rate.Limit(limitRate), 300)) + return limiter.(*rate.Limiter) +} + +func (d *proxyApplicant) Apply() (*ApplyCertResult, error) { + limiter := getLimiter(fmt.Sprintf("apply_%s", d.options.ContactEmail)) + limiter.Wait(context.Background()) + return apply(d.applicant, d.options) } diff --git a/internal/applicant/applicant_test.go b/internal/applicant/applicant_test.go new file mode 100644 index 00000000..352cea24 --- /dev/null +++ b/internal/applicant/applicant_test.go @@ -0,0 +1,44 @@ +package applicant + +import ( + "testing" + "time" + + "golang.org/x/time/rate" +) + +func TestRateLimit(t *testing.T) { + tests := []struct { + name string + burst int + rate rate.Limit + }{ + { + name: "test1", + burst: 300, + rate: rate.Limit(float64(1) / float64(20)), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + rl := rate.NewLimiter(tt.rate, tt.burst) + if rl.Burst() != tt.burst { + t.Errorf("Burst() = %v, want %v", rl.Burst(), tt.burst) + } + if rl.Limit() != tt.rate { + t.Errorf("Limit() = %v, want %v", rl.Limit(), tt.rate) + } + + t.Log("consume all tokens at once", rl.AllowN(time.Now(), tt.burst)) + + t.Log("consume more", rl.Allow()) + + time.Sleep(time.Second * 5) + t.Log("consume after 5 seconds", rl.Allow()) + + time.Sleep(time.Second * 20) + t.Log("consume after 20 seconds", rl.Allow()) + }) + } +} diff --git a/internal/pkg/utils/pool/pool.go b/internal/pkg/utils/pool/pool.go deleted file mode 100644 index 6f9ede08..00000000 --- a/internal/pkg/utils/pool/pool.go +++ /dev/null @@ -1,46 +0,0 @@ -package pool - -import ( - "context" -) - -type Task[I, O any] func(I) O - -type Pool[I, O any] struct { - ch chan struct{} - size int -} - -func NewPool[I, O any](size int) *Pool[I, O] { - return &Pool[I, O]{ - ch: make(chan struct{}, size), - size: size, - } -} - -func (p *Pool[I, O]) Submit(ctx context.Context, task Task[I, O], input I) <-chan O { - resultChan := make(chan O, 1) - - go func() { - select { - case p.ch <- struct{}{}: - defer func() { - <-p.ch - close(resultChan) - }() - - result := task(input) - select { - case <-ctx.Done(): - return - case resultChan <- result: - } - - case <-ctx.Done(): - close(resultChan) - return - } - }() - - return resultChan -}