hpke: move published public keys to a new endpoint (#4048)

hpke: move published public keys to a new endpoint (#4044)

Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com>
This commit is contained in:
backport-actions-token[bot] 2023-03-08 09:18:37 -07:00 committed by GitHub
parent 7afa9d4a95
commit ee1fefb218
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 136 additions and 71 deletions

View file

@ -30,13 +30,10 @@ func TestAuthorize_handleResult(t *testing.T) {
opt.DataBrokerURLString = "https://databroker.example.com"
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
htpkePrivateKey, err := opt.GetHPKEPrivateKey()
hpkePrivateKey, err := opt.GetHPKEPrivateKey()
require.NoError(t, err)
signingKey, err := opt.GetSigningKey()
require.NoError(t, err)
authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey()))
authnSrv := httptest.NewServer(handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()))
t.Cleanup(authnSrv.Close)
opt.AuthenticateURLString = authnSrv.URL
@ -228,13 +225,10 @@ func TestRequireLogin(t *testing.T) {
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
opt.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUJlMFRxbXJkSXBZWE03c3pSRERWYndXOS83RWJHVWhTdFFJalhsVHNXM1BvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFb0xaRDI2bEdYREhRQmhhZkdlbEVmRDdlNmYzaURjWVJPVjdUbFlIdHF1Y1BFL2hId2dmYQpNY3FBUEZsRmpueUpySXJhYTFlQ2xZRTJ6UktTQk5kNXBRPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
htpkePrivateKey, err := opt.GetHPKEPrivateKey()
hpkePrivateKey, err := opt.GetHPKEPrivateKey()
require.NoError(t, err)
signingKey, err := opt.GetSigningKey()
require.NoError(t, err)
authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey()))
authnSrv := httptest.NewServer(handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()))
t.Cleanup(authnSrv.Close)
opt.AuthenticateURLString = authnSrv.URL

View file

@ -14,6 +14,7 @@ import (
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/derivecert"
"github.com/pomerium/pomerium/pkg/hpke"
@ -248,7 +249,7 @@ func (cfg *Config) GetAuthenticateKeyFetcher() (hpke.KeyFetcher, error) {
return nil, err
}
jwksURL := authenticateURL.ResolveReference(&url.URL{
Path: "/.well-known/pomerium/jwks.json",
Path: urlutil.HPKEPublicKeyPath,
}).String()
return hpke.NewKeyFetcher(jwksURL, transport), nil
}

View file

@ -15,6 +15,7 @@ import (
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/telemetry/requestid"
"github.com/pomerium/pomerium/internal/urlutil"
)
func (srv *Server) addHTTPMiddleware(root *mux.Router, cfg *config.Config) {
@ -68,6 +69,7 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er
root.HandleFunc("/ping", handlers.HealthCheck)
root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL))
root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL))
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(signingKey, hpkePublicKey))
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(signingKey))
root.Path(urlutil.HPKEPublicKeyPath).Methods(http.MethodGet).Handler(handlers.HPKEPublicKeyHandler(hpkePublicKey))
return nil
}

View file

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"testing"
@ -77,14 +78,22 @@ func TestServerHTTP(t *testing.T) {
"x": "UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo",
"y": "KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ",
},
map[string]any{
"kty": "OKP",
"kid": "pomerium/hpke",
"crv": "X25519",
"x": "T0cbNrJbO9in-FgowKAP-HX6Ci8q50gopOt52sdheHg",
},
},
}
assert.Equal(t, expect, actual)
})
t.Run("hpke-public-key", func(t *testing.T) {
res, err := http.Get(fmt.Sprintf("http://localhost:%s/.well-known/pomerium/hpke-public-key", src.GetConfig().HTTPPort))
require.NoError(t, err)
defer res.Body.Close()
bs, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, []byte{
0x4f, 0x47, 0x1b, 0x36, 0xb2, 0x5b, 0x3b, 0xd8,
0xa7, 0xf8, 0x58, 0x28, 0xc0, 0xa0, 0x0f, 0xf8,
0x75, 0xfa, 0x0a, 0x2f, 0x2a, 0xe7, 0x48, 0x28,
0xa4, 0xeb, 0x79, 0xda, 0xc7, 0x61, 0x78, 0x78,
}, bs)
})
}

View file

@ -0,0 +1,33 @@
package handlers
import (
"bytes"
"fmt"
"hash/fnv"
"net/http"
"strconv"
"time"
"github.com/rs/cors"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/pkg/hpke"
)
// HPKEPublicKeyHandler returns a handler which returns the HPKE public key.
func HPKEPublicKeyHandler(publicKey *hpke.PublicKey) http.Handler {
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
bs := publicKey.Bytes()
hasher := fnv.New64()
_, _ = hasher.Write(bs)
h := hasher.Sum64()
w.Header().Set("Cache-Control", "max-age=60")
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Length", strconv.Itoa(len(bs)))
w.Header().Set("ETag", fmt.Sprintf(`"%x"`, h))
http.ServeContent(w, r, "hpke-public-key", time.Time{}, bytes.NewReader(bs))
return nil
}))
}

View file

@ -0,0 +1,34 @@
package handlers_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/pkg/hpke"
)
func TestHPKEPublicKeyHandler(t *testing.T) {
t.Parallel()
k1 := hpke.DerivePrivateKey([]byte("TEST"))
t.Run("cors", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodOptions, "/", nil)
r.Header.Set("Origin", "https://www.example.com")
r.Header.Set("Access-Control-Request-Method", "GET")
handlers.HPKEPublicKeyHandler(k1.PublicKey()).ServeHTTP(w, r)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
})
t.Run("keys", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
handlers.HPKEPublicKeyHandler(k1.PublicKey()).ServeHTTP(w, r)
assert.Equal(t, k1.PublicKey().Bytes(), w.Body.Bytes())
})
}

View file

