mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
hpke: add HPKE key to JWKS endpoint
This commit is contained in:
parent
52c967b8a5
commit
c17b07987d
9 changed files with 246 additions and 6 deletions
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/config"
|
||||
"github.com/pomerium/pomerium/pkg/hpke"
|
||||
)
|
||||
|
||||
// DisableHeaderKey is the key used to check whether to disable setting header
|
||||
|
@ -997,6 +998,16 @@ func (o *Options) GetSharedKey() ([]byte, error) {
|
|||
return base64.StdEncoding.DecodeString(sharedKey)
|
||||
}
|
||||
|
||||
// GetHPKEPrivateKey gets the hpke.PrivateKey dervived from the shared key.
|
||||
func (o *Options) GetHPKEPrivateKey() (hpke.PrivateKey, error) {
|
||||
sharedKey, err := o.GetSharedKey()
|
||||
if err != nil {
|
||||
return hpke.PrivateKey{}, err
|
||||
}
|
||||
|
||||
return hpke.DerivePrivateKey(sharedKey), nil
|
||||
}
|
||||
|
||||
// GetGoogleCloudServerlessAuthenticationServiceAccount gets the GoogleCloudServerlessAuthenticationServiceAccount.
|
||||
func (o *Options) GetGoogleCloudServerlessAuthenticationServiceAccount() string {
|
||||
return o.GoogleCloudServerlessAuthenticationServiceAccount
|
||||
|
|
1
go.mod
1
go.mod
|
@ -31,6 +31,7 @@ require (
|
|||
github.com/gorilla/handlers v1.5.1
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/golang-lru v0.5.4
|
||||
github.com/jackc/pgconn v1.13.0
|
||||
|
|
2
go.sum
2
go.sum
|
@ -507,6 +507,8 @@ github.com/gostaticanalysis/nilerr v0.1.1 h1:ThE+hJP0fEp4zWLkWHWcRyI2Od0p7DlgYG3
|
|||
github.com/gostaticanalysis/nilerr v0.1.1/go.mod h1:wZYb6YI5YAxxq0i1+VJbY0s2YONW0HU0GPE3+5PWN4A=
|
||||
github.com/gostaticanalysis/testutil v0.3.1-0.20210208050101-bfb5c8eec0e4/go.mod h1:D+FIZ+7OahH3ePw/izIEeH5I06eKs1IKI4Xr64/Am3M=
|
||||
github.com/gostaticanalysis/testutil v0.4.0 h1:nhdCmubdmDF6VEatUNjgUZBJKWRqugoISdUv3PPQgHY=
|
||||
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 h1:+ngKgrYPPJrOjhax5N+uePQ0Fh1Z7PheYoUI/0nzkPA=
|
||||
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA=
|
||||
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
|
||||
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk=
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY=
|
||||
|
|
|
@ -60,10 +60,16 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er
|
|||
return fmt.Errorf("invalid signing key: %w", err)
|
||||
}
|
||||
|
||||
hpkePrivateKey, err := cfg.Options.GetHPKEPrivateKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid hpke private key: %w", err)
|
||||
}
|
||||
hpkePublicKey := hpkePrivateKey.PublicKey()
|
||||
|
||||
root.HandleFunc("/healthz", handlers.HealthCheck)
|
||||
root.HandleFunc("/ping", handlers.HealthCheck)
|
||||
root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL))
|
||||
root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL))
|
||||
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey))
|
||||
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey, hpkePublicKey))
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"errors"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/rs/cors"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
|
@ -13,9 +12,14 @@ import (
|
|||
)
|
||||
|
||||
// JWKSHandler returns the /.well-known/pomerium/jwks.json handler.
|
||||
func JWKSHandler(rawSigningKey string) http.Handler {
|
||||
func JWKSHandler(
|
||||
rawSigningKey string,
|
||||
additionalKeys ...any,
|
||||
) http.Handler {
|
||||
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
var jwks jose.JSONWebKeySet
|
||||
var jwks struct {
|
||||
Keys []any `json:"keys"`
|
||||
}
|
||||
if rawSigningKey != "" {
|
||||
decodedCert, err := base64.StdEncoding.DecodeString(rawSigningKey)
|
||||
if err != nil {
|
||||
|
@ -27,6 +31,7 @@ func JWKSHandler(rawSigningKey string) http.Handler {
|
|||
}
|
||||
jwks.Keys = append(jwks.Keys, *jwk)
|
||||
}
|
||||
jwks.Keys = append(jwks.Keys, additionalKeys...)
|
||||
httputil.RenderJSON(w, http.StatusOK, jwks)
|
||||
return nil
|
||||
}))
|
||||
|
|
|
@ -1,22 +1,71 @@
|
|||
package handlers
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/hpke"
|
||||
)
|
||||
|
||||
func TestJWKSHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
signingKey, err := cryptutil.NewSigningKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
rawSigningKey, err := cryptutil.EncodePrivateKey(signingKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwkSigningKey, err := cryptutil.PublicJWKFromBytes(rawSigningKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
hpkePrivateKey, err := hpke.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("cors", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
r.Header.Set("Origin", "https://www.example.com")
|
||||
r.Header.Set("Access-Control-Request-Method", "GET")
|
||||
JWKSHandler("").ServeHTTP(w, r)
|
||||
handlers.JWKSHandler("", hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||
})
|
||||
t.Run("keys", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
handlers.JWKSHandler(base64.StdEncoding.EncodeToString(rawSigningKey), hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
|
||||
|
||||
var expect any = map[string]any{
|
||||
"keys": []any{
|
||||
map[string]any{
|
||||
"kty": "EC",
|
||||
"kid": jwkSigningKey.KeyID,
|
||||
"crv": "P-256",
|
||||
"alg": "ES256",
|
||||
"use": "sig",
|
||||
"x": base64.RawURLEncoding.EncodeToString(jwkSigningKey.Key.(*ecdsa.PublicKey).X.Bytes()),
|
||||
"y": base64.RawURLEncoding.EncodeToString(jwkSigningKey.Key.(*ecdsa.PublicKey).Y.Bytes()),
|
||||
},
|
||||
map[string]any{
|
||||
"kty": "OKP",
|
||||
"kid": "pomerium/hpke",
|
||||
"crv": "X25519",
|
||||
"x": hpkePrivateKey.PublicKey().String(),
|
||||
},
|
||||
},
|
||||
}
|
||||
var actual any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &actual)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expect, actual)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ package hpke
|
|||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudflare/circl/hpke"
|
||||
|
@ -58,11 +59,30 @@ func PrivateKeyFromString(raw string) (PrivateKey, error) {
|
|||
|
||||
// PublicKey returns the public key for the private key.
|
||||
func (key PrivateKey) PublicKey() PublicKey {
|
||||
if key.key == nil {
|
||||
return PublicKey{}
|
||||
}
|
||||
|
||||
return PublicKey{key: key.key.Public()}
|
||||
}
|
||||
|
||||
// MarshalJSON returns the JSON Web Key representation of the private key.
|
||||
func (key PrivateKey) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(JWK{
|
||||
Type: jwkType,
|
||||
ID: jwkID,
|
||||
Curve: jwkCurve,
|
||||
X: key.PublicKey().String(),
|
||||
D: key.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// String converts the private key into a string.
|
||||
func (key PrivateKey) String() string {
|
||||
if key.key == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
bs, err := key.key.MarshalBinary()
|
||||
if err != nil {
|
||||
// this should not happen
|
||||
|
@ -92,8 +112,32 @@ func PublicKeyFromString(raw string) (PublicKey, error) {
|
|||
return PublicKey{key: key}, nil
|
||||
}
|
||||
|
||||
// Equals returns true if the two keys are equivalent.
|
||||
func (key PublicKey) Equals(other PublicKey) bool {
|
||||
if key.key == nil && other.key == nil {
|
||||
return true
|
||||
} else if key.key == nil || other.key == nil {
|
||||
return false
|
||||
}
|
||||
return key.key.Equal(other.key)
|
||||
}
|
||||
|
||||
// MarshalJSON returns the JSON Web Key representation of the public key.
|
||||
func (key PublicKey) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(JWK{
|
||||
Type: jwkType,
|
||||
ID: jwkID,
|
||||
Curve: jwkCurve,
|
||||
X: key.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// String converts a public key into a string.
|
||||
func (key PublicKey) String() string {
|
||||
if key.key == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
bs, err := key.key.MarshalBinary()
|
||||
if err != nil {
|
||||
// this should not happen
|
||||
|
|
89
pkg/hpke/jwks.go
Normal file
89
pkg/hpke/jwks.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package hpke
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gregjones/httpcache"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxBodySize = 1024 * 1024 * 4
|
||||
|
||||
jwkType = "OKP"
|
||||
jwkID = "pomerium/hpke"
|
||||
jwkCurve = "X25519"
|
||||
)
|
||||
|
||||
// JWK is the JSON Web Key representation of an HPKE key.
|
||||
// Defined in RFC8037.
|
||||
type JWK struct {
|
||||
Type string `json:"kty"`
|
||||
ID string `json:"kid"`
|
||||
Curve string `json:"crv"`
|
||||
X string `json:"x"`
|
||||
D string `json:"d,omitempty"`
|
||||
}
|
||||
|
||||
// FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint.
|
||||
func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (PublicKey, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return PublicKey{}, fmt.Errorf("hpke: error building jwks http request: %w", err)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return PublicKey{}, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return PublicKey{}, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize))
|
||||
if err != nil {
|
||||
return PublicKey{}, fmt.Errorf("hpke: error reading jwks endpoint: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
var jwks struct {
|
||||
Keys []JWK `json:"keys"`
|
||||
}
|
||||
err = json.Unmarshal(bs, &jwks)
|
||||
if err != nil {
|
||||
return PublicKey{}, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err)
|
||||
}
|
||||
|
||||
for _, key := range jwks.Keys {
|
||||
if key.ID == jwkID && key.Type == jwkType && key.Curve == jwkCurve {
|
||||
return PublicKeyFromString(key.X)
|
||||
}
|
||||
}
|
||||
return PublicKey{}, fmt.Errorf("hpke key not found in JWKS endpoint")
|
||||
}
|
||||
|
||||
// A KeyFetcher fetches public keys.
|
||||
type KeyFetcher interface {
|
||||
FetchPublicKey(ctx context.Context) (PublicKey, error)
|
||||
}
|
||||
|
||||
type jwksKeyFetcher struct {
|
||||
client *http.Client
|
||||
endpoint string
|
||||
}
|
||||
|
||||
func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (PublicKey, error) {
|
||||
return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint)
|
||||
}
|
||||
|
||||
// NewKeyFetcher returns a new KeyFetcher which fetches keys using an in-memory HTTP cache.
|
||||
func NewKeyFetcher(endpoint string) KeyFetcher {
|
||||
return &jwksKeyFetcher{
|
||||
client: httpcache.NewMemoryCacheTransport().Client(),
|
||||
endpoint: endpoint,
|
||||
}
|
||||
}
|
33
pkg/hpke/jwks_test.go
Normal file
33
pkg/hpke/jwks_test.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package hpke
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
)
|
||||
|
||||
func TestFetchPublicKeyFromJWKS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
t.Cleanup(clearTimeout)
|
||||
|
||||
hpkePrivateKey, err := GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlers.JWKSHandler("", hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
publicKey, err := FetchPublicKeyFromJWKS(ctx, http.DefaultClient, srv.URL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String())
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue