mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-13 09:07:44 +02:00
hpke: move published public keys to a new endpoint (#4044)
This commit is contained in:
parent
74463c5468
commit
0f295d4a63
13 changed files with 136 additions and 71 deletions
|
@ -30,13 +30,10 @@ func TestAuthorize_handleResult(t *testing.T) {
|
||||||
opt.DataBrokerURLString = "https://databroker.example.com"
|
opt.DataBrokerURLString = "https://databroker.example.com"
|
||||||
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
|
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
|
||||||
|
|
||||||
htpkePrivateKey, err := opt.GetHPKEPrivateKey()
|
hpkePrivateKey, err := opt.GetHPKEPrivateKey()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
signingKey, err := opt.GetSigningKey()
|
authnSrv := httptest.NewServer(handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()))
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey()))
|
|
||||||
t.Cleanup(authnSrv.Close)
|
t.Cleanup(authnSrv.Close)
|
||||||
opt.AuthenticateURLString = authnSrv.URL
|
opt.AuthenticateURLString = authnSrv.URL
|
||||||
|
|
||||||
|
@ -228,13 +225,10 @@ func TestRequireLogin(t *testing.T) {
|
||||||
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
|
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
|
||||||
opt.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUJlMFRxbXJkSXBZWE03c3pSRERWYndXOS83RWJHVWhTdFFJalhsVHNXM1BvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFb0xaRDI2bEdYREhRQmhhZkdlbEVmRDdlNmYzaURjWVJPVjdUbFlIdHF1Y1BFL2hId2dmYQpNY3FBUEZsRmpueUpySXJhYTFlQ2xZRTJ6UktTQk5kNXBRPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
|
opt.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUJlMFRxbXJkSXBZWE03c3pSRERWYndXOS83RWJHVWhTdFFJalhsVHNXM1BvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFb0xaRDI2bEdYREhRQmhhZkdlbEVmRDdlNmYzaURjWVJPVjdUbFlIdHF1Y1BFL2hId2dmYQpNY3FBUEZsRmpueUpySXJhYTFlQ2xZRTJ6UktTQk5kNXBRPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
|
||||||
|
|
||||||
htpkePrivateKey, err := opt.GetHPKEPrivateKey()
|
hpkePrivateKey, err := opt.GetHPKEPrivateKey()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
signingKey, err := opt.GetSigningKey()
|
authnSrv := httptest.NewServer(handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()))
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey()))
|
|
||||||
t.Cleanup(authnSrv.Close)
|
t.Cleanup(authnSrv.Close)
|
||||||
opt.AuthenticateURLString = authnSrv.URL
|
opt.AuthenticateURLString = authnSrv.URL
|
||||||
|
|
||||||
|
|
|
@ -14,6 +14,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/derivecert"
|
"github.com/pomerium/pomerium/pkg/derivecert"
|
||||||
"github.com/pomerium/pomerium/pkg/hpke"
|
"github.com/pomerium/pomerium/pkg/hpke"
|
||||||
|
@ -253,7 +254,7 @@ func (cfg *Config) GetAuthenticateKeyFetcher() (hpke.KeyFetcher, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
jwksURL := authenticateURL.ResolveReference(&url.URL{
|
jwksURL := authenticateURL.ResolveReference(&url.URL{
|
||||||
Path: "/.well-known/pomerium/jwks.json",
|
Path: urlutil.HPKEPublicKeyPath,
|
||||||
}).String()
|
}).String()
|
||||||
return hpke.NewKeyFetcher(jwksURL, transport), nil
|
return hpke.NewKeyFetcher(jwksURL, transport), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry"
|
"github.com/pomerium/pomerium/internal/telemetry"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
"github.com/pomerium/pomerium/internal/telemetry/requestid"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (srv *Server) addHTTPMiddleware(root *mux.Router, cfg *config.Config) {
|
func (srv *Server) addHTTPMiddleware(root *mux.Router, cfg *config.Config) {
|
||||||
|
@ -68,6 +69,7 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er
|
||||||
root.HandleFunc("/ping", handlers.HealthCheck)
|
root.HandleFunc("/ping", handlers.HealthCheck)
|
||||||
root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL))
|
root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL))
|
||||||
root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL))
|
root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL))
|
||||||
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(signingKey, hpkePublicKey))
|
root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(signingKey))
|
||||||
|
root.Path(urlutil.HPKEPublicKeyPath).Methods(http.MethodGet).Handler(handlers.HPKEPublicKeyHandler(hpkePublicKey))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
@ -77,14 +78,22 @@ func TestServerHTTP(t *testing.T) {
|
||||||
"x": "UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo",
|
"x": "UG5xCP0JTT1H6Iol8jKuTIPVLM04CgW9PlEypNRmWlo",
|
||||||
"y": "KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ",
|
"y": "KChF0fR09zm884ymInM29PtSsFdnzExNfLsP-ta1AgQ",
|
||||||
},
|
},
|
||||||
map[string]any{
|
|
||||||
"kty": "OKP",
|
|
||||||
"kid": "pomerium/hpke",
|
|
||||||
"crv": "X25519",
|
|
||||||
"x": "T0cbNrJbO9in-FgowKAP-HX6Ci8q50gopOt52sdheHg",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
assert.Equal(t, expect, actual)
|
assert.Equal(t, expect, actual)
|
||||||
})
|
})
|
||||||
|
t.Run("hpke-public-key", func(t *testing.T) {
|
||||||
|
res, err := http.Get(fmt.Sprintf("http://localhost:%s/.well-known/pomerium/hpke-public-key", src.GetConfig().HTTPPort))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer res.Body.Close()
|
||||||
|
|
||||||
|
bs, err := io.ReadAll(res.Body)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, []byte{
|
||||||
|
0x4f, 0x47, 0x1b, 0x36, 0xb2, 0x5b, 0x3b, 0xd8,
|
||||||
|
0xa7, 0xf8, 0x58, 0x28, 0xc0, 0xa0, 0x0f, 0xf8,
|
||||||
|
0x75, 0xfa, 0x0a, 0x2f, 0x2a, 0xe7, 0x48, 0x28,
|
||||||
|
0xa4, 0xeb, 0x79, 0xda, 0xc7, 0x61, 0x78, 0x78,
|
||||||
|
}, bs)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
33
internal/handlers/hpke_public_key.go
Normal file
33
internal/handlers/hpke_public_key.go
Normal file
|
@ -0,0 +1,33 @@
|
||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"hash/fnv"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rs/cors"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/hpke"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HPKEPublicKeyHandler returns a handler which returns the HPKE public key.
|
||||||
|
func HPKEPublicKeyHandler(publicKey *hpke.PublicKey) http.Handler {
|
||||||
|
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
bs := publicKey.Bytes()
|
||||||
|
|
||||||
|
hasher := fnv.New64()
|
||||||
|
_, _ = hasher.Write(bs)
|
||||||
|
h := hasher.Sum64()
|
||||||
|
|
||||||
|
w.Header().Set("Cache-Control", "max-age=60")
|
||||||
|
w.Header().Set("Content-Type", "application/octet-stream")
|
||||||
|
w.Header().Set("Content-Length", strconv.Itoa(len(bs)))
|
||||||
|
w.Header().Set("ETag", fmt.Sprintf(`"%x"`, h))
|
||||||
|
http.ServeContent(w, r, "hpke-public-key", time.Time{}, bytes.NewReader(bs))
|
||||||
|
return nil
|
||||||
|
}))
|
||||||
|
}
|
34
internal/handlers/hpke_public_key_test.go
Normal file
34
internal/handlers/hpke_public_key_test.go
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
package handlers_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/handlers"
|
||||||
|
"github.com/pomerium/pomerium/pkg/hpke"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestHPKEPublicKeyHandler(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
k1 := hpke.DerivePrivateKey([]byte("TEST"))
|
||||||
|
|
||||||
|
t.Run("cors", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
|
r.Header.Set("Origin", "https://www.example.com")
|
||||||
|
r.Header.Set("Access-Control-Request-Method", "GET")
|
||||||
|
handlers.HPKEPublicKeyHandler(k1.PublicKey()).ServeHTTP(w, r)
|
||||||
|
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||||
|
})
|
||||||
|
t.Run("keys", func(t *testing.T) {
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
handlers.HPKEPublicKeyHandler(k1.PublicKey()).ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, k1.PublicKey().Bytes(), w.Body.Bytes())
|
||||||
|
})
|
||||||
|
}
|
|
@ -17,10 +17,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
// JWKSHandler returns the /.well-known/pomerium/jwks.json handler.
|
// JWKSHandler returns the /.well-known/pomerium/jwks.json handler.
|
||||||
func JWKSHandler(
|
func JWKSHandler(signingKey []byte) http.Handler {
|
||||||
signingKey []byte,
|
|
||||||
additionalKeys ...any,
|
|
||||||
) http.Handler {
|
|
||||||
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
var jwks struct {
|
var jwks struct {
|
||||||
Keys []any `json:"keys"`
|
Keys []any `json:"keys"`
|
||||||
|
@ -34,7 +31,6 @@ func JWKSHandler(
|
||||||
jwks.Keys = append(jwks.Keys, *k)
|
jwks.Keys = append(jwks.Keys, *k)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
jwks.Keys = append(jwks.Keys, additionalKeys...)
|
|
||||||
|
|
||||||
bs, err := json.Marshal(jwks)
|
bs, err := json.Marshal(jwks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -16,7 +16,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/deterministicecdsa"
|
"github.com/pomerium/pomerium/internal/deterministicecdsa"
|
||||||
"github.com/pomerium/pomerium/internal/handlers"
|
"github.com/pomerium/pomerium/internal/handlers"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/hpke"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJWKSHandler(t *testing.T) {
|
func TestJWKSHandler(t *testing.T) {
|
||||||
|
@ -38,24 +37,18 @@ func TestJWKSHandler(t *testing.T) {
|
||||||
jwkSigningKey2, err := cryptutil.PublicJWKFromBytes(rawSigningKey2)
|
jwkSigningKey2, err := cryptutil.PublicJWKFromBytes(rawSigningKey2)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
hpkePrivateKey, err := hpke.GeneratePrivateKey()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
t.Run("cors", func(t *testing.T) {
|
t.Run("cors", func(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest(http.MethodOptions, "/", nil)
|
r := httptest.NewRequest(http.MethodOptions, "/", nil)
|
||||||
r.Header.Set("Origin", "https://www.example.com")
|
r.Header.Set("Origin", "https://www.example.com")
|
||||||
r.Header.Set("Access-Control-Request-Method", "GET")
|
r.Header.Set("Access-Control-Request-Method", "GET")
|
||||||
handlers.JWKSHandler(nil, hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
|
handlers.JWKSHandler(nil).ServeHTTP(w, r)
|
||||||
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
assert.Equal(t, http.StatusNoContent, w.Result().StatusCode)
|
||||||
})
|
})
|
||||||
t.Run("keys", func(t *testing.T) {
|
t.Run("keys", func(t *testing.T) {
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
handlers.JWKSHandler(
|
handlers.JWKSHandler(append(rawSigningKey1, rawSigningKey2...)).ServeHTTP(w, r)
|
||||||
append(rawSigningKey1, rawSigningKey2...),
|
|
||||||
hpkePrivateKey.PublicKey(),
|
|
||||||
).ServeHTTP(w, r)
|
|
||||||
|
|
||||||
var expect any = map[string]any{
|
var expect any = map[string]any{
|
||||||
"keys": []any{
|
"keys": []any{
|
||||||
|
@ -77,12 +70,6 @@ func TestJWKSHandler(t *testing.T) {
|
||||||
"x": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).X.Bytes()),
|
"x": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).X.Bytes()),
|
||||||
"y": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).Y.Bytes()),
|
"y": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).Y.Bytes()),
|
||||||
},
|
},
|
||||||
map[string]any{
|
|
||||||
"kty": "OKP",
|
|
||||||
"kid": "pomerium/hpke",
|
|
||||||
"crv": "X25519",
|
|
||||||
"x": hpkePrivateKey.PublicKey().String(),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
var actual any
|
var actual any
|
||||||
|
|
|
@ -13,6 +13,9 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/hpke"
|
"github.com/pomerium/pomerium/pkg/hpke"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// HPKEPublicKeyPath is the well-known path to the HPKE public key
|
||||||
|
const HPKEPublicKeyPath = "/.well-known/pomerium/hpke-public-key"
|
||||||
|
|
||||||
// DefaultDeviceType is the default device type when none is specified.
|
// DefaultDeviceType is the default device type when none is specified.
|
||||||
const DefaultDeviceType = "any"
|
const DefaultDeviceType = "any"
|
||||||
|
|
||||||
|
|
|
@ -97,6 +97,16 @@ type PublicKey struct {
|
||||||
key kem.PublicKey
|
key kem.PublicKey
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PublicKeyFromBytes converts raw bytes into a public key.
|
||||||
|
func PublicKeyFromBytes(raw []byte) (*PublicKey, error) {
|
||||||
|
key, err := kemID.Scheme().UnmarshalBinaryPublicKey(raw)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return &PublicKey{key: key}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// PublicKeyFromString converts a string into a public key.
|
// PublicKeyFromString converts a string into a public key.
|
||||||
func PublicKeyFromString(raw string) (*PublicKey, error) {
|
func PublicKeyFromString(raw string) (*PublicKey, error) {
|
||||||
bs, err := decode(raw)
|
bs, err := decode(raw)
|
||||||
|
@ -128,6 +138,20 @@ func (key *PublicKey) Equals(other *PublicKey) bool {
|
||||||
return key.key.Equal(other.key)
|
return key.key.Equal(other.key)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Bytes returns the public key as raw bytes.
|
||||||
|
func (key *PublicKey) Bytes() []byte {
|
||||||
|
if key == nil || key.key == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bs, err := key.key.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
// this should not happen
|
||||||
|
panic(fmt.Sprintf("failed to marshal public HPKE key: %v", err))
|
||||||
|
}
|
||||||
|
return bs
|
||||||
|
}
|
||||||
|
|
||||||
// MarshalJSON returns the JSON Web Key representation of the public key.
|
// MarshalJSON returns the JSON Web Key representation of the public key.
|
||||||
func (key *PublicKey) MarshalJSON() ([]byte, error) {
|
func (key *PublicKey) MarshalJSON() ([]byte, error) {
|
||||||
return json.Marshal(JWK{
|
return json.Marshal(JWK{
|
||||||
|
|
|
@ -2,7 +2,6 @@ package hpke
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -28,42 +27,28 @@ type JWK struct {
|
||||||
D string `json:"d,omitempty"`
|
D string `json:"d,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// FetchPublicKeyFromJWKS fetches the HPKE public key from the JWKS endpoint.
|
// FetchPublicKey fetches the HPKE public key from the hpke-public-key endpoint.
|
||||||
func FetchPublicKeyFromJWKS(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) {
|
func FetchPublicKey(ctx context.Context, client *http.Client, endpoint string) (*PublicKey, error) {
|
||||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("hpke: error building jwks http request: %w", err)
|
return nil, fmt.Errorf("hpke: error building hpke-public-key http request: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := client.Do(req)
|
res, err := client.Do(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("hpke: error requesting jwks endpoint: %w", err)
|
return nil, fmt.Errorf("hpke: error requesting hpke-public-key endpoint: %w", err)
|
||||||
}
|
}
|
||||||
defer res.Body.Close()
|
defer res.Body.Close()
|
||||||
|
|
||||||
if res.StatusCode/100 != 2 {
|
if res.StatusCode/100 != 2 {
|
||||||
return nil, fmt.Errorf("hpke: error requesting jwks endpoint, invalid status code: %d", res.StatusCode)
|
return nil, fmt.Errorf("hpke: error requesting hpke-public-key endpoint, invalid status code: %d", res.StatusCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize))
|
bs, err := io.ReadAll(io.LimitReader(res.Body, defaultMaxBodySize))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("hpke: error reading jwks endpoint: %w", err)
|
return nil, fmt.Errorf("hpke: error reading hpke-public-key endpoint: %w", err)
|
||||||
}
|
}
|
||||||
|
return PublicKeyFromBytes(bs)
|
||||||
var jwks struct {
|
|
||||||
Keys []JWK `json:"keys"`
|
|
||||||
}
|
|
||||||
err = json.Unmarshal(bs, &jwks)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("hpke: error unmarshaling jwks endpoint: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, key := range jwks.Keys {
|
|
||||||
if key.ID == jwkID && key.Type == jwkType && key.Curve == jwkCurve {
|
|
||||||
return PublicKeyFromString(key.X)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("hpke key not found in JWKS endpoint")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// A KeyFetcher fetches public keys.
|
// A KeyFetcher fetches public keys.
|
||||||
|
@ -71,18 +56,18 @@ type KeyFetcher interface {
|
||||||
FetchPublicKey(ctx context.Context) (*PublicKey, error)
|
FetchPublicKey(ctx context.Context) (*PublicKey, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type jwksKeyFetcher struct {
|
type fetcher struct {
|
||||||
client *http.Client
|
client *http.Client
|
||||||
endpoint string
|
endpoint string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (fetcher *jwksKeyFetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) {
|
func (fetcher *fetcher) FetchPublicKey(ctx context.Context) (*PublicKey, error) {
|
||||||
return FetchPublicKeyFromJWKS(ctx, fetcher.client, fetcher.endpoint)
|
return FetchPublicKey(ctx, fetcher.client, fetcher.endpoint)
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewKeyFetcher returns a new KeyFetcher which fetches keys using an in-memory HTTP cache.
|
// NewKeyFetcher returns a new KeyFetcher which fetches keys using an in-memory HTTP cache.
|
||||||
func NewKeyFetcher(endpoint string, transport http.RoundTripper) KeyFetcher {
|
func NewKeyFetcher(endpoint string, transport http.RoundTripper) KeyFetcher {
|
||||||
return &jwksKeyFetcher{
|
return &fetcher{
|
||||||
client: (&httpcache.Transport{
|
client: (&httpcache.Transport{
|
||||||
Transport: transport,
|
Transport: transport,
|
||||||
Cache: httpcache.NewMemoryCache(),
|
Cache: httpcache.NewMemoryCache(),
|
|
@ -24,11 +24,11 @@ func TestFetchPublicKeyFromJWKS(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
handlers.JWKSHandler(nil, hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
|
handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()).ServeHTTP(w, r)
|
||||||
}))
|
}))
|
||||||
t.Cleanup(srv.Close)
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
publicKey, err := hpke.FetchPublicKeyFromJWKS(ctx, http.DefaultClient, srv.URL)
|
publicKey, err := hpke.FetchPublicKey(ctx, http.DefaultClient, srv.URL)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String())
|
assert.Equal(t, hpkePrivateKey.PublicKey().String(), publicKey.String())
|
||||||
}
|
}
|
|
@ -30,13 +30,10 @@ func testOptions(t *testing.T) *config.Options {
|
||||||
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
|
opts.SharedKey = "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="
|
||||||
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
|
opts.CookieSecret = "OromP1gurwGWjQPYb1nNgSxtbVB5NnLzX6z5WOKr0Yw="
|
||||||
|
|
||||||
htpkePrivateKey, err := opts.GetHPKEPrivateKey()
|
hpkePrivateKey, err := opts.GetHPKEPrivateKey()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
signingKey, err := opts.GetSigningKey()
|
authnSrv := httptest.NewServer(handlers.HPKEPublicKeyHandler(hpkePrivateKey.PublicKey()))
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey()))
|
|
||||||
t.Cleanup(authnSrv.Close)
|
t.Cleanup(authnSrv.Close)
|
||||||
opts.AuthenticateURLString = authnSrv.URL
|
opts.AuthenticateURLString = authnSrv.URL
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue