diff --git a/authenticate/authenticate_test.go b/authenticate/authenticate_test.go index 124e8c2a4..5583b645a 100644 --- a/authenticate/authenticate_test.go +++ b/authenticate/authenticate_test.go @@ -89,10 +89,7 @@ func TestNew(t *testing.T) { goodSigningKey.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=" badSigningKey := newTestOptions(t) - badSigningKey.SigningKey = "%" - - badSigninKeyPublic := newTestOptions(t) - badSigninKeyPublic.SigningKey = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUJFakNCdWdJSkFNWUdtVzhpYWd1TU1Bb0dDQ3FHU000OUJBTUNNQkV4RHpBTkJnTlZCQU1NQm5WdWRYTmwKWkRBZ0Z3MHlNREExTWpJeU1EUTFNalJhR0E4ME56VTRNRFF4T1RJd05EVXlORm93RVRFUE1BMEdBMVVFQXd3RwpkVzUxYzJWa01Ga3dFd1lIS29aSXpqMENBUVlJS29aSXpqMERBUWNEUWdBRVVHNXhDUDBKVFQxSDZJb2w4akt1ClRJUFZMTTA0Q2dXOVBsRXlwTlJtV2xvb0tFWFI5SFQzT2J6empLWWljemIwKzFLd1YyZk1URTE4dXcvNjFyVUMKQkRBS0JnZ3Foa2pPUFFRREFnTkhBREJFQWlBSFFDUFh2WG5oeHlDTGNhZ3N3eWt4RUM1NFV5RmdyUVJVRmVCYwpPUzVCSFFJZ1Y3T2FXY2pMeHdsRlIrWDZTQ2daZDI5bXBtOVZKNnpXQURhWGdEN3FURW89Ci0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K" + badSigningKey.SigningKey = "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tCk1JSUJFakNCdWdJSkFNWUdtVzhpYWd1TU1Bb0dDQ3FHU000OUJBTUNNQkV4RHpBTkJnTlZCQU1NQm5WdWRYTmwKWkRBZ0Z3MHlNREExTWpJeU1EUTFNalJhR0E4ME56VTRNRFF4T1RJd05EVXlORm93RVRFUE1BMEdBMVVFQXd3RwpkVzUxYzJWa01Ga3dFd1lIS29aSXpqMENBUVlJS29aSXpqMERBUWNEUWdBRVVHNXhDUDBKVFQxSDZJb2w4akt1ClRJUFZMTTA0Q2dXOVBsRXlwTlJtV2xvb0tFWFI5SFQzT2J6empLWWljemIwKzFLd1YyZk1URTE4dXcvNjFyVUMKQkRBS0JnZ3Foa2pPUFFRREFnTkhBREJFQWlBSFFDUFh2WG5oeHlDTGNhZ3N3eWt4RUM1NFV5RmdyUVJVRmVCYwpPUzVCSFFJZ1Y3T2FXY2pMeHdsRlIrWDZTQ2daZDI5bXBtOVZKNnpXQURhWGdEN3FURW89Ci0tLS0tRU5EIENFUlRJRklDQVRFLS0tLS0K" tests := []struct { name string @@ -105,7 +102,6 @@ func TestNew(t *testing.T) { {"fails to validate", badRedirectURL, true}, {"good signing key", goodSigningKey, false}, {"bad signing key", badSigningKey, true}, - {"bad public signing key", badSigninKeyPublic, true}, } for _, tt := range tests { tt := tt diff --git a/authenticate/state.go b/authenticate/state.go index 8813aba65..200e68976 100644 --- a/authenticate/state.go +++ b/authenticate/state.go @@ -2,7 +2,6 @@ package authenticate import ( "crypto/cipher" - "encoding/base64" "fmt" "net/url" @@ -115,16 +114,14 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err if err != nil { return nil, err } - if signingKey != "" { - decodedCert, err := base64.StdEncoding.DecodeString(cfg.Options.SigningKey) - if err != nil { - return nil, fmt.Errorf("authenticate: failed to decode signing key: %w", err) - } - jwk, err := cryptutil.PublicJWKFromBytes(decodedCert) + if len(signingKey) > 0 { + ks, err := cryptutil.PublicJWKsFromBytes(signingKey) if err != nil { return nil, fmt.Errorf("authenticate: failed to convert jwks: %w", err) } - state.jwk.Keys = append(state.jwk.Keys, *jwk) + for _, k := range ks { + state.jwk.Keys = append(state.jwk.Keys, *k) + } } sharedKey, err := cfg.Options.GetSharedKey() diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index 0b9b0fa85..4f9e93695 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -33,7 +33,10 @@ func TestAuthorize_handleResult(t *testing.T) { htpkePrivateKey, err := opt.GetHPKEPrivateKey() require.NoError(t, err) - authnSrv := httptest.NewServer(handlers.JWKSHandler(opt.SigningKey, htpkePrivateKey.PublicKey())) + signingKey, err := opt.GetSigningKey() + require.NoError(t, err) + + authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey())) t.Cleanup(authnSrv.Close) opt.AuthenticateURLString = authnSrv.URL @@ -198,7 +201,10 @@ func TestRequireLogin(t *testing.T) { htpkePrivateKey, err := opt.GetHPKEPrivateKey() require.NoError(t, err) - authnSrv := httptest.NewServer(handlers.JWKSHandler(opt.SigningKey, htpkePrivateKey.PublicKey())) + signingKey, err := opt.GetSigningKey() + require.NoError(t, err) + + authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey())) t.Cleanup(authnSrv.Close) opt.AuthenticateURLString = authnSrv.URL diff --git a/authorize/evaluator/config.go b/authorize/evaluator/config.go index 4230a6c26..9c7ecab2d 100644 --- a/authorize/evaluator/config.go +++ b/authorize/evaluator/config.go @@ -7,7 +7,7 @@ import ( type evaluatorConfig struct { policies []config.Policy clientCA []byte - signingKey string + signingKey []byte authenticateURL string googleCloudServerlessAuthenticationServiceAccount string jwtClaimsHeaders config.JWTClaimHeaders @@ -39,7 +39,7 @@ func WithClientCA(clientCA []byte) Option { } // WithSigningKey sets the signing key and algorithm in the config. -func WithSigningKey(signingKey string) Option { +func WithSigningKey(signingKey []byte) Option { return func(cfg *evaluatorConfig) { cfg.signingKey = signingKey } diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index 6d3d4d6e3..36dd8d453 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -223,7 +223,7 @@ func (e *Evaluator) updateStore(cfg *evaluatorConfig) error { func getJWK(cfg *evaluatorConfig) (*jose.JSONWebKey, error) { var decodedCert []byte // if we don't have a signing key, generate one - if cfg.signingKey == "" { + if len(cfg.signingKey) == 0 { key, err := cryptutil.NewSigningKey() if err != nil { return nil, fmt.Errorf("couldn't generate signing key: %w", err) @@ -233,11 +233,7 @@ func getJWK(cfg *evaluatorConfig) (*jose.JSONWebKey, error) { return nil, fmt.Errorf("bad signing key: %w", err) } } else { - var err error - decodedCert, err = base64.StdEncoding.DecodeString(cfg.signingKey) - if err != nil { - return nil, fmt.Errorf("bad signing key: %w", err) - } + decodedCert = cfg.signingKey } jwk, err := cryptutil.PrivateJWKFromBytes(decodedCert) diff --git a/config/options.go b/config/options.go index 33442fa74..d912e18b2 100644 --- a/config/options.go +++ b/config/options.go @@ -1176,18 +1176,27 @@ func (o *Options) GetCookieSecret() ([]byte, error) { } // GetSigningKey gets the signing key. -func (o *Options) GetSigningKey() (string, error) { +func (o *Options) GetSigningKey() ([]byte, error) { if o == nil { - return "", nil + return nil, nil } + + rawSigningKey := o.SigningKey if o.SigningKeyFile != "" { bs, err := os.ReadFile(o.SigningKeyFile) if err != nil { - return "", err + return nil, err } - return string(bs), nil + rawSigningKey = string(bs) } - return o.SigningKey, nil + + rawSigningKey = strings.TrimSpace(rawSigningKey) + + if bs, err := base64.StdEncoding.DecodeString(rawSigningKey); err == nil { + return bs, nil + } + + return []byte(rawSigningKey), nil } // Checksum returns the checksum of the current options struct diff --git a/config/options_test.go b/config/options_test.go index f0db8d017..37f6d2e41 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -8,6 +8,7 @@ import ( "fmt" "net/url" "os" + "path/filepath" "sync" "testing" "time" @@ -780,6 +781,110 @@ func TestOptions_GetSharedKey(t *testing.T) { }) } +func TestOptions_GetSigningKey(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + input string + output []byte + err error + }{ + {"missing", "", []byte{}, nil}, + {"pem", ` +-----BEGIN EC PRIVATE KEY----- +MHQCAQEEIGGh6FlBe8yy9dRJgm+35lj3naGFtDODOf6leCW1bRGwoAcGBSuBBAAK +oUQDQgAE7UlKcFatc9m3GinCrhhT2oRQZ/bEwS98iEUXr0DR8GdxH3e4fhnicsNB +jHOCur7NYTgf5VaPJwIqLGBmTwM0ew== +-----END EC PRIVATE KEY----- + +-----BEGIN EC PRIVATE KEY----- +MHQCAQEEIBo4wSjkFqQrzf2APNnPol8EDZzkhpcMSaEWXg8iOkbOoAcGBSuBBAAK +oUQDQgAEr+bGqssRv8RxPV2jJbDpMw81AVXr5+Q2pIF4u6xD9r56lst8uHYThPsw +ypaqswFIkSzQSW8awdWJ5d+1DEJRUQ== +-----END EC PRIVATE KEY----- + `, []byte{ + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52, + 0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x4d, + 0x48, 0x51, 0x43, 0x41, 0x51, 0x45, 0x45, 0x49, 0x47, 0x47, 0x68, 0x36, 0x46, 0x6c, 0x42, 0x65, + 0x38, 0x79, 0x79, 0x39, 0x64, 0x52, 0x4a, 0x67, 0x6d, 0x2b, 0x33, 0x35, 0x6c, 0x6a, 0x33, 0x6e, + 0x61, 0x47, 0x46, 0x74, 0x44, 0x4f, 0x44, 0x4f, 0x66, 0x36, 0x6c, 0x65, 0x43, 0x57, 0x31, 0x62, + 0x52, 0x47, 0x77, 0x6f, 0x41, 0x63, 0x47, 0x42, 0x53, 0x75, 0x42, 0x42, 0x41, 0x41, 0x4b, 0x0a, + 0x6f, 0x55, 0x51, 0x44, 0x51, 0x67, 0x41, 0x45, 0x37, 0x55, 0x6c, 0x4b, 0x63, 0x46, 0x61, 0x74, + 0x63, 0x39, 0x6d, 0x33, 0x47, 0x69, 0x6e, 0x43, 0x72, 0x68, 0x68, 0x54, 0x32, 0x6f, 0x52, 0x51, + 0x5a, 0x2f, 0x62, 0x45, 0x77, 0x53, 0x39, 0x38, 0x69, 0x45, 0x55, 0x58, 0x72, 0x30, 0x44, 0x52, + 0x38, 0x47, 0x64, 0x78, 0x48, 0x33, 0x65, 0x34, 0x66, 0x68, 0x6e, 0x69, 0x63, 0x73, 0x4e, 0x42, + 0x0a, 0x6a, 0x48, 0x4f, 0x43, 0x75, 0x72, 0x37, 0x4e, 0x59, 0x54, 0x67, 0x66, 0x35, 0x56, 0x61, + 0x50, 0x4a, 0x77, 0x49, 0x71, 0x4c, 0x47, 0x42, 0x6d, 0x54, 0x77, 0x4d, 0x30, 0x65, 0x77, 0x3d, + 0x3d, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52, + 0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52, + 0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x4d, + 0x48, 0x51, 0x43, 0x41, 0x51, 0x45, 0x45, 0x49, 0x42, 0x6f, 0x34, 0x77, 0x53, 0x6a, 0x6b, 0x46, + 0x71, 0x51, 0x72, 0x7a, 0x66, 0x32, 0x41, 0x50, 0x4e, 0x6e, 0x50, 0x6f, 0x6c, 0x38, 0x45, 0x44, + 0x5a, 0x7a, 0x6b, 0x68, 0x70, 0x63, 0x4d, 0x53, 0x61, 0x45, 0x57, 0x58, 0x67, 0x38, 0x69, 0x4f, + 0x6b, 0x62, 0x4f, 0x6f, 0x41, 0x63, 0x47, 0x42, 0x53, 0x75, 0x42, 0x42, 0x41, 0x41, 0x4b, 0x0a, + 0x6f, 0x55, 0x51, 0x44, 0x51, 0x67, 0x41, 0x45, 0x72, 0x2b, 0x62, 0x47, 0x71, 0x73, 0x73, 0x52, + 0x76, 0x38, 0x52, 0x78, 0x50, 0x56, 0x32, 0x6a, 0x4a, 0x62, 0x44, 0x70, 0x4d, 0x77, 0x38, 0x31, + 0x41, 0x56, 0x58, 0x72, 0x35, 0x2b, 0x51, 0x32, 0x70, 0x49, 0x46, 0x34, 0x75, 0x36, 0x78, 0x44, + 0x39, 0x72, 0x35, 0x36, 0x6c, 0x73, 0x74, 0x38, 0x75, 0x48, 0x59, 0x54, 0x68, 0x50, 0x73, 0x77, + 0x0a, 0x79, 0x70, 0x61, 0x71, 0x73, 0x77, 0x46, 0x49, 0x6b, 0x53, 0x7a, 0x51, 0x53, 0x57, 0x38, + 0x61, 0x77, 0x64, 0x57, 0x4a, 0x35, 0x64, 0x2b, 0x31, 0x44, 0x45, 0x4a, 0x52, 0x55, 0x51, 0x3d, + 0x3d, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52, + 0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, + }, nil}, + {"base64", ` +LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IUUNBUUVFSUdHaDZGbEJlOHl5OWRSSmdtKzM1bGozbmFHRnRET0RPZjZsZUNXMWJSR3dvQWNHQlN1QkJBQUsKb1VRRFFnQUU3VWxLY0ZhdGM5bTNHaW5DcmhoVDJvUlFaL2JFd1M5OGlFVVhyMERSOEdkeEgzZTRmaG5pY3NOQgpqSE9DdXI3TllUZ2Y1VmFQSndJcUxHQm1Ud00wZXc9PQotLS0tLUVORCBFQyBQUklWQVRFIEtFWS0tLS0tCgotLS0tLUJFR0lOIEVDIFBSSVZBVEUgS0VZLS0tLS0KTUhRQ0FRRUVJQm80d1Nqa0ZxUXJ6ZjJBUE5uUG9sOEVEWnpraHBjTVNhRVdYZzhpT2tiT29BY0dCU3VCQkFBSwpvVVFEUWdBRXIrYkdxc3NSdjhSeFBWMmpKYkRwTXc4MUFWWHI1K1EycElGNHU2eEQ5cjU2bHN0OHVIWVRoUHN3CnlwYXFzd0ZJa1N6UVNXOGF3ZFdKNWQrMURFSlJVUT09Ci0tLS0tRU5EIEVDIFBSSVZBVEUgS0VZLS0tLS0= + `, []byte{ + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52, + 0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x4d, + 0x48, 0x51, 0x43, 0x41, 0x51, 0x45, 0x45, 0x49, 0x47, 0x47, 0x68, 0x36, 0x46, 0x6c, 0x42, 0x65, + 0x38, 0x79, 0x79, 0x39, 0x64, 0x52, 0x4a, 0x67, 0x6d, 0x2b, 0x33, 0x35, 0x6c, 0x6a, 0x33, 0x6e, + 0x61, 0x47, 0x46, 0x74, 0x44, 0x4f, 0x44, 0x4f, 0x66, 0x36, 0x6c, 0x65, 0x43, 0x57, 0x31, 0x62, + 0x52, 0x47, 0x77, 0x6f, 0x41, 0x63, 0x47, 0x42, 0x53, 0x75, 0x42, 0x42, 0x41, 0x41, 0x4b, 0x0a, + 0x6f, 0x55, 0x51, 0x44, 0x51, 0x67, 0x41, 0x45, 0x37, 0x55, 0x6c, 0x4b, 0x63, 0x46, 0x61, 0x74, + 0x63, 0x39, 0x6d, 0x33, 0x47, 0x69, 0x6e, 0x43, 0x72, 0x68, 0x68, 0x54, 0x32, 0x6f, 0x52, 0x51, + 0x5a, 0x2f, 0x62, 0x45, 0x77, 0x53, 0x39, 0x38, 0x69, 0x45, 0x55, 0x58, 0x72, 0x30, 0x44, 0x52, + 0x38, 0x47, 0x64, 0x78, 0x48, 0x33, 0x65, 0x34, 0x66, 0x68, 0x6e, 0x69, 0x63, 0x73, 0x4e, 0x42, + 0x0a, 0x6a, 0x48, 0x4f, 0x43, 0x75, 0x72, 0x37, 0x4e, 0x59, 0x54, 0x67, 0x66, 0x35, 0x56, 0x61, + 0x50, 0x4a, 0x77, 0x49, 0x71, 0x4c, 0x47, 0x42, 0x6d, 0x54, 0x77, 0x4d, 0x30, 0x65, 0x77, 0x3d, + 0x3d, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52, + 0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a, + 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52, + 0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x4d, + 0x48, 0x51, 0x43, 0x41, 0x51, 0x45, 0x45, 0x49, 0x42, 0x6f, 0x34, 0x77, 0x53, 0x6a, 0x6b, 0x46, + 0x71, 0x51, 0x72, 0x7a, 0x66, 0x32, 0x41, 0x50, 0x4e, 0x6e, 0x50, 0x6f, 0x6c, 0x38, 0x45, 0x44, + 0x5a, 0x7a, 0x6b, 0x68, 0x70, 0x63, 0x4d, 0x53, 0x61, 0x45, 0x57, 0x58, 0x67, 0x38, 0x69, 0x4f, + 0x6b, 0x62, 0x4f, 0x6f, 0x41, 0x63, 0x47, 0x42, 0x53, 0x75, 0x42, 0x42, 0x41, 0x41, 0x4b, 0x0a, + 0x6f, 0x55, 0x51, 0x44, 0x51, 0x67, 0x41, 0x45, 0x72, 0x2b, 0x62, 0x47, 0x71, 0x73, 0x73, 0x52, + 0x76, 0x38, 0x52, 0x78, 0x50, 0x56, 0x32, 0x6a, 0x4a, 0x62, 0x44, 0x70, 0x4d, 0x77, 0x38, 0x31, + 0x41, 0x56, 0x58, 0x72, 0x35, 0x2b, 0x51, 0x32, 0x70, 0x49, 0x46, 0x34, 0x75, 0x36, 0x78, 0x44, + 0x39, 0x72, 0x35, 0x36, 0x6c, 0x73, 0x74, 0x38, 0x75, 0x48, 0x59, 0x54, 0x68, 0x50, 0x73, 0x77, + 0x0a, 0x79, 0x70, 0x61, 0x71, 0x73, 0x77, 0x46, 0x49, 0x6b, 0x53, 0x7a, 0x51, 0x53, 0x57, 0x38, + 0x61, 0x77, 0x64, 0x57, 0x4a, 0x35, 0x64, 0x2b, 0x31, 0x44, 0x45, 0x4a, 0x52, 0x55, 0x51, 0x3d, + 0x3d, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52, + 0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, + }, nil}, + } { + tc := tc + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + output, err := (&Options{SigningKey: tc.input}).GetSigningKey() + assert.Equal(t, tc.err, err) + assert.Equal(t, tc.output, output) + + dir := t.TempDir() + err = os.WriteFile(filepath.Join(dir, "cert"), []byte(tc.input), 0o0666) + assert.NoError(t, err) + + output, err = (&Options{SigningKeyFile: filepath.Join(dir, "cert")}).GetSigningKey() + assert.Equal(t, tc.err, err) + assert.Equal(t, tc.output, output) + }) + } +} + func TestOptions_GetCookieSecret(t *testing.T) { t.Run("default", func(t *testing.T) { o := NewDefaultOptions() diff --git a/internal/controlplane/http.go b/internal/controlplane/http.go index 90c693509..85ce1b4a0 100644 --- a/internal/controlplane/http.go +++ b/internal/controlplane/http.go @@ -53,7 +53,7 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er return fmt.Errorf("invalid authenticate URL: %w", err) } - rawSigningKey, err := cfg.Options.GetSigningKey() + signingKey, err := cfg.Options.GetSigningKey() if err != nil { return fmt.Errorf("invalid signing key: %w", err) } @@ -68,6 +68,6 @@ func (srv *Server) mountCommonEndpoints(root *mux.Router, cfg *config.Config) er root.HandleFunc("/ping", handlers.HealthCheck) root.Handle("/.well-known/pomerium", handlers.WellKnownPomerium(authenticateURL)) root.Handle("/.well-known/pomerium/", handlers.WellKnownPomerium(authenticateURL)) - root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(rawSigningKey, hpkePublicKey)) + root.Path("/.well-known/pomerium/jwks.json").Methods(http.MethodGet).Handler(handlers.JWKSHandler(signingKey, hpkePublicKey)) return nil } diff --git a/internal/handlers/jwks.go b/internal/handlers/jwks.go index 149a73b4f..68118b13f 100644 --- a/internal/handlers/jwks.go +++ b/internal/handlers/jwks.go @@ -2,7 +2,6 @@ package handlers import ( "bytes" - "encoding/base64" "encoding/json" "errors" "fmt" @@ -19,23 +18,21 @@ import ( // JWKSHandler returns the /.well-known/pomerium/jwks.json handler. func JWKSHandler( - rawSigningKey string, + signingKey []byte, additionalKeys ...any, ) http.Handler { return cors.AllowAll().Handler(httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { var jwks struct { Keys []any `json:"keys"` } - if rawSigningKey != "" { - decodedCert, err := base64.StdEncoding.DecodeString(rawSigningKey) - if err != nil { - return httputil.NewError(http.StatusInternalServerError, errors.New("bad base64 encoding for signing key")) - } - jwk, err := cryptutil.PublicJWKFromBytes(decodedCert) + if len(signingKey) > 0 { + ks, err := cryptutil.PublicJWKsFromBytes(signingKey) if err != nil { return httputil.NewError(http.StatusInternalServerError, errors.New("bad signing key")) } - jwks.Keys = append(jwks.Keys, *jwk) + for _, k := range ks { + jwks.Keys = append(jwks.Keys, *k) + } } jwks.Keys = append(jwks.Keys, additionalKeys...) diff --git a/internal/handlers/jwks_test.go b/internal/handlers/jwks_test.go index 335cd77eb..d05b0f78b 100644 --- a/internal/handlers/jwks_test.go +++ b/internal/handlers/jwks_test.go @@ -19,13 +19,19 @@ import ( func TestJWKSHandler(t *testing.T) { t.Parallel() - signingKey, err := cryptutil.NewSigningKey() + signingKey1, err := cryptutil.NewSigningKey() + require.NoError(t, err) + signingKey2, err := cryptutil.NewSigningKey() require.NoError(t, err) - rawSigningKey, err := cryptutil.EncodePrivateKey(signingKey) + rawSigningKey1, err := cryptutil.EncodePrivateKey(signingKey1) + require.NoError(t, err) + rawSigningKey2, err := cryptutil.EncodePrivateKey(signingKey2) require.NoError(t, err) - jwkSigningKey, err := cryptutil.PublicJWKFromBytes(rawSigningKey) + jwkSigningKey1, err := cryptutil.PublicJWKFromBytes(rawSigningKey1) + require.NoError(t, err) + jwkSigningKey2, err := cryptutil.PublicJWKFromBytes(rawSigningKey2) require.NoError(t, err) hpkePrivateKey, err := hpke.GeneratePrivateKey() @@ -36,24 +42,36 @@ func TestJWKSHandler(t *testing.T) { r := httptest.NewRequest(http.MethodOptions, "/", nil) r.Header.Set("Origin", "https://www.example.com") r.Header.Set("Access-Control-Request-Method", "GET") - handlers.JWKSHandler("", hpkePrivateKey.PublicKey()).ServeHTTP(w, r) + handlers.JWKSHandler(nil, hpkePrivateKey.PublicKey()).ServeHTTP(w, r) assert.Equal(t, http.StatusNoContent, w.Result().StatusCode) }) t.Run("keys", func(t *testing.T) { w := httptest.NewRecorder() r := httptest.NewRequest(http.MethodGet, "/", nil) - handlers.JWKSHandler(base64.StdEncoding.EncodeToString(rawSigningKey), hpkePrivateKey.PublicKey()).ServeHTTP(w, r) + handlers.JWKSHandler( + append(rawSigningKey1, rawSigningKey2...), + hpkePrivateKey.PublicKey(), + ).ServeHTTP(w, r) var expect any = map[string]any{ "keys": []any{ map[string]any{ "kty": "EC", - "kid": jwkSigningKey.KeyID, + "kid": jwkSigningKey1.KeyID, "crv": "P-256", "alg": "ES256", "use": "sig", - "x": base64.RawURLEncoding.EncodeToString(jwkSigningKey.Key.(*ecdsa.PublicKey).X.Bytes()), - "y": base64.RawURLEncoding.EncodeToString(jwkSigningKey.Key.(*ecdsa.PublicKey).Y.Bytes()), + "x": base64.RawURLEncoding.EncodeToString(jwkSigningKey1.Key.(*ecdsa.PublicKey).X.Bytes()), + "y": base64.RawURLEncoding.EncodeToString(jwkSigningKey1.Key.(*ecdsa.PublicKey).Y.Bytes()), + }, + map[string]any{ + "kty": "EC", + "kid": jwkSigningKey2.KeyID, + "crv": "P-256", + "alg": "ES256", + "use": "sig", + "x": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).X.Bytes()), + "y": base64.RawURLEncoding.EncodeToString(jwkSigningKey2.Key.(*ecdsa.PublicKey).Y.Bytes()), }, map[string]any{ "kty": "OKP", diff --git a/pkg/cryptutil/jose.go b/pkg/cryptutil/jose.go index 8d096d84c..3334050d7 100644 --- a/pkg/cryptutil/jose.go +++ b/pkg/cryptutil/jose.go @@ -15,52 +15,82 @@ import ( // PrivateJWKFromBytes returns a jose JSON Web _Private_ Key from bytes. func PrivateJWKFromBytes(data []byte) (*jose.JSONWebKey, error) { - return loadKey(data, loadPrivateKey) + jwks, err := loadKeys(data, loadPrivateKey) + if err != nil { + return nil, err + } else if len(jwks) == 0 { + return nil, fmt.Errorf("invalid pem data") + } + return jwks[0], nil +} + +// PrivateJWKsFromBytes returns jose JSON Web _Private_ Keys from bytes. +func PrivateJWKsFromBytes(data []byte) ([]*jose.JSONWebKey, error) { + return loadKeys(data, loadPrivateKey) } // PublicJWKFromBytes returns a jose JSON Web _Public_ Key from bytes. func PublicJWKFromBytes(data []byte) (*jose.JSONWebKey, error) { - return loadKey(data, loadPublicKey) -} - -func loadKey(data []byte, unmarshal func([]byte) (interface{}, error)) (*jose.JSONWebKey, error) { - block, _ := pem.Decode(data) - if block == nil { - return nil, fmt.Errorf("file contained no PEM encoded data") - } - priv, err := unmarshal(block.Bytes) - if err != nil { - return nil, fmt.Errorf("unmarshal key: %w", err) - } - alg, err := SignatureAlgorithmForKey(priv) + jwks, err := loadKeys(data, loadPublicKey) if err != nil { return nil, err + } else if len(jwks) == 0 { + return nil, fmt.Errorf("invalid pem data") } + return jwks[0], nil +} - key := &jose.JSONWebKey{Key: priv, Use: "sig", Algorithm: string(alg)} - thumbprint, err := key.Thumbprint(crypto.SHA256) - if err != nil { - return nil, fmt.Errorf("computing thumbprint: %w", err) +// PublicJWKsFromBytes returns jose JSON Web _Public_ Keys from bytes. +func PublicJWKsFromBytes(data []byte) ([]*jose.JSONWebKey, error) { + return loadKeys(data, loadPublicKey) +} + +func loadKeys(data []byte, unmarshal func([]byte) (any, error)) ([]*jose.JSONWebKey, error) { + var jwks []*jose.JSONWebKey + for { + var block *pem.Block + block, data = pem.Decode(data) + if block == nil { + break + } + + key, err := unmarshal(block.Bytes) + if err != nil { + return nil, fmt.Errorf("unmarshal key: %w", err) + } + + alg, err := SignatureAlgorithmForKey(key) + if err != nil { + return nil, err + } + + jwk := &jose.JSONWebKey{Key: key, Use: "sig", Algorithm: string(alg)} + thumbprint, err := jwk.Thumbprint(crypto.SHA256) + if err != nil { + return nil, fmt.Errorf("computing thumbprint: %w", err) + } + jwk.KeyID = hex.EncodeToString(thumbprint) + jwks = append(jwks, jwk) } - key.KeyID = hex.EncodeToString(thumbprint) - return key, nil + return jwks, nil } func loadPrivateKey(b []byte) (interface{}, error) { var wrappedErr error var err error + var key any - if key, err := x509.ParseECPrivateKey(b); err == nil { + if key, err = x509.ParseECPrivateKey(b); err == nil { return key, nil } wrappedErr = multierror.Append(wrappedErr, err) - if key, err := x509.ParsePKCS1PrivateKey(b); err == nil { + if key, err = x509.ParsePKCS1PrivateKey(b); err == nil { return key, nil } wrappedErr = multierror.Append(wrappedErr, err) - if key, err := x509.ParsePKCS8PrivateKey(b); err == nil { + if key, err = x509.ParsePKCS8PrivateKey(b); err == nil { return key, nil } wrappedErr = multierror.Append(wrappedErr, err) @@ -72,8 +102,9 @@ func loadPrivateKey(b []byte) (interface{}, error) { func loadPublicKey(b []byte) (interface{}, error) { var wrappedErr error var err error + var key any - if key, err := loadPrivateKey(b); err == nil { + if key, err = loadPrivateKey(b); err == nil { switch k := key.(type) { case *rsa.PrivateKey: return k.Public(), nil @@ -85,12 +116,12 @@ func loadPublicKey(b []byte) (interface{}, error) { } wrappedErr = multierror.Append(wrappedErr, err) - if key, err := x509.ParsePKIXPublicKey(b); err == nil { + if key, err = x509.ParsePKIXPublicKey(b); err == nil { return key, nil } wrappedErr = multierror.Append(wrappedErr, err) - if key, err := x509.ParseCertificate(b); err == nil { + if key, err = x509.ParseCertificate(b); err == nil { return key, nil } wrappedErr = multierror.Append(wrappedErr, err) diff --git a/pkg/hpke/jwks_test.go b/pkg/hpke/jwks_test.go index 82dac60d3..0bcf25d27 100644 --- a/pkg/hpke/jwks_test.go +++ b/pkg/hpke/jwks_test.go @@ -24,7 +24,7 @@ func TestFetchPublicKeyFromJWKS(t *testing.T) { require.NoError(t, err) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - handlers.JWKSHandler("", hpkePrivateKey.PublicKey()).ServeHTTP(w, r) + handlers.JWKSHandler(nil, hpkePrivateKey.PublicKey()).ServeHTTP(w, r) })) t.Cleanup(srv.Close) diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 5c2e5d86e..20057ae53 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -33,7 +33,10 @@ func testOptions(t *testing.T) *config.Options { htpkePrivateKey, err := opts.GetHPKEPrivateKey() require.NoError(t, err) - authnSrv := httptest.NewServer(handlers.JWKSHandler(opts.SigningKey, htpkePrivateKey.PublicKey())) + signingKey, err := opts.GetSigningKey() + require.NoError(t, err) + + authnSrv := httptest.NewServer(handlers.JWKSHandler(signingKey, htpkePrivateKey.PublicKey())) t.Cleanup(authnSrv.Close) opts.AuthenticateURLString = authnSrv.URL