mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
ocsp: reload on ocsp response changes (#2286)
This commit is contained in:
parent
f9675f61cc
commit
b372ab4bcc
4 changed files with 262 additions and 51 deletions
|
@ -12,6 +12,7 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/caddyserver/certmagic"
|
"github.com/caddyserver/certmagic"
|
||||||
|
"github.com/rs/zerolog"
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
@ -28,6 +29,12 @@ var (
|
||||||
renewCertLock sync.Mutex
|
renewCertLock sync.Mutex
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
ocspRespCacheSize = 50000
|
||||||
|
renewalInterval = time.Minute * 10
|
||||||
|
renewalTimeout = time.Hour
|
||||||
|
)
|
||||||
|
|
||||||
// Manager manages TLS certificates.
|
// Manager manages TLS certificates.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
src config.Source
|
src config.Source
|
||||||
|
@ -39,12 +46,14 @@ type Manager struct {
|
||||||
acmeMgr atomic.Value
|
acmeMgr atomic.Value
|
||||||
srv *http.Server
|
srv *http.Server
|
||||||
|
|
||||||
|
*ocspCache
|
||||||
|
|
||||||
config.ChangeDispatcher
|
config.ChangeDispatcher
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new autocert manager.
|
// New creates a new autocert manager.
|
||||||
func New(src config.Source) (*Manager, error) {
|
func New(src config.Source) (*Manager, error) {
|
||||||
return newManager(context.Background(), src, certmagic.DefaultACME, time.Minute*10)
|
return newManager(context.Background(), src, certmagic.DefaultACME, renewalInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
func newManager(ctx context.Context,
|
func newManager(ctx context.Context,
|
||||||
|
@ -52,6 +61,15 @@ func newManager(ctx context.Context,
|
||||||
acmeTemplate certmagic.ACMEManager,
|
acmeTemplate certmagic.ACMEManager,
|
||||||
checkInterval time.Duration,
|
checkInterval time.Duration,
|
||||||
) (*Manager, error) {
|
) (*Manager, error) {
|
||||||
|
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||||
|
return c.Str("service", "autocert-manager")
|
||||||
|
})
|
||||||
|
|
||||||
|
ocspRespCache, err := newOCSPCache(ocspRespCacheSize)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
certmagicConfig := certmagic.NewDefault()
|
certmagicConfig := certmagic.NewDefault()
|
||||||
// set certmagic default storage cache, otherwise cert renewal loop will be based off
|
// set certmagic default storage cache, otherwise cert renewal loop will be based off
|
||||||
// certmagic's own default location
|
// certmagic's own default location
|
||||||
|
@ -67,13 +85,14 @@ func newManager(ctx context.Context,
|
||||||
src: src,
|
src: src,
|
||||||
acmeTemplate: acmeTemplate,
|
acmeTemplate: acmeTemplate,
|
||||||
certmagic: certmagicConfig,
|
certmagic: certmagicConfig,
|
||||||
|
ocspCache: ocspRespCache,
|
||||||
}
|
}
|
||||||
err := mgr.update(src.GetConfig())
|
err = mgr.update(ctx, src.GetConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
mgr.src.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) {
|
mgr.src.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) {
|
||||||
err := mgr.update(cfg)
|
err := mgr.update(ctx, cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx).Err(err).Msg("autocert: error updating config")
|
log.Error(ctx).Err(err).Msg("autocert: error updating config")
|
||||||
return
|
return
|
||||||
|
@ -91,9 +110,9 @@ func newManager(ctx context.Context,
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
err := mgr.renewConfigCerts()
|
err := mgr.renewConfigCerts(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(context.TODO()).Err(err).Msg("autocert: error updating config")
|
log.Error(ctx).Err(err).Msg("autocert: error updating config")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -128,7 +147,10 @@ func (mgr *Manager) getCertMagicConfig(cfg *config.Config) (*certmagic.Config, e
|
||||||
return mgr.certmagic, nil
|
return mgr.certmagic, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *Manager) renewConfigCerts() error {
|
func (mgr *Manager) renewConfigCerts(ctx context.Context) error {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, renewalTimeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
mgr.mu.Lock()
|
mgr.mu.Lock()
|
||||||
defer mgr.mu.Unlock()
|
defer mgr.mu.Unlock()
|
||||||
|
|
||||||
|
@ -138,47 +160,69 @@ func (mgr *Manager) renewConfigCerts() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
needsRenewal := false
|
needsReload := false
|
||||||
|
var renew, ocsp []string
|
||||||
|
log.Debug(ctx).Strs("domains", sourceHostnames(cfg)).Msg("checking domains")
|
||||||
for _, domain := range sourceHostnames(cfg) {
|
for _, domain := range sourceHostnames(cfg) {
|
||||||
cert, err := cm.CacheManagedCertificate(domain)
|
cert, err := cm.CacheManagedCertificate(domain)
|
||||||
if err == nil && cert.NeedsRenewal(cm) {
|
if err != nil {
|
||||||
needsRenewal = true
|
log.Error(ctx).Err(err).Str("domain", domain).Msg("get cert")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if cert.NeedsRenewal(cm) {
|
||||||
|
renew = append(renew, domain)
|
||||||
|
needsReload = true
|
||||||
|
}
|
||||||
|
if mgr.ocspCache.updated(domain, cert.OCSPStaple) {
|
||||||
|
ocsp = append(ocsp, domain)
|
||||||
|
needsReload = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !needsRenewal {
|
if !needsReload {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
||||||
|
if len(renew) > 0 {
|
||||||
|
c = c.Strs("renew_domains", renew)
|
||||||
|
}
|
||||||
|
if len(ocsp) > 0 {
|
||||||
|
c = c.Strs("ocsp_refresh", ocsp)
|
||||||
|
}
|
||||||
|
return c
|
||||||
|
})
|
||||||
|
log.Info(ctx).Msg("updating certificates")
|
||||||
|
|
||||||
cfg = mgr.src.GetConfig().Clone()
|
cfg = mgr.src.GetConfig().Clone()
|
||||||
mgr.updateServer(cfg)
|
mgr.updateServer(ctx, cfg)
|
||||||
if err := mgr.updateAutocert(cfg); err != nil {
|
if err := mgr.updateAutocert(ctx, cfg); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr.config = cfg
|
mgr.config = cfg
|
||||||
mgr.Trigger(context.TODO(), cfg)
|
mgr.Trigger(ctx, cfg)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *Manager) update(cfg *config.Config) error {
|
func (mgr *Manager) update(ctx context.Context, cfg *config.Config) error {
|
||||||
cfg = cfg.Clone()
|
cfg = cfg.Clone()
|
||||||
|
|
||||||
mgr.mu.Lock()
|
mgr.mu.Lock()
|
||||||
defer mgr.mu.Unlock()
|
defer mgr.mu.Unlock()
|
||||||
defer func() { mgr.config = cfg }()
|
defer func() { mgr.config = cfg }()
|
||||||
|
|
||||||
mgr.updateServer(cfg)
|
mgr.updateServer(ctx, cfg)
|
||||||
return mgr.updateAutocert(cfg)
|
return mgr.updateAutocert(ctx, cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// obtainCert obtains a certificate for given domain, use cached manager if cert exists there.
|
// obtainCert obtains a certificate for given domain, use cached manager if cert exists there.
|
||||||
func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.Certificate, error) {
|
func (mgr *Manager) obtainCert(ctx context.Context, domain string, cm *certmagic.Config) (certmagic.Certificate, error) {
|
||||||
cert, err := cm.CacheManagedCertificate(domain)
|
cert, err := cm.CacheManagedCertificate(domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Info(context.TODO()).Str("domain", domain).Msg("obtaining certificate")
|
log.Info(ctx).Str("domain", domain).Msg("obtaining certificate")
|
||||||
err = cm.ObtainCert(context.Background(), domain, false)
|
err = cm.ObtainCert(ctx, domain, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(context.TODO()).Err(err).Msg("autocert failed to obtain client certificate")
|
log.Error(ctx).Err(err).Msg("autocert failed to obtain client certificate")
|
||||||
return certmagic.Certificate{}, errObtainCertFailed
|
return certmagic.Certificate{}, errObtainCertFailed
|
||||||
}
|
}
|
||||||
metrics.RecordAutocertRenewal()
|
metrics.RecordAutocertRenewal()
|
||||||
|
@ -188,22 +232,22 @@ func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.C
|
||||||
}
|
}
|
||||||
|
|
||||||
// renewCert attempts to renew given certificate.
|
// renewCert attempts to renew given certificate.
|
||||||
func (mgr *Manager) renewCert(domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) {
|
func (mgr *Manager) renewCert(ctx context.Context, domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) {
|
||||||
expired := time.Now().After(cert.Leaf.NotAfter)
|
expired := time.Now().After(cert.Leaf.NotAfter)
|
||||||
log.Info(context.TODO()).Str("domain", domain).Msg("renewing certificate")
|
log.Info(ctx).Str("domain", domain).Msg("renewing certificate")
|
||||||
renewCertLock.Lock()
|
renewCertLock.Lock()
|
||||||
err := cm.RenewCert(context.Background(), domain, false)
|
err := cm.RenewCert(ctx, domain, false)
|
||||||
renewCertLock.Unlock()
|
renewCertLock.Unlock()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if expired {
|
if expired {
|
||||||
return certmagic.Certificate{}, errRenewCertFailed
|
return certmagic.Certificate{}, errRenewCertFailed
|
||||||
}
|
}
|
||||||
log.Warn(context.TODO()).Err(err).Msg("renew client certificated failed, use existing cert")
|
log.Warn(ctx).Err(err).Msg("renew client certificated failed, use existing cert")
|
||||||
}
|
}
|
||||||
return cm.CacheManagedCertificate(domain)
|
return cm.CacheManagedCertificate(domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *Manager) updateAutocert(cfg *config.Config) error {
|
func (mgr *Manager) updateAutocert(ctx context.Context, cfg *config.Config) error {
|
||||||
if !cfg.Options.AutocertOptions.Enable {
|
if !cfg.Options.AutocertOptions.Enable {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -214,22 +258,22 @@ func (mgr *Manager) updateAutocert(cfg *config.Config) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, domain := range sourceHostnames(cfg) {
|
for _, domain := range sourceHostnames(cfg) {
|
||||||
cert, err := mgr.obtainCert(domain, cm)
|
cert, err := mgr.obtainCert(ctx, domain, cm)
|
||||||
if err != nil && errors.Is(err, errObtainCertFailed) {
|
if err != nil && errors.Is(err, errObtainCertFailed) {
|
||||||
return fmt.Errorf("autocert: failed to obtain client certificate: %w", err)
|
return fmt.Errorf("autocert: failed to obtain client certificate: %w", err)
|
||||||
}
|
}
|
||||||
if err == nil && cert.NeedsRenewal(cm) {
|
if err == nil && cert.NeedsRenewal(cm) {
|
||||||
cert, err = mgr.renewCert(domain, cert, cm)
|
cert, err = mgr.renewCert(ctx, domain, cert, cm)
|
||||||
}
|
}
|
||||||
if err != nil && errors.Is(err, errRenewCertFailed) {
|
if err != nil && errors.Is(err, errRenewCertFailed) {
|
||||||
return fmt.Errorf("autocert: failed to renew client certificate: %w", err)
|
return fmt.Errorf("autocert: failed to renew client certificate: %w", err)
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(context.TODO()).Err(err).Msg("autocert: failed to obtain client certificate")
|
log.Error(ctx).Err(err).Msg("autocert: failed to obtain client certificate")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info(context.TODO()).Strs("names", cert.Names).Msg("autocert: added certificate")
|
log.Info(ctx).Strs("names", cert.Names).Msg("autocert: added certificate")
|
||||||
cfg.AutoCertificates = append(cfg.AutoCertificates, cert.Certificate)
|
cfg.AutoCertificates = append(cfg.AutoCertificates, cert.Certificate)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -238,7 +282,7 @@ func (mgr *Manager) updateAutocert(cfg *config.Config) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *Manager) updateServer(cfg *config.Config) {
|
func (mgr *Manager) updateServer(ctx context.Context, cfg *config.Config) {
|
||||||
if mgr.srv != nil {
|
if mgr.srv != nil {
|
||||||
// nothing to do if the address hasn't changed
|
// nothing to do if the address hasn't changed
|
||||||
if mgr.srv.Addr == cfg.Options.HTTPRedirectAddr {
|
if mgr.srv.Addr == cfg.Options.HTTPRedirectAddr {
|
||||||
|
@ -265,10 +309,10 @@ func (mgr *Manager) updateServer(cfg *config.Config) {
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
go func() {
|
go func() {
|
||||||
log.Info(context.TODO()).Str("addr", hsrv.Addr).Msg("starting http redirect server")
|
log.Info(ctx).Str("addr", hsrv.Addr).Msg("starting http redirect server")
|
||||||
err := hsrv.ListenAndServe()
|
err := hsrv.ListenAndServe()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(context.TODO()).Err(err).Msg("failed to run http redirect server")
|
log.Error(ctx).Err(err).Msg("failed to run http redirect server")
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
mgr.srv = hsrv
|
mgr.srv = hsrv
|
||||||
|
|
|
@ -3,9 +3,9 @@ package autocert
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/ecdsa"
|
||||||
|
"crypto/elliptic"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"crypto/rsa"
|
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"crypto/x509/pkix"
|
"crypto/x509/pkix"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
|
@ -25,19 +25,70 @@ import (
|
||||||
"github.com/caddyserver/certmagic"
|
"github.com/caddyserver/certmagic"
|
||||||
"github.com/go-chi/chi"
|
"github.com/go-chi/chi"
|
||||||
"github.com/go-chi/chi/middleware"
|
"github.com/go-chi/chi/middleware"
|
||||||
"github.com/google/go-cmp/cmp"
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
"golang.org/x/crypto/ocsp"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
type M = map[string]interface{}
|
type M = map[string]interface{}
|
||||||
|
|
||||||
func newMockACME(srv *httptest.Server) http.Handler {
|
type testCA struct {
|
||||||
|
key *ecdsa.PrivateKey
|
||||||
|
cert *x509.Certificate
|
||||||
|
certPEM []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestCA() (*testCA, error) {
|
||||||
|
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(time.Now().Unix()),
|
||||||
|
Subject: pkix.Name{
|
||||||
|
CommonName: "Test CA",
|
||||||
|
},
|
||||||
|
NotBefore: time.Now(),
|
||||||
|
NotAfter: time.Now().Add(time.Minute * 10),
|
||||||
|
|
||||||
|
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature,
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
IsCA: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &key.PublicKey, key)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cert, err := x509.ParseCertificate(der)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &testCA{
|
||||||
|
key,
|
||||||
|
cert,
|
||||||
|
pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMockACME(ca *testCA, srv *httptest.Server) http.Handler {
|
||||||
var certBuffer bytes.Buffer
|
var certBuffer bytes.Buffer
|
||||||
|
|
||||||
|
var certs []*x509.Certificate
|
||||||
|
findCert := func(serial *big.Int) *x509.Certificate {
|
||||||
|
for _, c := range certs {
|
||||||
|
if c.SerialNumber.Cmp(serial) == 0 {
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
r := chi.NewRouter()
|
r := chi.NewRouter()
|
||||||
r.Use(middleware.Logger)
|
r.Use(middleware.Logger)
|
||||||
r.Get("/acme/directory", func(w http.ResponseWriter, r *http.Request) {
|
r.Get("/acme/directory", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -78,6 +129,22 @@ func newMockACME(srv *httptest.Server) http.Handler {
|
||||||
"finalize": srv.URL + "/acme/finalize",
|
"finalize": srv.URL + "/acme/finalize",
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
r.Post("/ocsp/request", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
reqData, _ := io.ReadAll(r.Body)
|
||||||
|
ocspReq, _ := ocsp.ParseRequest(reqData)
|
||||||
|
ocspResp := ocsp.Response{
|
||||||
|
Status: ocsp.Good,
|
||||||
|
SerialNumber: ocspReq.SerialNumber,
|
||||||
|
ThisUpdate: time.Now(),
|
||||||
|
NextUpdate: time.Now().Add(time.Second),
|
||||||
|
}
|
||||||
|
|
||||||
|
cert := findCert(ocspReq.SerialNumber)
|
||||||
|
data, _ := ocsp.CreateResponse(ca.cert, cert, ocspResp, ca.key)
|
||||||
|
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write(data)
|
||||||
|
})
|
||||||
r.Post("/acme/finalize", func(w http.ResponseWriter, r *http.Request) {
|
r.Post("/acme/finalize", func(w http.ResponseWriter, r *http.Request) {
|
||||||
var payload struct {
|
var payload struct {
|
||||||
CSR string `json:"csr"`
|
CSR string `json:"csr"`
|
||||||
|
@ -85,7 +152,6 @@ func newMockACME(srv *httptest.Server) http.Handler {
|
||||||
readJWSPayload(r.Body, &payload)
|
readJWSPayload(r.Body, &payload)
|
||||||
bs, _ := base64.RawURLEncoding.DecodeString(payload.CSR)
|
bs, _ := base64.RawURLEncoding.DecodeString(payload.CSR)
|
||||||
csr, _ := x509.ParseCertificateRequest(bs)
|
csr, _ := x509.ParseCertificateRequest(bs)
|
||||||
caKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
|
||||||
tpl := &x509.Certificate{
|
tpl := &x509.Certificate{
|
||||||
SerialNumber: big.NewInt(time.Now().Unix()),
|
SerialNumber: big.NewInt(time.Now().Unix()),
|
||||||
DNSNames: csr.DNSNames,
|
DNSNames: csr.DNSNames,
|
||||||
|
@ -100,10 +166,15 @@ func newMockACME(srv *httptest.Server) http.Handler {
|
||||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||||
BasicConstraintsValid: true,
|
BasicConstraintsValid: true,
|
||||||
IsCA: false,
|
IsCA: false,
|
||||||
|
|
||||||
|
IssuingCertificateURL: []string{srv.URL + "/certs/ca"},
|
||||||
|
OCSPServer: []string{srv.URL + "/ocsp/request"},
|
||||||
}
|
}
|
||||||
der, _ := x509.CreateCertificate(rand.Reader, tpl, tpl, csr.PublicKey, caKey)
|
der, _ := x509.CreateCertificate(rand.Reader, tpl, ca.cert, csr.PublicKey, ca.key)
|
||||||
certBuffer.Reset()
|
certBuffer.Reset()
|
||||||
_ = pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: der})
|
_ = pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||||
|
cert, _ := x509.ParseCertificate(der)
|
||||||
|
certs = append(certs, cert)
|
||||||
|
|
||||||
w.Header().Set("Replay-Nonce", "NONCE")
|
w.Header().Set("Replay-Nonce", "NONCE")
|
||||||
w.Header().Set("Content-Type", "application/json")
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
@ -120,6 +191,11 @@ func newMockACME(srv *httptest.Server) http.Handler {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
_, _ = w.Write(certBuffer.Bytes())
|
_, _ = w.Write(certBuffer.Bytes())
|
||||||
})
|
})
|
||||||
|
r.Get("/certs/ca", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/pkix-cert")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
_, _ = w.Write(ca.cert.Raw)
|
||||||
|
})
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -132,7 +208,11 @@ func TestConfig(t *testing.T) {
|
||||||
mockACME.ServeHTTP(w, r)
|
mockACME.ServeHTTP(w, r)
|
||||||
}))
|
}))
|
||||||
defer srv.Close()
|
defer srv.Close()
|
||||||
mockACME = newMockACME(srv)
|
|
||||||
|
ca, err := newTestCA()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockACME = newMockACME(ca, srv)
|
||||||
|
|
||||||
tmpdir := filepath.Join(os.TempDir(), uuid.New().String())
|
tmpdir := filepath.Join(os.TempDir(), uuid.New().String())
|
||||||
_ = os.MkdirAll(tmpdir, 0o755)
|
_ = os.MkdirAll(tmpdir, 0o755)
|
||||||
|
@ -158,7 +238,7 @@ func TestConfig(t *testing.T) {
|
||||||
AutocertOptions: config.AutocertOptions{
|
AutocertOptions: config.AutocertOptions{
|
||||||
Enable: true,
|
Enable: true,
|
||||||
UseStaging: true,
|
UseStaging: true,
|
||||||
MustStaple: false,
|
MustStaple: true,
|
||||||
Folder: tmpdir,
|
Folder: tmpdir,
|
||||||
},
|
},
|
||||||
HTTPRedirectAddr: addr,
|
HTTPRedirectAddr: addr,
|
||||||
|
@ -172,21 +252,44 @@ func TestConfig(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var certs []tls.Certificate
|
domainRenewed := make(chan bool)
|
||||||
for i := 0; i < 10; i++ {
|
ocspUpdated := make(chan bool)
|
||||||
cfg := mgr.GetConfig()
|
|
||||||
assert.LessOrEqual(t, len(cfg.AutoCertificates), 1)
|
|
||||||
if len(cfg.AutoCertificates) == 1 && certs == nil {
|
|
||||||
certs = cfg.AutoCertificates
|
|
||||||
}
|
|
||||||
|
|
||||||
if !cmp.Equal(certs, cfg.AutoCertificates) {
|
var initialOCSPStaple []byte
|
||||||
|
var certValidTime *time.Time
|
||||||
|
mgr.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) {
|
||||||
|
log.Info(ctx).Msg("OnConfigChange")
|
||||||
|
cert := cfg.AutoCertificates[0]
|
||||||
|
if initialOCSPStaple == nil {
|
||||||
|
initialOCSPStaple = cert.OCSPStaple
|
||||||
|
} else {
|
||||||
|
if bytes.Compare(initialOCSPStaple, cert.OCSPStaple) != 0 {
|
||||||
|
log.Info(ctx).Msg("OCSP updated")
|
||||||
|
ocspUpdated <- true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if certValidTime == nil {
|
||||||
|
certValidTime = &cert.Leaf.NotAfter
|
||||||
|
} else {
|
||||||
|
if !certValidTime.Equal(cert.Leaf.NotAfter) {
|
||||||
|
log.Info(ctx).Msg("domain renewed")
|
||||||
|
domainRenewed <- true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
domainRenewedOK := false
|
||||||
|
ocspUpdatedOK := false
|
||||||
|
|
||||||
|
for !domainRenewedOK || !ocspUpdatedOK {
|
||||||
|
select {
|
||||||
|
case <-time.After(time.Second * 10):
|
||||||
|
t.Error("timeout waiting for certs renewal")
|
||||||
return
|
return
|
||||||
|
case domainRenewedOK = <-domainRenewed:
|
||||||
|
case ocspUpdatedOK = <-ocspUpdated:
|
||||||
}
|
}
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
|
||||||
}
|
}
|
||||||
t.Fatalf("expected renewed certs, but certs never changed")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRedirect(t *testing.T) {
|
func TestRedirect(t *testing.T) {
|
||||||
|
|
33
internal/autocert/ocsp.go
Normal file
33
internal/autocert/ocsp.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package autocert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
|
||||||
|
lru "github.com/hashicorp/golang-lru"
|
||||||
|
)
|
||||||
|
|
||||||
|
type ocspCache struct {
|
||||||
|
*lru.Cache
|
||||||
|
}
|
||||||
|
|
||||||
|
func newOCSPCache(size int) (*ocspCache, error) {
|
||||||
|
c, err := lru.New(size)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &ocspCache{c}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// updated checks if OCSP response for this certificate was updated
|
||||||
|
func (c ocspCache) updated(key string, ocspResp []byte) bool {
|
||||||
|
current, there := c.Get(key)
|
||||||
|
if !there {
|
||||||
|
_ = c.Add(key, ocspResp)
|
||||||
|
return false // to avoid triggering reload first time we see this response
|
||||||
|
}
|
||||||
|
if bytes.Equal(current.([]byte), ocspResp) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
_ = c.Add(key, ocspResp)
|
||||||
|
return true
|
||||||
|
}
|
31
internal/autocert/ocsp_test.go
Normal file
31
internal/autocert/ocsp_test.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
package autocert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestOcspCache(t *testing.T) {
|
||||||
|
c, err := newOCSPCache(10)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cases := []struct {
|
||||||
|
data []byte
|
||||||
|
isUpdated bool
|
||||||
|
}{
|
||||||
|
{nil, false},
|
||||||
|
{nil, false},
|
||||||
|
{[]byte("a"), true},
|
||||||
|
{[]byte("a"), false},
|
||||||
|
{[]byte("b"), true},
|
||||||
|
{[]byte("b"), false},
|
||||||
|
{nil, true},
|
||||||
|
{nil, false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, tc := range cases {
|
||||||
|
assert.Equal(t, tc.isUpdated, c.updated("key", tc.data), "#%d: %v", i, tc)
|
||||||
|
}
|
||||||
|
}
|
Loading…
Add table
Add a link
Reference in a new issue