hpke: add HPKE key to JWKS endpoint

This commit is contained in:
Caleb Doxsey 2022-11-22 18:45:05 -07:00
parent 52c967b8a5
commit c17b07987d
9 changed files with 246 additions and 6 deletions

View file

@ -30,6 +30,7 @@ import (
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/hpke"
)
// DisableHeaderKey is the key used to check whether to disable setting header
@ -997,6 +998,16 @@ func (o *Options) GetSharedKey() ([]byte, error) {
return base64.StdEncoding.DecodeString(sharedKey)
}
// GetHPKEPrivateKey gets the hpke.PrivateKey dervived from the shared key.
func (o *Options) GetHPKEPrivateKey() (hpke.PrivateKey, error) {
sharedKey, err := o.GetSharedKey()
if err != nil {
return hpke.PrivateKey{}, err
}
return hpke.DerivePrivateKey(sharedKey), nil
}
// GetGoogleCloudServerlessAuthenticationServiceAccount gets the GoogleCloudServerlessAuthenticationServiceAccount.
func (o *Options) GetGoogleCloudServerlessAuthenticationServiceAccount() string {
return o.GoogleCloudServerlessAuthenticationServiceAccount

1
go.mod
View file

@ -31,6 +31,7 @@ require (
github.com/gorilla/handlers v1.5.1
github.com/gorilla/mux v1.8.0
github.com/gorilla/websocket v1.5.0
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79
github.com/hashicorp/go-multierror v1.1.1
github.com/hashicorp/golang-lru v0.5.4
github.com/jackc/pgconn v1.13.0

2
go.sum
View file

@ -507,6 +507,8 @@ github.com/gostaticanalysis/nilerr v0.1.1 h1:ThE+hJP0fEp4zWLkWHWcRyI2Od0p7DlgYG3
github.com/gostaticanalysis/nilerr v0.1.1/go.mod h1:wZYb6YI5YAxxq0i1+VJbY0s2YONW0HU0GPE3+5PWN4A=
github.com/gostaticanalysis/testutil v0.3.1-0.20210208050101-bfb5c8eec0e4/go.mod h1:D+FIZ+7OahH3ePw/izIEeH5I06eKs1IKI4Xr64/Am3M=
github.com/gostaticanalysis/testutil v0.4.0 h1:nhdCmubdmDF6VEatUNjgUZBJKWRqugoISdUv3PPQgHY=
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 h1:+ngKgrYPPJrOjhax5N+uePQ0Fh1Z7PheYoUI/0nzkPA=
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA=
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk=
github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY=

View file

@ -60,10 +60,16 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er
return fmt.Errorf("invalid signing key: %w", err)
}
hpkePrivateKey, err := cfg.Options.GetHPKEPrivateKey()
if err != nil {
return fmt.Errorf("invalid hpke private key: %w", err)
}
hpkePublicKey := hpkePrivateKey.PublicKey()
root.HandleFunc("/healthz", handlers.HealthCheck)
root.HandleFunc("/ping", handlers.HealthCheck)
root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL))
root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL))
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey))
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey, hpkePublicKey))
return nil
}

View file

@ -5,7 +5,6 @@ import (
"errors"
"net/http"
"github.com/go-jose/go-jose/v3"
"github.com/rs/cors"
"github.com/pomerium/pomerium/internal/httputil"
@ -13,9 +12,14 @@ import (
)
// JWKSHandler returns the /.well-known/pomerium/jwks.json handler.
func JWKSHandler(rawSigningKey string) http.Handler {
func JWKSHandler(
rawSigningKey string,
additionalKeys ...any,
) http.Handler {
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
var jwks jose.JSONWebKeySet
var jwks struct {
Keys []any `json:"keys"`
}
if rawSigningKey != "" {
decodedCert, err := base64.StdEncoding.DecodeString(rawSigningKey)
if err != nil {
@ -27,6 +31,7 @@ func JWKSHandler(rawSigningKey string) http.Handler {
}
jwks.Keys = append(jwks.Keys, *jwk)
}
jwks.Keys = append(jwks.Keys, additionalKeys...)
httputil.RenderJSON(w, http.StatusOK, jwks)
return nil
}))

View file

@ -1,22 +1,71 @@
package handlers
package handlers_test
import (
"crypto/ecdsa"
"encoding/base64"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/hpke"
)
func TestJWKSHandler(t *testing.T) {
t.Parallel()
signingKey, err := cryptutil.NewSigningKey()
require.NoError(t, err)
rawSigningKey, err := cryptutil.EncodePrivateKey(signingKey)
require.NoError(t, err)
jwkSigningKey, err := cryptutil.PublicJWKFromBytes(rawSigningKey)
require.NoError(t, err)
hpkePrivateKey, err := hpke.GeneratePrivateKey()
require.NoError(t, err)
t.Run("cors", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodOptions, "/", nil)
r.Header.Set("Origin", "https://www.example.com")
r.Header.Set("Access-Control-Request-Method", "GET")
JWKSHandler("").ServeHTTP(w, r)
handlers.JWKSHandler("", hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
})
t.Run("keys", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
handlers.JWKSHandler(base64.StdEncoding.EncodeToString(rawSigningKey), hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
var expect any = map[string]any{
"keys": []any{
map[string]any{
"kty": "EC",
"kid": jwkSigningKey.KeyID,
"crv": "P-256",
"alg": "ES256",
"use": "sig",
"x": base64.RawURLEncoding.EncodeToString(jwkSigningKey.Key.(*ecdsa.PublicKey).X.Bytes()),
"y": base64.RawURLEncoding.EncodeToString(jwkSigningKey.Key.(*ecdsa.PublicKey).Y.Bytes()),
},
map[string]any{
"kty": "OKP",
"kid": "pomerium/hpke",
"crv": "X25519",
"x": hpkePrivateKey.PublicKey().String(),
},
},
}
var actual any
err := json.Unmarshal(w.Body.Bytes(), &actual)
assert.NoError(t, err)
assert.Equal(t, expect, actual)
})
}

