hpke: add HPKE key to JWKS endpoint (#3762)

* hpke: add HPKE key to JWKS endpoint

* fix test, add http caching headers

* fix error message

* use pointers
This commit is contained in:
Caleb Doxsey 2022-11-23 08:45:59 -07:00 committed by GitHub
parent 52c967b8a5
commit ba07afc245
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 336 additions and 47 deletions

View file

@ -30,6 +30,7 @@ import (
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/hpke"
)
// DisableHeaderKey is the key used to check whether to disable setting header
@ -997,6 +998,16 @@ func (o *Options) GetSharedKey() ([]byte, error) {
return base64.StdEncoding.DecodeString(sharedKey)
}
// GetHPKEPrivateKey gets the hpke.PrivateKey dervived from the shared key.
func (o *Options) GetHPKEPrivateKey() (*hpke.PrivateKey, error) {
sharedKey, err := o.GetSharedKey()
if err != nil {
return nil, err
}
return hpke.DerivePrivateKey(sharedKey), nil
}
// GetGoogleCloudServerlessAuthenticationServiceAccount gets the GoogleCloudServerlessAuthenticationServiceAccount.
func (o *Options) GetGoogleCloudServerlessAuthenticationServiceAccount() string {
return o.GoogleCloudServerlessAuthenticationServiceAccount

1
go.mod
View file

@ -31,6 +31,7 @@ require (
github.com/gorilla/handlers v1.5.1
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/golang-lru v0.5.4
github.com/jackc/pgconn v1.13.0

2
go.sum
View file

@ -507,6 +507,8 @@ github.com/gostaticanalysis/nilerr v0.1.1 h1:ThE+hJP0fEp4zWLkWHWcRyI2Od0p7DlgYG3
github.com/gostaticanalysis/nilerr v0.1.1/go.mod h1:wZYb6YI5YAxxq0i1+VJbY0s2YONW0HU0GPE3+5PWN4A=
github.com/gostaticanalysis/testutil v0.3.1-0.20210208050101-bfb5c8eec0e4/go.mod h1:D+FIZ+7OahH3ePw/izIEeH5I06eKs1IKI4Xr64/Am3M=
github.com/gostaticanalysis/testutil v0.4.0 h1:nhdCmubdmDF6VEatUNjgUZBJKWRqugoISdUv3PPQgHY=
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 h1:+ngKgrYPPJrOjhax5N+uePQ0Fh1Z7PheYoUI/0nzkPA=
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk=
github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY=

View file

@ -60,10 +60,16 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er
return fmt.Errorf("invalid signing key: %w", err)
}
hpkePrivateKey, err := cfg.Options.GetHPKEPrivateKey()
if err != nil {
return fmt.Errorf("invalid hpke private key: %w", err)
}
hpkePublicKey := hpkePrivateKey.PublicKey()
root.HandleFunc("/healthz", handlers.HealthCheck)
root.HandleFunc("/ping", handlers.HealthCheck)
root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL))
root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL))
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey))
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey, hpkePublicKey))
return nil
}

View file

@ -34,6 +34,7 @@ func TestServerHTTP(t *testing.T) {
}
cfg.Options.AuthenticateURLString = "https://authenticate.localhost.pomerium.io"
cfg.Options.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
src := config.NewStaticSource(cfg)
srv, err := NewServer(cfg, config.NewMetricsManager(ctx, src), events.New())
@ -66,15 +67,23 @@ func TestServerHTTP(t *testing.T) {
require.NoError(t, err)
expect := map[string]any{
"keys": []any{map[string]any{
"alg": "ES256",
"crv": "P-256",
"kid": "5b419ade1895fec2d2def6cd33b1b9a018df60db231dc5ecb85cbed6d942813c",
"kty": "EC",
"use": "sig",
"x": "UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo",
"y": "KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ",
}},
"keys": []any{
map[string]any{
"alg": "ES256",
"crv": "P-256",
"kid": "5b419ade1895fec2d2def6cd33b1b9a018df60db231dc5ecb85cbed6d942813c",
"kty": "EC",
"use": "sig",
"x": "UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo",
"y": "KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ",
},
map[string]any{
"kty": "OKP",
"kid": "pomerium/hpke",
"crv": "X25519",
"x": "T0cbNrJbO9in-FgowKAP-HX6Ci8q50gopOt52sdheHg",
},
},
}
assert.Equal(t, expect, actual)
})

View file

@ -1,11 +1,16 @@
package handlers
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"hash/fnv"
"net/http"
"strconv"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/rs/cors"
"github.com/pomerium/pomerium/internal/httputil"
@ -13,9 +18,14 @@ import (
)
// JWKSHandler returns the /.well-known/pomerium/jwks.json handler.
func JWKSHandler(rawSigningKey string) http.Handler {
func JWKSHandler(
rawSigningKey string,
additionalKeys ...any,
) http.Handler {
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
var jwks jose.JSONWebKeySet
var jwks struct {
Keys []any `json:"keys"`
}
if rawSigningKey != "" {
decodedCert, err := base64.StdEncoding.DecodeString(rawSigningKey)
if err != nil {
@ -27,7 +37,22 @@ func JWKSHandler(rawSigningKey string) http.Handler {
}
jwks.Keys = append(jwks.Keys, *jwk)
}
httputil.RenderJSON(w, http.StatusOK, jwks)
jwks.Keys = append(jwks.Keys, additionalKeys...)
bs, err := json.Marshal(jwks)
if err != nil {
return err
}
hasher := fnv.New64()
_, _ = hasher.Write(bs)
h := hasher.Sum64()
w.Header().Set("Cache-Control", "max-age=60")
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(bs)))
w.Header().Set("ETag", fmt.Sprintf(`"%x"`, h))
http.ServeContent(w, r, "jwks.json", time.Time{}, bytes.NewReader(bs))
return nil
}))
}

View file

@ -1,22 +1,71 @@
package handlers
package handlers_test
import (
"crypto/ecdsa"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/hpke"
)
func TestJWKSHandler(t *testing.T) {
t.Parallel()
signingKey, err := cryptutil.NewSigningKey()
require.NoError(t, err)
rawSigningKey, err := cryptutil.EncodePrivateKey(signingKey)
require.NoError(t, err)
jwkSigningKey, err := cryptutil.PublicJWKFromBytes(rawSigningKey)
require.NoError(t, err)
hpkePrivateKey, err := hpke.GeneratePrivateKey()
require.NoError(t, err)
t.Run("cors", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodOptions, "/", nil)
r.Header.Set("Origin", "https://www.example.com")
r.Header.Set("Access-Control-Request-Method", "GET")
JWKSHandler("").ServeHTTP(w, r)
handlers.JWKSHandler("", hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
})
t.Run("keys", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
handlers.JWKSHandler(base64.StdEncoding.EncodeToString(rawSigningKey), hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
var expect any = map[string]any{
"keys": []any{
map[string]any{
"kty": "EC",
"kid": jwkSigningKey.KeyID,
"crv": "P-256",
"alg": "ES256",
"use": "sig",
"x": base64.RawURLEncoding.EncodeToString(jwkSigningKey.Key.(*ecdsa.PublicKey).X.Bytes()),
"y": base64.RawURLEncoding.EncodeToString(jwkSigningKey.Key.(*ecdsa.PublicKey).Y.Bytes()),
},
map[string]any{
"kty": "OKP",
"kid": "pomerium/hpke",
"crv": "X25519",
"x": hpkePrivateKey.PublicKey().String(),
},
},
}
var actual any
err := json.Unmarshal(w.Body.Bytes(), &actual)
assert.NoError(t, err)
assert.Equal(t, expect, actual)
})
}

View file

@ -4,6 +4,7 @@ package hpke
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/cloudflare/circl/hpke"
@ -25,44 +26,63 @@ type PrivateKey struct {
}
// DerivePrivateKey derives a private key from a seed. The same seed will always result in the same private key.
func DerivePrivateKey(seed []byte) PrivateKey {
func DerivePrivateKey(seed []byte) *PrivateKey {
pk := kdfID.Extract(seed, nil)
data := kdfID.Expand(pk, kdfExpandInfo, uint(kemID.Scheme().SeedSize()))
_, key := kemID.Scheme().DeriveKeyPair(data)
return PrivateKey{key: key}
return &PrivateKey{key: key}
}
// GeneratePrivateKey generates an HPKE private key.
func GeneratePrivateKey() (PrivateKey, error) {
func GeneratePrivateKey() (*PrivateKey, error) {
_, privateKey, err := kemID.Scheme().GenerateKeyPair()
if err != nil {
return PrivateKey{}, err
return nil, err
}
return PrivateKey{key: privateKey}, nil
return &PrivateKey{key: privateKey}, nil
}
// PrivateKeyFromString takes a string and returns a PrivateKey.
func PrivateKeyFromString(raw string) (PrivateKey, error) {
func PrivateKeyFromString(raw string) (*PrivateKey, error) {
bs, err := decode(raw)
if err != nil {
return PrivateKey{}, err
return nil, err
}
key, err := kemID.Scheme().UnmarshalBinaryPrivateKey(bs)
if err != nil {
return PrivateKey{}, err
return nil, err
}
return PrivateKey{key: key}, nil
return &PrivateKey{key: key}, nil
}
// PublicKey returns the public key for the private key.
func (key PrivateKey) PublicKey() PublicKey {
return PublicKey{key: key.key.Public()}
func (key *PrivateKey) PublicKey() *PublicKey {
if key == nil || key.key == nil {
return nil
}
return &PublicKey{key: key.key.Public()}
}
// MarshalJSON returns the JSON Web Key representation of the private key.
func (key *PrivateKey) MarshalJSON() ([]byte, error) {
return json.Marshal(JWK{
Type: jwkType,
ID: jwkID,
Curve: jwkCurve,
X: key.PublicKey().String(),
D: key.String(),
})
}
// String converts the private key into a string.
func (key PrivateKey) String() string {
func (key *PrivateKey) String() string {
if key == nil || key.key == nil {
return ""
}
bs, err := key.key.MarshalBinary()
if err != nil {
// this should not happen
@ -78,22 +98,52 @@ type PublicKey struct {
}
// PublicKeyFromString converts a string into a public key.
func PublicKeyFromString(raw string) (PublicKey, error) {
func PublicKeyFromString(raw string) (*PublicKey, error) {
bs, err := decode(raw)
if err != nil {
return PublicKey{}, err
return nil, err
}
key, err := kemID.Scheme().UnmarshalBinaryPublicKey(bs)
if err != nil {
return PublicKey{}, err
return nil, err
}
return PublicKey{key: key}, nil
return &PublicKey{key: key}, nil
}
// Equals returns true if the two keys are equivalent.
func (key *PublicKey) Equals(other *PublicKey) bool {
if key == nil && other == nil {
return true
} else if key == nil || other == nil {
return false
}
if key.key == nil && other.key == nil {
return true
} else if key.key == nil || other.key == nil {
return false
}
return key.key.Equal(other.key)
}
// MarshalJSON returns the JSON Web Key representation of the public key.
func (key *PublicKey) MarshalJSON() ([]byte, error) {
return json.Marshal(JWK{
Type: jwkType,
ID: jwkID,
Curve: jwkCurve,
X: key.String(),
})
}
// String converts a public key into a string.
func (key PublicKey) String() string {
func (key *PublicKey) String() string {
if key == nil || key.key == nil {
return ""
}
bs, err := key.key.MarshalBinary()
if err != nil {
// this should not happen
@ -105,10 +155,17 @@ func (key PublicKey) String() string {
// Seal seales a message using HPKE.
func Seal(
senderPrivateKey PrivateKey,
receiverPublicKey PublicKey,
senderPrivateKey *PrivateKey,
receiverPublicKey *PublicKey,
message []byte,
) (sealed []byte, err error) {
if senderPrivateKey == nil {
return nil, fmt.Errorf("hpke: sender private key cannot be nil")
}
if receiverPublicKey == nil {
return nil, fmt.Errorf("hpke: receiver public key cannot be nil")
}
sender, err := suite.NewSender(receiverPublicKey.key, nil)
if err != nil {
return nil, fmt.Errorf("hpke: error creating sender: %w", err)
@ -129,10 +186,17 @@ func Seal(
// Open opens a message using HPKE.
func Open(
receiverPrivateKey PrivateKey,
senderPublicKey PublicKey,
receiverPrivateKey *PrivateKey,
senderPublicKey *PublicKey,
sealed []byte,
) (message []byte, err error) {
if receiverPrivateKey == nil {
return nil, fmt.Errorf("hpke: receiver private key cannot be nil")
}
if senderPublicKey == nil {
return nil, fmt.Errorf("hpke: sender public key cannot be nil")
}
encSize := kemID.Scheme().SharedKeySize()
if len(sealed) < encSize {
return nil, fmt.Errorf("hpke: invalid sealed message")

89
pkg/hpke/jwks.go Normal file
View file

@ -0,0 +1,89 @@
package hpke
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gregjones/httpcache"
)
const (
defaultMaxBodySize = 1024 * 1024 * 4
jwkType = "OKP"
jwkID = "pomerium/hpke"
jwkCurve = "X25519"
)
// JWK is the JSON Web Key representation of an HPKE key.
// Defined in RFC8037.
type JWK struct {
Type string `json:"kty"`
ID string `json:"kid"`
Curve string `json:"crv"`
X string `json:"x"`
D string `json:"d,omitempty"`
}
// FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint.
func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("hpke: error building jwks http request: %w", err)
}
res, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err)
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
return nil, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode)
}
bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize))
if err != nil {
return nil, fmt.Errorf("hpke: error reading jwks endpoint: %w", err)
}
var jwks struct {
Keys []JWK `json:"keys"`
}
err = json.Unmarshal(bs, &jwks)
if err != nil {
return nil, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err)
}
for _, key := range jwks.Keys {
if key.ID == jwkID && key.Type == jwkType && key.Curve == jwkCurve {
return PublicKeyFromString(key.X)
}
}
return nil, fmt.Errorf("hpke key not found in JWKS endpoint")
}
// A KeyFetcher fetches public keys.
type KeyFetcher interface {
FetchPublicKey(ctx context.Context) (*PublicKey, error)
}
type jwksKeyFetcher struct {
client *http.Client
endpoint string
}
func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) {
return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint)
}
// NewKeyFetcher returns a new KeyFetcher which fetches keys using an in-memory HTTP cache.
func NewKeyFetcher(endpoint string) KeyFetcher {
return &jwksKeyFetcher{
client: httpcache.NewMemoryCacheTransport().Client(),
endpoint: endpoint,
}
}

33
pkg/hpke/jwks_test.go Normal file
View file

@ -0,0 +1,33 @@
package hpke
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/handlers"
)
func TestFetchPublicKeyFromJWKS(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
hpkePrivateKey, err := GeneratePrivateKey()
require.NoError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlers.JWKSHandler("", hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
}))
t.Cleanup(srv.Close)
publicKey, err := FetchPublicKeyFromJWKS(ctx, http.DefaultClient, srv.URL)
assert.NoError(t, err)
assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String())
}

View file

@ -18,8 +18,8 @@ func IsEncryptedURL(values url.Values) bool {
// EncryptURLValues encrypts URL values using the Seal method.
func EncryptURLValues(
senderPrivateKey PrivateKey,
receiverPublicKey PublicKey,
senderPrivateKey *PrivateKey,
receiverPublicKey *PublicKey,
values url.Values,
) (encrypted url.Values, err error) {
values = withoutHPKEParams(values)
@ -37,34 +37,34 @@ func EncryptURLValues(
// DecryptURLValues decrypts URL values using the Open method.
func DecryptURLValues(
receiverPrivateKey PrivateKey,
receiverPrivateKey *PrivateKey,
encrypted url.Values,
) (senderPublicKey PublicKey, values url.Values, err error) {
) (senderPublicKey *PublicKey, values url.Values, err error) {
if !encrypted.Has(ParamSenderPublicKey) {
return senderPublicKey, nil, fmt.Errorf("hpke: missing sender public key in query parameters")
return nil, nil, fmt.Errorf("hpke: missing sender public key in query parameters")
}
if !encrypted.Has(ParamQuery) {
return senderPublicKey, nil, fmt.Errorf("hpke: missing encrypted query in query parameters")
return nil, nil, fmt.Errorf("hpke: missing encrypted query in query parameters")
}
senderPublicKey, err = PublicKeyFromString(encrypted.Get(ParamSenderPublicKey))
if err != nil {
return senderPublicKey, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err)
return nil, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err)
}
sealed, err := decode(encrypted.Get(ParamQuery))
if err != nil {
return senderPublicKey, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err)
return nil, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err)
}
message, err := Open(receiverPrivateKey, senderPublicKey, sealed)
if err != nil {
return senderPublicKey, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err)
return nil, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err)
}
decrypted, err := url.ParseQuery(string(message))
if err != nil {
return senderPublicKey, nil, fmt.Errorf("hpke: invalid query parameter: %w", err)
return nil, nil, fmt.Errorf("hpke: invalid query parameter: %w", err)
}
values = withoutHPKEParams(encrypted)