diff --git a/config/options.go b/config/options.go index 02dabcc46..22bfb5b54 100644 --- a/config/options.go +++ b/config/options.go @@ -30,6 +30,7 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/config" + "github.com/pomerium/pomerium/pkg/hpke" ) // DisableHeaderKey is the key used to check whether to disable setting header @@ -997,6 +998,16 @@ func (o *Options) GetSharedKey() ([]byte, error) { return base64.StdEncoding.DecodeString(sharedKey) } +// GetHPKEPrivateKey gets the hpke.PrivateKey dervived from the shared key. +func (o *Options) GetHPKEPrivateKey() (*hpke.PrivateKey, error) { + sharedKey, err := o.GetSharedKey() + if err != nil { + return nil, err + } + + return hpke.DerivePrivateKey(sharedKey), nil +} + // GetGoogleCloudServerlessAuthenticationServiceAccount gets the GoogleCloudServerlessAuthenticationServiceAccount. func (o *Options) GetGoogleCloudServerlessAuthenticationServiceAccount() string { return o.GoogleCloudServerlessAuthenticationServiceAccount diff --git a/go.mod b/go.mod index 110ef6ae3..5367c9489 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/gorilla/handlers v1.5.1 github.com/gorilla/mux v1.8.0 github.com/gorilla/websocket v1.5.0 + github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 github.com/hashicorp/go-multierror v1.1.1 github.com/hashicorp/golang-lru v0.5.4 github.com/jackc/pgconn v1.13.0 diff --git a/go.sum b/go.sum index 1f6569bad..0a7b85bd3 100644 --- a/go.sum +++ b/go.sum @@ -507,6 +507,8 @@ github.com/gostaticanalysis/nilerr v0.1.1 h1:ThE+hJP0fEp4zWLkWHWcRyI2Od0p7DlgYG3 github.com/gostaticanalysis/nilerr v0.1.1/go.mod h1:wZYb6YI5YAxxq0i1+VJbY0s2YONW0HU0GPE3+5PWN4A= github.com/gostaticanalysis/testutil v0.3.1-0.20210208050101-bfb5c8eec0e4/go.mod h1:D+FIZ+7OahH3ePw/izIEeH5I06eKs1IKI4Xr64/Am3M= github.com/gostaticanalysis/testutil v0.4.0 h1:nhdCmubdmDF6VEatUNjgUZBJKWRqugoISdUv3PPQgHY= +github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 h1:+ngKgrYPPJrOjhax5N+uePQ0Fh1Z7PheYoUI/0nzkPA= +github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= diff --git a/internal/controlplane/http.go b/internal/controlplane/http.go index d7ce27c5a..3e4c0d74e 100644 --- a/internal/controlplane/http.go +++ b/internal/controlplane/http.go @@ -60,10 +60,16 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er return fmt.Errorf("invalid signing key: %w", err) } + hpkePrivateKey, err := cfg.Options.GetHPKEPrivateKey() + if err != nil { + return fmt.Errorf("invalid hpke private key: %w", err) + } + hpkePublicKey := hpkePrivateKey.PublicKey() + root.HandleFunc("/healthz", handlers.HealthCheck) root.HandleFunc("/ping", handlers.HealthCheck) root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL)) root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL)) - root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey)) + root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey, hpkePublicKey)) return nil } diff --git a/internal/controlplane/server_test.go b/internal/controlplane/server_test.go index 981002af3..c38a17f2a 100644 --- a/internal/controlplane/server_test.go +++ b/internal/controlplane/server_test.go @@ -34,6 +34,7 @@ func TestServerHTTP(t *testing.T) { } cfg.Options.AuthenticateURLString = "https://authenticate.localhost.pomerium.io" cfg.Options.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=" + cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs=" src := config.NewStaticSource(cfg) srv, err := NewServer(cfg, config.NewMetricsManager(ctx, src), events.New()) @@ -66,15 +67,23 @@ func TestServerHTTP(t *testing.T) { require.NoError(t, err) expect := map[string]any{ - "keys": []any{map[string]any{ - "alg": "ES256", - "crv": "P-256", - "kid": "5b419ade1895fec2d2def6cd33b1b9a018df60db231dc5ecb85cbed6d942813c", - "kty": "EC", - "use": "sig", - "x": "UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo", - "y": "KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ", - }}, + "keys": []any{ + map[string]any{ + "alg": "ES256", + "crv": "P-256", + "kid": "5b419ade1895fec2d2def6cd33b1b9a018df60db231dc5ecb85cbed6d942813c", + "kty": "EC", + "use": "sig", + "x": "UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo", + "y": "KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ", + }, + map[string]any{ + "kty": "OKP", + "kid": "pomerium/hpke", + "crv": "X25519", + "x": "T0cbNrJbO9in-FgowKAP-HX6Ci8q50gopOt52sdheHg", + }, + }, } assert.Equal(t, expect, actual) }) diff --git a/internal/handlers/jwks.go b/internal/handlers/jwks.go index 7abee06b0..149a73b4f 100644 --- a/internal/handlers/jwks.go +++ b/internal/handlers/jwks.go @@ -1,11 +1,16 @@ package handlers import ( + "bytes" "encoding/base64" + "encoding/json" "errors" + "fmt" + "hash/fnv" "net/http" + "strconv" + "time" - "github.com/go-jose/go-jose/v3" "github.com/rs/cors" "github.com/pomerium/pomerium/internal/httputil" @@ -13,9 +18,14 @@ import ( ) // JWKSHandler returns the /.well-known/pomerium/jwks.json handler. -func JWKSHandler(rawSigningKey string) http.Handler { +func JWKSHandler( + rawSigningKey string, + additionalKeys ...any, +) http.Handler { return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - var jwks jose.JSONWebKeySet + var jwks struct { + Keys []any `json:"keys"` + } if rawSigningKey != "" { decodedCert, err := base64.StdEncoding.DecodeString(rawSigningKey) if err != nil { @@ -27,7 +37,22 @@ func JWKSHandler(rawSigningKey string) http.Handler { } jwks.Keys = append(jwks.Keys, *jwk) } - httputil.RenderJSON(w, http.StatusOK, jwks) + jwks.Keys = append(jwks.Keys, additionalKeys...) + + bs, err := json.Marshal(jwks) + if err != nil { + return err + } + + hasher := fnv.New64() + _, _ = hasher.Write(bs) + h := hasher.Sum64() + + w.Header().Set("Cache-Control", "max-age=60") + w.Header().Set("Content-Type", "application/json") + w.Header().Set("Content-Length", strconv.Itoa(len(bs))) + w.Header().Set("ETag", fmt.Sprintf(`"%x"`, h)) + http.ServeContent(w, r, "jwks.json", time.Time{}, bytes.NewReader(bs)) return nil })) } diff --git a/internal/handlers/jwks_test.go b/internal/handlers/jwks_test.go index 3c3442b89..335cd77eb 100644 --- a/internal/handlers/jwks_test.go +++ b/internal/handlers/jwks_test.go @@ -1,22 +1,71 @@ -package handlers +package handlers_test import ( + "crypto/ecdsa" + "encoding/base64" + "encoding/json" "net/http" "net/http/httptest" "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/handlers" + "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/hpke" ) func TestJWKSHandler(t *testing.T) { t.Parallel() + signingKey, err := cryptutil.NewSigningKey() + require.NoError(t, err) + + rawSigningKey, err := cryptutil.EncodePrivateKey(signingKey) + require.NoError(t, err) + + jwkSigningKey, err := cryptutil.PublicJWKFromBytes(rawSigningKey) + require.NoError(t, err) + + hpkePrivateKey, err := hpke.GeneratePrivateKey() + require.NoError(t, err) + t.Run("cors", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodOptions, "/", nil) r.Header.Set("Origin", "https://www.example.com") r.Header.Set("Access-Control-Request-Method", "GET") - JWKSHandler("").ServeHTTP(w, r) + handlers.JWKSHandler("", hpkePrivateKey.PublicKey()).ServeHTTP(w, r) assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) }) + t.Run("keys", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/", nil) + handlers.JWKSHandler(base64.StdEncoding.EncodeToString(rawSigningKey), hpkePrivateKey.PublicKey()).ServeHTTP(w, r) + + var expect any = map[string]any{ + "keys": []any{ + map[string]any{ + "kty": "EC", + "kid": jwkSigningKey.KeyID, + "crv": "P-256", + "alg": "ES256", + "use": "sig", + "x": base64.RawURLEncoding.EncodeToString(jwkSigningKey.Key.(*ecdsa.PublicKey).X.Bytes()), + "y": base64.RawURLEncoding.EncodeToString(jwkSigningKey.Key.(*ecdsa.PublicKey).Y.Bytes()), + }, + map[string]any{ + "kty": "OKP", + "kid": "pomerium/hpke", + "crv": "X25519", + "x": hpkePrivateKey.PublicKey().String(), + }, + }, + } + var actual any + err := json.Unmarshal(w.Body.Bytes(), &actual) + assert.NoError(t, err) + assert.Equal(t, expect, actual) + }) } diff --git a/pkg/hpke/hpke.go b/pkg/hpke/hpke.go index 75d69029b..5cfb031ac 100644 --- a/pkg/hpke/hpke.go +++ b/pkg/hpke/hpke.go @@ -4,6 +4,7 @@ package hpke import ( "crypto/rand" "encoding/base64" + "encoding/json" "fmt" "github.com/cloudflare/circl/hpke" @@ -25,44 +26,63 @@ type PrivateKey struct { } // DerivePrivateKey derives a private key from a seed. The same seed will always result in the same private key. -func DerivePrivateKey(seed []byte) PrivateKey { +func DerivePrivateKey(seed []byte) *PrivateKey { pk := kdfID.Extract(seed, nil) data := kdfID.Expand(pk, kdfExpandInfo, uint(kemID.Scheme().SeedSize())) _, key := kemID.Scheme().DeriveKeyPair(data) - return PrivateKey{key: key} + return &PrivateKey{key: key} } // GeneratePrivateKey generates an HPKE private key. -func GeneratePrivateKey() (PrivateKey, error) { +func GeneratePrivateKey() (*PrivateKey, error) { _, privateKey, err := kemID.Scheme().GenerateKeyPair() if err != nil { - return PrivateKey{}, err + return nil, err } - return PrivateKey{key: privateKey}, nil + return &PrivateKey{key: privateKey}, nil } // PrivateKeyFromString takes a string and returns a PrivateKey. -func PrivateKeyFromString(raw string) (PrivateKey, error) { +func PrivateKeyFromString(raw string) (*PrivateKey, error) { bs, err := decode(raw) if err != nil { - return PrivateKey{}, err + return nil, err } key, err := kemID.Scheme().UnmarshalBinaryPrivateKey(bs) if err != nil { - return PrivateKey{}, err + return nil, err } - return PrivateKey{key: key}, nil + return &PrivateKey{key: key}, nil } // PublicKey returns the public key for the private key. -func (key PrivateKey) PublicKey() PublicKey { - return PublicKey{key: key.key.Public()} +func (key *PrivateKey) PublicKey() *PublicKey { + if key == nil || key.key == nil { + return nil + } + + return &PublicKey{key: key.key.Public()} +} + +// MarshalJSON returns the JSON Web Key representation of the private key. +func (key *PrivateKey) MarshalJSON() ([]byte, error) { + return json.Marshal(JWK{ + Type: jwkType, + ID: jwkID, + Curve: jwkCurve, + X: key.PublicKey().String(), + D: key.String(), + }) } // String converts the private key into a string. -func (key PrivateKey) String() string { +func (key *PrivateKey) String() string { + if key == nil || key.key == nil { + return "" + } + bs, err := key.key.MarshalBinary() if err != nil { // this should not happen @@ -78,22 +98,52 @@ type PublicKey struct { } // PublicKeyFromString converts a string into a public key. -func PublicKeyFromString(raw string) (PublicKey, error) { +func PublicKeyFromString(raw string) (*PublicKey, error) { bs, err := decode(raw) if err != nil { - return PublicKey{}, err + return nil, err } key, err := kemID.Scheme().UnmarshalBinaryPublicKey(bs) if err != nil { - return PublicKey{}, err + return nil, err } - return PublicKey{key: key}, nil + return &PublicKey{key: key}, nil +} + +// Equals returns true if the two keys are equivalent. +func (key *PublicKey) Equals(other *PublicKey) bool { + if key == nil && other == nil { + return true + } else if key == nil || other == nil { + return false + } + + if key.key == nil && other.key == nil { + return true + } else if key.key == nil || other.key == nil { + return false + } + return key.key.Equal(other.key) +} + +// MarshalJSON returns the JSON Web Key representation of the public key. +func (key *PublicKey) MarshalJSON() ([]byte, error) { + return json.Marshal(JWK{ + Type: jwkType, + ID: jwkID, + Curve: jwkCurve, + X: key.String(), + }) } // String converts a public key into a string. -func (key PublicKey) String() string { +func (key *PublicKey) String() string { + if key == nil || key.key == nil { + return "" + } + bs, err := key.key.MarshalBinary() if err != nil { // this should not happen @@ -105,10 +155,17 @@ func (key PublicKey) String() string { // Seal seales a message using HPKE. func Seal( - senderPrivateKey PrivateKey, - receiverPublicKey PublicKey, + senderPrivateKey *PrivateKey, + receiverPublicKey *PublicKey, message []byte, ) (sealed []byte, err error) { + if senderPrivateKey == nil { + return nil, fmt.Errorf("hpke: sender private key cannot be nil") + } + if receiverPublicKey == nil { + return nil, fmt.Errorf("hpke: receiver public key cannot be nil") + } + sender, err := suite.NewSender(receiverPublicKey.key, nil) if err != nil { return nil, fmt.Errorf("hpke: error creating sender: %w", err) @@ -129,10 +186,17 @@ func Seal( // Open opens a message using HPKE. func Open( - receiverPrivateKey PrivateKey, - senderPublicKey PublicKey, + receiverPrivateKey *PrivateKey, + senderPublicKey *PublicKey, sealed []byte, ) (message []byte, err error) { + if receiverPrivateKey == nil { + return nil, fmt.Errorf("hpke: receiver private key cannot be nil") + } + if senderPublicKey == nil { + return nil, fmt.Errorf("hpke: sender public key cannot be nil") + } + encSize := kemID.Scheme().SharedKeySize() if len(sealed) < encSize { return nil, fmt.Errorf("hpke: invalid sealed message") diff --git a/pkg/hpke/jwks.go b/pkg/hpke/jwks.go new file mode 100644 index 000000000..3e71676ee --- /dev/null +++ b/pkg/hpke/jwks.go @@ -0,0 +1,89 @@ +package hpke + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + + "github.com/gregjones/httpcache" +) + +const ( + defaultMaxBodySize = 1024 * 1024 * 4 + + jwkType = "OKP" + jwkID = "pomerium/hpke" + jwkCurve = "X25519" +) + +// JWK is the JSON Web Key representation of an HPKE key. +// Defined in RFC8037. +type JWK struct { + Type string `json:"kty"` + ID string `json:"kid"` + Curve string `json:"crv"` + X string `json:"x"` + D string `json:"d,omitempty"` +} + +// FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint. +func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) + if err != nil { + return nil, fmt.Errorf("hpke: error building jwks http request: %w", err) + } + + res, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err) + } + defer res.Body.Close() + + if res.StatusCode/100 != 2 { + return nil, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode) + } + + bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize)) + if err != nil { + return nil, fmt.Errorf("hpke: error reading jwks endpoint: %w", err) + } + + var jwks struct { + Keys []JWK `json:"keys"` + } + err = json.Unmarshal(bs, &jwks) + if err != nil { + return nil, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err) + } + + for _, key := range jwks.Keys { + if key.ID == jwkID && key.Type == jwkType && key.Curve == jwkCurve { + return PublicKeyFromString(key.X) + } + } + return nil, fmt.Errorf("hpke key not found in JWKS endpoint") +} + +// A KeyFetcher fetches public keys. +type KeyFetcher interface { + FetchPublicKey(ctx context.Context) (*PublicKey, error) +} + +type jwksKeyFetcher struct { + client *http.Client + endpoint string +} + +func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) { + return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint) +} + +// NewKeyFetcher returns a new KeyFetcher which fetches keys using an in-memory HTTP cache. +func NewKeyFetcher(endpoint string) KeyFetcher { + return &jwksKeyFetcher{ + client: httpcache.NewMemoryCacheTransport().Client(), + endpoint: endpoint, + } +} diff --git a/pkg/hpke/jwks_test.go b/pkg/hpke/jwks_test.go new file mode 100644 index 000000000..7a85f917d --- /dev/null +++ b/pkg/hpke/jwks_test.go @@ -0,0 +1,33 @@ +package hpke + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/handlers" +) + +func TestFetchPublicKeyFromJWKS(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + t.Cleanup(clearTimeout) + + hpkePrivateKey, err := GeneratePrivateKey() + require.NoError(t, err) + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + handlers.JWKSHandler("", hpkePrivateKey.PublicKey()).ServeHTTP(w, r) + })) + t.Cleanup(srv.Close) + + publicKey, err := FetchPublicKeyFromJWKS(ctx, http.DefaultClient, srv.URL) + assert.NoError(t, err) + assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String()) +} diff --git a/pkg/hpke/url.go b/pkg/hpke/url.go index 29188b4eb..49cbbbc19 100644 --- a/pkg/hpke/url.go +++ b/pkg/hpke/url.go @@ -18,8 +18,8 @@ func IsEncryptedURL(values url.Values) bool { // EncryptURLValues encrypts URL values using the Seal method. func EncryptURLValues( - senderPrivateKey PrivateKey, - receiverPublicKey PublicKey, + senderPrivateKey *PrivateKey, + receiverPublicKey *PublicKey, values url.Values, ) (encrypted url.Values, err error) { values = withoutHPKEParams(values) @@ -37,34 +37,34 @@ func EncryptURLValues( // DecryptURLValues decrypts URL values using the Open method. func DecryptURLValues( - receiverPrivateKey PrivateKey, + receiverPrivateKey *PrivateKey, encrypted url.Values, -) (senderPublicKey PublicKey, values url.Values, err error) { +) (senderPublicKey *PublicKey, values url.Values, err error) { if !encrypted.Has(ParamSenderPublicKey) { - return senderPublicKey, nil, fmt.Errorf("hpke: missing sender public key in query parameters") + return nil, nil, fmt.Errorf("hpke: missing sender public key in query parameters") } if !encrypted.Has(ParamQuery) { - return senderPublicKey, nil, fmt.Errorf("hpke: missing encrypted query in query parameters") + return nil, nil, fmt.Errorf("hpke: missing encrypted query in query parameters") } senderPublicKey, err = PublicKeyFromString(encrypted.Get(ParamSenderPublicKey)) if err != nil { - return senderPublicKey, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err) + return nil, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err) } sealed, err := decode(encrypted.Get(ParamQuery)) if err != nil { - return senderPublicKey, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err) + return nil, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err) } message, err := Open(receiverPrivateKey, senderPublicKey, sealed) if err != nil { - return senderPublicKey, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err) + return nil, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err) } decrypted, err := url.ParseQuery(string(message)) if err != nil { - return senderPublicKey, nil, fmt.Errorf("hpke: invalid query parameter: %w", err) + return nil, nil, fmt.Errorf("hpke: invalid query parameter: %w", err) } values = withoutHPKEParams(encrypted)