ocsp: reload on ocsp response changes (#2286)

This commit is contained in:
wasaga 2021-06-11 15:58:01 -04:00 committed by GitHub
parent f9675f61cc
commit b372ab4bcc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 262 additions and 51 deletions

View file

@ -12,6 +12,7 @@ import (
"time" "time"
"github.com/caddyserver/certmagic" "github.com/caddyserver/certmagic"
"github.com/rs/zerolog"
"go.uber.org/zap" "go.uber.org/zap"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
@ -28,6 +29,12 @@ var (
renewCertLock sync.Mutex renewCertLock sync.Mutex
) )
const (
ocspRespCacheSize = 50000
renewalInterval = time.Minute * 10
renewalTimeout = time.Hour
)
// Manager manages TLS certificates. // Manager manages TLS certificates.
type Manager struct { type Manager struct {
src config.Source src config.Source
@ -39,12 +46,14 @@ type Manager struct {
acmeMgr atomic.Value acmeMgr atomic.Value
srv *http.Server srv *http.Server
*ocspCache
config.ChangeDispatcher config.ChangeDispatcher
} }
// New creates a new autocert manager. // New creates a new autocert manager.
func New(src config.Source) (*Manager, error) { func New(src config.Source) (*Manager, error) {
return newManager(context.Background(), src, certmagic.DefaultACME, time.Minute*10) return newManager(context.Background(), src, certmagic.DefaultACME, renewalInterval)
} }
func newManager(ctx context.Context, func newManager(ctx context.Context,
@ -52,6 +61,15 @@ func newManager(ctx context.Context,
acmeTemplate certmagic.ACMEManager, acmeTemplate certmagic.ACMEManager,
checkInterval time.Duration, checkInterval time.Duration,
) (*Manager, error) { ) (*Manager, error) {
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("service", "autocert-manager")
})
ocspRespCache, err := newOCSPCache(ocspRespCacheSize)
if err != nil {
return nil, err
}
certmagicConfig := certmagic.NewDefault() certmagicConfig := certmagic.NewDefault()
// set certmagic default storage cache, otherwise cert renewal loop will be based off // set certmagic default storage cache, otherwise cert renewal loop will be based off
// certmagic's own default location // certmagic's own default location
@ -67,13 +85,14 @@ func newManager(ctx context.Context,
src: src, src: src,
acmeTemplate: acmeTemplate, acmeTemplate: acmeTemplate,
certmagic: certmagicConfig, certmagic: certmagicConfig,
ocspCache: ocspRespCache,
} }
err := mgr.update(src.GetConfig()) err = mgr.update(ctx, src.GetConfig())
if err != nil { if err != nil {
return nil, err return nil, err
} }
mgr.src.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) { mgr.src.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) {
err := mgr.update(cfg) err := mgr.update(ctx, cfg)
if err != nil { if err != nil {
log.Error(ctx).Err(err).Msg("autocert: error updating config") log.Error(ctx).Err(err).Msg("autocert: error updating config")
return return
@ -91,9 +110,9 @@ func newManager(ctx context.Context,
case <-ctx.Done(): case <-ctx.Done():
return return
case <-ticker.C: case <-ticker.C:
err := mgr.renewConfigCerts() err := mgr.renewConfigCerts(ctx)
if err != nil { if err != nil {
log.Error(context.TODO()).Err(err).Msg("autocert: error updating config") log.Error(ctx).Err(err).Msg("autocert: error updating config")
return return
} }
} }
@ -128,7 +147,10 @@ func (mgr *Manager) getCertMagicConfig(cfg *config.Config) (*certmagic.Config, e
return mgr.certmagic, nil return mgr.certmagic, nil
} }
func (mgr *Manager) renewConfigCerts() error { func (mgr *Manager) renewConfigCerts(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, renewalTimeout)
defer cancel()
mgr.mu.Lock() mgr.mu.Lock()
defer mgr.mu.Unlock() defer mgr.mu.Unlock()
@ -138,47 +160,69 @@ func (mgr *Manager) renewConfigCerts() error {
return err return err
} }
needsRenewal := false needsReload := false
var renew, ocsp []string
log.Debug(ctx).Strs("domains", sourceHostnames(cfg)).Msg("checking domains")
for _, domain := range sourceHostnames(cfg) { for _, domain := range sourceHostnames(cfg) {
cert, err := cm.CacheManagedCertificate(domain) cert, err := cm.CacheManagedCertificate(domain)
if err == nil && cert.NeedsRenewal(cm) { if err != nil {
needsRenewal = true log.Error(ctx).Err(err).Str("domain", domain).Msg("get cert")
continue
}
if cert.NeedsRenewal(cm) {
renew = append(renew, domain)
needsReload = true
}
if mgr.ocspCache.updated(domain, cert.OCSPStaple) {
ocsp = append(ocsp, domain)
needsReload = true
} }
} }
if !needsRenewal { if !needsReload {
return nil return nil
} }
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
if len(renew) > 0 {
c = c.Strs("renew_domains", renew)
}
if len(ocsp) > 0 {
c = c.Strs("ocsp_refresh", ocsp)
}
return c
})
log.Info(ctx).Msg("updating certificates")
cfg = mgr.src.GetConfig().Clone() cfg = mgr.src.GetConfig().Clone()
mgr.updateServer(cfg) mgr.updateServer(ctx, cfg)
if err := mgr.updateAutocert(cfg); err != nil { if err := mgr.updateAutocert(ctx, cfg); err != nil {
return err return err
} }
mgr.config = cfg mgr.config = cfg
mgr.Trigger(context.TODO(), cfg) mgr.Trigger(ctx, cfg)
return nil return nil
} }
func (mgr *Manager) update(cfg *config.Config) error { func (mgr *Manager) update(ctx context.Context, cfg *config.Config) error {
cfg = cfg.Clone() cfg = cfg.Clone()
mgr.mu.Lock() mgr.mu.Lock()
defer mgr.mu.Unlock() defer mgr.mu.Unlock()
defer func() { mgr.config = cfg }() defer func() { mgr.config = cfg }()
mgr.updateServer(cfg) mgr.updateServer(ctx, cfg)
return mgr.updateAutocert(cfg) return mgr.updateAutocert(ctx, cfg)
} }
// obtainCert obtains a certificate for given domain, use cached manager if cert exists there. // obtainCert obtains a certificate for given domain, use cached manager if cert exists there.
func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.Certificate, error) { func (mgr *Manager) obtainCert(ctx context.Context, domain string, cm *certmagic.Config) (certmagic.Certificate, error) {
cert, err := cm.CacheManagedCertificate(domain) cert, err := cm.CacheManagedCertificate(domain)
if err != nil { if err != nil {
log.Info(context.TODO()).Str("domain", domain).Msg("obtaining certificate") log.Info(ctx).Str("domain", domain).Msg("obtaining certificate")
err = cm.ObtainCert(context.Background(), domain, false) err = cm.ObtainCert(ctx, domain, false)
if err != nil { if err != nil {
log.Error(context.TODO()).Err(err).Msg("autocert failed to obtain client certificate") log.Error(ctx).Err(err).Msg("autocert failed to obtain client certificate")
return certmagic.Certificate{}, errObtainCertFailed return certmagic.Certificate{}, errObtainCertFailed
} }
metrics.RecordAutocertRenewal() metrics.RecordAutocertRenewal()
@ -188,22 +232,22 @@ func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.C
} }
// renewCert attempts to renew given certificate. // renewCert attempts to renew given certificate.
func (mgr *Manager) renewCert(domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) { func (mgr *Manager) renewCert(ctx context.Context, domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) {
expired := time.Now().After(cert.Leaf.NotAfter) expired := time.Now().After(cert.Leaf.NotAfter)
log.Info(context.TODO()).Str("domain", domain).Msg("renewing certificate") log.Info(ctx).Str("domain", domain).Msg("renewing certificate")
renewCertLock.Lock() renewCertLock.Lock()
err := cm.RenewCert(context.Background(), domain, false) err := cm.RenewCert(ctx, domain, false)
renewCertLock.Unlock() renewCertLock.Unlock()
if err != nil { if err != nil {
if expired { if expired {
return certmagic.Certificate{}, errRenewCertFailed return certmagic.Certificate{}, errRenewCertFailed
} }
log.Warn(context.TODO()).Err(err).Msg("renew client certificated failed, use existing cert") log.Warn(ctx).Err(err).Msg("renew client certificated failed, use existing cert")
} }
return cm.CacheManagedCertificate(domain) return cm.CacheManagedCertificate(domain)
} }
func (mgr *Manager) updateAutocert(cfg *config.Config) error { func (mgr *Manager) updateAutocert(ctx context.Context, cfg *config.Config) error {
if !cfg.Options.AutocertOptions.Enable { if !cfg.Options.AutocertOptions.Enable {
return nil return nil
} }
@ -214,22 +258,22 @@ func (mgr *Manager) updateAutocert(cfg *config.Config) error {
} }
for _, domain := range sourceHostnames(cfg) { for _, domain := range sourceHostnames(cfg) {
cert, err := mgr.obtainCert(domain, cm) cert, err := mgr.obtainCert(ctx, domain, cm)
if err != nil && errors.Is(err, errObtainCertFailed) { if err != nil && errors.Is(err, errObtainCertFailed) {
return fmt.Errorf("autocert: failed to obtain client certificate: %w", err) return fmt.Errorf("autocert: failed to obtain client certificate: %w", err)
} }
if err == nil && cert.NeedsRenewal(cm) { if err == nil && cert.NeedsRenewal(cm) {
cert, err = mgr.renewCert(domain, cert, cm) cert, err = mgr.renewCert(ctx, domain, cert, cm)
} }
if err != nil && errors.Is(err, errRenewCertFailed) { if err != nil && errors.Is(err, errRenewCertFailed) {
return fmt.Errorf("autocert: failed to renew client certificate: %w", err) return fmt.Errorf("autocert: failed to renew client certificate: %w", err)
} }
if err != nil { if err != nil {
log.Error(context.TODO()).Err(err).Msg("autocert: failed to obtain client certificate") log.Error(ctx).Err(err).Msg("autocert: failed to obtain client certificate")
continue continue
} }
log.Info(context.TODO()).Strs("names", cert.Names).Msg("autocert: added certificate") log.Info(ctx).Strs("names", cert.Names).Msg("autocert: added certificate")
cfg.AutoCertificates = append(cfg.AutoCertificates, cert.Certificate) cfg.AutoCertificates = append(cfg.AutoCertificates, cert.Certificate)
} }
@ -238,7 +282,7 @@ func (mgr *Manager) updateAutocert(cfg *config.Config) error {
return nil return nil
} }
func (mgr *Manager) updateServer(cfg *config.Config) { func (mgr *Manager) updateServer(ctx context.Context, cfg *config.Config) {
if mgr.srv != nil { if mgr.srv != nil {
// nothing to do if the address hasn't changed // nothing to do if the address hasn't changed
if mgr.srv.Addr == cfg.Options.HTTPRedirectAddr { if mgr.srv.Addr == cfg.Options.HTTPRedirectAddr {
@ -265,10 +309,10 @@ func (mgr *Manager) updateServer(cfg *config.Config) {
}), }),
} }
go func() { go func() {
log.Info(context.TODO()).Str("addr", hsrv.Addr).Msg("starting http redirect server") log.Info(ctx).Str("addr", hsrv.Addr).Msg("starting http redirect server")
err := hsrv.ListenAndServe() err := hsrv.ListenAndServe()
if err != nil { if err != nil {
log.Error(context.TODO()).Err(err).Msg("failed to run http redirect server") log.Error(ctx).Err(err).Msg("failed to run http redirect server")
} }
}() }()
mgr.srv = hsrv mgr.srv = hsrv

View file

@ -3,9 +3,9 @@ package autocert
import ( import (
"bytes" "bytes"
"context" "context"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand" "crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509" "crypto/x509"
"crypto/x509/pkix" "crypto/x509/pkix"
"encoding/base64" "encoding/base64"
@ -25,19 +25,70 @@ import (
"github.com/caddyserver/certmagic" "github.com/caddyserver/certmagic"
"github.com/go-chi/chi" "github.com/go-chi/chi"
"github.com/go-chi/chi/middleware" "github.com/go-chi/chi/middleware"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/crypto/ocsp"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
) )
type M = map[string]interface{} type M = map[string]interface{}
func newMockACME(srv *httptest.Server) http.Handler { type testCA struct {
key *ecdsa.PrivateKey
cert *x509.Certificate
certPEM []byte
}
func newTestCA() (*testCA, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, err
}
tpl := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().Unix()),
Subject: pkix.Name{
CommonName: "Test CA",
},
NotBefore: time.Now(),
NotAfter: time.Now().Add(time.Minute * 10),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
BasicConstraintsValid: true,
IsCA: true,
}
der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &key.PublicKey, key)
if err != nil {
return nil, err
}
cert, err := x509.ParseCertificate(der)
if err != nil {
return nil, err
}
return &testCA{
key,
cert,
pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}),
}, nil
}
func newMockACME(ca *testCA, srv *httptest.Server) http.Handler {
var certBuffer bytes.Buffer var certBuffer bytes.Buffer
var certs []*x509.Certificate
findCert := func(serial *big.Int) *x509.Certificate {
for _, c := range certs {
if c.SerialNumber.Cmp(serial) == 0 {
return c
}
}
return nil
}
r := chi.NewRouter() r := chi.NewRouter()
r.Use(middleware.Logger) r.Use(middleware.Logger)
r.Get("/acme/directory", func(w http.ResponseWriter, r *http.Request) { r.Get("/acme/directory", func(w http.ResponseWriter, r *http.Request) {
@ -78,6 +129,22 @@ func newMockACME(srv *httptest.Server) http.Handler {
"finalize": srv.URL + "/acme/finalize", "finalize": srv.URL + "/acme/finalize",
}) })
}) })
r.Post("/ocsp/request", func(w http.ResponseWriter, r *http.Request) {
reqData, _ := io.ReadAll(r.Body)
ocspReq, _ := ocsp.ParseRequest(reqData)
ocspResp := ocsp.Response{
Status: ocsp.Good,
SerialNumber: ocspReq.SerialNumber,
ThisUpdate: time.Now(),
NextUpdate: time.Now().Add(time.Second),
}
cert := findCert(ocspReq.SerialNumber)
data, _ := ocsp.CreateResponse(ca.cert, cert, ocspResp, ca.key)
w.WriteHeader(http.StatusOK)
_, _ = w.Write(data)
})
r.Post("/acme/finalize", func(w http.ResponseWriter, r *http.Request) { r.Post("/acme/finalize", func(w http.ResponseWriter, r *http.Request) {
var payload struct { var payload struct {
CSR string `json:"csr"` CSR string `json:"csr"`
@ -85,7 +152,6 @@ func newMockACME(srv *httptest.Server) http.Handler {
readJWSPayload(r.Body, &payload) readJWSPayload(r.Body, &payload)
bs, _ := base64.RawURLEncoding.DecodeString(payload.CSR) bs, _ := base64.RawURLEncoding.DecodeString(payload.CSR)
csr, _ := x509.ParseCertificateRequest(bs) csr, _ := x509.ParseCertificateRequest(bs)
caKey, _ := rsa.GenerateKey(rand.Reader, 2048)
tpl := &x509.Certificate{ tpl := &x509.Certificate{
SerialNumber: big.NewInt(time.Now().Unix()), SerialNumber: big.NewInt(time.Now().Unix()),
DNSNames: csr.DNSNames, DNSNames: csr.DNSNames,
@ -100,10 +166,15 @@ func newMockACME(srv *httptest.Server) http.Handler {
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
BasicConstraintsValid: true, BasicConstraintsValid: true,
IsCA: false, IsCA: false,
IssuingCertificateURL: []string{srv.URL + "/certs/ca"},
OCSPServer: []string{srv.URL + "/ocsp/request"},
} }
der, _ := x509.CreateCertificate(rand.Reader, tpl, tpl, csr.PublicKey, caKey) der, _ := x509.CreateCertificate(rand.Reader, tpl, ca.cert, csr.PublicKey, ca.key)
certBuffer.Reset() certBuffer.Reset()
_ = pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: der}) _ = pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: der})
cert, _ := x509.ParseCertificate(der)
certs = append(certs, cert)
w.Header().Set("Replay-Nonce", "NONCE") w.Header().Set("Replay-Nonce", "NONCE")
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
@ -120,6 +191,11 @@ func newMockACME(srv *httptest.Server) http.Handler {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
_, _ = w.Write(certBuffer.Bytes()) _, _ = w.Write(certBuffer.Bytes())
}) })
r.Get("/certs/ca", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/pkix-cert")
w.WriteHeader(http.StatusOK)
_, _ = w.Write(ca.cert.Raw)
})
return r return r
} }
@ -132,7 +208,11 @@ func TestConfig(t *testing.T) {
mockACME.ServeHTTP(w, r) mockACME.ServeHTTP(w, r)
})) }))
defer srv.Close() defer srv.Close()
mockACME = newMockACME(srv)
ca, err := newTestCA()
require.NoError(t, err)
mockACME = newMockACME(ca, srv)
tmpdir := filepath.Join(os.TempDir(), uuid.New().String()) tmpdir := filepath.Join(os.TempDir(), uuid.New().String())
_ = os.MkdirAll(tmpdir, 0o755) _ = os.MkdirAll(tmpdir, 0o755)
@ -158,7 +238,7 @@ func TestConfig(t *testing.T) {
AutocertOptions: config.AutocertOptions{ AutocertOptions: config.AutocertOptions{
Enable: true, Enable: true,
UseStaging: true, UseStaging: true,
MustStaple: false, MustStaple: true,
Folder: tmpdir, Folder: tmpdir,
}, },
HTTPRedirectAddr: addr, HTTPRedirectAddr: addr,
@ -172,21 +252,44 @@ func TestConfig(t *testing.T) {
return return
} }
var certs []tls.Certificate domainRenewed := make(chan bool)
for i := 0; i < 10; i++ { ocspUpdated := make(chan bool)
cfg := mgr.GetConfig()
assert.LessOrEqual(t, len(cfg.AutoCertificates), 1)
if len(cfg.AutoCertificates) == 1 && certs == nil {
certs = cfg.AutoCertificates
}
if !cmp.Equal(certs, cfg.AutoCertificates) { var initialOCSPStaple []byte
var certValidTime *time.Time
mgr.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) {
log.Info(ctx).Msg("OnConfigChange")
cert := cfg.AutoCertificates[0]
if initialOCSPStaple == nil {
initialOCSPStaple = cert.OCSPStaple
} else {
if bytes.Compare(initialOCSPStaple, cert.OCSPStaple) != 0 {
log.Info(ctx).Msg("OCSP updated")
ocspUpdated <- true
}
}
if certValidTime == nil {
certValidTime = &cert.Leaf.NotAfter
} else {
if !certValidTime.Equal(cert.Leaf.NotAfter) {
log.Info(ctx).Msg("domain renewed")
domainRenewed <- true
}
}
})
domainRenewedOK := false
ocspUpdatedOK := false
for !domainRenewedOK || !ocspUpdatedOK {
select {
case <-time.After(time.Second * 10):
t.Error("timeout waiting for certs renewal")
return return
case domainRenewedOK = <-domainRenewed:
case ocspUpdatedOK = <-ocspUpdated:
} }
time.Sleep(time.Second)
} }
t.Fatalf("expected renewed certs, but certs never changed")
} }
func TestRedirect(t *testing.T) { func TestRedirect(t *testing.T) {

33
internal/autocert/ocsp.go Normal file
View file

@ -0,0 +1,33 @@
package autocert
import (
"bytes"
lru "github.com/hashicorp/golang-lru"
)
type ocspCache struct {
*lru.Cache
}
func newOCSPCache(size int) (*ocspCache, error) {
c, err := lru.New(size)
if err != nil {
return nil, err
}
return &ocspCache{c}, nil
}
// updated checks if OCSP response for this certificate was updated
func (c ocspCache) updated(key string, ocspResp []byte) bool {
current, there := c.Get(key)
if !there {
_ = c.Add(key, ocspResp)
return false // to avoid triggering reload first time we see this response
}
if bytes.Equal(current.([]byte), ocspResp) {
return false
}
_ = c.Add(key, ocspResp)
return true
}

View file

@ -0,0 +1,31 @@
package autocert
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOcspCache(t *testing.T) {
c, err := newOCSPCache(10)
require.NoError(t, err)
cases := []struct {
data []byte
isUpdated bool
}{
{nil, false},
{nil, false},
{[]byte("a"), true},
{[]byte("a"), false},
{[]byte("b"), true},
{[]byte("b"), false},
{nil, true},
{nil, false},
}
for i, tc := range cases {
assert.Equal(t, tc.isUpdated, c.updated("key", tc.data), "#%d: %v", i, tc)
}
}