@ -17,10 +17,7 @@ import (
)
// JWKSHandler returns the /.well-known/pomerium/jwks.json handler.
func JWKSHandler(
signingKey []byte,
additionalKeys ...any,
) http.Handler {
func JWKSHandler(signingKey []byte) http.Handler {
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
var jwks struct {
Keys []any `json:"keys"`
@ -34,7 +31,6 @@ func JWKSHandler(
jwks.Keys = append(jwks.Keys, *k)
}
}
jwks.Keys = append(jwks.Keys, additionalKeys...)
bs, err := json.Marshal(jwks)
if err != nil {

View file

@ -13,7 +13,6 @@ import (
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/hpke"
)
func TestJWKSHandler(t *testing.T) {
@ -34,24 +33,18 @@ func TestJWKSHandler(t *testing.T) {
jwkSigningKey2, err := cryptutil.PublicJWKFromBytes(rawSigningKey2)
require.NoError(t, err)
hpkePrivateKey, err := hpke.GeneratePrivateKey()
require.NoError(t, err)
t.Run("cors", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodOptions, "/", nil)
r.Header.Set("Origin", "https://www.example.com")
r.Header.Set("Access-Control-Request-Method", "GET")
handlers.JWKSHandler(nil, hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
handlers.JWKSHandler(nil).ServeHTTP(w, r)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
})
t.Run("keys", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
handlers.JWKSHandler(
append(rawSigningKey1, rawSigningKey2...),
hpkePrivateKey.PublicKey(),
).ServeHTTP(w, r)
handlers.JWKSHandler(append(rawSigningKey1, rawSigningKey2...)).ServeHTTP(w, r)
var expect any = map[string]any{
"keys": []any{
@ -73,12 +66,6 @@ func TestJWKSHandler(t *testing.T) {
"x": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).X.Bytes()),
"y": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).Y.Bytes()),
},
map[string]any{
"kty": "OKP",
"kid": "pomerium/hpke",
"crv": "X25519",
"x": hpkePrivateKey.PublicKey().String(),
},
},
}
var actual any

View file

@ -5,6 +5,9 @@ import (
"net/url"
)
// HPKEPublicKeyPath is the well-known path to the HPKE public key
const HPKEPublicKeyPath = "/.well-known/pomerium/hpke-public-key"
// DefaultDeviceType is the default device type when none is specified.
const DefaultDeviceType = "any"

View file

@ -97,6 +97,16 @@ type PublicKey struct {
key kem.PublicKey
}
// PublicKeyFromBytes converts raw bytes into a public key.
func PublicKeyFromBytes(raw []byte) (*PublicKey, error) {
key, err := kemID.Scheme().UnmarshalBinaryPublicKey(raw)
if err != nil {
return nil, err
}
return &PublicKey{key: key}, nil
}
// PublicKeyFromString converts a string into a public key.
func PublicKeyFromString(raw string) (*PublicKey, error) {
bs, err := decode(raw)
@ -128,6 +138,20 @@ func (key *PublicKey) Equals(other *PublicKey) bool {
return key.key.Equal(other.key)
}
// Bytes returns the public key as raw bytes.
func (key *PublicKey) Bytes() []byte {
if key == nil || key.key == nil {
return nil
}
bs, err := key.key.MarshalBinary()
if err != nil {
// this should not happen
panic(fmt.Sprintf("failed to marshal public HPKE key: %v", err))
}
return bs
}
// MarshalJSON returns the JSON Web Key representation of the public key.
func (key *PublicKey) MarshalJSON() ([]byte, error) {
return json.Marshal(JWK{

View file

@ -2,7 +2,6 @@ package hpke
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
@ -28,42 +27,28 @@ type JWK struct {
D string `json:"d,omitempty"`
}
// FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint.
func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) {
// FetchPublicKey fetches the HPKE public key from the hpke-public-key endpoint.
func FetchPublicKey(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil {
return nil, fmt.Errorf("hpke: error building jwks http request: %w", err)
return nil, fmt.Errorf("hpke: error building hpke-public-key http request: %w", err)
}
res, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err)
return nil, fmt.Errorf("hpke: error requesting hpke-public-key endpoint: %w", err)
}
defer res.Body.Close()
if res.StatusCode/100 != 2 {
return nil, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode)
return nil, fmt.Errorf("hpke: error requesting hpke-public-key endpoint, invalid status code: %d", res.StatusCode)
}
bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize))
if err != nil {
return nil, fmt.Errorf("hpke: error reading jwks endpoint: %w", err)
return nil, fmt.Errorf("hpke: error reading hpke-public-key endpoint: %w", err)
}
var jwks struct {
Keys []JWK `json:"keys"`
}
err = json.Unmarshal(bs, &jwks)
if err != nil {
return nil, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err)
}
for _, key := range jwks.Keys {
if key.ID == jwkID && key.Type == jwkType && key.Curve == jwkCurve {
return PublicKeyFromString(key.X)
}
}
return nil, fmt.Errorf("hpke key not found in JWKS endpoint")
return PublicKeyFromBytes(bs)
}
// A KeyFetcher fetches public keys.
@ -71,18 +56,18 @@ type KeyFetcher interface {
FetchPublicKey(ctx context.Context) (*PublicKey, error)
}
type jwksKeyFetcher struct {
type fetcher struct {
client *http.Client
endpoint string
}
func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) {
return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint)
func (fetcher *fetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) {
return FetchPublicKey(ctx, fetcher.client, fetcher.endpoint)
}
// NewKeyFetcher returns a new KeyFetcher which fetches keys using an in-memory HTTP cache.
func NewKeyFetcher(endpoint string, transport http.RoundTripper) KeyFetcher {
return &jwksKeyFetcher{
return &fetcher{
client: (&httpcache.Transport{
Transport: transport,
Cache: httpcache.NewMemoryCache(),

View file

@ -24,11 +24,11 @@ func TestFetchPublicKeyFromJWKS(t *testing.T) {
require.NoError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlers.JWKSHandler(nil, hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
}))
t.Cleanup(srv.Close)
publicKey, err := hpke.FetchPublicKeyFromJWKS(ctx, http.DefaultClient, srv.URL)
publicKey, err := hpke.FetchPublicKey(ctx, http.DefaultClient, srv.URL)
assert.NoError(t, err)
assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String())
}

View file

@ -30,13 +30,10 @@ func testOptions(t *testing.T) *config.Options {
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
htpkePrivateKey, err := opts.GetHPKEPrivateKey()
hpkePrivateKey, err := opts.GetHPKEPrivateKey()
require.NoError(t, err)
signingKey, err := opts.GetSigningKey()
require.NoError(t, err)
authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey()))
authnSrv := httptest.NewServer(handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()))
t.Cleanup(authnSrv.Close)
opts.AuthenticateURLString = authnSrv.URL