hpke: move published public keys to a new endpoint (#4044)

This commit is contained in:
Caleb Doxsey 2023-03-08 09:17:04 -07:00 committed by GitHub
parent 74463c5468
commit 0f295d4a63
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 136 additions and 71 deletions

View file

@ -30,13 +30,10 @@ func TestAuthorize_handleResult(t *testing.T) {
opt.DataBrokerURLString = "https://databroker.example.com" opt.DataBrokerURLString = "https://databroker.example.com"
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
htpkePrivateKey, err := opt.GetHPKEPrivateKey() hpkePrivateKey, err := opt.GetHPKEPrivateKey()
require.NoError(t, err) require.NoError(t, err)
signingKey, err := opt.GetSigningKey() authnSrv := httptest.NewServer(handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()))
require.NoError(t, err)
authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey()))
t.Cleanup(authnSrv.Close) t.Cleanup(authnSrv.Close)
opt.AuthenticateURLString = authnSrv.URL opt.AuthenticateURLString = authnSrv.URL
@ -228,13 +225,10 @@ func TestRequireLogin(t *testing.T) {
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
opt.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUJlMFRxbXJkSXBZWE03c3pSRERWYndXOS83RWJHVWhTdFFJalhsVHNXM1BvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFb0xaRDI2bEdYREhRQmhhZkdlbEVmRDdlNmYzaURjWVJPVjdUbFlIdHF1Y1BFL2hId2dmYQpNY3FBUEZsRmpueUpySXJhYTFlQ2xZRTJ6UktTQk5kNXBRPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=" opt.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUJlMFRxbXJkSXBZWE03c3pSRERWYndXOS83RWJHVWhTdFFJalhsVHNXM1BvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFb0xaRDI2bEdYREhRQmhhZkdlbEVmRDdlNmYzaURjWVJPVjdUbFlIdHF1Y1BFL2hId2dmYQpNY3FBUEZsRmpueUpySXJhYTFlQ2xZRTJ6UktTQk5kNXBRPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
htpkePrivateKey, err := opt.GetHPKEPrivateKey() hpkePrivateKey, err := opt.GetHPKEPrivateKey()
require.NoError(t, err) require.NoError(t, err)
signingKey, err := opt.GetSigningKey() authnSrv := httptest.NewServer(handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()))
require.NoError(t, err)
authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey()))
t.Cleanup(authnSrv.Close) t.Cleanup(authnSrv.Close)
opt.AuthenticateURLString = authnSrv.URL opt.AuthenticateURLString = authnSrv.URL

View file

@ -14,6 +14,7 @@ import (
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/derivecert" "github.com/pomerium/pomerium/pkg/derivecert"
"github.com/pomerium/pomerium/pkg/hpke" "github.com/pomerium/pomerium/pkg/hpke"
@ -253,7 +254,7 @@ func (cfg *Config) GetAuthenticateKeyFetcher() (hpke.KeyFetcher, error) {
return nil, err return nil, err
} }
jwksURL := authenticateURL.ResolveReference(&url.URL{ jwksURL := authenticateURL.ResolveReference(&url.URL{
Path: "/.well-known/pomerium/jwks.json", Path: urlutil.HPKEPublicKeyPath,
}).String() }).String()
return hpke.NewKeyFetcher(jwksURL, transport), nil return hpke.NewKeyFetcher(jwksURL, transport), nil
} }

View file

@ -15,6 +15,7 @@ import (
"github.com/pomerium/pomerium/internal/middleware" "github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/telemetry/requestid" "github.com/pomerium/pomerium/internal/telemetry/requestid"
"github.com/pomerium/pomerium/internal/urlutil"
) )
func (srv *Server) addHTTPMiddleware(root *mux.Router, cfg *config.Config) { func (srv *Server) addHTTPMiddleware(root *mux.Router, cfg *config.Config) {
@ -68,6 +69,7 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er
root.HandleFunc("/ping", handlers.HealthCheck) root.HandleFunc("/ping", handlers.HealthCheck)
root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL)) root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL))
root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL)) root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL))
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(signingKey, hpkePublicKey)) root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(signingKey))
root.Path(urlutil.HPKEPublicKeyPath).Methods(http.MethodGet).Handler(handlers.HPKEPublicKeyHandler(hpkePublicKey))
return nil return nil
} }

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io"
"net/http" "net/http"
"testing" "testing"
@ -77,14 +78,22 @@ func TestServerHTTP(t *testing.T) {
"x": "UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo", "x": "UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo",
"y": "KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ", "y": "KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ",
}, },
map[string]any{
"kty": "OKP",
"kid": "pomerium/hpke",
"crv": "X25519",
"x": "T0cbNrJbO9in-FgowKAP-HX6Ci8q50gopOt52sdheHg",
},
}, },
} }
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
}) })
t.Run("hpke-public-key", func(t *testing.T) {
res, err := http.Get(fmt.Sprintf("http://localhost:%s/.well-known/pomerium/hpke-public-key", src.GetConfig().HTTPPort))
require.NoError(t, err)
defer res.Body.Close()
bs, err := io.ReadAll(res.Body)
require.NoError(t, err)
assert.Equal(t, []byte{
0x4f, 0x47, 0x1b, 0x36, 0xb2, 0x5b, 0x3b, 0xd8,
0xa7, 0xf8, 0x58, 0x28, 0xc0, 0xa0, 0x0f, 0xf8,
0x75, 0xfa, 0x0a, 0x2f, 0x2a, 0xe7, 0x48, 0x28,
0xa4, 0xeb, 0x79, 0xda, 0xc7, 0x61, 0x78, 0x78,
}, bs)
})
} }

View file

@ -0,0 +1,33 @@
package handlers
import (
"bytes"
"fmt"
"hash/fnv"
"net/http"
"strconv"
"time"
"github.com/rs/cors"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/pkg/hpke"
)
// HPKEPublicKeyHandler returns a handler which returns the HPKE public key.
func HPKEPublicKeyHandler(publicKey *hpke.PublicKey) http.Handler {
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
bs := publicKey.Bytes()
hasher := fnv.New64()
_, _ = hasher.Write(bs)
h := hasher.Sum64()
w.Header().Set("Cache-Control", "max-age=60")
w.Header().Set("Content-Type", "application/octet-stream")
w.Header().Set("Content-Length", strconv.Itoa(len(bs)))
w.Header().Set("ETag", fmt.Sprintf(`"%x"`, h))
http.ServeContent(w, r, "hpke-public-key", time.Time{}, bytes.NewReader(bs))
return nil
}))
}

View file

@ -0,0 +1,34 @@
package handlers_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/pkg/hpke"
)
func TestHPKEPublicKeyHandler(t *testing.T) {
t.Parallel()
k1 := hpke.DerivePrivateKey([]byte("TEST"))
t.Run("cors", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodOptions, "/", nil)
r.Header.Set("Origin", "https://www.example.com")
r.Header.Set("Access-Control-Request-Method", "GET")
handlers.HPKEPublicKeyHandler(k1.PublicKey()).ServeHTTP(w, r)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
})
t.Run("keys", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil)
handlers.HPKEPublicKeyHandler(k1.PublicKey()).ServeHTTP(w, r)
assert.Equal(t, k1.PublicKey().Bytes(), w.Body.Bytes())
})
}

View file

@ -17,10 +17,7 @@ import (
) )
// JWKSHandler returns the /.well-known/pomerium/jwks.json handler. // JWKSHandler returns the /.well-known/pomerium/jwks.json handler.
func JWKSHandler( func JWKSHandler(signingKey []byte) http.Handler {
signingKey []byte,
additionalKeys ...any,
) http.Handler {
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
var jwks struct { var jwks struct {
Keys []any `json:"keys"` Keys []any `json:"keys"`
@ -34,7 +31,6 @@ func JWKSHandler(
jwks.Keys = append(jwks.Keys, *k) jwks.Keys = append(jwks.Keys, *k)
} }
} }
jwks.Keys = append(jwks.Keys, additionalKeys...)
bs, err := json.Marshal(jwks) bs, err := json.Marshal(jwks)
if err != nil { if err != nil {

View file

@ -16,7 +16,6 @@ import (
"github.com/pomerium/pomerium/internal/deterministicecdsa" "github.com/pomerium/pomerium/internal/deterministicecdsa"
"github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/hpke"
) )
func TestJWKSHandler(t *testing.T) { func TestJWKSHandler(t *testing.T) {
@ -38,24 +37,18 @@ func TestJWKSHandler(t *testing.T) {
jwkSigningKey2, err := cryptutil.PublicJWKFromBytes(rawSigningKey2) jwkSigningKey2, err := cryptutil.PublicJWKFromBytes(rawSigningKey2)
require.NoError(t, err) require.NoError(t, err)
hpkePrivateKey, err := hpke.GeneratePrivateKey()
require.NoError(t, err)
t.Run("cors", func(t *testing.T) { t.Run("cors", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodOptions, "/", nil) r := httptest.NewRequest(http.MethodOptions, "/", nil)
r.Header.Set("Origin", "https://www.example.com") r.Header.Set("Origin", "https://www.example.com")
r.Header.Set("Access-Control-Request-Method", "GET") r.Header.Set("Access-Control-Request-Method", "GET")
handlers.JWKSHandler(nil, hpkePrivateKey.PublicKey()).ServeHTTP(w, r) handlers.JWKSHandler(nil).ServeHTTP(w, r)
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
}) })
t.Run("keys", func(t *testing.T) { t.Run("keys", func(t *testing.T) {
w := httptest.NewRecorder() w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/", nil) r := httptest.NewRequest(http.MethodGet, "/", nil)
handlers.JWKSHandler( handlers.JWKSHandler(append(rawSigningKey1, rawSigningKey2...)).ServeHTTP(w, r)
append(rawSigningKey1, rawSigningKey2...),
hpkePrivateKey.PublicKey(),
).ServeHTTP(w, r)
var expect any = map[string]any{ var expect any = map[string]any{
"keys": []any{ "keys": []any{
@ -77,12 +70,6 @@ func TestJWKSHandler(t *testing.T) {
"x": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).X.Bytes()), "x": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).X.Bytes()),
"y": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).Y.Bytes()), "y": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).Y.Bytes()),
}, },
map[string]any{
"kty": "OKP",
"kid": "pomerium/hpke",
"crv": "X25519",
"x": hpkePrivateKey.PublicKey().String(),
},
}, },
} }
var actual any var actual any

View file

@ -13,6 +13,9 @@ import (
"github.com/pomerium/pomerium/pkg/hpke" "github.com/pomerium/pomerium/pkg/hpke"
) )
// HPKEPublicKeyPath is the well-known path to the HPKE public key
const HPKEPublicKeyPath = "/.well-known/pomerium/hpke-public-key"
// DefaultDeviceType is the default device type when none is specified. // DefaultDeviceType is the default device type when none is specified.
const DefaultDeviceType = "any" const DefaultDeviceType = "any"

View file

@ -97,6 +97,16 @@ type PublicKey struct {
key kem.PublicKey key kem.PublicKey
} }
// PublicKeyFromBytes converts raw bytes into a public key.
func PublicKeyFromBytes(raw []byte) (*PublicKey, error) {
key, err := kemID.Scheme().UnmarshalBinaryPublicKey(raw)
if err != nil {
return nil, err
}
return &PublicKey{key: key}, nil
}
// PublicKeyFromString converts a string into a public key. // PublicKeyFromString converts a string into a public key.
func PublicKeyFromString(raw string) (*PublicKey, error) { func PublicKeyFromString(raw string) (*PublicKey, error) {
bs, err := decode(raw) bs, err := decode(raw)
@ -128,6 +138,20 @@ func (key *PublicKey) Equals(other *PublicKey) bool {
return key.key.Equal(other.key) return key.key.Equal(other.key)
} }
// Bytes returns the public key as raw bytes.
func (key *PublicKey) Bytes() []byte {
if key == nil || key.key == nil {
return nil
}
bs, err := key.key.MarshalBinary()
if err != nil {
// this should not happen
panic(fmt.Sprintf("failed to marshal public HPKE key: %v", err))
}
return bs
}
// MarshalJSON returns the JSON Web Key representation of the public key. // MarshalJSON returns the JSON Web Key representation of the public key.
func (key *PublicKey) MarshalJSON() ([]byte, error) { func (key *PublicKey) MarshalJSON() ([]byte, error) {
return json.Marshal(JWK{ return json.Marshal(JWK{

View file

@ -2,7 +2,6 @@ package hpke
import ( import (
"context" "context"
"encoding/json"
"fmt" "fmt"
"io" "io"
"net/http" "net/http"
@ -28,42 +27,28 @@ type JWK struct {
D string `json:"d,omitempty"` D string `json:"d,omitempty"`
} }
// FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint. // FetchPublicKey fetches the HPKE public key from the hpke-public-key endpoint.
func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) { func FetchPublicKey(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
if err != nil { if err != nil {
return nil, fmt.Errorf("hpke: error building jwks http request: %w", err) return nil, fmt.Errorf("hpke: error building hpke-public-key http request: %w", err)
} }
res, err := client.Do(req) res, err := client.Do(req)
if err != nil { if err != nil {
return nil, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err) return nil, fmt.Errorf("hpke: error requesting hpke-public-key endpoint: %w", err)
} }
defer res.Body.Close() defer res.Body.Close()
if res.StatusCode/100 != 2 { if res.StatusCode/100 != 2 {
return nil, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode) return nil, fmt.Errorf("hpke: error requesting hpke-public-key endpoint, invalid status code: %d", res.StatusCode)
} }
bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize)) bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize))
if err != nil { if err != nil {
return nil, fmt.Errorf("hpke: error reading jwks endpoint: %w", err) return nil, fmt.Errorf("hpke: error reading hpke-public-key endpoint: %w", err)
} }
return PublicKeyFromBytes(bs)
var jwks struct {
Keys []JWK `json:"keys"`
}
err = json.Unmarshal(bs, &jwks)
if err != nil {
return nil, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err)
}
for _, key := range jwks.Keys {
if key.ID == jwkID && key.Type == jwkType && key.Curve == jwkCurve {
return PublicKeyFromString(key.X)
}
}
return nil, fmt.Errorf("hpke key not found in JWKS endpoint")
} }
// A KeyFetcher fetches public keys. // A KeyFetcher fetches public keys.
@ -71,18 +56,18 @@ type KeyFetcher interface {
FetchPublicKey(ctx context.Context) (*PublicKey, error) FetchPublicKey(ctx context.Context) (*PublicKey, error)
} }
type jwksKeyFetcher struct { type fetcher struct {
client *http.Client client *http.Client
endpoint string endpoint string
} }
func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) { func (fetcher *fetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) {
return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint) return FetchPublicKey(ctx, fetcher.client, fetcher.endpoint)
} }
// NewKeyFetcher returns a new KeyFetcher which fetches keys using an in-memory HTTP cache. // NewKeyFetcher returns a new KeyFetcher which fetches keys using an in-memory HTTP cache.
func NewKeyFetcher(endpoint string, transport http.RoundTripper) KeyFetcher { func NewKeyFetcher(endpoint string, transport http.RoundTripper) KeyFetcher {
return &jwksKeyFetcher{ return &fetcher{
client: (&httpcache.Transport{ client: (&httpcache.Transport{
Transport: transport, Transport: transport,
Cache: httpcache.NewMemoryCache(), Cache: httpcache.NewMemoryCache(),

View file

@ -24,11 +24,11 @@ func TestFetchPublicKeyFromJWKS(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handlers.JWKSHandler(nil, hpkePrivateKey.PublicKey()).ServeHTTP(w, r) handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
})) }))
t.Cleanup(srv.Close) t.Cleanup(srv.Close)
publicKey, err := hpke.FetchPublicKeyFromJWKS(ctx, http.DefaultClient, srv.URL) publicKey, err := hpke.FetchPublicKey(ctx, http.DefaultClient, srv.URL)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String()) assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String())
} }

View file

@ -30,13 +30,10 @@ func testOptions(t *testing.T) *config.Options {
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=" opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw=" opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
htpkePrivateKey, err := opts.GetHPKEPrivateKey() hpkePrivateKey, err := opts.GetHPKEPrivateKey()
require.NoError(t, err) require.NoError(t, err)
signingKey, err := opts.GetSigningKey() authnSrv := httptest.NewServer(handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()))
require.NoError(t, err)
authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey()))
t.Cleanup(authnSrv.Close) t.Cleanup(authnSrv.Close)
opts.AuthenticateURLString = authnSrv.URL opts.AuthenticateURLString = authnSrv.URL