identity: add IdP access and identity token verification for OIDC (#5614)

## Summary
For the generic `oidc` provider, used by `auth0`, `cognito`, `gitlab`,
`google`, `oidc`, `okta`, `onelogin` and `ping`, add support for direct
access and identity token verification. Because Keycloak uses `oidc`
this also adds support for Keycloak.

Access tokens are verified by using the user info endpoint. If a call to
this endpoint succeeds using the access token, that access token is
considered valid and the user info claims will be returned.

Identity tokens are verified by using the jwks endpoint to retrieve the
signing key, and verifying that the identity token was signed with that
key. If the identity token is valid the claims in the JWT will be
returned.

## Related issues
-
[ENG-2312](https://linear.app/pomerium/issue/ENG-2312/core-implement-token-validation-for-keycloak)


## Checklist

- [x] reference any related issues
- [x] updated unit tests
- [x] add appropriate label (`enhancement`, `bug`, `breaking`,
`dependencies`, `ci`)
- [x] ready for review
This commit is contained in:
Caleb Doxsey 2025-05-12 13:45:25 -06:00 committed by GitHub
parent 93b8c93daa
commit f6b344fd9e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 177 additions and 40 deletions

View file

@ -16,7 +16,6 @@ import (
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/identity/identity"
"github.com/pomerium/pomerium/pkg/identity/oauth"
)
@ -110,24 +109,3 @@ func TestVerifyAccessToken(t *testing.T) {
_, err = p.VerifyAccessToken(ctx, rawAccessToken2)
assert.ErrorContains(t, err, "invalid audience")
}
func TestVerifyIdentityToken(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute)
mux := http.NewServeMux()
srv := httptest.NewServer(mux)
p, err := New(ctx, &oauth.Options{
ProviderName: Name,
ProviderURL: srv.URL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
})
require.NoError(t, err)
claims, err := p.VerifyIdentityToken(ctx, "RAW IDENTITY TOKEN")
assert.ErrorIs(t, identity.ErrVerifyIdentityTokenNotSupported, err)
assert.Nil(t, claims)
}

View file

@ -9,6 +9,7 @@ import (
"fmt"
"net/http"
"net/url"
"slices"
"sync"
go_oidc "github.com/coreos/go-oidc/v3/oidc"
@ -16,6 +17,7 @@ import (
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/jwtutil"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/version"
"github.com/pomerium/pomerium/pkg/identity/identity"
@ -49,6 +51,8 @@ type Provider struct {
// to the request flow signin url.
AuthCodeOptions map[string]string
accessTokenAllowedAudiences *[]string
mu sync.Mutex
provider *go_oidc.Provider
}
@ -94,6 +98,7 @@ func New(ctx context.Context, o *oauth.Options, options ...Option) (*Provider, e
return provider.Verifier(&go_oidc.Config{ClientID: o.ClientID})
}),
}, options...)...)
p.accessTokenAllowedAudiences = o.AccessTokenAllowedAudiences
return p, nil
}
@ -362,11 +367,53 @@ func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint,
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyAccessTokenNotSupported
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
pp, err := p.GetProvider()
if err != nil {
return nil, fmt.Errorf("error getting oauth provider: %w", err)
}
// use the access token to call the user info endpoint
userInfo, err := pp.UserInfo(ctx, oauth2.StaticTokenSource(&oauth2.Token{
TokenType: "Bearer",
AccessToken: rawAccessToken,
}))
if err != nil {
return nil, fmt.Errorf("error retrieving user info with access token: %w", err)
}
claims = jwtutil.Claims(map[string]any{})
err = userInfo.Claims(&claims)
if err != nil {
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)
}
if p.accessTokenAllowedAudiences != nil {
if audience, ok := claims["aud"].(string); !ok || !slices.Contains(*p.accessTokenAllowedAudiences, audience) {
return nil, fmt.Errorf("error verifying access token audience claim, invalid audience")
}
}
return claims, nil
}
// VerifyIdentityToken verifies an identity token.
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyIdentityTokenNotSupported
func (p *Provider) VerifyIdentityToken(ctx context.Context, rawIdentityToken string) (claims map[string]any, err error) {
verifier, err := p.GetVerifier()
if err != nil {
return nil, fmt.Errorf("error getting verifier: %w", err)
}
identityToken, err := verifier.Verify(ctx, rawIdentityToken)
if err != nil {
return nil, fmt.Errorf("error verifying identity token: %w", err)
}
claims = jwtutil.Claims(map[string]any{})
err = identityToken.Claims(&claims)
if err != nil {
return nil, fmt.Errorf("error unmarshaling identity token claims: %w", err)
}
return claims, nil
}

View file

@ -1,4 +1,4 @@
package oidc
package oidc_test
import (
"context"
@ -18,7 +18,10 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/identity/oauth"
"github.com/pomerium/pomerium/pkg/identity/oidc"
)
// Claims implements identity.State. (We can't use identity.Claims directly
@ -59,7 +62,7 @@ func TestSignIn(t *testing.T) {
srv = httptest.NewServer(handler)
t.Cleanup(srv.Close)
p, err := New(ctx, &oauth.Options{
p, err := oidc.New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
@ -118,7 +121,7 @@ func TestSignOut(t *testing.T) {
srv = httptest.NewServer(handler)
t.Cleanup(srv.Close)
p, err := New(ctx, &oauth.Options{
p, err := oidc.New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
@ -219,7 +222,7 @@ func TestAuthenticate(t *testing.T) {
srv = httptest.NewServer(handler)
t.Cleanup(srv.Close)
p, err := New(ctx, &oauth.Options{
p, err := oidc.New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
@ -313,7 +316,7 @@ func TestRefresh_WithIDToken(t *testing.T) {
srv = httptest.NewServer(handler)
t.Cleanup(srv.Close)
p, err := New(ctx, &oauth.Options{
p, err := oidc.New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
@ -384,7 +387,7 @@ func TestRefresh_WithoutIDToken(t *testing.T) {
srv = httptest.NewServer(handler)
t.Cleanup(srv.Close)
p, err := New(ctx, &oauth.Options{
p, err := oidc.New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
@ -439,7 +442,7 @@ func TestRevoke(t *testing.T) {
redirectURL, err := url.Parse(srv.URL)
require.NoError(t, err)
p, err := New(ctx, &oauth.Options{
p, err := oidc.New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
@ -452,7 +455,7 @@ func TestRevoke(t *testing.T) {
AccessToken: "ACCESS_TOKEN",
}))
assert.Equal(t, ErrMissingAccessToken, p.Revoke(ctx, nil))
assert.Equal(t, oidc.ErrMissingAccessToken, p.Revoke(ctx, nil))
}
func TestUnsupportedFeatures(t *testing.T) {
@ -479,7 +482,7 @@ func TestUnsupportedFeatures(t *testing.T) {
srv = httptest.NewServer(handler)
t.Cleanup(srv.Close)
p, err := New(ctx, &oauth.Options{
p, err := oidc.New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
@ -490,19 +493,19 @@ func TestUnsupportedFeatures(t *testing.T) {
rec := httptest.NewRecorder()
err = p.SignOut(rec, httptest.NewRequest(http.MethodGet, "/", nil), "ID_TOKEN", "", "")
assert.Equal(t, ErrSignoutNotImplemented, err)
assert.Equal(t, oidc.ErrSignoutNotImplemented, err)
err = p.Revoke(ctx, &oauth2.Token{
AccessToken: "ACCESS_TOKEN",
})
assert.Equal(t, ErrRevokeNotImplemented, err)
assert.Equal(t, oidc.ErrRevokeNotImplemented, err)
_, err = New(ctx, &oauth.Options{})
assert.Equal(t, ErrMissingProviderURL, err)
_, err = oidc.New(ctx, &oauth.Options{})
assert.Equal(t, oidc.ErrMissingProviderURL, err)
}
func TestName(t *testing.T) {
assert.Equal(t, "oidc", (*Provider)(nil).Name())
assert.Equal(t, "oidc", (*oidc.Provider)(nil).Name())
}
// setupJWTSigning returns a JWT signer and a corresponding JWKS for signature verification.
@ -523,3 +526,112 @@ func setupJWTSigning(t *testing.T) (jose.Signer, jose.JSONWebKeySet) {
}
return jwtSigner, jwks
}
func TestVerifyAccessToken(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute)
var srv *httptest.Server
m := http.NewServeMux()
m.HandleFunc("GET /.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) {
baseURL, err := url.Parse(srv.URL)
require.NoError(t, err)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]any{
"issuer": baseURL.String(),
"userinfo_endpoint": baseURL.ResolveReference(&url.URL{
Path: "/userinfo",
}).String(),
})
})
m.HandleFunc("GET /userinfo", func(w http.ResponseWriter, r *http.Request) {
assert.Equal(t, "Bearer ACCESS_TOKEN", r.Header.Get("Authorization"))
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]any{
"aud": "AUDIENCE",
"sub": "SUBJECT",
})
})
srv = httptest.NewServer(m)
p, err := oidc.New(ctx, &oauth.Options{
ProviderURL: srv.URL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
RedirectURL: urlutil.MustParseAndValidateURL("https://www.example.com"),
})
require.NoError(t, err)
claims, err := p.VerifyAccessToken(ctx, "ACCESS_TOKEN")
require.NoError(t, err)
assert.Equal(t, map[string]any{
"aud": "AUDIENCE",
"sub": "SUBJECT",
}, claims)
}
func TestVerifyIdentityToken(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute)
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
jwtSigner, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privateKey}, nil)
require.NoError(t, err)
iat := time.Now().Unix()
exp := iat + 3600
var srv *httptest.Server
m := http.NewServeMux()
m.HandleFunc("GET /.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) {
baseURL, err := url.Parse(srv.URL)
require.NoError(t, err)
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]any{
"issuer": baseURL.String(),
"jwks_uri": baseURL.ResolveReference(&url.URL{
Path: "/jwks",
}).String(),
})
})
m.HandleFunc("GET /jwks", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{Key: privateKey.Public(), Use: "sig", Algorithm: "RS256"},
},
})
})
srv = httptest.NewServer(m)
rawIdentityToken1, err := jwt.Signed(jwtSigner).Claims(map[string]any{
"iss": srv.URL,
"aud": "CLIENT_ID",
"sub": "subject",
"exp": exp,
"iat": iat,
}).CompactSerialize()
require.NoError(t, err)
p, err := oidc.New(ctx, &oauth.Options{
ProviderURL: srv.URL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
RedirectURL: urlutil.MustParseAndValidateURL("https://www.example.com"),
})
require.NoError(t, err)
claims, err := p.VerifyIdentityToken(ctx, rawIdentityToken1)
require.NoError(t, err)
delete(claims, "iat")
delete(claims, "exp")
assert.Equal(t, map[string]any{
"aud": "CLIENT_ID",
"iss": srv.URL,
"sub": "subject",
}, claims)
}