authorize: support authenticating with idp tokens (#5484)

* identity: add support for verifying access and identity tokens

* allow overriding with policy option

* authenticate: add verify endpoints

* wip

* implement session creation

* add verify test

* implement idp token login

* fix tests

* add pr permission

* make session ids route-specific

* rename method

* add test

* add access token test

* test for newUserFromIDPClaims

* more tests

* make the session id per-idp

* use type for

* add test

* remove nil checks
This commit is contained in:
Caleb Doxsey 2025-02-18 13:02:06 -07:00 committed by GitHub
parent 6e22b7a19a
commit b9fd926618
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 2791 additions and 885 deletions

9
pkg/identity/errors.go Normal file
View file

@ -0,0 +1,9 @@
package identity
import "github.com/pomerium/pomerium/pkg/identity/identity"
// re-exported errors
var (
ErrVerifyAccessTokenNotSupported = identity.ErrVerifyAccessTokenNotSupported
ErrVerifyIdentityTokenNotSupported = identity.ErrVerifyIdentityTokenNotSupported
)

View file

@ -0,0 +1,9 @@
package identity
import "errors"
// well known errors
var (
ErrVerifyAccessTokenNotSupported = errors.New("identity: access token verification not supported")
ErrVerifyIdentityTokenNotSupported = errors.New("identity: identity token verification not supported")
)

View file

@ -2,6 +2,7 @@ package identity
import (
"context"
"fmt"
"net/http"
"golang.org/x/oauth2"
@ -55,3 +56,13 @@ func (mp MockProvider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ s
func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string) error {
return mp.SignInError
}
// VerifyAccessToken verifies an access token.
func (mp MockProvider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, fmt.Errorf("VerifyAccessToken not implemented")
}
// VerifyIdentityToken verifies an identity token.
func (mp MockProvider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, fmt.Errorf("VerifyIdentityToken not implemented")
}

View file

@ -182,3 +182,13 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string)
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
return oidc.ErrSignoutNotImplemented
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyAccessTokenNotSupported
}
// VerifyIdentityToken verifies an identity token.
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyIdentityTokenNotSupported
}

View file

@ -256,3 +256,13 @@ func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string)
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
return oidc.ErrSignoutNotImplemented
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyAccessTokenNotSupported
}
// VerifyIdentityToken verifies an identity token.
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyIdentityTokenNotSupported
}

View file

@ -3,7 +3,9 @@
// authorization with Bearer JWT.
package oauth
import "net/url"
import (
"net/url"
)
// Options contains the fields required for an OAuth 2.0 (inc. OIDC) auth flow.
//
@ -29,4 +31,7 @@ type Options struct {
// AuthCodeOptions specifies additional key value pairs query params to add
// to the request flow signin url.
AuthCodeOptions map[string]string
// When set validates the audience in access tokens.
AccessTokenAllowedAudiences *[]string
}

View file

@ -10,10 +10,14 @@ import (
"fmt"
"io"
"net/http"
"slices"
"strings"
go_oidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/google/uuid"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/jwtutil"
"github.com/pomerium/pomerium/pkg/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/pkg/identity/oidc"
)
@ -37,11 +41,13 @@ var defaultAuthCodeOptions = map[string]string{"prompt": "select_account"}
// Provider is an Azure implementation of the Authenticator interface.
type Provider struct {
*pom_oidc.Provider
accessTokenAllowedAudiences *[]string
}
// New instantiates an OpenID Connect (OIDC) provider for Azure.
func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
var p Provider
p.accessTokenAllowedAudiences = o.AccessTokenAllowedAudiences
var err error
if o.ProviderURL == "" {
o.ProviderURL = defaultProviderURL
@ -73,6 +79,59 @@ func (p *Provider) Name() string {
return Name
}
// VerifyAccessToken verifies a raw access token.
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
pp, err := p.GetProvider()
if err != nil {
return nil, fmt.Errorf("error getting oidc provider: %w", err)
}
// azure access tokens are JWTs signed with the same keys as identity tokens
verifier := pp.Verifier(&go_oidc.Config{
SkipClientIDCheck: true,
SkipIssuerCheck: true, // checked later
})
token, err := verifier.Verify(ctx, rawAccessToken)
if err != nil {
return nil, fmt.Errorf("error verifying access token: %w", err)
}
claims = jwtutil.Claims(map[string]any{})
err = token.Claims(&claims)
if err != nil {
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)
}
// verify audience
if p.accessTokenAllowedAudiences != nil {
if audience, ok := claims["aud"].(string); !ok || !slices.Contains(*p.accessTokenAllowedAudiences, audience) {
return nil, fmt.Errorf("error verifying access token audience claim, invalid audience")
}
}
err = verifyIssuer(pp, claims)
if err != nil {
return nil, fmt.Errorf("error verifying access token issuer claim: %w", err)
}
if scope, ok := claims["scp"].(string); ok && slices.Contains(strings.Fields(scope), "openid") {
userInfo, err := pp.UserInfo(ctx, oauth2.StaticTokenSource(&oauth2.Token{
TokenType: "Bearer",
AccessToken: rawAccessToken,
}))
if err != nil {
return nil, fmt.Errorf("error calling user info endpoint: %w", err)
}
err = userInfo.Claims(claims)
if err != nil {
return nil, fmt.Errorf("error unmarshaling user info claims: %w", err)
}
}
return claims, nil
}
// newProvider overrides the default round tripper for well-known endpoint call that happens
// on new provider registration.
// By default, the "common" (both public and private domains) responds with
@ -128,3 +187,55 @@ func (transport *wellKnownConfiguration) RoundTrip(req *http.Request) (*http.Res
res.Body = io.NopCloser(bytes.NewReader(bs))
return res, nil
}
const (
v1IssuerPrefix = "https://sts.windows.net/"
v1IssuerSuffix = "/"
v2IssuerPrefix = "https://login.microsoftonline.com/"
v2IssuerSuffix = "/v2.0"
)
func verifyIssuer(pp *go_oidc.Provider, claims map[string]any) error {
tenantID, ok := getTenantIDFromURL(pp.Endpoint().TokenURL)
if !ok {
return fmt.Errorf("failed to find tenant id")
}
iss, ok := claims["iss"].(string)
if !ok {
return fmt.Errorf("missing issuer claim")
}
if !(iss == v1IssuerPrefix+tenantID+v1IssuerSuffix || iss == v2IssuerPrefix+tenantID+v2IssuerSuffix) {
return fmt.Errorf("invalid issuer: %s", iss)
}
return nil
}
func getTenantIDFromURL(rawTokenURL string) (string, bool) {
// URLs look like:
// - https://login.microsoftonline.com/f42bce3b-671c-4162-b24c-00ecc7641897/v2.0
// Or:
// - https://sts.windows.net/f42bce3b-671c-4162-b24c-00ecc7641897/
for _, prefix := range []string{v1IssuerPrefix, v2IssuerPrefix} {
path, ok := strings.CutPrefix(rawTokenURL, prefix)
if !ok {
continue
}
idx := strings.Index(path, "/")
if idx <= 0 {
continue
}
rawTenantID := path[:idx]
if _, err := uuid.Parse(rawTenantID); err != nil {
continue
}
return rawTenantID, true
}
return "", false
}

View file

@ -2,15 +2,27 @@ package azure
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/identity/identity"
"github.com/pomerium/pomerium/pkg/identity/oauth"
)
func TestAuthCodeOptions(t *testing.T) {
t.Parallel()
var options oauth.Options
p, err := New(context.Background(), &options)
require.NoError(t, err)
@ -21,3 +33,101 @@ func TestAuthCodeOptions(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, map[string]string{}, p.AuthCodeOptions)
}
func TestVerifyAccessToken(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute)
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
jwtSigner, err := jose.NewSigner(jose.SigningKey{Algorithm: jose.RS256, Key: privateKey}, nil)
require.NoError(t, err)
iat := time.Now().Unix()
exp := iat + 3600
rawAccessToken1, err := jwt.Signed(jwtSigner).Claims(map[string]any{
"iss": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/",
"aud": "https://client.example.com",
"sub": "subject",
"exp": exp,
"iat": iat,
}).CompactSerialize()
require.NoError(t, err)
rawAccessToken2, err := jwt.Signed(jwtSigner).Claims(map[string]any{
"iss": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/",
"aud": "https://unexpected.example.com",
"sub": "subject",
"exp": exp,
"iat": iat,
}).CompactSerialize()
require.NoError(t, err)
var srvURL string
mux := http.NewServeMux()
mux.HandleFunc("GET /.well-known/openid-configuration", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(map[string]any{
"issuer": srvURL,
"authorization_endpoint": srvURL + "/auth",
"token_endpoint": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/token",
"jwks_uri": srvURL + "/keys",
"id_token_signing_alg_values_supported": []any{"RS256"},
})
})
mux.HandleFunc("GET /keys", func(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
json.NewEncoder(w).Encode(jose.JSONWebKeySet{
Keys: []jose.JSONWebKey{
{Key: privateKey.Public(), Use: "sig", Algorithm: "RS256"},
},
})
})
srv := httptest.NewServer(mux)
srvURL = srv.URL
audiences := []string{"https://other.example.com", "https://client.example.com"}
p, err := New(ctx, &oauth.Options{
ProviderName: Name,
ProviderURL: srv.URL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
AccessTokenAllowedAudiences: &audiences,
})
require.NoError(t, err)
claims, err := p.VerifyAccessToken(ctx, rawAccessToken1)
require.NoError(t, err)
delete(claims, "iat")
delete(claims, "exp")
assert.Equal(t, map[string]any{
"iss": "https://sts.windows.net/323b4000-7ad7-4ed3-9f4e-adee06ee8bbe/",
"aud": "https://client.example.com",
"sub": "subject",
}, claims)
_, err = p.VerifyAccessToken(ctx, rawAccessToken2)
assert.ErrorContains(t, err, "invalid audience")
}
func TestVerifyIdentityToken(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute)
mux := http.NewServeMux()
srv := httptest.NewServer(mux)
p, err := New(ctx, &oauth.Options{
ProviderName: Name,
ProviderURL: srv.URL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
})
require.NoError(t, err)
claims, err := p.VerifyIdentityToken(ctx, "RAW IDENTITY TOKEN")
assert.ErrorIs(t, identity.ErrVerifyIdentityTokenNotSupported, err)
assert.Nil(t, claims)
}

View file

@ -360,3 +360,13 @@ func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint,
httputil.Redirect(w, r, endSessionURL.String(), http.StatusFound)
return nil
}
// VerifyAccessToken verifies an access token.
func (p *Provider) VerifyAccessToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyAccessTokenNotSupported
}
// VerifyIdentityToken verifies an identity token.
func (p *Provider) VerifyIdentityToken(_ context.Context, _ string) (claims map[string]any, err error) {
return nil, identity.ErrVerifyIdentityTokenNotSupported
}

View file

@ -8,6 +8,7 @@ import (
"net/http"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
oteltrace "go.opentelemetry.io/otel/trace"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/pkg/identity/identity"
@ -23,7 +24,6 @@ import (
"github.com/pomerium/pomerium/pkg/identity/oidc/okta"
"github.com/pomerium/pomerium/pkg/identity/oidc/onelogin"
"github.com/pomerium/pomerium/pkg/identity/oidc/ping"
oteltrace "go.opentelemetry.io/otel/trace"
)
// State is the identity state.
@ -36,6 +36,8 @@ type Authenticator interface {
Revoke(context.Context, *oauth2.Token) error
Name() string
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v any) error
VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error)
VerifyIdentityToken(ctx context.Context, rawIdentityToken string) (claims map[string]any, err error)
SignIn(w http.ResponseWriter, r *http.Request, state string) error
SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error