mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-02 16:30:17 +02:00
authorize: support authenticating with idp tokens (#5484)
* identity: add support for verifying access and identity tokens * allow overriding with policy option * authenticate: add verify endpoints * wip * implement session creation * add verify test * implement idp token login * fix tests * add pr permission * make session ids route-specific * rename method * add test * add access token test * test for newUserFromIDPClaims * more tests * make the session id per-idp * use type for * add test * remove nil checks
This commit is contained in:
parent
6e22b7a19a
commit
b9fd926618
36 changed files with 2791 additions and 885 deletions
|
@ -10,10 +10,14 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
go_oidc "github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/google/uuid"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/jwtutil"
|
||||
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
||||
pom_oidc "github.com/pomerium/pomerium/pkg/identity/oidc"
|
||||
)
|
||||
|
@ -37,11 +41,13 @@ var defaultAuthCodeOptions = map[string]string{"prompt": "select_account"}
|
|||
// Provider is an Azure implementation of the Authenticator interface.
|
||||
type Provider struct {
|
||||
*pom_oidc.Provider
|
||||
accessTokenAllowedAudiences *[]string
|
||||
}
|
||||
|
||||
// New instantiates an OpenID Connect (OIDC) provider for Azure.
|
||||
func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
||||
var p Provider
|
||||
p.accessTokenAllowedAudiences = o.AccessTokenAllowedAudiences
|
||||
var err error
|
||||
if o.ProviderURL == "" {
|
||||
o.ProviderURL = defaultProviderURL
|
||||
|
@ -73,6 +79,59 @@ func (p *Provider) Name() string {
|
|||
return Name
|
||||
}
|
||||
|
||||
// VerifyAccessToken verifies a raw access token.
|
||||
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
|
||||
pp, err := p.GetProvider()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error getting oidc provider: %w", err)
|
||||
}
|
||||
|
||||
// azure access tokens are JWTs signed with the same keys as identity tokens
|
||||
verifier := pp.Verifier(&go_oidc.Config{
|
||||
SkipClientIDCheck: true,
|
||||
SkipIssuerCheck: true, // checked later
|
||||
})
|
||||
token, err := verifier.Verify(ctx, rawAccessToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying access token: %w", err)
|
||||
}
|
||||
|
||||
claims = jwtutil.Claims(map[string]any{})
|
||||
err = token.Claims(&claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)
|
||||
}
|
||||
|
||||
// verify audience
|
||||
if p.accessTokenAllowedAudiences != nil {
|
||||
if audience, ok := claims["aud"].(string); !ok || !slices.Contains(*p.accessTokenAllowedAudiences, audience) {
|
||||
return nil, fmt.Errorf("error verifying access token audience claim, invalid audience")
|
||||
}
|
||||
}
|
||||
|
||||
err = verifyIssuer(pp, claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error verifying access token issuer claim: %w", err)
|
||||
}
|
||||
|
||||
if scope, ok := claims["scp"].(string); ok && slices.Contains(strings.Fields(scope), "openid") {
|
||||
userInfo, err := pp.UserInfo(ctx, oauth2.StaticTokenSource(&oauth2.Token{
|
||||
TokenType: "Bearer",
|
||||
AccessToken: rawAccessToken,
|
||||
}))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error calling user info endpoint: %w", err)
|
||||
}
|
||||
|
||||
err = userInfo.Claims(claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling user info claims: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
// newProvider overrides the default round tripper for well-known endpoint call that happens
|
||||
// on new provider registration.
|
||||
// By default, the "common" (both public and private domains) responds with
|
||||
|
@ -128,3 +187,55 @@ func (transport *wellKnownConfiguration) RoundTrip(req *http.Request) (*http.Res
|
|||
res.Body = io.NopCloser(bytes.NewReader(bs))
|
||||
return res, nil
|
||||
}
|
||||
|
||||
const (
|
||||
v1IssuerPrefix = "https://sts.windows.net/"
|
||||
v1IssuerSuffix = "/"
|
||||
v2IssuerPrefix = "https://login.microsoftonline.com/"
|
||||
v2IssuerSuffix = "/v2.0"
|
||||
)
|
||||
|
||||
func verifyIssuer(pp *go_oidc.Provider, claims map[string]any) error {
|
||||
tenantID, ok := getTenantIDFromURL(pp.Endpoint().TokenURL)
|
||||
if !ok {
|
||||
return fmt.Errorf("failed to find tenant id")
|
||||
}
|
||||
|
||||
iss, ok := claims["iss"].(string)
|
||||
if !ok {
|
||||
return fmt.Errorf("missing issuer claim")
|
||||
}
|
||||
|
||||
if !(iss == v1IssuerPrefix+tenantID+v1IssuerSuffix || iss == v2IssuerPrefix+tenantID+v2IssuerSuffix) {
|
||||
return fmt.Errorf("invalid issuer: %s", iss)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getTenantIDFromURL(rawTokenURL string) (string, bool) {
|
||||
// URLs look like:
|
||||
// - https://login.microsoftonline.com/f42bce3b-671c-4162-b24c-00ecc7641897/v2.0
|
||||
// Or:
|
||||
// - https://sts.windows.net/f42bce3b-671c-4162-b24c-00ecc7641897/
|
||||
for _, prefix := range []string{v1IssuerPrefix, v2IssuerPrefix} {
|
||||
path, ok := strings.CutPrefix(rawTokenURL, prefix)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
idx := strings.Index(path, "/")
|
||||
if idx <= 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
rawTenantID := path[:idx]
|
||||
if _, err := uuid.Parse(rawTenantID); err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
return rawTenantID, true
|
||||
}
|
||||
|
||||
return "", false
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue