diff --git a/internal/autocert/manager.go b/internal/autocert/manager.go index 71a3f7cec..f502f9363 100644 --- a/internal/autocert/manager.go +++ b/internal/autocert/manager.go @@ -3,6 +3,7 @@ package autocert import ( "context" + "errors" "fmt" "net/http" "sort" @@ -16,6 +17,11 @@ import ( "github.com/pomerium/pomerium/internal/log" ) +var ( + errObtainCertFailed = errors.New("obtain cert failed") + errRenewCertFailed = errors.New("renew cert failed") +) + // Manager manages TLS certificates. type Manager struct { src config.Source @@ -85,6 +91,35 @@ func (mgr *Manager) update(cfg *config.Config) error { return mgr.updateAutocert(cfg) } +// obtainCert obtains a certificate for given domain, use cached manager if cert exists there. +func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.Certificate, error) { + cert, err := cm.CacheManagedCertificate(domain) + if err != nil { + log.Info().Str("domain", domain).Msg("obtaining certificate") + err = cm.ObtainCert(context.Background(), domain, false) + if err != nil { + log.Error().Err(err).Msg("autocert failed to obtain client certificate") + return certmagic.Certificate{}, errObtainCertFailed + } + cert, err = cm.CacheManagedCertificate(domain) + } + return cert, err +} + +// renewCert attempts to renew given certificate. +func (mgr *Manager) renewCert(domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) { + expired := time.Now().After(cert.Leaf.NotAfter) + log.Info().Str("domain", domain).Msg("renewing certificate") + err := cm.RenewCert(context.Background(), domain, false) + if err != nil { + if expired { + return certmagic.Certificate{}, errRenewCertFailed + } + log.Warn().Err(err).Msg("renew client certificated failed, use existing cert") + } + return cm.CacheManagedCertificate(domain) +} + func (mgr *Manager) updateAutocert(cfg *config.Config) error { if !cfg.Options.AutocertOptions.Enable { return nil @@ -96,33 +131,22 @@ func (mgr *Manager) updateAutocert(cfg *config.Config) error { } for _, domain := range sourceHostnames(cfg) { - cert, err := cm.CacheManagedCertificate(domain) - if err != nil { - log.Info().Str("domain", domain).Msg("obtaining certificate") - err = cm.ObtainCert(context.Background(), domain, false) - if err != nil { - return fmt.Errorf("autocert: failed to obtain client certificate: %w", err) - } - cert, err = cm.CacheManagedCertificate(domain) + cert, err := mgr.obtainCert(domain, cm) + if err != nil && errors.Is(err, errObtainCertFailed) { + return fmt.Errorf("autocert: failed to obtain client certificate: %w", err) } if err == nil && cert.NeedsRenewal(cm) { - expired := time.Now().After(cert.Leaf.NotAfter) - log.Info().Str("domain", domain).Msg("renewing certificate") - err = cm.RenewCert(context.Background(), domain, false) - if err != nil && expired { - return fmt.Errorf("autocert: failed to renew client certificate: %w", err) - } - if !expired { - log.Warn().Err(err).Msg("renew client certificated failed, use existing cert") - } - cert, err = cm.CacheManagedCertificate(domain) + cert, err = mgr.renewCert(domain, cert, cm) } - if err == nil { - log.Info().Strs("names", cert.Names).Msg("autocert: added certificate") - cfg.Options.Certificates = append(cfg.Options.Certificates, cert.Certificate) - } else { + if err != nil && errors.Is(err, errRenewCertFailed) { + return fmt.Errorf("autocert: failed to renew client certificate: %w", err) + } + if err != nil { log.Error().Err(err).Msg("autocert: failed to obtain client certificate") + continue } + log.Info().Strs("names", cert.Names).Msg("autocert: added certificate") + cfg.Options.Certificates = append(cfg.Options.Certificates, cert.Certificate) } return nil