// Package azure implements OpenID Connect for Microsoft Azure // // https://www.pomerium.com/docs/identity-providers/azure package azure import ( "bytes" "context" "encoding/json" "fmt" "io" "net/http" "slices" "strings" go_oidc "github.com/coreos/go-oidc/v3/oidc" "github.com/google/uuid" "golang.org/x/oauth2" "github.com/pomerium/pomerium/internal/jwtutil" "github.com/pomerium/pomerium/pkg/identity/oauth" pom_oidc "github.com/pomerium/pomerium/pkg/identity/oidc" ) // Name identifies the Azure identity provider const Name = "azure" // defaultProviderURL Users with both a personal Microsoft // account and a work or school account from Azure Active Directory (Azure AD) // an sign in to the application. const defaultProviderURL = "https://login.microsoftonline.com/common/v2.0" // nonSpecIssuerURL is the non-oidc spec issuer url which azure will reply with // if using the defaultProviderURL. // https://github.com/MicrosoftDocs/azure-docs/issues/38427#issuecomment-705892340 const nonSpecIssuerURL = "https://login.microsoftonline.com/{tenantid}/v2.0" // https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow#request-an-authorization-code var defaultAuthCodeOptions = map[string]string{"prompt": "select_account"} // Provider is an Azure implementation of the Authenticator interface. type Provider struct { *pom_oidc.Provider accessTokenAllowedAudiences *[]string } // New instantiates an OpenID Connect (OIDC) provider for Azure. func New(ctx context.Context, o *oauth.Options) (*Provider, error) { var p Provider p.accessTokenAllowedAudiences = o.AccessTokenAllowedAudiences var err error if o.ProviderURL == "" { o.ProviderURL = defaultProviderURL } genericOidc, err := newProvider(ctx, o, pom_oidc.WithGetVerifier(func(provider *go_oidc.Provider) *go_oidc.IDTokenVerifier { return provider.Verifier(&go_oidc.Config{ ClientID: o.ClientID, // If using the common endpoint, the verification provider URI will not match. // https://github.com/pomerium/pomerium/issues/1605 SkipIssuerCheck: o.ProviderURL == defaultProviderURL, }) })) if err != nil { return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err) } p.Provider = genericOidc p.AuthCodeOptions = defaultAuthCodeOptions if o.AuthCodeOptions != nil { p.AuthCodeOptions = o.AuthCodeOptions } return &p, nil } // Name returns the provider name. func (p *Provider) Name() string { return Name } // VerifyAccessToken verifies a raw access token. func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) { pp, err := p.GetProvider() if err != nil { return nil, fmt.Errorf("error getting oidc provider: %w", err) } // azure access tokens are JWTs signed with the same keys as identity tokens verifier := pp.Verifier(&go_oidc.Config{ SkipClientIDCheck: true, SkipIssuerCheck: true, // checked later }) token, err := verifier.Verify(ctx, rawAccessToken) if err != nil { return nil, fmt.Errorf("error verifying access token: %w", err) } claims = jwtutil.Claims(map[string]any{}) err = token.Claims(&claims) if err != nil { return nil, fmt.Errorf("error unmarshaling access token claims: %w", err) } // verify audience if p.accessTokenAllowedAudiences != nil { if audience, ok := claims["aud"].(string); !ok || !slices.Contains(*p.accessTokenAllowedAudiences, audience) { return nil, fmt.Errorf("error verifying access token audience claim, invalid audience") } } err = verifyIssuer(pp, claims) if err != nil { return nil, fmt.Errorf("error verifying access token issuer claim: %w", err) } if scope, ok := claims["scp"].(string); ok && slices.Contains(strings.Fields(scope), "openid") { userInfo, err := pp.UserInfo(ctx, oauth2.StaticTokenSource(&oauth2.Token{ TokenType: "Bearer", AccessToken: rawAccessToken, })) if err != nil { return nil, fmt.Errorf("error calling user info endpoint: %w", err) } err = userInfo.Claims(claims) if err != nil { return nil, fmt.Errorf("error unmarshaling user info claims: %w", err) } } return claims, nil } // newProvider overrides the default round tripper for well-known endpoint call that happens // on new provider registration. // By default, the "common" (both public and private domains) responds with // https://login.microsoftonline.com/{tenantid}/v2.0 for issuer which is not OIDC spec. // If {tenantid} is in the issuer string, we force the issuer to match the defaultURL. // // https://github.com/pomerium/pomerium/issues/1605 func newProvider(ctx context.Context, o *oauth.Options, options ...pom_oidc.Option) (*pom_oidc.Provider, error) { originalClient := http.DefaultClient if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok { originalClient = c } client := new(http.Client) *client = *originalClient client.Transport = &wellKnownConfiguration{underlying: client.Transport} ctx = context.WithValue(ctx, oauth2.HTTPClient, client) return pom_oidc.New(ctx, o, options...) } type wellKnownConfiguration struct { underlying http.RoundTripper } func (transport *wellKnownConfiguration) RoundTrip(req *http.Request) (*http.Response, error) { underlying := transport.underlying if underlying == nil { underlying = http.DefaultTransport } res, err := underlying.RoundTrip(req) if err != nil { return nil, err } defer res.Body.Close() bs, err := io.ReadAll(res.Body) if err != nil { return nil, err } var wk map[string]any if err := json.Unmarshal(bs, &wk); err == nil { if issuerVar, ok := wk["issuer"]; ok { if fmt.Sprint(issuerVar) == nonSpecIssuerURL { wk["issuer"] = defaultProviderURL } } bs, _ = json.Marshal(wk) } res.Body = io.NopCloser(bytes.NewReader(bs)) return res, nil } const ( v1IssuerPrefix = "https://sts.windows.net/" v1IssuerSuffix = "/" v2IssuerPrefix = "https://login.microsoftonline.com/" v2IssuerSuffix = "/v2.0" ) func verifyIssuer(pp *go_oidc.Provider, claims map[string]any) error { tenantID, ok := getTenantIDFromURL(pp.Endpoint().TokenURL) if !ok { return fmt.Errorf("failed to find tenant id") } iss, ok := claims["iss"].(string) if !ok { return fmt.Errorf("missing issuer claim") } if !(iss == v1IssuerPrefix+tenantID+v1IssuerSuffix || iss == v2IssuerPrefix+tenantID+v2IssuerSuffix) { return fmt.Errorf("invalid issuer: %s", iss) } return nil } func getTenantIDFromURL(rawTokenURL string) (string, bool) { // URLs look like: // - https://login.microsoftonline.com/f42bce3b-671c-4162-b24c-00ecc7641897/v2.0 // Or: // - https://sts.windows.net/f42bce3b-671c-4162-b24c-00ecc7641897/ for _, prefix := range []string{v1IssuerPrefix, v2IssuerPrefix} { path, ok := strings.CutPrefix(rawTokenURL, prefix) if !ok { continue } idx := strings.Index(path, "/") if idx <= 0 { continue } rawTenantID := path[:idx] if _, err := uuid.Parse(rawTenantID); err != nil { continue } return rawTenantID, true } return "", false }