From 9a75d2ac8f337421e51855147bd39e24efa072a1 Mon Sep 17 00:00:00 2001 From: Leo Chen Date: Fri, 15 Nov 2024 00:33:09 +0800 Subject: [PATCH] add key algorithm check --- internal/domains/deploy.go | 49 ++++++++++++++++++++++++++++++++++---- 1 file changed, 44 insertions(+), 5 deletions(-) diff --git a/internal/domains/deploy.go b/internal/domains/deploy.go index f9a6f406..85bc14c6 100644 --- a/internal/domains/deploy.go +++ b/internal/domains/deploy.go @@ -6,12 +6,16 @@ import ( "strings" "time" + "crypto/rsa" + "crypto/ecdsa" + "github.com/pocketbase/pocketbase/models" "golang.org/x/exp/slices" "github.com/usual2970/certimate/internal/applicant" "github.com/usual2970/certimate/internal/deployer" + "github.com/usual2970/certimate/internal/domain" "github.com/usual2970/certimate/internal/utils/app" "github.com/usual2970/certimate/internal/pkg/utils/x509" @@ -51,9 +55,9 @@ func deploy(ctx context.Context, record *models.Record) error { expiredAt := currRecord.GetDateTime("expiredAt").Time() // 检查证书是否包含设置的所有域名 - included := isAllDomainsIncludedInCert(cert, currRecord.GetString("domain")) + changed := isCertChanged(cert, currRecord) - if cert != "" && time.Until(expiredAt) > time.Hour*24*10 && currRecord.GetBool("deployed") && included { + if cert != "" && time.Until(expiredAt) > time.Hour*24*10 && currRecord.GetBool("deployed") && changed { app.GetApp().Logger().Info("证书在有效期内") history.record(checkPhase, "证书在有效期内且已部署,跳过", &RecordInfo{ Info: []string{fmt.Sprintf("证书有效期至 %s", expiredAt.Format("2006-01-02"))}, @@ -68,7 +72,7 @@ func deploy(ctx context.Context, record *models.Record) error { // ############2.申请证书 history.record(applyPhase, "开始申请", nil) - if cert != "" && time.Until(expiredAt) > time.Hour*24 && included { + if cert != "" && time.Until(expiredAt) > time.Hour*24 && changed { history.record(applyPhase, "证书在有效期内,跳过", &RecordInfo{ Info: []string{fmt.Sprintf("证书有效期至 %s", expiredAt.Format("2006-01-02"))}, }) @@ -130,7 +134,7 @@ func deploy(ctx context.Context, record *models.Record) error { return nil } -func isAllDomainsIncludedInCert(certificate, domains string) bool { +func isCertChanged(certificate string, record *models.Record) bool { // 如果证书为空,直接返回false if certificate == "" { return false @@ -144,12 +148,47 @@ func isAllDomainsIncludedInCert(certificate, domains string) bool { } // 遍历域名列表,检查是否都在证书中,找到第一个不存在证书中域名时提前返回false - for _, domain := range strings.Split(domains, ";") { + for _, domain := range strings.Split(record.GetString("domain"), ";") { if !slices.Contains(cert.DNSNames, domain) && !slices.Contains(cert.DNSNames, "*."+removeLastSubdomain(domain)) { return false } } + // 解析applyConfig + applyConfig := &domain.ApplyConfig{} + record.UnmarshalJSONField("applyConfig", applyConfig) + + + // 检查证书加密算法是否一致 + switch pubkey := cert.PublicKey.(type) { + case *rsa.PublicKey: + bitSize := pubkey.N.BitLen() + switch bitSize { + case 2048: + // RSA2048 + if applyConfig.KeyAlgorithm != "" && applyConfig.KeyAlgorithm != "RSA2048" { return false } + case 3072: + // RSA3072 + if applyConfig.KeyAlgorithm != "RSA3072" { return false } + case 4096: + // RSA4096 + if applyConfig.KeyAlgorithm != "RSA4096" { return false } + case 8192: + // RSA8192 + if applyConfig.KeyAlgorithm != "RSA8192" { return false } + } + case *ecdsa.PublicKey: + bitSize := pubkey.Curve.Params().BitSize + switch bitSize { + case 256: + // EC256 + if applyConfig.KeyAlgorithm != "EC256" { return false } + case 384: + // EC384 + if applyConfig.KeyAlgorithm != "EC384" { return false } + } + } + return true }