diff --git a/pkg/identity/oidc/azure/microsoft_test.go b/pkg/identity/oidc/azure/microsoft_test.go index 4fcf6559f..3ac641975 100644 --- a/pkg/identity/oidc/azure/microsoft_test.go +++ b/pkg/identity/oidc/azure/microsoft_test.go @@ -16,7 +16,6 @@ import ( "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/internal/testutil" - "github.com/pomerium/pomerium/pkg/identity/identity" "github.com/pomerium/pomerium/pkg/identity/oauth" ) @@ -110,24 +109,3 @@ func TestVerifyAccessToken(t *testing.T) { _, err = p.VerifyAccessToken(ctx, rawAccessToken2) assert.ErrorContains(t, err, "invalid audience") } - -func TestVerifyIdentityToken(t *testing.T) { - t.Parallel() - - ctx := testutil.GetContext(t, time.Minute) - - mux := http.NewServeMux() - srv := httptest.NewServer(mux) - - p, err := New(ctx, &oauth.Options{ - ProviderName: Name, - ProviderURL: srv.URL, - ClientID: "CLIENT_ID", - ClientSecret: "CLIENT_SECRET", - }) - require.NoError(t, err) - - claims, err := p.VerifyIdentityToken(ctx, "RAW IDENTITY TOKEN") - assert.ErrorIs(t, identity.ErrVerifyIdentityTokenNotSupported, err) - assert.Nil(t, claims) -} diff --git a/pkg/identity/oidc/oidc.go b/pkg/identity/oidc/oidc.go index 75950d013..d78e4c6a8 100644 --- a/pkg/identity/oidc/oidc.go +++ b/pkg/identity/oidc/oidc.go @@ -9,6 +9,7 @@ import ( "fmt" "net/http" "net/url" + "slices" "sync" go_oidc "github.com/coreos/go-oidc/v3/oidc" @@ -16,6 +17,7 @@ import ( "golang.org/x/oauth2" "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/jwtutil" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/version" "github.com/pomerium/pomerium/pkg/identity/identity" @@ -49,6 +51,8 @@ type Provider struct { // to the request flow signin url. AuthCodeOptions map[string]string + accessTokenAllowedAudiences *[]string + mu sync.Mutex provider *go_oidc.Provider } @@ -94,6 +98,7 @@ func New(ctx context.Context, o *oauth.Options, options ...Option) (*Provider, e return provider.Verifier(&go_oidc.Config{ClientID: o.ClientID}) }), }, options...)...) + p.accessTokenAllowedAudiences = o.AccessTokenAllowedAudiences return p, nil } @@ -362,11 +367,53 @@ func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, } // VerifyAccessToken verifies an access token. -func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) { - return nil, identity.ErrVerifyAccessTokenNotSupported +func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) { + pp, err := p.GetProvider() + if err != nil { + return nil, fmt.Errorf("error getting oauth provider: %w", err) + } + + // use the access token to call the user info endpoint + userInfo, err := pp.UserInfo(ctx, oauth2.StaticTokenSource(&oauth2.Token{ + TokenType: "Bearer", + AccessToken: rawAccessToken, + })) + if err != nil { + return nil, fmt.Errorf("error retrieving user info with access token: %w", err) + } + + claims = jwtutil.Claims(map[string]any{}) + err = userInfo.Claims(&claims) + if err != nil { + return nil, fmt.Errorf("error unmarshaling access token claims: %w", err) + } + + if p.accessTokenAllowedAudiences != nil { + if audience, ok := claims["aud"].(string); !ok || !slices.Contains(*p.accessTokenAllowedAudiences, audience) { + return nil, fmt.Errorf("error verifying access token audience claim, invalid audience") + } + } + + return claims, nil } // VerifyIdentityToken verifies an identity token. -func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) { - return nil, identity.ErrVerifyIdentityTokenNotSupported +func (p *Provider) VerifyIdentityToken(ctx context.Context, rawIdentityToken string) (claims map[string]any, err error) { + verifier, err := p.GetVerifier() + if err != nil { + return nil, fmt.Errorf("error getting verifier: %w", err) + } + + identityToken, err := verifier.Verify(ctx, rawIdentityToken) + if err != nil { + return nil, fmt.Errorf("error verifying identity token: %w", err) + } + + claims = jwtutil.Claims(map[string]any{}) + err = identityToken.Claims(&claims) + if err != nil { + return nil, fmt.Errorf("error unmarshaling identity token claims: %w", err) + } + + return claims, nil } diff --git a/pkg/identity/oidc/oidc_test.go b/pkg/identity/oidc/oidc_test.go index 42e4c716e..54d8177a6 100644 --- a/pkg/identity/oidc/oidc_test.go +++ b/pkg/identity/oidc/oidc_test.go @@ -1,4 +1,4 @@ -package oidc +package oidc_test import ( "context" @@ -18,7 +18,10 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/oauth2" + "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/identity/oauth" + "github.com/pomerium/pomerium/pkg/identity/oidc" ) // Claims implements identity.State. (We can't use identity.Claims directly @@ -59,7 +62,7 @@ func TestSignIn(t *testing.T) { srv = httptest.NewServer(handler) t.Cleanup(srv.Close) - p, err := New(ctx, &oauth.Options{ + p, err := oidc.New(ctx, &oauth.Options{ ProviderURL: srv.URL, RedirectURL: redirectURL, ClientID: "CLIENT_ID", @@ -118,7 +121,7 @@ func TestSignOut(t *testing.T) { srv = httptest.NewServer(handler) t.Cleanup(srv.Close) - p, err := New(ctx, &oauth.Options{ + p, err := oidc.New(ctx, &oauth.Options{ ProviderURL: srv.URL, RedirectURL: redirectURL, ClientID: "CLIENT_ID", @@ -219,7 +222,7 @@ func TestAuthenticate(t *testing.T) { srv = httptest.NewServer(handler) t.Cleanup(srv.Close) - p, err := New(ctx, &oauth.Options{ + p, err := oidc.New(ctx, &oauth.Options{ ProviderURL: srv.URL, RedirectURL: redirectURL, ClientID: "CLIENT_ID", @@ -313,7 +316,7 @@ func TestRefresh_WithIDToken(t *testing.T) { srv = httptest.NewServer(handler) t.Cleanup(srv.Close) - p, err := New(ctx, &oauth.Options{ + p, err := oidc.New(ctx, &oauth.Options{ ProviderURL: srv.URL, RedirectURL: redirectURL, ClientID: "CLIENT_ID", @@ -384,7 +387,7 @@ func TestRefresh_WithoutIDToken(t *testing.T) { srv = httptest.NewServer(handler) t.Cleanup(srv.Close) - p, err := New(ctx, &oauth.Options{ + p, err := oidc.New(ctx, &oauth.Options{ ProviderURL: srv.URL, RedirectURL: redirectURL, ClientID: "CLIENT_ID", @@ -439,7 +442,7 @@ func TestRevoke(t *testing.T) { redirectURL, err := url.Parse(srv.URL) require.NoError(t, err) - p, err := New(ctx, &oauth.Options{ + p, err := oidc.New(ctx, &oauth.Options{ ProviderURL: srv.URL, RedirectURL: redirectURL, ClientID: "CLIENT_ID", @@ -452,7 +455,7 @@ func TestRevoke(t *testing.T) { AccessToken: "ACCESS_TOKEN", })) - assert.Equal(t, ErrMissingAccessToken, p.Revoke(ctx, nil)) + assert.Equal(t, oidc.ErrMissingAccessToken, p.Revoke(ctx, nil)) } func TestUnsupportedFeatures(t *testing.T) { @@ -479,7 +482,7 @@ func TestUnsupportedFeatures(t *testing.T) { srv = httptest.NewServer(handler) t.Cleanup(srv.Close) - p, err := New(ctx, &oauth.Options{ + p, err := oidc.New(ctx, &oauth.Options{ ProviderURL: srv.URL, RedirectURL: redirectURL, ClientID: "CLIENT_ID", @@ -490,19 +493,19 @@ func TestUnsupportedFeatures(t *testing.T) { rec := httptest.NewRecorder() err = p.SignOut(rec, httptest.NewRequest(http.MethodGet, "/", nil), "ID_TOKEN", "", "") - assert.Equal(t, ErrSignoutNotImplemented, err) + assert.Equal(t, oidc.ErrSignoutNotImplemented, err) err = p.Revoke(ctx, &oauth2.Token{ AccessToken: "ACCESS_TOKEN", }) - assert.Equal(t, ErrRevokeNotImplemented, err) + assert.Equal(t, oidc.ErrRevokeNotImplemented, err) - _, err = New(ctx, &oauth.Options{}) - assert.Equal(t, ErrMissingProviderURL, err) + _, err = oidc.New(ctx, &oauth.Options{}) + assert.Equal(t, oidc.ErrMissingProviderURL, err) } func TestName(t *testing.T) { - assert.Equal(t, "oidc", (*Provider)(nil).Name()) + assert.Equal(t, "oidc", (*oidc.Provider)(nil).Name()) } // setupJWTSigning returns a JWT signer and a corresponding JWKS for signature verification. @@ -523,3 +526,112 @@ func setupJWTSigning(t *testing.T) (jose.Signer, jose.JSONWebKeySet) { } return jwtSigner, jwks } + +func TestVerifyAccessToken(t *testing.T) { + t.Parallel() + + ctx := testutil.GetContext(t, time.Minute) + + var srv *httptest.Server + m := http.NewServeMux() + m.HandleFunc("GET /.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) { + baseURL, err := url.Parse(srv.URL) + require.NoError(t, err) + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + json.NewEncoder(w).Encode(map[string]any{ + "issuer": baseURL.String(), + "userinfo_endpoint": baseURL.ResolveReference(&url.URL{ + Path: "/userinfo", + }).String(), + }) + }) + m.HandleFunc("GET /userinfo", func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "Bearer ACCESS_TOKEN", r.Header.Get("Authorization")) + w.Header().Set("Content-Type", "application/json; charset=utf-8") + json.NewEncoder(w).Encode(map[string]any{ + "aud": "AUDIENCE", + "sub": "SUBJECT", + }) + }) + srv = httptest.NewServer(m) + + p, err := oidc.New(ctx, &oauth.Options{ + ProviderURL: srv.URL, + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + RedirectURL: urlutil.MustParseAndValidateURL("https://www.example.com"), + }) + require.NoError(t, err) + + claims, err := p.VerifyAccessToken(ctx, "ACCESS_TOKEN") + require.NoError(t, err) + assert.Equal(t, map[string]any{ + "aud": "AUDIENCE", + "sub": "SUBJECT", + }, claims) +} + +func TestVerifyIdentityToken(t *testing.T) { + t.Parallel() + + ctx := testutil.GetContext(t, time.Minute) + + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + require.NoError(t, err) + jwtSigner, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privateKey}, nil) + require.NoError(t, err) + iat := time.Now().Unix() + exp := iat + 3600 + + var srv *httptest.Server + m := http.NewServeMux() + m.HandleFunc("GET /.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) { + baseURL, err := url.Parse(srv.URL) + require.NoError(t, err) + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + json.NewEncoder(w).Encode(map[string]any{ + "issuer": baseURL.String(), + "jwks_uri": baseURL.ResolveReference(&url.URL{ + Path: "/jwks", + }).String(), + }) + }) + m.HandleFunc("GET /jwks", func(w http.ResponseWriter, _ *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + json.NewEncoder(w).Encode(jose.JSONWebKeySet{ + Keys: []jose.JSONWebKey{ + {Key: privateKey.Public(), Use: "sig", Algorithm: "RS256"}, + }, + }) + }) + srv = httptest.NewServer(m) + + rawIdentityToken1, err := jwt.Signed(jwtSigner).Claims(map[string]any{ + "iss": srv.URL, + "aud": "CLIENT_ID", + "sub": "subject", + "exp": exp, + "iat": iat, + }).CompactSerialize() + require.NoError(t, err) + + p, err := oidc.New(ctx, &oauth.Options{ + ProviderURL: srv.URL, + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + RedirectURL: urlutil.MustParseAndValidateURL("https://www.example.com"), + }) + require.NoError(t, err) + + claims, err := p.VerifyIdentityToken(ctx, rawIdentityToken1) + require.NoError(t, err) + delete(claims, "iat") + delete(claims, "exp") + assert.Equal(t, map[string]any{ + "aud": "CLIENT_ID", + "iss": srv.URL, + "sub": "subject", + }, claims) +}