package autocert import ( "bytes" "context" "crypto/ecdsa" "crypto/elliptic" "crypto/rand" "crypto/x509" "crypto/x509/pkix" "encoding/base64" "encoding/json" "encoding/pem" "fmt" "io" "math/big" "net" "net/http" "net/http/httptest" "os" "testing" "time" "github.com/caddyserver/certmagic" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/mholt/acmez/acme" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "golang.org/x/crypto/ocsp" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/log" ) type M = map[string]interface{} type testCA struct { key *ecdsa.PrivateKey cert *x509.Certificate certPEM []byte } func newTestCA() (*testCA, error) { key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) if err != nil { return nil, err } tpl := &x509.Certificate{ SerialNumber: big.NewInt(time.Now().Unix()), Subject: pkix.Name{ CommonName: "Test CA", }, NotBefore: time.Now(), NotAfter: time.Now().Add(time.Minute * 10), KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageDigitalSignature, BasicConstraintsValid: true, IsCA: true, } der, err := x509.CreateCertificate(rand.Reader, tpl, tpl, &key.PublicKey, key) if err != nil { return nil, err } cert, err := x509.ParseCertificate(der) if err != nil { return nil, err } return &testCA{ key, cert, pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}), }, nil } func newMockACME(ca *testCA, srv *httptest.Server) http.Handler { var certBuffer bytes.Buffer var certs []*x509.Certificate findCert := func(serial *big.Int) *x509.Certificate { for _, c := range certs { if c.SerialNumber.Cmp(serial) == 0 { return c } } return nil } r := chi.NewRouter() r.Use(middleware.Logger) r.Get("/acme/directory", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") _ = json.NewEncoder(w).Encode(M{ "keyChange": srv.URL + "/acme/key-change", "newAccount": srv.URL + "/acme/new-acct", "newNonce": srv.URL + "/acme/new-nonce", "newOrder": srv.URL + "/acme/new-order", "revokeCert": srv.URL + "/acme/revoke-cert", }) }) r.Head("/acme/new-nonce", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Replay-Nonce", "NONCE") w.WriteHeader(http.StatusOK) }) r.Post("/acme/new-acct", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Replay-Nonce", "NONCE") w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) _ = json.NewEncoder(w).Encode(M{ "status": "valid", }) }) r.Post("/acme/new-order", func(w http.ResponseWriter, r *http.Request) { var payload struct { Identifiers []struct { Type string `json:"type"` Value string `json:"value"` } `json:"identifiers"` } readJWSPayload(r.Body, &payload) w.Header().Set("Replay-Nonce", "NONCE") w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) _ = json.NewEncoder(w).Encode(M{ "status": "pending", "finalize": srv.URL + "/acme/finalize", }) }) r.Post("/ocsp/request", func(w http.ResponseWriter, r *http.Request) { reqData, _ := io.ReadAll(r.Body) ocspReq, _ := ocsp.ParseRequest(reqData) ocspResp := ocsp.Response{ Status: ocsp.Good, SerialNumber: ocspReq.SerialNumber, ThisUpdate: time.Now(), NextUpdate: time.Now().Add(time.Second), } cert := findCert(ocspReq.SerialNumber) data, _ := ocsp.CreateResponse(ca.cert, cert, ocspResp, ca.key) w.WriteHeader(http.StatusOK) _, _ = w.Write(data) }) r.Post("/acme/finalize", func(w http.ResponseWriter, r *http.Request) { var payload struct { CSR string `json:"csr"` } readJWSPayload(r.Body, &payload) bs, _ := base64.RawURLEncoding.DecodeString(payload.CSR) csr, _ := x509.ParseCertificateRequest(bs) tpl := &x509.Certificate{ SerialNumber: big.NewInt(time.Now().Unix()), DNSNames: csr.DNSNames, IPAddresses: csr.IPAddresses, Subject: pkix.Name{ CommonName: csr.DNSNames[0], }, NotBefore: time.Now(), NotAfter: time.Now().Add(time.Second * 2), KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, BasicConstraintsValid: true, IsCA: false, IssuingCertificateURL: []string{srv.URL + "/certs/ca"}, OCSPServer: []string{srv.URL + "/ocsp/request"}, } der, _ := x509.CreateCertificate(rand.Reader, tpl, ca.cert, csr.PublicKey, ca.key) certBuffer.Reset() _ = pem.Encode(&certBuffer, &pem.Block{Type: "CERTIFICATE", Bytes: der}) cert, _ := x509.ParseCertificate(der) certs = append(certs, cert) w.Header().Set("Replay-Nonce", "NONCE") w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusCreated) _ = json.NewEncoder(w).Encode(M{ "status": "valid", "finalize": srv.URL + "/acme/finalize", "certificate": srv.URL + "/acme/certificate", }) }) r.Post("/acme/certificate", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Replay-Nonce", "NONCE") w.Header().Set("Content-Type", "application/pem-certificate-chain") w.WriteHeader(http.StatusOK) _, _ = w.Write(certBuffer.Bytes()) }) r.Get("/certs/ca", func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/pkix-cert") w.WriteHeader(http.StatusOK) _, _ = w.Write(ca.cert.Raw) }) return r } func TestConfig(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() var mockACME http.Handler srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mockACME.ServeHTTP(w, r) })) defer srv.Close() ca, err := newTestCA() require.NoError(t, err) mockACME = newMockACME(ca, srv) tmpdir := t.TempDir() li, err := net.Listen("tcp", "127.0.0.1:0") if !assert.NoError(t, err) { return } addr := li.Addr().String() _ = li.Close() to, err := config.ParseWeightedUrls("http://to.example.com") require.NoError(t, err) p1 := config.Policy{ From: "http://from.example.com", To: to, } _ = p1.Validate() mgr, err := newManager(ctx, config.NewStaticSource(&config.Config{ Options: &config.Options{ AutocertOptions: config.AutocertOptions{ Enable: true, UseStaging: true, Email: "pomerium-test@example.com", MustStaple: true, Folder: tmpdir, }, HTTPRedirectAddr: addr, Policies: []config.Policy{p1}, }, }), certmagic.ACMEIssuer{ CA: srv.URL + "/acme/directory", TestCA: srv.URL + "/acme/directory", }, time.Millisecond*100) if !assert.NoError(t, err) { return } domainRenewed := make(chan bool) ocspUpdated := make(chan bool) var initialOCSPStaple []byte var certValidTime *time.Time mgr.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) { if len(cfg.AutoCertificates) == 0 { return } cert := cfg.AutoCertificates[0] if initialOCSPStaple == nil { initialOCSPStaple = cert.OCSPStaple } else { if !bytes.Equal(initialOCSPStaple, cert.OCSPStaple) { log.Info(ctx).Msg("OCSP updated") ocspUpdated <- true } } if certValidTime == nil { certValidTime = &cert.Leaf.NotAfter } else { if !certValidTime.Equal(cert.Leaf.NotAfter) { log.Info(ctx).Msg("domain renewed") domainRenewed <- true } } }) domainRenewedOK := false ocspUpdatedOK := false for !domainRenewedOK || !ocspUpdatedOK { select { case <-time.After(time.Second * 10): t.Error("timeout waiting for certs renewal") return case domainRenewedOK = <-domainRenewed: case ocspUpdatedOK = <-ocspUpdated: } } } func TestRedirect(t *testing.T) { li, err := net.Listen("tcp", "127.0.0.1:0") if !assert.NoError(t, err) { return } addr := li.Addr().String() _ = li.Close() src := config.NewStaticSource(&config.Config{ Options: &config.Options{ HTTPRedirectAddr: addr, SetResponseHeaders: map[string]string{ "X-Frame-Options": "SAMEORIGIN", "X-XSS-Protection": "1; mode=block", "Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload", }, }, }) _, err = New(src) if !assert.NoError(t, err) { return } err = waitFor(addr) if !assert.NoError(t, err) { return } client := &http.Client{ CheckRedirect: func(req *http.Request, via []*http.Request) error { return http.ErrUseLastResponse }, } res, err := client.Get(fmt.Sprintf("http://%s", addr)) if !assert.NoError(t, err) { return } defer res.Body.Close() assert.Equal(t, http.StatusMovedPermanently, res.StatusCode, "should redirect to https") for k, v := range src.GetConfig().Options.SetResponseHeaders { assert.NotEqual(t, v, res.Header.Get(k), "should ignore options header") } } func waitFor(addr string) error { var err error deadline := time.Now().Add(time.Second * 30) for time.Now().Before(deadline) { var conn net.Conn conn, err = net.Dial("tcp", addr) if err == nil { conn.Close() return nil } time.Sleep(time.Second) } return err } func readJWSPayload(r io.Reader, dst interface{}) { var req struct { Protected string `json:"protected"` Payload string `json:"payload"` Signature string `json:"signature"` } _ = json.NewDecoder(r).Decode(&req) bs, _ := base64.RawURLEncoding.DecodeString(req.Payload) _ = json.Unmarshal(bs, dst) } func newACMEIssuer() *certmagic.ACMEIssuer { return &certmagic.ACMEIssuer{ CA: certmagic.DefaultACME.CA, TestCA: certmagic.DefaultACME.TestCA, } } func Test_configureCertificateAuthority(t *testing.T) { type args struct { acmeMgr *certmagic.ACMEIssuer opts config.AutocertOptions } type test struct { args args expected *certmagic.ACMEIssuer wantErr bool } tests := map[string]func(t *testing.T) test{ "ok/default": func(t *testing.T) test { return test{ args: args{ acmeMgr: newACMEIssuer(), opts: config.AutocertOptions{}, }, expected: &certmagic.ACMEIssuer{ Agreed: true, CA: certmagic.DefaultACME.CA, Email: " ", TestCA: certmagic.DefaultACME.TestCA, }, wantErr: false, } }, "ok/staging": func(t *testing.T) test { return test{ args: args{ acmeMgr: newACMEIssuer(), opts: config.AutocertOptions{ UseStaging: true, }, }, expected: &certmagic.ACMEIssuer{ Agreed: true, CA: certmagic.DefaultACME.TestCA, Email: " ", TestCA: certmagic.DefaultACME.TestCA, }, wantErr: false, } }, "ok/custom-ca-staging": func(t *testing.T) test { return test{ args: args{ acmeMgr: newACMEIssuer(), opts: config.AutocertOptions{ CA: "test-ca.example.com/directory", Email: "test@example.com", UseStaging: true, }, }, expected: &certmagic.ACMEIssuer{ Agreed: true, CA: "test-ca.example.com/directory", Email: "test@example.com", TestCA: certmagic.DefaultACME.TestCA, }, wantErr: false, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { if err := configureCertificateAuthority(tc.args.acmeMgr, tc.args.opts); (err != nil) != tc.wantErr { t.Errorf("configureCertificateAuthority() error = %v, wantErr %v", err, tc.wantErr) } if !cmp.Equal(tc.expected, tc.args.acmeMgr, cmpopts.IgnoreUnexported(certmagic.ACMEIssuer{})) { t.Errorf("configureCertificateAuthority() diff = %s", cmp.Diff(tc.expected, tc.args.acmeMgr, cmpopts.IgnoreUnexported(certmagic.ACMEIssuer{}))) } }) } } func Test_configureExternalAccountBinding(t *testing.T) { type args struct { acmeMgr *certmagic.ACMEIssuer opts config.AutocertOptions } type test struct { args args expected *certmagic.ACMEIssuer wantErr bool } tests := map[string]func(t *testing.T) test{ "ok": func(t *testing.T) test { return test{ args: args{ acmeMgr: newACMEIssuer(), opts: config.AutocertOptions{ EABKeyID: "keyID", EABMACKey: "29D7t6-mOuEV5vvBRX0UYF5T7x6fomidhM1kMJco-yw", }, }, expected: &certmagic.ACMEIssuer{ CA: certmagic.DefaultACME.CA, TestCA: certmagic.DefaultACME.TestCA, ExternalAccount: &acme.EAB{ KeyID: "keyID", MACKey: "29D7t6-mOuEV5vvBRX0UYF5T7x6fomidhM1kMJco-yw", }, }, wantErr: false, } }, "fail/error-decoding-mac-key": func(t *testing.T) test { return test{ args: args{ acmeMgr: newACMEIssuer(), opts: config.AutocertOptions{ EABKeyID: "keyID", EABMACKey: ">invalid-base-64-data<", }, }, wantErr: true, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { err := configureExternalAccountBinding(tc.args.acmeMgr, tc.args.opts) if (err != nil) != tc.wantErr { t.Errorf("configureExternalAccountBinding() error = %v, wantErr %v", err, tc.wantErr) } if err == nil && !cmp.Equal(tc.expected, tc.args.acmeMgr, cmpopts.IgnoreUnexported(certmagic.ACMEIssuer{})) { t.Errorf("configureCertificateAuthority() diff = %s", cmp.Diff(tc.expected, tc.args.acmeMgr, cmpopts.IgnoreUnexported(certmagic.ACMEIssuer{}))) } }) } } func Test_configureTrustedRoots(t *testing.T) { ca, err := newTestCA() require.NoError(t, err) type args struct { acmeMgr *certmagic.ACMEIssuer opts config.AutocertOptions } type test struct { args args expected *certmagic.ACMEIssuer wantErr bool cleanup func() } tests := map[string]func(t *testing.T) test{ "ok/pem": func(t *testing.T) test { roots, err := x509.SystemCertPool() require.NoError(t, err) ok := roots.AppendCertsFromPEM(ca.certPEM) require.Equal(t, true, ok) return test{ args: args{ acmeMgr: newACMEIssuer(), opts: config.AutocertOptions{ TrustedCA: base64.StdEncoding.EncodeToString(ca.certPEM), }, }, expected: &certmagic.ACMEIssuer{ CA: certmagic.DefaultACME.CA, TestCA: certmagic.DefaultACME.TestCA, TrustedRoots: roots, }, wantErr: false, } }, "ok/file": func(t *testing.T) test { roots, err := x509.SystemCertPool() require.NoError(t, err) ok := roots.AppendCertsFromPEM(ca.certPEM) require.Equal(t, true, ok) f, err := os.CreateTemp("", "pomerium-test-ca") require.NoError(t, err) n, err := f.Write(ca.certPEM) require.NoError(t, err) require.Equal(t, len(ca.certPEM), n) return test{ args: args{ acmeMgr: newACMEIssuer(), opts: config.AutocertOptions{ TrustedCAFile: f.Name(), }, }, expected: &certmagic.ACMEIssuer{ CA: certmagic.DefaultACME.CA, TestCA: certmagic.DefaultACME.TestCA, TrustedRoots: roots, }, wantErr: false, cleanup: func() { os.Remove(f.Name()) }, } }, "fail/pem": func(t *testing.T) test { roots, err := x509.SystemCertPool() require.NoError(t, err) return test{ args: args{ acmeMgr: newACMEIssuer(), opts: config.AutocertOptions{ TrustedCA: ">invalid-base-64-ca-pem<", }, }, expected: &certmagic.ACMEIssuer{ CA: certmagic.DefaultACME.CA, TestCA: certmagic.DefaultACME.TestCA, TrustedRoots: roots, }, wantErr: true, } }, "fail/file": func(t *testing.T) test { roots, err := x509.SystemCertPool() require.NoError(t, err) return test{ args: args{ acmeMgr: newACMEIssuer(), opts: config.AutocertOptions{ TrustedCAFile: "some-non-existing-file", }, }, expected: &certmagic.ACMEIssuer{ CA: certmagic.DefaultACME.CA, TestCA: certmagic.DefaultACME.TestCA, TrustedRoots: roots, }, wantErr: true, } }, } for name, run := range tests { tc := run(t) t.Run(name, func(t *testing.T) { err := configureTrustedRoots(tc.args.acmeMgr, tc.args.opts) if (err != nil) != tc.wantErr { t.Errorf("configureTrustedRoots() error = %v, wantErr %v", err, tc.wantErr) } if err == nil && !cmp.Equal(tc.expected, tc.args.acmeMgr, cmpopts.IgnoreUnexported(certmagic.ACMEIssuer{}, x509.CertPool{})) { t.Errorf("configureCertificateAuthority() diff = %s", cmp.Diff(tc.expected, tc.args.acmeMgr, cmpopts.IgnoreUnexported(certmagic.ACMEIssuer{}, x509.CertPool{}))) } if err == nil && !cmp.Equal(tc.expected.TrustedRoots.Subjects(), tc.args.acmeMgr.TrustedRoots.Subjects()) { t.Errorf("configureCertificateAuthority() subjects diff = %s", cmp.Diff(tc.expected.TrustedRoots.Subjects(), tc.args.acmeMgr.TrustedRoots.Subjects())) } if tc.cleanup != nil { tc.cleanup() } }) } }