diff --git a/internal/autocert/manager.go b/internal/autocert/manager.go index 33bc2a82b..a8992e385 100644 --- a/internal/autocert/manager.go +++ b/internal/autocert/manager.go @@ -12,6 +12,7 @@ import ( "time" "github.com/caddyserver/certmagic" + "github.com/rs/zerolog" "go.uber.org/zap" "github.com/pomerium/pomerium/config" @@ -28,6 +29,12 @@ var ( renewCertLock sync.Mutex ) +const ( + ocspRespCacheSize = 50000 + renewalInterval = time.Minute * 10 + renewalTimeout = time.Hour +) + // Manager manages TLS certificates. type Manager struct { src config.Source @@ -39,12 +46,14 @@ type Manager struct { acmeMgr atomic.Value srv *http.Server + *ocspCache + config.ChangeDispatcher } // New creates a new autocert manager. func New(src config.Source) (*Manager, error) { - return newManager(context.Background(), src, certmagic.DefaultACME, time.Minute*10) + return newManager(context.Background(), src, certmagic.DefaultACME, renewalInterval) } func newManager(ctx context.Context, @@ -52,6 +61,15 @@ func newManager(ctx context.Context, acmeTemplate certmagic.ACMEManager, checkInterval time.Duration, ) (*Manager, error) { + ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { + return c.Str("service", "autocert-manager") + }) + + ocspRespCache, err := newOCSPCache(ocspRespCacheSize) + if err != nil { + return nil, err + } + certmagicConfig := certmagic.NewDefault() // set certmagic default storage cache, otherwise cert renewal loop will be based off // certmagic's own default location @@ -67,13 +85,14 @@ func newManager(ctx context.Context, src: src, acmeTemplate: acmeTemplate, certmagic: certmagicConfig, + ocspCache: ocspRespCache, } - err := mgr.update(src.GetConfig()) + err = mgr.update(ctx, src.GetConfig()) if err != nil { return nil, err } mgr.src.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) { - err := mgr.update(cfg) + err := mgr.update(ctx, cfg) if err != nil { log.Error(ctx).Err(err).Msg("autocert: error updating config") return @@ -91,9 +110,9 @@ func newManager(ctx context.Context, case <-ctx.Done(): return case <-ticker.C: - err := mgr.renewConfigCerts() + err := mgr.renewConfigCerts(ctx) if err != nil { - log.Error(context.TODO()).Err(err).Msg("autocert: error updating config") + log.Error(ctx).Err(err).Msg("autocert: error updating config") return } } @@ -128,7 +147,10 @@ func (mgr *Manager) getCertMagicConfig(cfg *config.Config) (*certmagic.Config, e return mgr.certmagic, nil } -func (mgr *Manager) renewConfigCerts() error { +func (mgr *Manager) renewConfigCerts(ctx context.Context) error { + ctx, cancel := context.WithTimeout(ctx, renewalTimeout) + defer cancel() + mgr.mu.Lock() defer mgr.mu.Unlock() @@ -138,47 +160,69 @@ func (mgr *Manager) renewConfigCerts() error { return err } - needsRenewal := false + needsReload := false + var renew, ocsp []string + log.Debug(ctx).Strs("domains", sourceHostnames(cfg)).Msg("checking domains") for _, domain := range sourceHostnames(cfg) { cert, err := cm.CacheManagedCertificate(domain) - if err == nil && cert.NeedsRenewal(cm) { - needsRenewal = true + if err != nil { + log.Error(ctx).Err(err).Str("domain", domain).Msg("get cert") + continue + } + if cert.NeedsRenewal(cm) { + renew = append(renew, domain) + needsReload = true + } + if mgr.ocspCache.updated(domain, cert.OCSPStaple) { + ocsp = append(ocsp, domain) + needsReload = true } } - if !needsRenewal { + if !needsReload { return nil } + ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context { + if len(renew) > 0 { + c = c.Strs("renew_domains", renew) + } + if len(ocsp) > 0 { + c = c.Strs("ocsp_refresh", ocsp) + } + return c + }) + log.Info(ctx).Msg("updating certificates") + cfg = mgr.src.GetConfig().Clone() - mgr.updateServer(cfg) - if err := mgr.updateAutocert(cfg); err != nil { + mgr.updateServer(ctx, cfg) + if err := mgr.updateAutocert(ctx, cfg); err != nil { return err } mgr.config = cfg - mgr.Trigger(context.TODO(), cfg) + mgr.Trigger(ctx, cfg) return nil } -func (mgr *Manager) update(cfg *config.Config) error { +func (mgr *Manager) update(ctx context.Context, cfg *config.Config) error { cfg = cfg.Clone() mgr.mu.Lock() defer mgr.mu.Unlock() defer func() { mgr.config = cfg }() - mgr.updateServer(cfg) - return mgr.updateAutocert(cfg) + mgr.updateServer(ctx, cfg) + return mgr.updateAutocert(ctx, cfg) } // obtainCert obtains a certificate for given domain, use cached manager if cert exists there. -func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.Certificate, error) { +func (mgr *Manager) obtainCert(ctx context.Context, domain string, cm *certmagic.Config) (certmagic.Certificate, error) { cert, err := cm.CacheManagedCertificate(domain) if err != nil { - log.Info(context.TODO()).Str("domain", domain).Msg("obtaining certificate") - err = cm.ObtainCert(context.Background(), domain, false) + log.Info(ctx).Str("domain", domain).Msg("obtaining certificate") + err = cm.ObtainCert(ctx, domain, false) if err != nil { - log.Error(context.TODO()).Err(err).Msg("autocert failed to obtain client certificate") + log.Error(ctx).Err(err).Msg("autocert failed to obtain client certificate") return certmagic.Certificate{}, errObtainCertFailed } metrics.RecordAutocertRenewal() @@ -188,22 +232,22 @@ func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.C } // renewCert attempts to renew given certificate. -func (mgr *Manager) renewCert(domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) { +func (mgr *Manager) renewCert(ctx context.Context, domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) { expired := time.Now().After(cert.Leaf.NotAfter) - log.Info(context.TODO()).Str("domain", domain).Msg("renewing certificate") + log.Info(ctx).Str("domain", domain).Msg("renewing certificate") renewCertLock.Lock() - err := cm.RenewCert(context.Background(), domain, false) + err := cm.RenewCert(ctx, domain, false) renewCertLock.Unlock() if err != nil { if expired { return certmagic.Certificate{}, errRenewCertFailed } - log.Warn(context.TODO()).Err(err).Msg("renew client certificated failed, use existing cert") + log.Warn(ctx).Err(err).Msg("renew client certificated failed, use existing cert") } return cm.CacheManagedCertificate(domain) } -func (mgr *Manager) updateAutocert(cfg *config.Config) error { +func (mgr *Manager) updateAutocert(ctx context.Context, cfg *config.Config) error { if !cfg.Options.AutocertOptions.Enable { return nil } @@ -214,22 +258,22 @@ func (mgr *Manager) updateAutocert(cfg *config.Config) error { } for _, domain := range sourceHostnames(cfg) { - cert, err := mgr.obtainCert(domain, cm) + cert, err := mgr.obtainCert(ctx, domain, cm) if err != nil && errors.Is(err, errObtainCertFailed) { return fmt.Errorf("autocert: failed to obtain client certificate: %w", err) } if err == nil && cert.NeedsRenewal(cm) { - cert, err = mgr.renewCert(domain, cert, cm) + cert, err = mgr.renewCert(ctx, domain, cert, cm) } if err != nil && errors.Is(err, errRenewCertFailed) { return fmt.Errorf("autocert: failed to renew client certificate: %w", err) } if err != nil { - log.Error(context.TODO()).Err(err).Msg("autocert: failed to obtain client certificate") + log.Error(ctx).Err(err).Msg("autocert: failed to obtain client certificate") continue } - log.Info(context.TODO()).Strs("names", cert.Names).Msg("autocert: added certificate") + log.Info(ctx).Strs("names", cert.Names).Msg("autocert: added certificate") cfg.AutoCertificates = append(cfg.AutoCertificates, cert.Certificate) } @@ -238,7 +282,7 @@ func (mgr *Manager) updateAutocert(cfg *config.Config) error { return nil } -func (mgr *Manager) updateServer(cfg *config.Config) { +func (mgr *Manager) updateServer(ctx context.Context, cfg *config.Config) { if mgr.srv != nil { // nothing to do if the address hasn't changed if mgr.srv.Addr == cfg.Options.HTTPRedirectAddr { @@ -265,10 +309,10 @@ func (mgr *Manager) updateServer(cfg *config.Config) { }), } go func() { - log.Info(context.TODO()).Str("addr", hsrv.Addr).Msg("starting http redirect server") + log.Info(ctx).Str("addr", hsrv.Addr).Msg("starting http redirect server") err := hsrv.ListenAndServe() if err != nil { - log.Error(context.TODO()).Err(err).Msg("failed to run http redirect server") + log.Error(ctx).Err(err).Msg("failed to run http redirect server") } }() mgr.srv = hsrv diff --git a/internal/autocert/manager_test.go b/internal/autocert/manager_test.go index d9f0814e6..5ab201908 100644 --- a/internal/autocert/manager_test.go +++ b/internal/autocert/manager_test.go @@ -3,9 +3,9 @@ package autocert import ( "bytes" "context" + "crypto/ecdsa" + "crypto/elliptic" "crypto/rand" - "crypto/rsa" - "crypto/tls" "crypto/x509" "crypto/x509/pkix" "encoding/base64" @@ -25,19 +25,70 @@ import ( "github.com/caddyserver/certmagic" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" - "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "golang.org/x/crypto/ocsp" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/log" ) type M = map[string]interface{} -func newMockACME(srv *httptest.Server) http.Handler { +type testCA struct { + key *ecdsa.PrivateKey + cert *x509.Certificate + certPEM []byte +} + +func newTestCA() (*testCA, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, err + } + tpl := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().Unix()), + Subject: pkix.Name{ + CommonName: "Test CA", + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Minute * 10), + + KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, + BasicConstraintsValid: true, + IsCA: true, + } + + der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &key.PublicKey, key) + if err != nil { + return nil, err + } + cert, err := x509.ParseCertificate(der) + if err != nil { + return nil, err + } + + return &testCA{ + key, + cert, + pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}), + }, nil +} + +func newMockACME(ca *testCA, srv *httptest.Server) http.Handler { var certBuffer bytes.Buffer + var certs []*x509.Certificate + findCert := func(serial *big.Int) *x509.Certificate { + for _, c := range certs { + if c.SerialNumber.Cmp(serial) == 0 { + return c + } + } + return nil + } + r := chi.NewRouter() r.Use(middleware.Logger) r.Get("/acme/directory", func(w http.ResponseWriter, r *http.Request) { @@ -78,6 +129,22 @@ func newMockACME(srv *httptest.Server) http.Handler { "finalize": srv.URL + "/acme/finalize", }) }) + r.Post("/ocsp/request", func(w http.ResponseWriter, r *http.Request) { + reqData, _ := io.ReadAll(r.Body) + ocspReq, _ := ocsp.ParseRequest(reqData) + ocspResp := ocsp.Response{ + Status: ocsp.Good, + SerialNumber: ocspReq.SerialNumber, + ThisUpdate: time.Now(), + NextUpdate: time.Now().Add(time.Second), + } + + cert := findCert(ocspReq.SerialNumber) + data, _ := ocsp.CreateResponse(ca.cert, cert, ocspResp, ca.key) + + w.WriteHeader(http.StatusOK) + _, _ = w.Write(data) + }) r.Post("/acme/finalize", func(w http.ResponseWriter, r *http.Request) { var payload struct { CSR string `json:"csr"` @@ -85,7 +152,6 @@ func newMockACME(srv *httptest.Server) http.Handler { readJWSPayload(r.Body, &payload) bs, _ := base64.RawURLEncoding.DecodeString(payload.CSR) csr, _ := x509.ParseCertificateRequest(bs) - caKey, _ := rsa.GenerateKey(rand.Reader, 2048) tpl := &x509.Certificate{ SerialNumber: big.NewInt(time.Now().Unix()), DNSNames: csr.DNSNames, @@ -100,10 +166,15 @@ func newMockACME(srv *httptest.Server) http.Handler { ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, BasicConstraintsValid: true, IsCA: false, + + IssuingCertificateURL: []string{srv.URL + "/certs/ca"}, + OCSPServer: []string{srv.URL + "/ocsp/request"}, } - der, _ := x509.CreateCertificate(rand.Reader, tpl, tpl, csr.PublicKey, caKey) + der, _ := x509.CreateCertificate(rand.Reader, tpl, ca.cert, csr.PublicKey, ca.key) certBuffer.Reset() _ = pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: der}) + cert, _ := x509.ParseCertificate(der) + certs = append(certs, cert) w.Header().Set("Replay-Nonce", "NONCE") w.Header().Set("Content-Type", "application/json") @@ -120,6 +191,11 @@ func newMockACME(srv *httptest.Server) http.Handler { w.WriteHeader(http.StatusOK) _, _ = w.Write(certBuffer.Bytes()) }) + r.Get("/certs/ca", func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/pkix-cert") + w.WriteHeader(http.StatusOK) + _, _ = w.Write(ca.cert.Raw) + }) return r } @@ -132,7 +208,11 @@ func TestConfig(t *testing.T) { mockACME.ServeHTTP(w, r) })) defer srv.Close() - mockACME = newMockACME(srv) + + ca, err := newTestCA() + require.NoError(t, err) + + mockACME = newMockACME(ca, srv) tmpdir := filepath.Join(os.TempDir(), uuid.New().String()) _ = os.MkdirAll(tmpdir, 0o755) @@ -158,7 +238,7 @@ func TestConfig(t *testing.T) { AutocertOptions: config.AutocertOptions{ Enable: true, UseStaging: true, - MustStaple: false, + MustStaple: true, Folder: tmpdir, }, HTTPRedirectAddr: addr, @@ -172,21 +252,44 @@ func TestConfig(t *testing.T) { return } - var certs []tls.Certificate - for i := 0; i < 10; i++ { - cfg := mgr.GetConfig() - assert.LessOrEqual(t, len(cfg.AutoCertificates), 1) - if len(cfg.AutoCertificates) == 1 && certs == nil { - certs = cfg.AutoCertificates - } + domainRenewed := make(chan bool) + ocspUpdated := make(chan bool) - if !cmp.Equal(certs, cfg.AutoCertificates) { + var initialOCSPStaple []byte + var certValidTime *time.Time + mgr.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) { + log.Info(ctx).Msg("OnConfigChange") + cert := cfg.AutoCertificates[0] + if initialOCSPStaple == nil { + initialOCSPStaple = cert.OCSPStaple + } else { + if bytes.Compare(initialOCSPStaple, cert.OCSPStaple) != 0 { + log.Info(ctx).Msg("OCSP updated") + ocspUpdated <- true + } + } + if certValidTime == nil { + certValidTime = &cert.Leaf.NotAfter + } else { + if !certValidTime.Equal(cert.Leaf.NotAfter) { + log.Info(ctx).Msg("domain renewed") + domainRenewed <- true + } + } + }) + + domainRenewedOK := false + ocspUpdatedOK := false + + for !domainRenewedOK || !ocspUpdatedOK { + select { + case <-time.After(time.Second * 10): + t.Error("timeout waiting for certs renewal") return + case domainRenewedOK = <-domainRenewed: + case ocspUpdatedOK = <-ocspUpdated: } - - time.Sleep(time.Second) } - t.Fatalf("expected renewed certs, but certs never changed") } func TestRedirect(t *testing.T) { diff --git a/internal/autocert/ocsp.go b/internal/autocert/ocsp.go new file mode 100644 index 000000000..067f40907 --- /dev/null +++ b/internal/autocert/ocsp.go @@ -0,0 +1,33 @@ +package autocert + +import ( + "bytes" + + lru "github.com/hashicorp/golang-lru" +) + +type ocspCache struct { + *lru.Cache +} + +func newOCSPCache(size int) (*ocspCache, error) { + c, err := lru.New(size) + if err != nil { + return nil, err + } + return &ocspCache{c}, nil +} + +// updated checks if OCSP response for this certificate was updated +func (c ocspCache) updated(key string, ocspResp []byte) bool { + current, there := c.Get(key) + if !there { + _ = c.Add(key, ocspResp) + return false // to avoid triggering reload first time we see this response + } + if bytes.Equal(current.([]byte), ocspResp) { + return false + } + _ = c.Add(key, ocspResp) + return true +} diff --git a/internal/autocert/ocsp_test.go b/internal/autocert/ocsp_test.go new file mode 100644 index 000000000..9a935b5d6 --- /dev/null +++ b/internal/autocert/ocsp_test.go @@ -0,0 +1,31 @@ +package autocert + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestOcspCache(t *testing.T) { + c, err := newOCSPCache(10) + require.NoError(t, err) + + cases := []struct { + data []byte + isUpdated bool + }{ + {nil, false}, + {nil, false}, + {[]byte("a"), true}, + {[]byte("a"), false}, + {[]byte("b"), true}, + {[]byte("b"), false}, + {nil, true}, + {nil, false}, + } + + for i, tc := range cases { + assert.Equal(t, tc.isUpdated, c.updated("key", tc.data), "#%d: %v", i, tc) + } +}