mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
hpke: add HPKE key to JWKS endpoint (#3762)
* hpke: add HPKE key to JWKS endpoint * fix test, add http caching headers * fix error message * use pointers
This commit is contained in:
parent
52c967b8a5
commit
ba07afc245
11 changed files with 336 additions and 47 deletions
|
@ -30,6 +30,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/config"
|
||||
"github.com/pomerium/pomerium/pkg/hpke"
|
||||
)
|
||||
|
||||
// DisableHeaderKey is the key used to check whether to disable setting header
|
||||
|
@ -997,6 +998,16 @@ func (o *Options) GetSharedKey() ([]byte, error) {
|
|||
return base64.StdEncoding.DecodeString(sharedKey)
|
||||
}
|
||||
|
||||
// GetHPKEPrivateKey gets the hpke.PrivateKey dervived from the shared key.
|
||||
func (o *Options) GetHPKEPrivateKey() (*hpke.PrivateKey, error) {
|
||||
sharedKey, err := o.GetSharedKey()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hpke.DerivePrivateKey(sharedKey), nil
|
||||
}
|
||||
|
||||
// GetGoogleCloudServerlessAuthenticationServiceAccount gets the GoogleCloudServerlessAuthenticationServiceAccount.
|
||||
func (o *Options) GetGoogleCloudServerlessAuthenticationServiceAccount() string {
|
||||
return o.GoogleCloudServerlessAuthenticationServiceAccount
|
||||
|
|
1
go.mod
1
go.mod
|
@ -31,6 +31,7 @@ require (
|
|||
github.com/gorilla/handlers v1.5.1
|
||||
github.com/gorilla/mux v1.8.0
|
||||
github.com/gorilla/websocket v1.5.0
|
||||
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79
|
||||
github.com/hashicorp/go-multierror v1.1.1
|
||||
github.com/hashicorp/golang-lru v0.5.4
|
||||
github.com/jackc/pgconn v1.13.0
|
||||
|
|
2
go.sum
2
go.sum
|
@ -507,6 +507,8 @@ github.com/gostaticanalysis/nilerr v0.1.1 h1:ThE+hJP0fEp4zWLkWHWcRyI2Od0p7DlgYG3
|
|||
github.com/gostaticanalysis/nilerr v0.1.1/go.mod h1:wZYb6YI5YAxxq0i1+VJbY0s2YONW0HU0GPE3+5PWN4A=
|
||||
github.com/gostaticanalysis/testutil v0.3.1-0.20210208050101-bfb5c8eec0e4/go.mod h1:D+FIZ+7OahH3ePw/izIEeH5I06eKs1IKI4Xr64/Am3M=
|
||||
github.com/gostaticanalysis/testutil v0.4.0 h1:nhdCmubdmDF6VEatUNjgUZBJKWRqugoISdUv3PPQgHY=
|
||||
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 h1:+ngKgrYPPJrOjhax5N+uePQ0Fh1Z7PheYoUI/0nzkPA=
|
||||
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79/go.mod h1:FecbI9+v66THATjSRHfNgh1IVFe/9kFxbXtjV0ctIMA=
|
||||
github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs=
|
||||
github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk=
|
||||
github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY=
|
||||
|
|
|
@ -60,10 +60,16 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er
|
|||
return fmt.Errorf("invalid signing key: %w", err)
|
||||
}
|
||||
|
||||
hpkePrivateKey, err := cfg.Options.GetHPKEPrivateKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid hpke private key: %w", err)
|
||||
}
|
||||
hpkePublicKey := hpkePrivateKey.PublicKey()
|
||||
|
||||
root.HandleFunc("/healthz", handlers.HealthCheck)
|
||||
root.HandleFunc("/ping", handlers.HealthCheck)
|
||||
root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL))
|
||||
root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL))
|
||||
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey))
|
||||
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey, hpkePublicKey))
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -34,6 +34,7 @@ func TestServerHTTP(t *testing.T) {
|
|||
}
|
||||
cfg.Options.AuthenticateURLString = "https://authenticate.localhost.pomerium.io"
|
||||
cfg.Options.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
|
||||
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
|
||||
|
||||
src := config.NewStaticSource(cfg)
|
||||
srv, err := NewServer(cfg, config.NewMetricsManager(ctx, src), events.New())
|
||||
|
@ -66,15 +67,23 @@ func TestServerHTTP(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
expect := map[string]any{
|
||||
"keys": []any{map[string]any{
|
||||
"alg": "ES256",
|
||||
"crv": "P-256",
|
||||
"kid": "5b419ade1895fec2d2def6cd33b1b9a018df60db231dc5ecb85cbed6d942813c",
|
||||
"kty": "EC",
|
||||
"use": "sig",
|
||||
"x": "UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo",
|
||||
"y": "KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ",
|
||||
}},
|
||||
"keys": []any{
|
||||
map[string]any{
|
||||
"alg": "ES256",
|
||||
"crv": "P-256",
|
||||
"kid": "5b419ade1895fec2d2def6cd33b1b9a018df60db231dc5ecb85cbed6d942813c",
|
||||
"kty": "EC",
|
||||
"use": "sig",
|
||||
"x": "UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo",
|
||||
"y": "KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ",
|
||||
},
|
||||
map[string]any{
|
||||
"kty": "OKP",
|
||||
"kid": "pomerium/hpke",
|
||||
"crv": "X25519",
|
||||
"x": "T0cbNrJbO9in-FgowKAP-HX6Ci8q50gopOt52sdheHg",
|
||||
},
|
||||
},
|
||||
}
|
||||
assert.Equal(t, expect, actual)
|
||||
})
|
||||
|
|
|
@ -1,11 +1,16 @@
|
|||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"github.com/rs/cors"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
|
@ -13,9 +18,14 @@ import (
|
|||
)
|
||||
|
||||
// JWKSHandler returns the /.well-known/pomerium/jwks.json handler.
|
||||
func JWKSHandler(rawSigningKey string) http.Handler {
|
||||
func JWKSHandler(
|
||||
rawSigningKey string,
|
||||
additionalKeys ...any,
|
||||
) http.Handler {
|
||||
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
var jwks jose.JSONWebKeySet
|
||||
var jwks struct {
|
||||
Keys []any `json:"keys"`
|
||||
}
|
||||
if rawSigningKey != "" {
|
||||
decodedCert, err := base64.StdEncoding.DecodeString(rawSigningKey)
|
||||
if err != nil {
|
||||
|
@ -27,7 +37,22 @@ func JWKSHandler(rawSigningKey string) http.Handler {
|
|||
}
|
||||
jwks.Keys = append(jwks.Keys, *jwk)
|
||||
}
|
||||
httputil.RenderJSON(w, http.StatusOK, jwks)
|
||||
jwks.Keys = append(jwks.Keys, additionalKeys...)
|
||||
|
||||
bs, err := json.Marshal(jwks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
hasher := fnv.New64()
|
||||
_, _ = hasher.Write(bs)
|
||||
h := hasher.Sum64()
|
||||
|
||||
w.Header().Set("Cache-Control", "max-age=60")
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Header().Set("Content-Length", strconv.Itoa(len(bs)))
|
||||
w.Header().Set("ETag", fmt.Sprintf(`"%x"`, h))
|
||||
http.ServeContent(w, r, "jwks.json", time.Time{}, bytes.NewReader(bs))
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
|
|
@ -1,22 +1,71 @@
|
|||
package handlers
|
||||
package handlers_test
|
||||
|
||||
import (
|
||||
"crypto/ecdsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/hpke"
|
||||
)
|
||||
|
||||
func TestJWKSHandler(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
signingKey, err := cryptutil.NewSigningKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
rawSigningKey, err := cryptutil.EncodePrivateKey(signingKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
jwkSigningKey, err := cryptutil.PublicJWKFromBytes(rawSigningKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
hpkePrivateKey, err := hpke.GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("cors", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||
r.Header.Set("Origin", "https://www.example.com")
|
||||
r.Header.Set("Access-Control-Request-Method", "GET")
|
||||
JWKSHandler("").ServeHTTP(w, r)
|
||||
handlers.JWKSHandler("", hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
|
||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||
})
|
||||
t.Run("keys", func(t *testing.T) {
|
||||
w := httptest.NewRecorder()
|
||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
handlers.JWKSHandler(base64.StdEncoding.EncodeToString(rawSigningKey), hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
|
||||
|
||||
var expect any = map[string]any{
|
||||
"keys": []any{
|
||||
map[string]any{
|
||||
"kty": "EC",
|
||||
"kid": jwkSigningKey.KeyID,
|
||||
"crv": "P-256",
|
||||
"alg": "ES256",
|
||||
"use": "sig",
|
||||
"x": base64.RawURLEncoding.EncodeToString(jwkSigningKey.Key.(*ecdsa.PublicKey).X.Bytes()),
|
||||
"y": base64.RawURLEncoding.EncodeToString(jwkSigningKey.Key.(*ecdsa.PublicKey).Y.Bytes()),
|
||||
},
|
||||
map[string]any{
|
||||
"kty": "OKP",
|
||||
"kid": "pomerium/hpke",
|
||||
"crv": "X25519",
|
||||
"x": hpkePrivateKey.PublicKey().String(),
|
||||
},
|
||||
},
|
||||
}
|
||||
var actual any
|
||||
err := json.Unmarshal(w.Body.Bytes(), &actual)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expect, actual)
|
||||
})
|
||||
}
|
||||
|
|
106
pkg/hpke/hpke.go
106
pkg/hpke/hpke.go
|
@ -4,6 +4,7 @@ package hpke
|
|||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/cloudflare/circl/hpke"
|
||||
|
@ -25,44 +26,63 @@ type PrivateKey struct {
|
|||
}
|
||||
|
||||
// DerivePrivateKey derives a private key from a seed. The same seed will always result in the same private key.
|
||||
func DerivePrivateKey(seed []byte) PrivateKey {
|
||||
func DerivePrivateKey(seed []byte) *PrivateKey {
|
||||
pk := kdfID.Extract(seed, nil)
|
||||
data := kdfID.Expand(pk, kdfExpandInfo, uint(kemID.Scheme().SeedSize()))
|
||||
_, key := kemID.Scheme().DeriveKeyPair(data)
|
||||
return PrivateKey{key: key}
|
||||
return &PrivateKey{key: key}
|
||||
}
|
||||
|
||||
// GeneratePrivateKey generates an HPKE private key.
|
||||
func GeneratePrivateKey() (PrivateKey, error) {
|
||||
func GeneratePrivateKey() (*PrivateKey, error) {
|
||||
_, privateKey, err := kemID.Scheme().GenerateKeyPair()
|
||||
if err != nil {
|
||||
return PrivateKey{}, err
|
||||
return nil, err
|
||||
}
|
||||
return PrivateKey{key: privateKey}, nil
|
||||
return &PrivateKey{key: privateKey}, nil
|
||||
}
|
||||
|
||||
// PrivateKeyFromString takes a string and returns a PrivateKey.
|
||||
func PrivateKeyFromString(raw string) (PrivateKey, error) {
|
||||
func PrivateKeyFromString(raw string) (*PrivateKey, error) {
|
||||
bs, err := decode(raw)
|
||||
if err != nil {
|
||||
return PrivateKey{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := kemID.Scheme().UnmarshalBinaryPrivateKey(bs)
|
||||
if err != nil {
|
||||
return PrivateKey{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return PrivateKey{key: key}, nil
|
||||
return &PrivateKey{key: key}, nil
|
||||
}
|
||||
|
||||
// PublicKey returns the public key for the private key.
|
||||
func (key PrivateKey) PublicKey() PublicKey {
|
||||
return PublicKey{key: key.key.Public()}
|
||||
func (key *PrivateKey) PublicKey() *PublicKey {
|
||||
if key == nil || key.key == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &PublicKey{key: key.key.Public()}
|
||||
}
|
||||
|
||||
// MarshalJSON returns the JSON Web Key representation of the private key.
|
||||
func (key *PrivateKey) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(JWK{
|
||||
Type: jwkType,
|
||||
ID: jwkID,
|
||||
Curve: jwkCurve,
|
||||
X: key.PublicKey().String(),
|
||||
D: key.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// String converts the private key into a string.
|
||||
func (key PrivateKey) String() string {
|
||||
func (key *PrivateKey) String() string {
|
||||
if key == nil || key.key == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
bs, err := key.key.MarshalBinary()
|
||||
if err != nil {
|
||||
// this should not happen
|
||||
|
@ -78,22 +98,52 @@ type PublicKey struct {
|
|||
}
|
||||
|
||||
// PublicKeyFromString converts a string into a public key.
|
||||
func PublicKeyFromString(raw string) (PublicKey, error) {
|
||||
func PublicKeyFromString(raw string) (*PublicKey, error) {
|
||||
bs, err := decode(raw)
|
||||
if err != nil {
|
||||
return PublicKey{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := kemID.Scheme().UnmarshalBinaryPublicKey(bs)
|
||||
if err != nil {
|
||||
return PublicKey{}, err
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return PublicKey{key: key}, nil
|
||||
return &PublicKey{key: key}, nil
|
||||
}
|
||||
|
||||
// Equals returns true if the two keys are equivalent.
|
||||
func (key *PublicKey) Equals(other *PublicKey) bool {
|
||||
if key == nil && other == nil {
|
||||
return true
|
||||
} else if key == nil || other == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if key.key == nil && other.key == nil {
|
||||
return true
|
||||
} else if key.key == nil || other.key == nil {
|
||||
return false
|
||||
}
|
||||
return key.key.Equal(other.key)
|
||||
}
|
||||
|
||||
// MarshalJSON returns the JSON Web Key representation of the public key.
|
||||
func (key *PublicKey) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(JWK{
|
||||
Type: jwkType,
|
||||
ID: jwkID,
|
||||
Curve: jwkCurve,
|
||||
X: key.String(),
|
||||
})
|
||||
}
|
||||
|
||||
// String converts a public key into a string.
|
||||
func (key PublicKey) String() string {
|
||||
func (key *PublicKey) String() string {
|
||||
if key == nil || key.key == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
bs, err := key.key.MarshalBinary()
|
||||
if err != nil {
|
||||
// this should not happen
|
||||
|
@ -105,10 +155,17 @@ func (key PublicKey) String() string {
|
|||
|
||||
// Seal seales a message using HPKE.
|
||||
func Seal(
|
||||
senderPrivateKey PrivateKey,
|
||||
receiverPublicKey PublicKey,
|
||||
senderPrivateKey *PrivateKey,
|
||||
receiverPublicKey *PublicKey,
|
||||
message []byte,
|
||||
) (sealed []byte, err error) {
|
||||
if senderPrivateKey == nil {
|
||||
return nil, fmt.Errorf("hpke: sender private key cannot be nil")
|
||||
}
|
||||
if receiverPublicKey == nil {
|
||||
return nil, fmt.Errorf("hpke: receiver public key cannot be nil")
|
||||
}
|
||||
|
||||
sender, err := suite.NewSender(receiverPublicKey.key, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hpke: error creating sender: %w", err)
|
||||
|
@ -129,10 +186,17 @@ func Seal(
|
|||
|
||||
// Open opens a message using HPKE.
|
||||
func Open(
|
||||
receiverPrivateKey PrivateKey,
|
||||
senderPublicKey PublicKey,
|
||||
receiverPrivateKey *PrivateKey,
|
||||
senderPublicKey *PublicKey,
|
||||
sealed []byte,
|
||||
) (message []byte, err error) {
|
||||
if receiverPrivateKey == nil {
|
||||
return nil, fmt.Errorf("hpke: receiver private key cannot be nil")
|
||||
}
|
||||
if senderPublicKey == nil {
|
||||
return nil, fmt.Errorf("hpke: sender public key cannot be nil")
|
||||
}
|
||||
|
||||
encSize := kemID.Scheme().SharedKeySize()
|
||||
if len(sealed) < encSize {
|
||||
return nil, fmt.Errorf("hpke: invalid sealed message")
|
||||
|
|
89
pkg/hpke/jwks.go
Normal file
89
pkg/hpke/jwks.go
Normal file
|
@ -0,0 +1,89 @@
|
|||
package hpke
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gregjones/httpcache"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultMaxBodySize = 1024 * 1024 * 4
|
||||
|
||||
jwkType = "OKP"
|
||||
jwkID = "pomerium/hpke"
|
||||
jwkCurve = "X25519"
|
||||
)
|
||||
|
||||
// JWK is the JSON Web Key representation of an HPKE key.
|
||||
// Defined in RFC8037.
|
||||
type JWK struct {
|
||||
Type string `json:"kty"`
|
||||
ID string `json:"kid"`
|
||||
Curve string `json:"crv"`
|
||||
X string `json:"x"`
|
||||
D string `json:"d,omitempty"`
|
||||
}
|
||||
|
||||
// FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint.
|
||||
func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hpke: error building jwks http request: %w", err)
|
||||
}
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err)
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
||||
if res.StatusCode/100 != 2 {
|
||||
return nil, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode)
|
||||
}
|
||||
|
||||
bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hpke: error reading jwks endpoint: %w", err)
|
||||
}
|
||||
|
||||
var jwks struct {
|
||||
Keys []JWK `json:"keys"`
|
||||
}
|
||||
err = json.Unmarshal(bs, &jwks)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err)
|
||||
}
|
||||
|
||||
for _, key := range jwks.Keys {
|
||||
if key.ID == jwkID && key.Type == jwkType && key.Curve == jwkCurve {
|
||||
return PublicKeyFromString(key.X)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("hpke key not found in JWKS endpoint")
|
||||
}
|
||||
|
||||
// A KeyFetcher fetches public keys.
|
||||
type KeyFetcher interface {
|
||||
FetchPublicKey(ctx context.Context) (*PublicKey, error)
|
||||
}
|
||||
|
||||
type jwksKeyFetcher struct {
|
||||
client *http.Client
|
||||
endpoint string
|
||||
}
|
||||
|
||||
func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) {
|
||||
return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint)
|
||||
}
|
||||
|
||||
// NewKeyFetcher returns a new KeyFetcher which fetches keys using an in-memory HTTP cache.
|
||||
func NewKeyFetcher(endpoint string) KeyFetcher {
|
||||
return &jwksKeyFetcher{
|
||||
client: httpcache.NewMemoryCacheTransport().Client(),
|
||||
endpoint: endpoint,
|
||||
}
|
||||
}
|
33
pkg/hpke/jwks_test.go
Normal file
33
pkg/hpke/jwks_test.go
Normal file
|
@ -0,0 +1,33 @@
|
|||
package hpke
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
)
|
||||
|
||||
func TestFetchPublicKeyFromJWKS(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
t.Cleanup(clearTimeout)
|
||||
|
||||
hpkePrivateKey, err := GeneratePrivateKey()
|
||||
require.NoError(t, err)
|
||||
|
||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handlers.JWKSHandler("", hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
|
||||
}))
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
publicKey, err := FetchPublicKeyFromJWKS(ctx, http.DefaultClient, srv.URL)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String())
|
||||
}
|
|
@ -18,8 +18,8 @@ func IsEncryptedURL(values url.Values) bool {
|
|||
|
||||
// EncryptURLValues encrypts URL values using the Seal method.
|
||||
func EncryptURLValues(
|
||||
senderPrivateKey PrivateKey,
|
||||
receiverPublicKey PublicKey,
|
||||
senderPrivateKey *PrivateKey,
|
||||
receiverPublicKey *PublicKey,
|
||||
values url.Values,
|
||||
) (encrypted url.Values, err error) {
|
||||
values = withoutHPKEParams(values)
|
||||
|
@ -37,34 +37,34 @@ func EncryptURLValues(
|
|||
|
||||
// DecryptURLValues decrypts URL values using the Open method.
|
||||
func DecryptURLValues(
|
||||
receiverPrivateKey PrivateKey,
|
||||
receiverPrivateKey *PrivateKey,
|
||||
encrypted url.Values,
|
||||
) (senderPublicKey PublicKey, values url.Values, err error) {
|
||||
) (senderPublicKey *PublicKey, values url.Values, err error) {
|
||||
if !encrypted.Has(ParamSenderPublicKey) {
|
||||
return senderPublicKey, nil, fmt.Errorf("hpke: missing sender public key in query parameters")
|
||||
return nil, nil, fmt.Errorf("hpke: missing sender public key in query parameters")
|
||||
}
|
||||
if !encrypted.Has(ParamQuery) {
|
||||
return senderPublicKey, nil, fmt.Errorf("hpke: missing encrypted query in query parameters")
|
||||
return nil, nil, fmt.Errorf("hpke: missing encrypted query in query parameters")
|
||||
}
|
||||
|
||||
senderPublicKey, err = PublicKeyFromString(encrypted.Get(ParamSenderPublicKey))
|
||||
if err != nil {
|
||||
return senderPublicKey, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err)
|
||||
return nil, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err)
|
||||
}
|
||||
|
||||
sealed, err := decode(encrypted.Get(ParamQuery))
|
||||
if err != nil {
|
||||
return senderPublicKey, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err)
|
||||
return nil, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err)
|
||||
}
|
||||
|
||||
message, err := Open(receiverPrivateKey, senderPublicKey, sealed)
|
||||
if err != nil {
|
||||
return senderPublicKey, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err)
|
||||
return nil, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err)
|
||||
}
|
||||
|
||||
decrypted, err := url.ParseQuery(string(message))
|
||||
if err != nil {
|
||||
return senderPublicKey, nil, fmt.Errorf("hpke: invalid query parameter: %w", err)
|
||||
return nil, nil, fmt.Errorf("hpke: invalid query parameter: %w", err)
|
||||
}
|
||||
|
||||
values = withoutHPKEParams(encrypted)
|
||||
|
|
Loading…
Add table
Reference in a new issue