View file

@ -4,6 +4,7 @@ package hpke
import (
"crypto/rand"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/cloudflare/circl/hpke"
@ -58,11 +59,30 @@ func PrivateKeyFromString(raw string) (PrivateKey, error) {
// PublicKey returns the public key for the private key.
func (key PrivateKey) PublicKey() PublicKey {
if key.key == nil {
return PublicKey{}
}
return PublicKey{key: key.key.Public()}
}
// MarshalJSON returns the JSON Web Key representation of the private key.
func (key PrivateKey) MarshalJSON() ([]byte, error) {
return json.Marshal(JWK{
Type: jwkType,
ID: jwkID,
Curve: jwkCurve,
X: key.PublicKey().String(),
D: key.String(),
})
}
// String converts the private key into a string.
func (key PrivateKey) String() string {
if key.key == nil {
return ""
}
bs, err := key.key.MarshalBinary()
if err != nil {
// this should not happen
@ -92,8 +112,32 @@ func PublicKeyFromString(raw string) (PublicKey, error) {
return PublicKey{key: key}, nil
}
// Equals returns true if the two keys are equivalent.
func (key PublicKey) Equals(other PublicKey) bool {
if key.key == nil && other.key == nil {
return true
} else if key.key == nil || other.key == nil {
return false
}
return key.key.Equal(other.key)
}
// MarshalJSON returns the JSON Web Key representation of the public key.
func (key PublicKey) MarshalJSON() ([]byte, error) {
return json.Marshal(JWK{
Type: jwkType,
ID: jwkID,
Curve: jwkCurve,
X: key.String(),
})
}
// String converts a public key into a string.
func (key PublicKey) String() string {
if key.key == nil {
return ""
}
bs, err := key.key.MarshalBinary()
if err != nil {
// this should not happen

89
pkg/hpke/jwks.go Normal file
View file

@ -0,0 +1,89 @@
package hpke
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/gregjones/httpcache"
)
const (
defaultMaxBodySize = 1024 * 1024 * 4
jwkType = "OKP"
jwkID = "pomerium/hpke"
jwkCurve = "X25519"
)
// JWK is the JSON Web Key representation of an HPKE key.
// Defined in RFC8037.
type JWK struct {
Type string `json:"kty"`
ID string `json:"kid"`
Curve string `json:"crv"`
X string `json:"x"`
D string `json:"d,omitempty"`
}
// FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint.
func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (PublicKey, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return PublicKey{}, fmt.Errorf("hpke: error building jwks http request: %w", err)
}
res, err := client.Do(req)
if err != nil {
return PublicKey{}, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err)
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
return PublicKey{}, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode)
}
bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize))
if err != nil {
return PublicKey{}, fmt.Errorf("hpke: error reading jwks endpoint: %d", res.StatusCode)
}
var jwks struct {
Keys []JWK `json:"keys"`
}
err = json.Unmarshal(bs, &jwks)
if err != nil {
return PublicKey{}, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err)
}
for _, key := range jwks.Keys {
if key.ID == jwkID && key.Type == jwkType && key.Curve == jwkCurve {
return PublicKeyFromString(key.X)
}
}
return PublicKey{}, fmt.Errorf("hpke key not found in JWKS endpoint")
}
// A KeyFetcher fetches public keys.
type KeyFetcher interface {
FetchPublicKey(ctx context.Context) (PublicKey, error)
}
type jwksKeyFetcher struct {
client *http.Client
endpoint string
}
func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (PublicKey, error) {
return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint)
}
// NewKeyFetcher returns a new KeyFetcher which fetches keys using an in-memory HTTP cache.
func NewKeyFetcher(endpoint string) KeyFetcher {
return &jwksKeyFetcher{
client: httpcache.NewMemoryCacheTransport().Client(),
endpoint: endpoint,
}
}

33
pkg/hpke/jwks_test.go Normal file
View file

@ -0,0 +1,33 @@
package hpke
import (
"context"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/handlers"
)
func TestFetchPublicKeyFromJWKS(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
hpkePrivateKey, err := GeneratePrivateKey()
require.NoError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlers.JWKSHandler("", hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
}))
t.Cleanup(srv.Close)
publicKey, err := FetchPublicKeyFromJWKS(ctx, http.DefaultClient, srv.URL)
assert.NoError(t, err)
assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String())
}