From 2c0bd9e434bb3eec6a74bad88fde3026b5a1ba61 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 12 Feb 2025 11:32:30 -0700 Subject: [PATCH] wip --- authenticate/handlers.go | 6 ++ authorize/databroker.go | 1 - authorize/grpc.go | 6 +- authorize/state.go | 2 +- config/config.go | 6 +- internal/sessions/idptokens/api.go | 20 +++-- internal/sessions/idptokens/idptokens.go | 12 +-- pkg/identity/oidc/azure/microsoft.go | 101 +++++++++++++++++++++++ 8 files changed, 134 insertions(+), 20 deletions(-) diff --git a/authenticate/handlers.go b/authenticate/handlers.go index bd9700ad1..d39c41c99 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -44,6 +44,12 @@ func (a *Authenticate) Handler() http.Handler { func (a *Authenticate) Mount(r *mux.Router) { r.StrictSlash(true) r.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy)) + r.Use(func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r = csrf.UnsafeSkipCheck(r) + h.ServeHTTP(w, r) + }) + }) r.Use(func(h http.Handler) http.Handler { options := a.options.Load() state := a.state.Load() diff --git a/authorize/databroker.go b/authorize/databroker.go index 2c59e4c30..a5c4b98c1 100644 --- a/authorize/databroker.go +++ b/authorize/databroker.go @@ -13,7 +13,6 @@ import ( ) type sessionOrServiceAccount interface { - GetId() string GetUserId() string Validate() error } diff --git a/authorize/grpc.go b/authorize/grpc.go index ed7611867..5fe762dd0 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -45,10 +45,12 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe requestID := requestid.FromHTTPHeader(hreq.Header) ctx = requestid.WithValue(ctx, requestID) + var sessionID string var s sessionOrServiceAccount var u *user.User if sess, err := a.state.Load().idpTokensLoader.LoadSession(hreq); err == nil { s = sess + sessionID = sess.GetId() } else if !errors.Is(err, sessions.ErrNoSessionFound) { log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("error verifying idp tokens") } else { @@ -60,6 +62,8 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe return nil, err } else if err != nil { log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account") + } else { + sessionID = sessionState.ID } } } @@ -67,7 +71,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error } - req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in, s.GetId()) + req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in, sessionID) if err != nil { log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error building evaluator request") return nil, err diff --git a/authorize/state.go b/authorize/state.go index 70d4f23ff..05429f50a 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -90,7 +90,7 @@ func newAuthorizeStateFromConfig( return nil, err } - state.idpTokensLoader = idptokens.NewLoader(cfg.Options, state.dataBrokerClient) + state.idpTokensLoader = idptokens.NewLoader(cfg, state.dataBrokerClient) return state, nil } diff --git a/config/config.go b/config/config.go index 14a51b2e4..6a22b59f4 100644 --- a/config/config.go +++ b/config/config.go @@ -215,7 +215,7 @@ func (cfg *Config) GetCertificatePool() (*x509.CertPool, error) { // GetAuthenticateKeyFetcher returns a key fetcher for the authenticate service func (cfg *Config) GetAuthenticateKeyFetcher() (hpke.KeyFetcher, error) { - authenticateURL, transport, err := cfg.resolveAuthenticateURL() + authenticateURL, transport, err := cfg.ResolveAuthenticateURL() if err != nil { return nil, err } @@ -225,7 +225,9 @@ func (cfg *Config) GetAuthenticateKeyFetcher() (hpke.KeyFetcher, error) { return hpke.NewKeyFetcher(hpkeURL, transport), nil } -func (cfg *Config) resolveAuthenticateURL() (*url.URL, *http.Transport, error) { +// ResolveAuthenticateURL resolves the authenticate service URL and returns a transport suitable +// for accessing the authenticate service. +func (cfg *Config) ResolveAuthenticateURL() (*url.URL, *http.Transport, error) { authenticateURL, err := cfg.Options.GetInternalAuthenticateURL() if err != nil { return nil, nil, fmt.Errorf("invalid authenticate service url: %w", err) diff --git a/internal/sessions/idptokens/api.go b/internal/sessions/idptokens/api.go index 76504f55f..0300326d7 100644 --- a/internal/sessions/idptokens/api.go +++ b/internal/sessions/idptokens/api.go @@ -9,7 +9,7 @@ import ( "net/http" "net/url" - "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/config" ) // endpoints @@ -39,11 +39,11 @@ type VerifyIdentityTokenRequest struct { func apiVerifyAccessToken( ctx context.Context, - authenticateServiceURL string, + cfg *config.Config, request *VerifyAccessTokenRequest, ) (*VerifyTokenResponse, error) { var response VerifyTokenResponse - err := api(ctx, authenticateServiceURL, "verify-access-token", request, &response) + err := api(ctx, cfg, "verify-access-token", request, &response) if err != nil { return nil, err } @@ -52,11 +52,11 @@ func apiVerifyAccessToken( func apiVerifyIdentityToken( ctx context.Context, - authenticateServiceURL string, + cfg *config.Config, request *VerifyIdentityTokenRequest, ) (*VerifyTokenResponse, error) { var response VerifyTokenResponse - err := api(ctx, authenticateServiceURL, "verify-identity-token", request, &response) + err := api(ctx, cfg, "verify-identity-token", request, &response) if err != nil { return nil, err } @@ -65,15 +65,15 @@ func apiVerifyIdentityToken( func api( ctx context.Context, - authenticateServiceURL string, + cfg *config.Config, endpoint string, request, response any, ) error { - u, err := urlutil.ParseAndValidateURL(authenticateServiceURL) + authenticateURL, transport, err := cfg.ResolveAuthenticateURL() if err != nil { return fmt.Errorf("invalid authenticate service url: %w", err) } - u = u.ResolveReference(&url.URL{ + u := authenticateURL.ResolveReference(&url.URL{ Path: "/.pomerium/" + endpoint, }) @@ -87,7 +87,9 @@ func api( return fmt.Errorf("error creating %s http request: %w", endpoint, err) } - res, err := http.DefaultClient.Do(req) + res, err := (&http.Client{ + Transport: transport, + }).Do(req) if err != nil { return fmt.Errorf("error executing %s http request: %w", endpoint, err) } diff --git a/internal/sessions/idptokens/idptokens.go b/internal/sessions/idptokens/idptokens.go index f098db247..44aa24fab 100644 --- a/internal/sessions/idptokens/idptokens.go +++ b/internal/sessions/idptokens/idptokens.go @@ -26,14 +26,14 @@ var ( // A Loader loads sessions from IdP access and identity tokens. type Loader struct { - options *config.Options + cfg *config.Config dataBrokerServiceClient databroker.DataBrokerServiceClient } // NewLoader creates a new Loader. -func NewLoader(options *config.Options, dataBrokerServiceClient databroker.DataBrokerServiceClient) *Loader { +func NewLoader(cfg *config.Config, dataBrokerServiceClient databroker.DataBrokerServiceClient) *Loader { return &Loader{ - options: options, + cfg: cfg, dataBrokerServiceClient: dataBrokerServiceClient, } } @@ -42,7 +42,7 @@ func NewLoader(options *config.Options, dataBrokerServiceClient databroker.DataB func (l *Loader) LoadSession(r *http.Request) (*session.Session, error) { ctx := r.Context() - idp, err := l.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String()) + idp, err := l.cfg.Options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String()) if err != nil { return nil, err } @@ -87,7 +87,7 @@ func (l *Loader) loadSessionFromAccessToken(ctx context.Context, idp *identity.P return nil, err } - res, err := apiVerifyAccessToken(ctx, idp.GetAuthenticateServiceUrl(), &VerifyAccessTokenRequest{ + res, err := apiVerifyAccessToken(ctx, l.cfg, &VerifyAccessTokenRequest{ AccessToken: rawAccessToken, IdentityProviderID: idp.GetId(), }) @@ -121,7 +121,7 @@ func (l *Loader) loadSessionFromIdentityToken(ctx context.Context, idp *identity return nil, err } - res, err := apiVerifyIdentityToken(ctx, idp.GetAuthenticateServiceUrl(), &VerifyIdentityTokenRequest{ + res, err := apiVerifyIdentityToken(ctx, l.cfg, &VerifyIdentityTokenRequest{ IdentityToken: rawIdentityToken, IdentityProviderID: idp.GetId(), }) diff --git a/pkg/identity/oidc/azure/microsoft.go b/pkg/identity/oidc/azure/microsoft.go index 1ae77ea15..0e5b79bd5 100644 --- a/pkg/identity/oidc/azure/microsoft.go +++ b/pkg/identity/oidc/azure/microsoft.go @@ -10,8 +10,11 @@ import ( "fmt" "io" "net/http" + "slices" + "strings" go_oidc "github.com/coreos/go-oidc/v3/oidc" + "github.com/google/uuid" "golang.org/x/oauth2" "github.com/pomerium/pomerium/pkg/identity/oauth" @@ -73,6 +76,52 @@ func (p *Provider) Name() string { return Name } +// VerifyAccessToken verifies a raw access token using the oidc UserInfo endpoint. +func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) (claims map[string]any, err error) { + pp, err := p.GetProvider() + if err != nil { + return nil, fmt.Errorf("error getting oidc provider: %w", err) + } + + verifier := pp.Verifier(&go_oidc.Config{ + SkipClientIDCheck: true, + SkipIssuerCheck: true, // checked later + }) + + token, err := verifier.Verify(ctx, rawAccessToken) + if err != nil { + return nil, fmt.Errorf("error verifying access token: %w", err) + } + + claims = map[string]any{} + err = token.Claims(&claims) + if err != nil { + return nil, fmt.Errorf("error unmarshaling access token claims: %w", err) + } + + err = verifyIssuer(pp, claims) + if err != nil { + return nil, fmt.Errorf("error verifying access token issuer claim: %w", err) + } + + if scope, ok := claims["scp"].(string); ok && slices.Contains(strings.Fields(scope), "openid") { + userInfo, err := pp.UserInfo(ctx, oauth2.StaticTokenSource(&oauth2.Token{ + TokenType: "Bearer", + AccessToken: rawAccessToken, + })) + if err != nil { + return nil, fmt.Errorf("error calling user info endpoint: %w", err) + } + + err = userInfo.Claims(claims) + if err != nil { + return nil, fmt.Errorf("error unmarshaling user info claims: %w", err) + } + } + + return claims, nil +} + // newProvider overrides the default round tripper for well-known endpoint call that happens // on new provider registration. // By default, the "common" (both public and private domains) responds with @@ -128,3 +177,55 @@ func (transport *wellKnownConfiguration) RoundTrip(req *http.Request) (*http.Res res.Body = io.NopCloser(bytes.NewReader(bs)) return res, nil } + +const ( + v1IssuerPrefix = "https://sts.windows.net/" + v1IssuerSuffix = "/" + v2IssuerPrefix = "https://login.microsoftonline.com/" + v2IssuerSuffix = "/v2.0" +) + +func verifyIssuer(pp *go_oidc.Provider, claims map[string]any) error { + tenantID, ok := getTenantIDFromURL(pp.Endpoint().TokenURL) + if !ok { + return fmt.Errorf("failed to find tenant id") + } + + iss, ok := claims["iss"].(string) + if !ok { + return fmt.Errorf("missing issuer claim") + } + + if !(iss == v1IssuerPrefix+tenantID+v1IssuerSuffix || iss == v2IssuerPrefix+tenantID+v2IssuerSuffix) { + return fmt.Errorf("invalid issuer: %s", iss) + } + + return nil +} + +func getTenantIDFromURL(rawTokenURL string) (string, bool) { + // URLs look like: + // - https://login.microsoftonline.com/f42bce3b-671c-4162-b24c-00ecc7641897/v2.0 + // Or: + // - https://sts.windows.net/f42bce3b-671c-4162-b24c-00ecc7641897/ + for _, prefix := range []string{v1IssuerPrefix, v2IssuerPrefix} { + path, ok := strings.CutPrefix(rawTokenURL, prefix) + if !ok { + continue + } + + idx := strings.Index(path, "/") + if idx <= 0 { + continue + } + + rawTenantID := path[:idx] + if _, err := uuid.Parse(rawTenantID); err != nil { + continue + } + + return rawTenantID, true + } + + return "", false +}