mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-06 04:42:56 +02:00
identity: add IdP access and identity token verification for OIDC (#5614)
## Summary For the generic `oidc` provider, used by `auth0`, `cognito`, `gitlab`, `google`, `oidc`, `okta`, `onelogin` and `ping`, add support for direct access and identity token verification. Because Keycloak uses `oidc` this also adds support for Keycloak. Access tokens are verified by using the user info endpoint. If a call to this endpoint succeeds using the access token, that access token is considered valid and the user info claims will be returned. Identity tokens are verified by using the jwks endpoint to retrieve the signing key, and verifying that the identity token was signed with that key. If the identity token is valid the claims in the JWT will be returned. ## Related issues - [ENG-2312](https://linear.app/pomerium/issue/ENG-2312/core-implement-token-validation-for-keycloak) ## Checklist - [x] reference any related issues - [x] updated unit tests - [x] add appropriate label (`enhancement`, `bug`, `breaking`, `dependencies`, `ci`) - [x] ready for review
This commit is contained in:
parent
93b8c93daa
commit
f6b344fd9e
3 changed files with 177 additions and 40 deletions
|
@ -16,7 +16,6 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/identity/identity"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
)
|
||||
|
||||
|
@ -110,24 +109,3 @@ func TestVerifyAccessToken(t *testing.T) {
|
|||
_, err = p.VerifyAccessToken(ctx, rawAccessToken2)
|
||||
assert.ErrorContains(t, err, "invalid audience")
|
||||
}
|
||||
|
||||
func TestVerifyIdentityToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
|
||||
mux := http.NewServeMux()
|
||||
srv := httptest.NewServer(mux)
|
||||
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
ProviderName: Name,
|
||||
ProviderURL: srv.URL,
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := p.VerifyIdentityToken(ctx, "RAW IDENTITY TOKEN")
|
||||
assert.ErrorIs(t, identity.ErrVerifyIdentityTokenNotSupported, err)
|
||||
assert.Nil(t, claims)
|
||||
}
|
||||
|
|
|
@ -9,6 +9,7 @@ import (
|
|||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
go_oidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
|
@ -16,6 +17,7 @@ import (
|
|||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
"github.com/pomerium/pomerium/pkg/identity/identity"
|
||||
|
@ -49,6 +51,8 @@ type Provider struct {
|
|||
// to the request flow signin url.
|
||||
AuthCodeOptions map[string]string
|
||||
|
||||
accessTokenAllowedAudiences *[]string
|
||||
|
||||
mu sync.Mutex
|
||||
provider *go_oidc.Provider
|
||||
}
|
||||
|
@ -94,6 +98,7 @@ func New(ctx context.Context, o *oauth.Options, options ...Option) (*Provider, e
|
|||
return provider.Verifier(&go_oidc.Config{ClientID: o.ClientID})
|
||||
}),
|
||||
}, options...)...)
|
||||
p.accessTokenAllowedAudiences = o.AccessTokenAllowedAudiences
|
||||
return p, nil
|
||||
}
|
||||
|
||||
|
@ -362,11 +367,53 @@ func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint,
|
|||
}
|
||||
|
||||
// VerifyAccessToken verifies an access token.
|
||||
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyAccessTokenNotSupported
|
||||
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
|
||||
pp, err := p.GetProvider()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting oauth provider: %w", err)
|
||||
}
|
||||
|
||||
// use the access token to call the user info endpoint
|
||||
userInfo, err := pp.UserInfo(ctx, oauth2.StaticTokenSource(&oauth2.Token{
|
||||
TokenType: "Bearer",
|
||||
AccessToken: rawAccessToken,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error retrieving user info with access token: %w", err)
|
||||
}
|
||||
|
||||
claims = jwtutil.Claims(map[string]any{})
|
||||
err = userInfo.Claims(&claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)
|
||||
}
|
||||
|
||||
if p.accessTokenAllowedAudiences != nil {
|
||||
if audience, ok := claims["aud"].(string); !ok || !slices.Contains(*p.accessTokenAllowedAudiences, audience) {
|
||||
return nil, fmt.Errorf("error verifying access token audience claim, invalid audience")
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// VerifyIdentityToken verifies an identity token.
|
||||
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
|
||||
return nil, identity.ErrVerifyIdentityTokenNotSupported
|
||||
func (p *Provider) VerifyIdentityToken(ctx context.Context, rawIdentityToken string) (claims map[string]any, err error) {
|
||||
verifier, err := p.GetVerifier()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting verifier: %w", err)
|
||||
}
|
||||
|
||||
identityToken, err := verifier.Verify(ctx, rawIdentityToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying identity token: %w", err)
|
||||
}
|
||||
|
||||
claims = jwtutil.Claims(map[string]any{})
|
||||
err = identityToken.Claims(&claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling identity token claims: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
package oidc
|
||||
package oidc_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
@ -18,7 +18,10 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oidc"
|
||||
)
|
||||
|
||||
// Claims implements identity.State. (We can't use identity.Claims directly
|
||||
|
@ -59,7 +62,7 @@ func TestSignIn(t *testing.T) {
|
|||
srv = httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
p, err := oidc.New(ctx, &oauth.Options{
|
||||
ProviderURL: srv.URL,
|
||||
RedirectURL: redirectURL,
|
||||
ClientID: "CLIENT_ID",
|
||||
|
@ -118,7 +121,7 @@ func TestSignOut(t *testing.T) {
|
|||
srv = httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
p, err := oidc.New(ctx, &oauth.Options{
|
||||
ProviderURL: srv.URL,
|
||||
RedirectURL: redirectURL,
|
||||
ClientID: "CLIENT_ID",
|
||||
|
@ -219,7 +222,7 @@ func TestAuthenticate(t *testing.T) {
|
|||
srv = httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
p, err := oidc.New(ctx, &oauth.Options{
|
||||
ProviderURL: srv.URL,
|
||||
RedirectURL: redirectURL,
|
||||
ClientID: "CLIENT_ID",
|
||||
|
@ -313,7 +316,7 @@ func TestRefresh_WithIDToken(t *testing.T) {
|
|||
srv = httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
p, err := oidc.New(ctx, &oauth.Options{
|
||||
ProviderURL: srv.URL,
|
||||
RedirectURL: redirectURL,
|
||||
ClientID: "CLIENT_ID",
|
||||
|
@ -384,7 +387,7 @@ func TestRefresh_WithoutIDToken(t *testing.T) {
|
|||
srv = httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
p, err := oidc.New(ctx, &oauth.Options{
|
||||
ProviderURL: srv.URL,
|
||||
RedirectURL: redirectURL,
|
||||
ClientID: "CLIENT_ID",
|
||||
|
@ -439,7 +442,7 @@ func TestRevoke(t *testing.T) {
|
|||
redirectURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
p, err := oidc.New(ctx, &oauth.Options{
|
||||
ProviderURL: srv.URL,
|
||||
RedirectURL: redirectURL,
|
||||
ClientID: "CLIENT_ID",
|
||||
|
@ -452,7 +455,7 @@ func TestRevoke(t *testing.T) {
|
|||
AccessToken: "ACCESS_TOKEN",
|
||||
}))
|
||||
|
||||
assert.Equal(t, ErrMissingAccessToken, p.Revoke(ctx, nil))
|
||||
assert.Equal(t, oidc.ErrMissingAccessToken, p.Revoke(ctx, nil))
|
||||
}
|
||||
|
||||
func TestUnsupportedFeatures(t *testing.T) {
|
||||
|
@ -479,7 +482,7 @@ func TestUnsupportedFeatures(t *testing.T) {
|
|||
srv = httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
p, err := oidc.New(ctx, &oauth.Options{
|
||||
ProviderURL: srv.URL,
|
||||
RedirectURL: redirectURL,
|
||||
ClientID: "CLIENT_ID",
|
||||
|
@ -490,19 +493,19 @@ func TestUnsupportedFeatures(t *testing.T) {
|
|||
|
||||
rec := httptest.NewRecorder()
|
||||
err = p.SignOut(rec, httptest.NewRequest(http.MethodGet, "/", nil), "ID_TOKEN", "", "")
|
||||
assert.Equal(t, ErrSignoutNotImplemented, err)
|
||||
assert.Equal(t, oidc.ErrSignoutNotImplemented, err)
|
||||
|
||||
err = p.Revoke(ctx, &oauth2.Token{
|
||||
AccessToken: "ACCESS_TOKEN",
|
||||
})
|
||||
assert.Equal(t, ErrRevokeNotImplemented, err)
|
||||
assert.Equal(t, oidc.ErrRevokeNotImplemented, err)
|
||||
|
||||
_, err = New(ctx, &oauth.Options{})
|
||||
assert.Equal(t, ErrMissingProviderURL, err)
|
||||
_, err = oidc.New(ctx, &oauth.Options{})
|
||||
assert.Equal(t, oidc.ErrMissingProviderURL, err)
|
||||
}
|
||||
|
||||
func TestName(t *testing.T) {
|
||||
assert.Equal(t, "oidc", (*Provider)(nil).Name())
|
||||
assert.Equal(t, "oidc", (*oidc.Provider)(nil).Name())
|
||||
}
|
||||
|
||||
// setupJWTSigning returns a JWT signer and a corresponding JWKS for signature verification.
|
||||
|
@ -523,3 +526,112 @@ func setupJWTSigning(t *testing.T) (jose.Signer, jose.JSONWebKeySet) {
|
|||
}
|
||||
return jwtSigner, jwks
|
||||
}
|
||||
|
||||
func TestVerifyAccessToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
|
||||
var srv *httptest.Server
|
||||
m := http.NewServeMux()
|
||||
m.HandleFunc("GET /.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) {
|
||||
baseURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"issuer": baseURL.String(),
|
||||
"userinfo_endpoint": baseURL.ResolveReference(&url.URL{
|
||||
Path: "/userinfo",
|
||||
}).String(),
|
||||
})
|
||||
})
|
||||
m.HandleFunc("GET /userinfo", func(w http.ResponseWriter, r *http.Request) {
|
||||
assert.Equal(t, "Bearer ACCESS_TOKEN", r.Header.Get("Authorization"))
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"aud": "AUDIENCE",
|
||||
"sub": "SUBJECT",
|
||||
})
|
||||
})
|
||||
srv = httptest.NewServer(m)
|
||||
|
||||
p, err := oidc.New(ctx, &oauth.Options{
|
||||
ProviderURL: srv.URL,
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
RedirectURL: urlutil.MustParseAndValidateURL("https://www.example.com"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := p.VerifyAccessToken(ctx, "ACCESS_TOKEN")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, map[string]any{
|
||||
"aud": "AUDIENCE",
|
||||
"sub": "SUBJECT",
|
||||
}, claims)
|
||||
}
|
||||
|
||||
func TestVerifyIdentityToken(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := testutil.GetContext(t, time.Minute)
|
||||
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
jwtSigner, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privateKey}, nil)
|
||||
require.NoError(t, err)
|
||||
iat := time.Now().Unix()
|
||||
exp := iat + 3600
|
||||
|
||||
var srv *httptest.Server
|
||||
m := http.NewServeMux()
|
||||
m.HandleFunc("GET /.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) {
|
||||
baseURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"issuer": baseURL.String(),
|
||||
"jwks_uri": baseURL.ResolveReference(&url.URL{
|
||||
Path: "/jwks",
|
||||
}).String(),
|
||||
})
|
||||
})
|
||||
m.HandleFunc("GET /jwks", func(w http.ResponseWriter, _ *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
json.NewEncoder(w).Encode(jose.JSONWebKeySet{
|
||||
Keys: []jose.JSONWebKey{
|
||||
{Key: privateKey.Public(), Use: "sig", Algorithm: "RS256"},
|
||||
},
|
||||
})
|
||||
})
|
||||
srv = httptest.NewServer(m)
|
||||
|
||||
rawIdentityToken1, err := jwt.Signed(jwtSigner).Claims(map[string]any{
|
||||
"iss": srv.URL,
|
||||
"aud": "CLIENT_ID",
|
||||
"sub": "subject",
|
||||
"exp": exp,
|
||||
"iat": iat,
|
||||
}).CompactSerialize()
|
||||
require.NoError(t, err)
|
||||
|
||||
p, err := oidc.New(ctx, &oauth.Options{
|
||||
ProviderURL: srv.URL,
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
RedirectURL: urlutil.MustParseAndValidateURL("https://www.example.com"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := p.VerifyIdentityToken(ctx, rawIdentityToken1)
|
||||
require.NoError(t, err)
|
||||
delete(claims, "iat")
|
||||
delete(claims, "exp")
|
||||
assert.Equal(t, map[string]any{
|
||||
"aud": "CLIENT_ID",
|
||||
"iss": srv.URL,
|
||||
"sub": "subject",
|
||||
}, claims)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue