mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 08:50:42 +02:00
autocert: support certificate renewal (#1516)
This commit is contained in:
parent
04c582121d
commit
ac19c5041f
2 changed files with 239 additions and 2 deletions
|
@ -22,6 +22,9 @@ import (
|
|||
var (
|
||||
errObtainCertFailed = errors.New("obtain cert failed")
|
||||
errRenewCertFailed = errors.New("renew cert failed")
|
||||
|
||||
checkInterval = time.Minute * 10
|
||||
acmeTemplate = certmagic.DefaultACME
|
||||
)
|
||||
|
||||
// Manager manages TLS certificates.
|
||||
|
@ -64,6 +67,18 @@ func New(src config.Source) (*Manager, error) {
|
|||
cfg = mgr.GetConfig()
|
||||
mgr.Trigger(cfg)
|
||||
})
|
||||
go func() {
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
err := mgr.renewConfigCerts()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("autocert: error updating config")
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
return mgr, nil
|
||||
}
|
||||
|
||||
|
@ -77,10 +92,10 @@ func (mgr *Manager) getCertMagicConfig(options *config.Options) (*certmagic.Conf
|
|||
return nil, fmt.Errorf("config: failed caching cert: %w", err)
|
||||
}
|
||||
}
|
||||
acmeMgr := certmagic.NewACMEManager(mgr.certmagic, certmagic.DefaultACME)
|
||||
acmeMgr := certmagic.NewACMEManager(mgr.certmagic, acmeTemplate)
|
||||
acmeMgr.Agreed = true
|
||||
if options.AutocertOptions.UseStaging {
|
||||
acmeMgr.CA = certmagic.LetsEncryptStagingCA
|
||||
acmeMgr.CA = acmeMgr.TestCA
|
||||
}
|
||||
acmeMgr.DisableTLSALPNChallenge = true
|
||||
mgr.certmagic.Issuer = acmeMgr
|
||||
|
@ -89,6 +104,38 @@ func (mgr *Manager) getCertMagicConfig(options *config.Options) (*certmagic.Conf
|
|||
return mgr.certmagic, nil
|
||||
}
|
||||
|
||||
func (mgr *Manager) renewConfigCerts() error {
|
||||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
||||
cfg := mgr.config
|
||||
cm, err := mgr.getCertMagicConfig(cfg.Options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
needsRenewal := false
|
||||
for _, domain := range sourceHostnames(cfg) {
|
||||
cert, err := cm.CacheManagedCertificate(domain)
|
||||
if err == nil && cert.NeedsRenewal(cm) {
|
||||
needsRenewal = true
|
||||
}
|
||||
}
|
||||
if !needsRenewal {
|
||||
return nil
|
||||
}
|
||||
|
||||
cfg = mgr.src.GetConfig().Clone()
|
||||
mgr.updateServer(cfg)
|
||||
if err := mgr.updateAutocert(cfg); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mgr.config = cfg
|
||||
mgr.Trigger(cfg)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mgr *Manager) update(cfg *config.Config) error {
|
||||
cfg = cfg.Clone()
|
||||
|
||||
|
@ -154,6 +201,7 @@ func (mgr *Manager) updateAutocert(cfg *config.Config) error {
|
|||
log.Error().Err(err).Msg("autocert: failed to obtain client certificate")
|
||||
continue
|
||||
}
|
||||
|
||||
log.Info().Strs("names", cert.Names).Msg("autocert: added certificate")
|
||||
cfg.Options.Certificates = append(cfg.Options.Certificates, cert.Certificate)
|
||||
}
|
||||
|
|
|
@ -1,17 +1,194 @@
|
|||
package autocert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
)
|
||||
|
||||
type M = map[string]interface{}
|
||||
|
||||
func newMockACME(srv *httptest.Server) http.Handler {
|
||||
var certBuffer bytes.Buffer
|
||||
|
||||
r := chi.NewRouter()
|
||||
r.Use(middleware.Logger)
|
||||
r.Get("/acme/directory", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"keyChange": srv.URL + "/acme/key-change",
|
||||
"newAccount": srv.URL + "/acme/new-acct",
|
||||
"newNonce": srv.URL + "/acme/new-nonce",
|
||||
"newOrder": srv.URL + "/acme/new-order",
|
||||
"revokeCert": srv.URL + "/acme/revoke-cert",
|
||||
})
|
||||
})
|
||||
r.Head("/acme/new-nonce", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Replay-Nonce", "NONCE")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
r.Post("/acme/new-acct", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Replay-Nonce", "NONCE")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"status": "valid",
|
||||
})
|
||||
})
|
||||
r.Post("/acme/new-order", func(w http.ResponseWriter, r *http.Request) {
|
||||
var payload struct {
|
||||
Identifiers []struct {
|
||||
Type string `json:"type"`
|
||||
Value string `json:"value"`
|
||||
} `json:"identifiers"`
|
||||
}
|
||||
readJWSPayload(r.Body, &payload)
|
||||
w.Header().Set("Replay-Nonce", "NONCE")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"status": "pending",
|
||||
"finalize": srv.URL + "/acme/finalize",
|
||||
})
|
||||
})
|
||||
r.Post("/acme/finalize", func(w http.ResponseWriter, r *http.Request) {
|
||||
var payload struct {
|
||||
CSR string `json:"csr"`
|
||||
}
|
||||
readJWSPayload(r.Body, &payload)
|
||||
bs, _ := base64.RawURLEncoding.DecodeString(payload.CSR)
|
||||
csr, _ := x509.ParseCertificateRequest(bs)
|
||||
caKey, _ := rsa.GenerateKey(rand.Reader, 2048)
|
||||
tpl := &x509.Certificate{
|
||||
SerialNumber: big.NewInt(time.Now().Unix()),
|
||||
DNSNames: csr.DNSNames,
|
||||
IPAddresses: csr.IPAddresses,
|
||||
Subject: pkix.Name{
|
||||
CommonName: csr.DNSNames[0],
|
||||
},
|
||||
NotBefore: time.Now(),
|
||||
NotAfter: time.Now().Add(time.Second * 2),
|
||||
|
||||
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth},
|
||||
BasicConstraintsValid: true,
|
||||
IsCA: false,
|
||||
}
|
||||
der, _ := x509.CreateCertificate(rand.Reader, tpl, tpl, csr.PublicKey, caKey)
|
||||
certBuffer.Reset()
|
||||
_ = pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||
|
||||
w.Header().Set("Replay-Nonce", "NONCE")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusCreated)
|
||||
_ = json.NewEncoder(w).Encode(M{
|
||||
"status": "valid",
|
||||
"finalize": srv.URL + "/acme/finalize",
|
||||
"certificate": srv.URL + "/acme/certificate",
|
||||
})
|
||||
})
|
||||
r.Post("/acme/certificate", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Replay-Nonce", "NONCE")
|
||||
w.Header().Set("Content-Type", "application/pem-certificate-chain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write(certBuffer.Bytes())
|
||||
})
|
||||
return r
|
||||
}
|
||||
|
||||
func TestConfig(t *testing.T) {
|
||||
var mockACME http.Handler
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
mockACME.ServeHTTP(w, r)
|
||||
}))
|
||||
defer srv.Close()
|
||||
mockACME = newMockACME(srv)
|
||||
|
||||
tmpdir := filepath.Join(os.TempDir(), uuid.New().String())
|
||||
_ = os.MkdirAll(tmpdir, 0755)
|
||||
defer os.RemoveAll(tmpdir)
|
||||
|
||||
li, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
addr := li.Addr().String()
|
||||
_ = li.Close()
|
||||
|
||||
oAcmeTemplate := acmeTemplate
|
||||
defer func() { acmeTemplate = oAcmeTemplate }()
|
||||
acmeTemplate = certmagic.ACMEManager{
|
||||
CA: srv.URL + "/acme/directory",
|
||||
TestCA: srv.URL + "/acme/directory",
|
||||
}
|
||||
|
||||
oCheckInterval := checkInterval
|
||||
defer func() { checkInterval = oCheckInterval }()
|
||||
checkInterval = time.Second
|
||||
|
||||
p1 := config.Policy{
|
||||
From: "http://from.example.com", To: "http://to.example.com",
|
||||
}
|
||||
_ = p1.Validate()
|
||||
|
||||
mgr, err := New(config.NewStaticSource(&config.Config{
|
||||
Options: &config.Options{
|
||||
AutocertOptions: config.AutocertOptions{
|
||||
Enable: true,
|
||||
UseStaging: true,
|
||||
MustStaple: false,
|
||||
Folder: tmpdir,
|
||||
},
|
||||
HTTPRedirectAddr: addr,
|
||||
Policies: []config.Policy{p1},
|
||||
},
|
||||
}))
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
var certs []tls.Certificate
|
||||
for i := 0; i < 10; i++ {
|
||||
cfg := mgr.GetConfig()
|
||||
assert.LessOrEqual(t, len(cfg.Options.Certificates), 1)
|
||||
if len(cfg.Options.Certificates) == 1 && certs == nil {
|
||||
certs = cfg.Options.Certificates
|
||||
}
|
||||
|
||||
if !cmp.Equal(certs, cfg.Options.Certificates) {
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(time.Second)
|
||||
}
|
||||
t.Fatalf("expected renewed certs, but certs never changed")
|
||||
}
|
||||
|
||||
func TestRedirect(t *testing.T) {
|
||||
li, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if !assert.NoError(t, err) {
|
||||
|
@ -71,3 +248,15 @@ func waitFor(addr string) error {
|
|||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func readJWSPayload(r io.Reader, dst interface{}) {
|
||||
var req struct {
|
||||
Protected string `json:"protected"`
|
||||
Payload string `json:"payload"`
|
||||
Signature string `json:"signature"`
|
||||
}
|
||||
_ = json.NewDecoder(r).Decode(&req)
|
||||
|
||||
bs, _ := base64.RawURLEncoding.DecodeString(req.Payload)
|
||||
_ = json.Unmarshal(bs, dst)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue