From ee1fefb2186d97c8de8f1e8cbe09bad37209e4a1 Mon Sep 17 00:00:00 2001 From: "backport-actions-token[bot]" <87506591+backport-actions-token[bot]@users.noreply.github.com> Date: Wed, 8 Mar 2023 09:18:37 -0700 Subject: [PATCH] hpke: move published public keys to a new endpoint (#4048) hpke: move published public keys to a new endpoint (#4044) Co-authored-by: Caleb Doxsey --- authorize/check_response_test.go | 14 +++------ config/config.go | 3 +- internal/controlplane/http.go | 4 ++- internal/controlplane/server_test.go | 21 +++++++++---- internal/handlers/hpke_public_key.go | 33 ++++++++++++++++++++ internal/handlers/hpke_public_key_test.go | 34 +++++++++++++++++++++ internal/handlers/jwks.go | 6 +--- internal/handlers/jwks_test.go | 17 ++--------- internal/urlutil/known.go | 3 ++ pkg/hpke/hpke.go | 24 +++++++++++++++ pkg/hpke/{jwks.go => http.go} | 37 +++++++---------------- pkg/hpke/{jwks_test.go => http_test.go} | 4 +-- proxy/proxy_test.go | 7 ++--- 13 files changed, 136 insertions(+), 71 deletions(-) create mode 100644 internal/handlers/hpke_public_key.go create mode 100644 internal/handlers/hpke_public_key_test.go rename pkg/hpke/{jwks.go => http.go} (52%) rename pkg/hpke/{jwks_test.go => http_test.go} (82%) diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index 79c3ec23b..29cc87f47 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -30,13 +30,10 @@ func TestAuthorize_handleResult(t *testing.T) { opt.DataBrokerURLString = "https://databroker.example.com" opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" - htpkePrivateKey, err := opt.GetHPKEPrivateKey() + hpkePrivateKey, err := opt.GetHPKEPrivateKey() require.NoError(t, err) - signingKey, err := opt.GetSigningKey() - require.NoError(t, err) - - authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey())) + authnSrv := httptest.NewServer(handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey())) t.Cleanup(authnSrv.Close) opt.AuthenticateURLString = authnSrv.URL @@ -228,13 +225,10 @@ func TestRequireLogin(t *testing.T) { opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" opt.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUJlMFRxbXJkSXBZWE03c3pSRERWYndXOS83RWJHVWhTdFFJalhsVHNXM1BvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFb0xaRDI2bEdYREhRQmhhZkdlbEVmRDdlNmYzaURjWVJPVjdUbFlIdHF1Y1BFL2hId2dmYQpNY3FBUEZsRmpueUpySXJhYTFlQ2xZRTJ6UktTQk5kNXBRPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=" - htpkePrivateKey, err := opt.GetHPKEPrivateKey() + hpkePrivateKey, err := opt.GetHPKEPrivateKey() require.NoError(t, err) - signingKey, err := opt.GetSigningKey() - require.NoError(t, err) - - authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey())) + authnSrv := httptest.NewServer(handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey())) t.Cleanup(authnSrv.Close) opt.AuthenticateURLString = authnSrv.URL diff --git a/config/config.go b/config/config.go index 38d220253..ca7b77b15 100644 --- a/config/config.go +++ b/config/config.go @@ -14,6 +14,7 @@ import ( "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/metrics" + "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/derivecert" "github.com/pomerium/pomerium/pkg/hpke" @@ -248,7 +249,7 @@ func (cfg *Config) GetAuthenticateKeyFetcher() (hpke.KeyFetcher, error) { return nil, err } jwksURL := authenticateURL.ResolveReference(&url.URL{ - Path: "/.well-known/pomerium/jwks.json", + Path: urlutil.HPKEPublicKeyPath, }).String() return hpke.NewKeyFetcher(jwksURL, transport), nil } diff --git a/internal/controlplane/http.go b/internal/controlplane/http.go index 85ce1b4a0..520f0cdb6 100644 --- a/internal/controlplane/http.go +++ b/internal/controlplane/http.go @@ -15,6 +15,7 @@ import ( "github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/telemetry/requestid" + "github.com/pomerium/pomerium/internal/urlutil" ) func (srv *Server) addHTTPMiddleware(root *mux.Router, cfg *config.Config) { @@ -68,6 +69,7 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er root.HandleFunc("/ping", handlers.HealthCheck) root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL)) root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL)) - root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(signingKey, hpkePublicKey)) + root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(signingKey)) + root.Path(urlutil.HPKEPublicKeyPath).Methods(http.MethodGet).Handler(handlers.HPKEPublicKeyHandler(hpkePublicKey)) return nil } diff --git a/internal/controlplane/server_test.go b/internal/controlplane/server_test.go index c38a17f2a..4f953a608 100644 --- a/internal/controlplane/server_test.go +++ b/internal/controlplane/server_test.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "io" "net/http" "testing" @@ -77,14 +78,22 @@ func TestServerHTTP(t *testing.T) { "x": "UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo", "y": "KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ", }, - map[string]any{ - "kty": "OKP", - "kid": "pomerium/hpke", - "crv": "X25519", - "x": "T0cbNrJbO9in-FgowKAP-HX6Ci8q50gopOt52sdheHg", - }, }, } assert.Equal(t, expect, actual) }) + t.Run("hpke-public-key", func(t *testing.T) { + res, err := http.Get(fmt.Sprintf("http://localhost:%s/.well-known/pomerium/hpke-public-key", src.GetConfig().HTTPPort)) + require.NoError(t, err) + defer res.Body.Close() + + bs, err := io.ReadAll(res.Body) + require.NoError(t, err) + assert.Equal(t, []byte{ + 0x4f, 0x47, 0x1b, 0x36, 0xb2, 0x5b, 0x3b, 0xd8, + 0xa7, 0xf8, 0x58, 0x28, 0xc0, 0xa0, 0x0f, 0xf8, + 0x75, 0xfa, 0x0a, 0x2f, 0x2a, 0xe7, 0x48, 0x28, + 0xa4, 0xeb, 0x79, 0xda, 0xc7, 0x61, 0x78, 0x78, + }, bs) + }) } diff --git a/internal/handlers/hpke_public_key.go b/internal/handlers/hpke_public_key.go new file mode 100644 index 000000000..114519bfd --- /dev/null +++ b/internal/handlers/hpke_public_key.go @@ -0,0 +1,33 @@ +package handlers + +import ( + "bytes" + "fmt" + "hash/fnv" + "net/http" + "strconv" + "time" + + "github.com/rs/cors" + + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/pkg/hpke" +) + +// HPKEPublicKeyHandler returns a handler which returns the HPKE public key. +func HPKEPublicKeyHandler(publicKey *hpke.PublicKey) http.Handler { + return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + bs := publicKey.Bytes() + + hasher := fnv.New64() + _, _ = hasher.Write(bs) + h := hasher.Sum64() + + w.Header().Set("Cache-Control", "max-age=60") + w.Header().Set("Content-Type", "application/octet-stream") + w.Header().Set("Content-Length", strconv.Itoa(len(bs))) + w.Header().Set("ETag", fmt.Sprintf(`"%x"`, h)) + http.ServeContent(w, r, "hpke-public-key", time.Time{}, bytes.NewReader(bs)) + return nil + })) +} diff --git a/internal/handlers/hpke_public_key_test.go b/internal/handlers/hpke_public_key_test.go new file mode 100644 index 000000000..fc753a948 --- /dev/null +++ b/internal/handlers/hpke_public_key_test.go @@ -0,0 +1,34 @@ +package handlers_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/handlers" + "github.com/pomerium/pomerium/pkg/hpke" +) + +func TestHPKEPublicKeyHandler(t *testing.T) { + t.Parallel() + + k1 := hpke.DerivePrivateKey([]byte("TEST")) + + t.Run("cors", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodOptions, "/", nil) + r.Header.Set("Origin", "https://www.example.com") + r.Header.Set("Access-Control-Request-Method", "GET") + handlers.HPKEPublicKeyHandler(k1.PublicKey()).ServeHTTP(w, r) + assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) + }) + t.Run("keys", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + handlers.HPKEPublicKeyHandler(k1.PublicKey()).ServeHTTP(w, r) + + assert.Equal(t, k1.PublicKey().Bytes(), w.Body.Bytes()) + }) +} diff --git a/internal/handlers/jwks.go b/internal/handlers/jwks.go index 68118b13f..f540f5d77 100644 --- a/internal/handlers/jwks.go +++ b/internal/handlers/jwks.go @@ -17,10 +17,7 @@ import ( ) // JWKSHandler returns the /.well-known/pomerium/jwks.json handler. -func JWKSHandler( - signingKey []byte, - additionalKeys ...any, -) http.Handler { +func JWKSHandler(signingKey []byte) http.Handler { return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { var jwks struct { Keys []any `json:"keys"` @@ -34,7 +31,6 @@ func JWKSHandler( jwks.Keys = append(jwks.Keys, *k) } } - jwks.Keys = append(jwks.Keys, additionalKeys...) bs, err := json.Marshal(jwks) if err != nil { diff --git a/internal/handlers/jwks_test.go b/internal/handlers/jwks_test.go index d05b0f78b..086adb1bf 100644 --- a/internal/handlers/jwks_test.go +++ b/internal/handlers/jwks_test.go @@ -13,7 +13,6 @@ import ( "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/pkg/cryptutil" - "github.com/pomerium/pomerium/pkg/hpke" ) func TestJWKSHandler(t *testing.T) { @@ -34,24 +33,18 @@ func TestJWKSHandler(t *testing.T) { jwkSigningKey2, err := cryptutil.PublicJWKFromBytes(rawSigningKey2) require.NoError(t, err) - hpkePrivateKey, err := hpke.GeneratePrivateKey() - require.NoError(t, err) - t.Run("cors", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodOptions, "/", nil) r.Header.Set("Origin", "https://www.example.com") r.Header.Set("Access-Control-Request-Method", "GET") - handlers.JWKSHandler(nil, hpkePrivateKey.PublicKey()).ServeHTTP(w, r) + handlers.JWKSHandler(nil).ServeHTTP(w, r) assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) }) t.Run("keys", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", nil) - handlers.JWKSHandler( - append(rawSigningKey1, rawSigningKey2...), - hpkePrivateKey.PublicKey(), - ).ServeHTTP(w, r) + handlers.JWKSHandler(append(rawSigningKey1, rawSigningKey2...)).ServeHTTP(w, r) var expect any = map[string]any{ "keys": []any{ @@ -73,12 +66,6 @@ func TestJWKSHandler(t *testing.T) { "x": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).X.Bytes()), "y": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).Y.Bytes()), }, - map[string]any{ - "kty": "OKP", - "kid": "pomerium/hpke", - "crv": "X25519", - "x": hpkePrivateKey.PublicKey().String(), - }, }, } var actual any diff --git a/internal/urlutil/known.go b/internal/urlutil/known.go index 6ef226cb9..c27c6c985 100644 --- a/internal/urlutil/known.go +++ b/internal/urlutil/known.go @@ -5,6 +5,9 @@ import ( "net/url" ) +// HPKEPublicKeyPath is the well-known path to the HPKE public key +const HPKEPublicKeyPath = "/.well-known/pomerium/hpke-public-key" + // DefaultDeviceType is the default device type when none is specified. const DefaultDeviceType = "any" diff --git a/pkg/hpke/hpke.go b/pkg/hpke/hpke.go index 5cfb031ac..64e2bbd66 100644 --- a/pkg/hpke/hpke.go +++ b/pkg/hpke/hpke.go @@ -97,6 +97,16 @@ type PublicKey struct { key kem.PublicKey } +// PublicKeyFromBytes converts raw bytes into a public key. +func PublicKeyFromBytes(raw []byte) (*PublicKey, error) { + key, err := kemID.Scheme().UnmarshalBinaryPublicKey(raw) + if err != nil { + return nil, err + } + + return &PublicKey{key: key}, nil +} + // PublicKeyFromString converts a string into a public key. func PublicKeyFromString(raw string) (*PublicKey, error) { bs, err := decode(raw) @@ -128,6 +138,20 @@ func (key *PublicKey) Equals(other *PublicKey) bool { return key.key.Equal(other.key) } +// Bytes returns the public key as raw bytes. +func (key *PublicKey) Bytes() []byte { + if key == nil || key.key == nil { + return nil + } + + bs, err := key.key.MarshalBinary() + if err != nil { + // this should not happen + panic(fmt.Sprintf("failed to marshal public HPKE key: %v", err)) + } + return bs +} + // MarshalJSON returns the JSON Web Key representation of the public key. func (key *PublicKey) MarshalJSON() ([]byte, error) { return json.Marshal(JWK{ diff --git a/pkg/hpke/jwks.go b/pkg/hpke/http.go similarity index 52% rename from pkg/hpke/jwks.go rename to pkg/hpke/http.go index 9fcaf9246..3d1b1528d 100644 --- a/pkg/hpke/jwks.go +++ b/pkg/hpke/http.go @@ -2,7 +2,6 @@ package hpke import ( "context" - "encoding/json" "fmt" "io" "net/http" @@ -28,42 +27,28 @@ type JWK struct { D string `json:"d,omitempty"` } -// FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint. -func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) { +// FetchPublicKey fetches the HPKE public key from the hpke-public-key endpoint. +func FetchPublicKey(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) if err != nil { - return nil, fmt.Errorf("hpke: error building jwks http request: %w", err) + return nil, fmt.Errorf("hpke: error building hpke-public-key http request: %w", err) } res, err := client.Do(req) if err != nil { - return nil, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err) + return nil, fmt.Errorf("hpke: error requesting hpke-public-key endpoint: %w", err) } defer res.Body.Close() if res.StatusCode/100 != 2 { - return nil, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode) + return nil, fmt.Errorf("hpke: error requesting hpke-public-key endpoint, invalid status code: %d", res.StatusCode) } bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize)) if err != nil { - return nil, fmt.Errorf("hpke: error reading jwks endpoint: %w", err) + return nil, fmt.Errorf("hpke: error reading hpke-public-key endpoint: %w", err) } - - var jwks struct { - Keys []JWK `json:"keys"` - } - err = json.Unmarshal(bs, &jwks) - if err != nil { - return nil, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err) - } - - for _, key := range jwks.Keys { - if key.ID == jwkID && key.Type == jwkType && key.Curve == jwkCurve { - return PublicKeyFromString(key.X) - } - } - return nil, fmt.Errorf("hpke key not found in JWKS endpoint") + return PublicKeyFromBytes(bs) } // A KeyFetcher fetches public keys. @@ -71,18 +56,18 @@ type KeyFetcher interface { FetchPublicKey(ctx context.Context) (*PublicKey, error) } -type jwksKeyFetcher struct { +type fetcher struct { client *http.Client endpoint string } -func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) { - return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint) +func (fetcher *fetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) { + return FetchPublicKey(ctx, fetcher.client, fetcher.endpoint) } // NewKeyFetcher returns a new KeyFetcher which fetches keys using an in-memory HTTP cache. func NewKeyFetcher(endpoint string, transport http.RoundTripper) KeyFetcher { - return &jwksKeyFetcher{ + return &fetcher{ client: (&httpcache.Transport{ Transport: transport, Cache: httpcache.NewMemoryCache(), diff --git a/pkg/hpke/jwks_test.go b/pkg/hpke/http_test.go similarity index 82% rename from pkg/hpke/jwks_test.go rename to pkg/hpke/http_test.go index 0bcf25d27..e90f2732b 100644 --- a/pkg/hpke/jwks_test.go +++ b/pkg/hpke/http_test.go @@ -24,11 +24,11 @@ func TestFetchPublicKeyFromJWKS(t *testing.T) { require.NoError(t, err) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handlers.JWKSHandler(nil, hpkePrivateKey.PublicKey()).ServeHTTP(w, r) + handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()).ServeHTTP(w, r) })) t.Cleanup(srv.Close) - publicKey, err := hpke.FetchPublicKeyFromJWKS(ctx, http.DefaultClient, srv.URL) + publicKey, err := hpke.FetchPublicKey(ctx, http.DefaultClient, srv.URL) assert.NoError(t, err) assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String()) } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 20057ae53..b677e0bb4 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -30,13 +30,10 @@ func testOptions(t *testing.T) *config.Options { opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=" opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=" - htpkePrivateKey, err := opts.GetHPKEPrivateKey() + hpkePrivateKey, err := opts.GetHPKEPrivateKey() require.NoError(t, err) - signingKey, err := opts.GetSigningKey() - require.NoError(t, err) - - authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey())) + authnSrv := httptest.NewServer(handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey())) t.Cleanup(authnSrv.Close) opts.AuthenticateURLString = authnSrv.URL