mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
* identity: add support for verifying access and identity tokens * allow overriding with policy option * authenticate: add verify endpoints * wip * implement session creation * add verify test * implement idp token login * fix tests * add pr permission * make session ids route-specific * rename method * add test * add access token test * test for newUserFromIDPClaims * more tests * make the session id per-idp * use type for * add test * remove nil checks
241 lines
6.9 KiB
Go
241 lines
6.9 KiB
Go
// Package azure implements OpenID Connect for Microsoft Azure
|
|
//
|
|
// https://www.pomerium.com/docs/identity-providers/azure
|
|
package azure
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"slices"
|
|
"strings"
|
|
|
|
go_oidc "github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/google/uuid"
|
|
"golang.org/x/oauth2"
|
|
|
|
"github.com/pomerium/pomerium/internal/jwtutil"
|
|
"github.com/pomerium/pomerium/pkg/identity/oauth"
|
|
pom_oidc "github.com/pomerium/pomerium/pkg/identity/oidc"
|
|
)
|
|
|
|
// Name identifies the Azure identity provider
|
|
const Name = "azure"
|
|
|
|
// defaultProviderURL Users with both a personal Microsoft
|
|
// account and a work or school account from Azure Active Directory (Azure AD)
|
|
// an sign in to the application.
|
|
const defaultProviderURL = "https://login.microsoftonline.com/common/v2.0"
|
|
|
|
// nonSpecIssuerURL is the non-oidc spec issuer url which azure will reply with
|
|
// if using the defaultProviderURL.
|
|
// https://github.com/MicrosoftDocs/azure-docs/issues/38427#issuecomment-705892340
|
|
const nonSpecIssuerURL = "https://login.microsoftonline.com/{tenantid}/v2.0"
|
|
|
|
// https://docs.microsoft.com/en-us/azure/active-directory/develop/v2-oauth2-auth-code-flow#request-an-authorization-code
|
|
var defaultAuthCodeOptions = map[string]string{"prompt": "select_account"}
|
|
|
|
// Provider is an Azure implementation of the Authenticator interface.
|
|
type Provider struct {
|
|
*pom_oidc.Provider
|
|
accessTokenAllowedAudiences *[]string
|
|
}
|
|
|
|
// New instantiates an OpenID Connect (OIDC) provider for Azure.
|
|
func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
|
var p Provider
|
|
p.accessTokenAllowedAudiences = o.AccessTokenAllowedAudiences
|
|
var err error
|
|
if o.ProviderURL == "" {
|
|
o.ProviderURL = defaultProviderURL
|
|
}
|
|
genericOidc, err := newProvider(ctx, o,
|
|
pom_oidc.WithGetVerifier(func(provider *go_oidc.Provider) *go_oidc.IDTokenVerifier {
|
|
return provider.Verifier(&go_oidc.Config{
|
|
ClientID: o.ClientID,
|
|
// If using the common endpoint, the verification provider URI will not match.
|
|
// https://github.com/pomerium/pomerium/issues/1605
|
|
SkipIssuerCheck: o.ProviderURL == defaultProviderURL,
|
|
})
|
|
}))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("%s: failed creating oidc provider: %w", Name, err)
|
|
}
|
|
p.Provider = genericOidc
|
|
|
|
p.AuthCodeOptions = defaultAuthCodeOptions
|
|
if o.AuthCodeOptions != nil {
|
|
p.AuthCodeOptions = o.AuthCodeOptions
|
|
}
|
|
|
|
return &p, nil
|
|
}
|
|
|
|
// Name returns the provider name.
|
|
func (p *Provider) Name() string {
|
|
return Name
|
|
}
|
|
|
|
// VerifyAccessToken verifies a raw access token.
|
|
func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) {
|
|
pp, err := p.GetProvider()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error getting oidc provider: %w", err)
|
|
}
|
|
|
|
// azure access tokens are JWTs signed with the same keys as identity tokens
|
|
verifier := pp.Verifier(&go_oidc.Config{
|
|
SkipClientIDCheck: true,
|
|
SkipIssuerCheck: true, // checked later
|
|
})
|
|
token, err := verifier.Verify(ctx, rawAccessToken)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error verifying access token: %w", err)
|
|
}
|
|
|
|
claims = jwtutil.Claims(map[string]any{})
|
|
err = token.Claims(&claims)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)
|
|
}
|
|
|
|
// verify audience
|
|
if p.accessTokenAllowedAudiences != nil {
|
|
if audience, ok := claims["aud"].(string); !ok || !slices.Contains(*p.accessTokenAllowedAudiences, audience) {
|
|
return nil, fmt.Errorf("error verifying access token audience claim, invalid audience")
|
|
}
|
|
}
|
|
|
|
err = verifyIssuer(pp, claims)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error verifying access token issuer claim: %w", err)
|
|
}
|
|
|
|
if scope, ok := claims["scp"].(string); ok && slices.Contains(strings.Fields(scope), "openid") {
|
|
userInfo, err := pp.UserInfo(ctx, oauth2.StaticTokenSource(&oauth2.Token{
|
|
TokenType: "Bearer",
|
|
AccessToken: rawAccessToken,
|
|
}))
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error calling user info endpoint: %w", err)
|
|
}
|
|
|
|
err = userInfo.Claims(claims)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error unmarshaling user info claims: %w", err)
|
|
}
|
|
}
|
|
|
|
return claims, nil
|
|
}
|
|
|
|
// newProvider overrides the default round tripper for well-known endpoint call that happens
|
|
// on new provider registration.
|
|
// By default, the "common" (both public and private domains) responds with
|
|
// https://login.microsoftonline.com/{tenantid}/v2.0 for issuer which is not OIDC spec.
|
|
// If {tenantid} is in the issuer string, we force the issuer to match the defaultURL.
|
|
//
|
|
// https://github.com/pomerium/pomerium/issues/1605
|
|
func newProvider(ctx context.Context, o *oauth.Options, options ...pom_oidc.Option) (*pom_oidc.Provider, error) {
|
|
originalClient := http.DefaultClient
|
|
if c, ok := ctx.Value(oauth2.HTTPClient).(*http.Client); ok {
|
|
originalClient = c
|
|
}
|
|
|
|
client := new(http.Client)
|
|
*client = *originalClient
|
|
client.Transport = &wellKnownConfiguration{underlying: client.Transport}
|
|
|
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, client)
|
|
return pom_oidc.New(ctx, o, options...)
|
|
}
|
|
|
|
type wellKnownConfiguration struct {
|
|
underlying http.RoundTripper
|
|
}
|
|
|
|
func (transport *wellKnownConfiguration) RoundTrip(req *http.Request) (*http.Response, error) {
|
|
underlying := transport.underlying
|
|
if underlying == nil {
|
|
underlying = http.DefaultTransport
|
|
}
|
|
|
|
res, err := underlying.RoundTrip(req)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer res.Body.Close()
|
|
|
|
bs, err := io.ReadAll(res.Body)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var wk map[string]any
|
|
if err := json.Unmarshal(bs, &wk); err == nil {
|
|
if issuerVar, ok := wk["issuer"]; ok {
|
|
if fmt.Sprint(issuerVar) == nonSpecIssuerURL {
|
|
wk["issuer"] = defaultProviderURL
|
|
}
|
|
}
|
|
bs, _ = json.Marshal(wk)
|
|
}
|
|
|
|
res.Body = io.NopCloser(bytes.NewReader(bs))
|
|
return res, nil
|
|
}
|
|
|
|
const (
|
|
v1IssuerPrefix = "https://sts.windows.net/"
|
|
v1IssuerSuffix = "/"
|
|
v2IssuerPrefix = "https://login.microsoftonline.com/"
|
|
v2IssuerSuffix = "/v2.0"
|
|
)
|
|
|
|
func verifyIssuer(pp *go_oidc.Provider, claims map[string]any) error {
|
|
tenantID, ok := getTenantIDFromURL(pp.Endpoint().TokenURL)
|
|
if !ok {
|
|
return fmt.Errorf("failed to find tenant id")
|
|
}
|
|
|
|
iss, ok := claims["iss"].(string)
|
|
if !ok {
|
|
return fmt.Errorf("missing issuer claim")
|
|
}
|
|
|
|
if !(iss == v1IssuerPrefix+tenantID+v1IssuerSuffix || iss == v2IssuerPrefix+tenantID+v2IssuerSuffix) {
|
|
return fmt.Errorf("invalid issuer: %s", iss)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func getTenantIDFromURL(rawTokenURL string) (string, bool) {
|
|
// URLs look like:
|
|
// - https://login.microsoftonline.com/f42bce3b-671c-4162-b24c-00ecc7641897/v2.0
|
|
// Or:
|
|
// - https://sts.windows.net/f42bce3b-671c-4162-b24c-00ecc7641897/
|
|
for _, prefix := range []string{v1IssuerPrefix, v2IssuerPrefix} {
|
|
path, ok := strings.CutPrefix(rawTokenURL, prefix)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
idx := strings.Index(path, "/")
|
|
if idx <= 0 {
|
|
continue
|
|
}
|
|
|
|
rawTenantID := path[:idx]
|
|
if _, err := uuid.Parse(rawTenantID); err != nil {
|
|
continue
|
|
}
|
|
|
|
return rawTenantID, true
|
|
}
|
|
|
|
return "", false
|
|
}
|