authorize: support authenticating with idp tokens (#5484)

* identity: add support for verifying access and identity tokens

* allow overriding with policy option

* authenticate: add verify endpoints

* wip

* implement session creation

* add verify test

* implement idp token login

* fix tests

* add pr permission

* make session ids route-specific

* rename method

* add test

* add access token test

* test for newUserFromIDPClaims

* more tests

* make the session id per-idp

* use type for

* add test

* remove nil checks
This commit is contained in:
Caleb Doxsey 2025-02-18 13:02:06 -07:00 committed by GitHub
parent 6e22b7a19a
commit b9fd926618
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
36 changed files with 2791 additions and 885 deletions

View file

@ -10,10 +10,14 @@ import (
"fmt"
"io"
"net/http"
"slices"
"strings"
go_oidc "github.com/coreos/go-oidc/v3/oidc"
"github.com/google/uuid"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/internal/jwtutil"
"github.com/pomerium/pomerium/pkg/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/pkg/identity/oidc"
)
@ -37,11 +41,13 @@ var defaultAuthCodeOptions = map[string]string{"prompt": "select_account"}
// Provider is an Azure implementation of the Authenticator interface.
type Provider struct {
*pom_oidc.Provider
accessTokenAllowedAudiences *[]string
}
// New instantiates an OpenID Connect (OIDC) provider for Azure.
func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
var p Provider
p.accessTokenAllowedAudiences = o.AccessTokenAllowedAudiences
var err error
if o.ProviderURL == "" {
o.ProviderURL = defaultProviderURL
@ -73,6 +79,59 @@ func (p *Provider) Name() string {
return Name
}
// VerifyAccessToken verifies a raw access token.
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
pp, err := p.GetProvider()
if err != nil {
return nil, fmt.Errorf("error getting oidc provider: %w", err)
}
// azure access tokens are JWTs signed with the same keys as identity tokens
verifier := pp.Verifier(&go_oidc.Config{
SkipClientIDCheck: true,
SkipIssuerCheck: true, // checked later
})
token, err := verifier.Verify(ctx, rawAccessToken)
if err != nil {
return nil, fmt.Errorf("error verifying access token: %w", err)
}
claims = jwtutil.Claims(map[string]any{})
err = token.Claims(&claims)
if err != nil {
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)
}
// verify audience
if p.accessTokenAllowedAudiences != nil {
if audience, ok := claims["aud"].(string); !ok || !slices.Contains(*p.accessTokenAllowedAudiences, audience) {
return nil, fmt.Errorf("error verifying access token audience claim, invalid audience")
}
}
err = verifyIssuer(pp, claims)
if err != nil {
return nil, fmt.Errorf("error verifying access token issuer claim: %w", err)
}
if scope, ok := claims["scp"].(string); ok && slices.Contains(strings.Fields(scope), "openid") {
userInfo, err := pp.UserInfo(ctx, oauth2.StaticTokenSource(&oauth2.Token{
TokenType: "Bearer",
AccessToken: rawAccessToken,
}))
if err != nil {
return nil, fmt.Errorf("error calling user info endpoint: %w", err)
}
err = userInfo.Claims(claims)
if err != nil {
return nil, fmt.Errorf("error unmarshaling user info claims: %w", err)
}
}
return claims, nil
}
// newProvider overrides the default round tripper for well-known endpoint call that happens
// on new provider registration.
// By default, the "common" (both public and private domains) responds with
@ -128,3 +187,55 @@ func (transport *wellKnownConfiguration) RoundTrip(req *http.Request) (*http.Res
res.Body = io.NopCloser(bytes.NewReader(bs))
return res, nil
}
const (
v1IssuerPrefix = "https://sts.windows.net/"
v1IssuerSuffix = "/"
v2IssuerPrefix = "https://login.microsoftonline.com/"
v2IssuerSuffix = "/v2.0"
)
func verifyIssuer(pp *go_oidc.Provider, claims map[string]any) error {
tenantID, ok := getTenantIDFromURL(pp.Endpoint().TokenURL)
if !ok {
return fmt.Errorf("failed to find tenant id")
}
iss, ok := claims["iss"].(string)
if !ok {
return fmt.Errorf("missing issuer claim")
}
if !(iss == v1IssuerPrefix+tenantID+v1IssuerSuffix || iss == v2IssuerPrefix+tenantID+v2IssuerSuffix) {
return fmt.Errorf("invalid issuer: %s", iss)
}
return nil
}
func getTenantIDFromURL(rawTokenURL string) (string, bool) {
// URLs look like:
// - https://login.microsoftonline.com/f42bce3b-671c-4162-b24c-00ecc7641897/v2.0
// Or:
// - https://sts.windows.net/f42bce3b-671c-4162-b24c-00ecc7641897/
for _, prefix := range []string{v1IssuerPrefix, v2IssuerPrefix} {
path, ok := strings.CutPrefix(rawTokenURL, prefix)
if !ok {
continue
}
idx := strings.Index(path, "/")
if idx <= 0 {
continue
}
rawTenantID := path[:idx]
if _, err := uuid.Parse(rawTenantID); err != nil {
continue
}
return rawTenantID, true
}
return "", false
}