From b95ad4dbc3a25c8d59f197d0cd5f5d4c86c16dd0 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 14 Feb 2025 14:43:23 -0700 Subject: [PATCH] implement session creation --- authenticate/handlers_verify.go | 24 +- authorize/authorize.go | 6 +- authorize/check_response.go | 8 +- authorize/check_response_test.go | 22 +- authorize/databroker.go | 4 +- authorize/grpc.go | 100 ++++++--- authorize/grpc_test.go | 44 ++-- authorize/log.go | 2 +- config/bearer_token_format.go | 2 +- config/constants.go | 4 +- config/session.go | 293 ++++++++++++++++++++----- internal/httputil/headers.go | 13 +- internal/jwtutil/jwtutil.go | 160 ++++++++++++++ pkg/authenticateapi/authenticateapi.go | 109 +++++++++ pkg/identity/oidc/azure/microsoft.go | 3 +- 15 files changed, 646 insertions(+), 148 deletions(-) create mode 100644 internal/jwtutil/jwtutil.go create mode 100644 pkg/authenticateapi/authenticateapi.go diff --git a/authenticate/handlers_verify.go b/authenticate/handlers_verify.go index 0ab865b64..ad5f5ee2b 100644 --- a/authenticate/handlers_verify.go +++ b/authenticate/handlers_verify.go @@ -6,25 +6,11 @@ import ( "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/authenticateapi" ) -type VerifyAccessTokenRequest struct { - AccessToken string `json:"accessToken"` - IdentityProviderID string `json:"identityProviderId,omitempty"` -} - -type VerifyIdentityTokenRequest struct { - IdentityToken string `json:"identityToken"` - IdentityProviderID string `json:"identityProviderId,omitempty"` -} - -type VerifyTokenResponse struct { - Valid bool `json:"valid"` - Claims map[string]any `json:"claims,omitempty"` -} - func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request) error { - var req VerifyAccessTokenRequest + var req authenticateapi.VerifyAccessTokenRequest err := json.NewDecoder(r.Body).Decode(&req) if err != nil { return httputil.NewError(http.StatusBadRequest, err) @@ -35,7 +21,7 @@ func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request) return err } - var res VerifyTokenResponse + var res authenticateapi.VerifyTokenResponse claims, err := authenticator.VerifyAccessToken(r.Context(), req.AccessToken) if err == nil { res.Valid = true @@ -57,7 +43,7 @@ func (a *Authenticate) verifyAccessToken(w http.ResponseWriter, r *http.Request) } func (a *Authenticate) verifyIdentityToken(w http.ResponseWriter, r *http.Request) error { - var req VerifyIdentityTokenRequest + var req authenticateapi.VerifyIdentityTokenRequest err := json.NewDecoder(r.Body).Decode(&req) if err != nil { return httputil.NewError(http.StatusBadRequest, err) @@ -68,7 +54,7 @@ func (a *Authenticate) verifyIdentityToken(w http.ResponseWriter, r *http.Reques return err } - var res VerifyTokenResponse + var res authenticateapi.VerifyTokenResponse claims, err := authenticator.VerifyIdentityToken(r.Context(), req.IdentityToken) if err == nil { res.Valid = true diff --git a/authorize/authorize.go b/authorize/authorize.go index a25da96fe..d72b7712a 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -29,7 +29,7 @@ import ( type Authorize struct { state *atomicutil.Value[*authorizeState] store *store.Store - currentOptions *atomicutil.Value[*config.Options] + currentConfig *atomicutil.Value[*config.Config] accessTracker *AccessTracker globalCache storage.Cache groupsCacheWarmer *cacheWarmer @@ -43,7 +43,7 @@ func New(ctx context.Context, cfg *config.Config) (*Authorize, error) { tracerProvider := trace.NewTracerProvider(ctx, "Authorize") tracer := tracerProvider.Tracer(trace.PomeriumCoreTracer) a := &Authorize{ - currentOptions: config.NewAtomicOptions(), + currentConfig: atomicutil.NewValue(&config.Config{Options: new(config.Options)}), store: store.New(), globalCache: storage.NewGlobalCache(time.Minute), tracerProvider: tracerProvider, @@ -155,7 +155,7 @@ func newPolicyEvaluator( // OnConfigChange updates internal structures based on config.Options func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) { currentState := a.state.Load() - a.currentOptions.Store(cfg.Options) + a.currentConfig.Store(cfg) if newState, err := newAuthorizeStateFromConfig(ctx, a.tracerProvider, cfg, a.store, currentState.evaluator); err != nil { log.Ctx(ctx).Error().Err(err).Msg("authorize: error updating state") } else { diff --git a/authorize/check_response.go b/authorize/check_response.go index 027402af4..873acc754 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -186,7 +186,7 @@ func (a *Authorize) deniedResponse( Err: errors.New(reason), DebugURL: debugEndpoint, RequestID: requestid.FromContext(ctx), - BrandingOptions: a.currentOptions.Load().BrandingOptions, + BrandingOptions: a.currentConfig.Load().Options.BrandingOptions, } httpErr.ErrorResponse(ctx, w, r) @@ -213,7 +213,7 @@ func (a *Authorize) requireLoginResponse( in *envoy_service_auth_v3.CheckRequest, request *evaluator.Request, ) (*envoy_service_auth_v3.CheckResponse, error) { - options := a.currentOptions.Load() + options := a.currentConfig.Load().Options state := a.state.Load() if !a.shouldRedirect(in) { @@ -251,7 +251,7 @@ func (a *Authorize) requireWebAuthnResponse( request *evaluator.Request, result *evaluator.Result, ) (*envoy_service_auth_v3.CheckResponse, error) { - opts := a.currentOptions.Load() + opts := a.currentConfig.Load().Options state := a.state.Load() // always assume https scheme @@ -327,7 +327,7 @@ func toEnvoyHeaders(headers http.Header) []*envoy_config_core_v3.HeaderValueOpti // userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's // session that lives on the authenticate service. func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) { - opts := a.currentOptions.Load() + opts := a.currentConfig.Load().Options authenticateURL, err := opts.GetAuthenticateURL() if err != nil { return nil, err diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index 3fbbaffb8..5ed6a9629 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -127,8 +127,9 @@ func TestAuthorize_okResponse(t *testing.T) { }}, JWTClaimsHeaders: config.NewJWTClaimHeaders("email"), } - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} - a.currentOptions.Store(opt) + a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{ + Options: opt, + }), state: atomicutil.NewValue(new(authorizeState))} a.store = store.New() pe, err := newPolicyEvaluator(context.Background(), opt, a.store, nil) require.NoError(t, err) @@ -183,15 +184,16 @@ func TestAuthorize_okResponse(t *testing.T) { func TestAuthorize_deniedResponse(t *testing.T) { t.Parallel() - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} - a.currentOptions.Store(&config.Options{ - Policies: []config.Policy{{ - From: "https://example.com", - SubPolicies: []config.SubPolicy{{ - Rego: []string{"allow = true"}, + a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{ + Options: &config.Options{ + Policies: []config.Policy{{ + From: "https://example.com", + SubPolicies: []config.SubPolicy{{ + Rego: []string{"allow = true"}, + }}, }}, - }}, - }) + }, + }), state: atomicutil.NewValue(new(authorizeState))} t.Run("json", func(t *testing.T) { t.Parallel() diff --git a/authorize/databroker.go b/authorize/databroker.go index e82550522..2c59e4c30 100644 --- a/authorize/databroker.go +++ b/authorize/databroker.go @@ -3,15 +3,17 @@ package authorize import ( "context" + "google.golang.org/grpc" + "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpcutil" "github.com/pomerium/pomerium/pkg/storage" - "google.golang.org/grpc" ) type sessionOrServiceAccount interface { + GetId() string GetUserId() string Validate() error } diff --git a/authorize/grpc.go b/authorize/grpc.go index f7da2ab0a..258fcbc79 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -3,6 +3,7 @@ package authorize import ( "context" "encoding/pem" + "errors" "io" "net/http" "net/url" @@ -21,6 +22,7 @@ import ( "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/contextutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/telemetry/requestid" @@ -44,31 +46,25 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe requestID := requestid.FromHTTPHeader(hreq.Header) ctx = requestid.WithValue(ctx, requestID) - sessionState, _ := state.sessionStore.LoadSessionStateAndCheckIDP(hreq) - - var s sessionOrServiceAccount - var u *user.User - var err error - if sessionState != nil { - s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion) - if status.Code(err) == codes.Unavailable { - log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable") - return nil, err - } else if err != nil { - log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account") - sessionState = nil - } - } - if sessionState != nil && s != nil { - u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error - } - - req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in, sessionState) + req, err := a.getEvaluatorRequestFromCheckRequest(ctx, in) if err != nil { log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error building evaluator request") return nil, err } + // load the session + s, err := a.loadSession(ctx, hreq, req) + if err != nil { + return nil, err + } + + // if there's a session or service account, load the user + var u *user.User + if s != nil { + req.Session.ID = s.GetId() + u, _ = a.getDataBrokerUser(ctx, s.GetUserId()) // ignore any missing user error + } + res, err := state.evaluator.Evaluate(ctx, req) if err != nil { log.Ctx(ctx).Error().Err(err).Str("request-id", requestID).Msg("error during OPA evaluation") @@ -88,10 +84,65 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe return resp, err } +func (a *Authorize) loadSession( + ctx context.Context, + hreq *http.Request, + req *evaluator.Request, +) (s sessionOrServiceAccount, err error) { + requestID := requestid.FromHTTPHeader(hreq.Header) + + // attempt to create a session from an incoming idp token + s, err = config.NewIncomingIDPTokenSessionCreator( + func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) { + return getDataBrokerRecord(ctx, recordType, recordID, 0) + }, + func(ctx context.Context, records []*databroker.Record) error { + _, err := a.state.Load().dataBrokerClient.Put(ctx, &databroker.PutRequest{ + Records: records, + }) + if err != nil { + return err + } + // invalidate cache + for _, record := range records { + storage.GetQuerier(ctx).InvalidateCache(ctx, &databroker.QueryRequest{ + Type: record.GetType(), + Query: record.GetId(), + Limit: 1, + }) + } + return nil + }, + ).CreateSession(ctx, a.currentConfig.Load(), req.Policy, hreq) + if err == nil { + return s, nil + } else if !errors.Is(err, sessions.ErrNoSessionFound) { + log.Ctx(ctx).Info(). + Str("request-id", requestID). + Err(err). + Msg("error creating session for incoming idp token") + } + + sessionState, _ := a.state.Load().sessionStore.LoadSessionStateAndCheckIDP(hreq) + if sessionState == nil { + return nil, nil + } + + s, err = a.getDataBrokerSessionOrServiceAccount(ctx, sessionState.ID, sessionState.DatabrokerRecordVersion) + if status.Code(err) == codes.Unavailable { + log.Ctx(ctx).Debug().Str("request-id", requestID).Err(err).Msg("temporary error checking authorization: data broker unavailable") + return nil, err + } else if err != nil { + log.Ctx(ctx).Info().Err(err).Str("request-id", requestID).Msg("clearing session due to missing or invalid session or service account") + return nil, nil + } + + return s, nil +} + func (a *Authorize) getEvaluatorRequestFromCheckRequest( ctx context.Context, in *envoy_service_auth_v3.CheckRequest, - sessionState *sessions.State, ) (*evaluator.Request, error) { requestURL := getCheckRequestURL(in) attrs := in.GetAttributes() @@ -106,17 +157,12 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest( attrs.GetSource().GetAddress().GetSocketAddress().GetAddress(), ), } - if sessionState != nil { - req.Session = evaluator.RequestSession{ - ID: sessionState.ID, - } - } req.Policy = a.getMatchingPolicy(envoyconfig.ExtAuthzContextExtensionsRouteID(attrs.GetContextExtensions())) return req, nil } func (a *Authorize) getMatchingPolicy(routeID uint64) *config.Policy { - options := a.currentOptions.Load() + options := a.currentConfig.Load().Options for p := range options.GetAllPolicies() { id, _ := p.RouteID() diff --git a/authorize/grpc_test.go b/authorize/grpc_test.go index 8d239e4c7..bb3603329 100644 --- a/authorize/grpc_test.go +++ b/authorize/grpc_test.go @@ -18,7 +18,6 @@ import ( "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/atomicutil" - "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" @@ -49,15 +48,16 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA== -----END CERTIFICATE-----` func Test_getEvaluatorRequest(t *testing.T) { - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} - a.currentOptions.Store(&config.Options{ - Policies: []config.Policy{{ - From: "https://example.com", - SubPolicies: []config.SubPolicy{{ - Rego: []string{"allow = true"}, + a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{ + Options: &config.Options{ + Policies: []config.Policy{{ + From: "https://example.com", + SubPolicies: []config.SubPolicy{{ + Rego: []string{"allow = true"}, + }}, }}, - }}, - }) + }, + }), state: atomicutil.NewValue(new(authorizeState))} actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(), &envoy_service_auth_v3.CheckRequest{ @@ -88,13 +88,10 @@ func Test_getEvaluatorRequest(t *testing.T) { }, }, }, - &sessions.State{ - ID: "SESSION_ID", - }, ) require.NoError(t, err) expect := &evaluator.Request{ - Policy: &a.currentOptions.Load().Policies[0], + Policy: &a.currentConfig.Load().Options.Policies[0], Session: evaluator.RequestSession{ ID: "SESSION_ID", }, @@ -117,15 +114,16 @@ func Test_getEvaluatorRequest(t *testing.T) { } func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) { - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} - a.currentOptions.Store(&config.Options{ - Policies: []config.Policy{{ - From: "https://example.com", - SubPolicies: []config.SubPolicy{{ - Rego: []string{"allow = true"}, + a := &Authorize{currentConfig: atomicutil.NewValue(&config.Config{ + Options: &config.Options{ + Policies: []config.Policy{{ + From: "https://example.com", + SubPolicies: []config.SubPolicy{{ + Rego: []string{"allow = true"}, + }}, }}, - }}, - }) + }, + }), state: atomicutil.NewValue(new(authorizeState))} actual, err := a.getEvaluatorRequestFromCheckRequest(context.Background(), &envoy_service_auth_v3.CheckRequest{ @@ -145,10 +143,10 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) { }, }, }, - }, nil) + }) require.NoError(t, err) expect := &evaluator.Request{ - Policy: &a.currentOptions.Load().Policies[0], + Policy: &a.currentConfig.Load().Options.Policies[0], Session: evaluator.RequestSession{}, HTTP: evaluator.NewRequestHTTP( http.MethodGet, diff --git a/authorize/log.go b/authorize/log.go index 3f1a571ec..c59cf38ce 100644 --- a/authorize/log.go +++ b/authorize/log.go @@ -31,7 +31,7 @@ func (a *Authorize) logAuthorizeCheck( impersonateDetails := a.getImpersonateDetails(ctx, s) evt := log.Ctx(ctx).Info().Str("service", "authorize") - fields := a.currentOptions.Load().GetAuthorizeLogFields() + fields := a.currentConfig.Load().Options.GetAuthorizeLogFields() for _, field := range fields { evt = populateLogEvent(ctx, field, evt, in, s, u, hdrs, impersonateDetails, res) } diff --git a/config/bearer_token_format.go b/config/bearer_token_format.go index d55ddfc73..7d0723bce 100644 --- a/config/bearer_token_format.go +++ b/config/bearer_token_format.go @@ -67,7 +67,7 @@ func (bearerTokenFormat *BearerTokenFormat) ToPB() *configpb.BearerTokenFormat { case BearerTokenFormatIDPIdentityToken: return configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_IDP_IDENTITY_TOKEN.Enum() default: - panic(fmt.Sprintf("unknown bearer token format: %s", bearerTokenFormat)) + panic(fmt.Sprintf("unknown bearer token format: %v", bearerTokenFormat)) } } diff --git a/config/constants.go b/config/constants.go index 32307b52f..0cb260c29 100644 --- a/config/constants.go +++ b/config/constants.go @@ -4,9 +4,10 @@ import ( "errors" "github.com/mitchellh/mapstructure" - "github.com/pomerium/pomerium/config/otelconfig" "github.com/spf13/viper" "google.golang.org/protobuf/encoding/protojson" + + "github.com/pomerium/pomerium/config/otelconfig" ) const ( @@ -37,6 +38,7 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc( DecodePolicyBase64Hook(), decodeNullBoolHookFunc(), decodeJWTClaimHeadersHookFunc(), + decodeBearerTokenFormatHookFunc(), decodeCodecTypeHookFunc(), decodePPLPolicyHookFunc(), decodeSANMatcherHookFunc(), diff --git a/config/session.go b/config/session.go index 1fd37b22e..595a05d9b 100644 --- a/config/session.go +++ b/config/session.go @@ -5,15 +5,28 @@ import ( "fmt" "net/http" "strings" + "time" + + "github.com/google/uuid" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/jws" + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/jwtutil" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions/cookie" "github.com/pomerium/pomerium/internal/sessions/header" "github.com/pomerium/pomerium/internal/sessions/queryparam" "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/authenticateapi" + "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/grpcutil" + "github.com/pomerium/pomerium/pkg/identity" + "github.com/pomerium/pomerium/pkg/protoutil" + "github.com/pomerium/pomerium/pkg/storage" ) // A SessionStore saves and loads sessions based on the options. @@ -116,82 +129,253 @@ func (store *SessionStore) SaveSession(w http.ResponseWriter, r *http.Request, v return store.store.SaveSession(w, r, v) } -// An IDPTokenSessionHandler handles incoming idp access and identity tokens. -type IDPTokenSessionHandler struct { - options *Options - getSession func(ctx context.Context, id string) (*session.Session, error) - putSession func(ctx context.Context, s *session.Session) error +var ( + accessTokenUUIDNamespace = uuid.MustParse("0194f6f8-e760-76a0-8917-e28ac927a34d") + identityTokenUUIDNamespace = uuid.MustParse("0194f6f9-aec0-704e-bb4a-51054f17ad17") +) + +type IncomingIDPTokenSessionCreator interface { + CreateSession(ctx context.Context, cfg *Config, policy *Policy, r *http.Request) (*session.Session, error) } -// NewIDPTokenSessionHandler creates a new IDPTokenSessionHandler. -func NewIDPTokenSessionHandler( - options *Options, - getSession func(ctx context.Context, id string) (*session.Session, error), - putSession func(ctx context.Context, s *session.Session) error, -) *IDPTokenSessionHandler { - return &IDPTokenSessionHandler{ - options: options, - getSession: getSession, - putSession: putSession, +type incomingIDPTokenSessionCreator struct { + getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error) + putRecords func(ctx context.Context, records []*databroker.Record) error +} + +func NewIncomingIDPTokenSessionCreator( + getRecord func(ctx context.Context, recordType, recordID string) (*databroker.Record, error), + putRecords func(ctx context.Context, records []*databroker.Record) error, +) IncomingIDPTokenSessionCreator { + return &incomingIDPTokenSessionCreator{getRecord: getRecord, putRecords: putRecords} +} + +// CreateSession attempts to create a session for incoming idp access and +// identity tokens. If no access or identity token is passed ErrNoSessionFound will be returned. +// If the tokens are not valid an error will be returned. +func (c *incomingIDPTokenSessionCreator) CreateSession( + ctx context.Context, + cfg *Config, + policy *Policy, + r *http.Request, +) (session *session.Session, err error) { + if rawAccessToken, ok := cfg.GetIncomingIDPAccessTokenForPolicy(policy, r); ok { + return c.createSessionAccessToken(ctx, cfg, policy, rawAccessToken) } + + if rawIdentityToken, ok := cfg.GetIncomingIDPIdentityTokenForPolicy(policy, r); ok { + return c.createSessionForIdentityToken(ctx, cfg, policy, rawIdentityToken) + } + + return nil, sessions.ErrNoSessionFound } -// // CreateSessionForIncomingIDPToken creates a session from an incoming idp access or identity token. -// // If no such tokens are found or they are invalid ErrNoSessionFound will be returned. -// func (h *IDPTokenSessionHandler) CreateSessionForIncomingIDPToken(r *http.Request) (*session.Session, error) { -// idp, err := h.options.GetIdentityProviderForRequestURL(urlutil.GetAbsoluteURL(r).String()) -// if err != nil { -// return nil, err -// } +func (c *incomingIDPTokenSessionCreator) createSessionAccessToken( + ctx context.Context, + cfg *Config, + policy *Policy, + rawAccessToken string, +) (*session.Session, error) { + sessionID := uuid.NewSHA1(accessTokenUUIDNamespace, []byte(rawAccessToken)).String() + s, err := c.getSession(ctx, sessionID) + if err == nil { + return s, nil + } else if !storage.IsNotFound(err) { + return nil, err + } -// return nil, sessions.ErrNoSessionFound -// } + idp, err := cfg.Options.GetIdentityProviderForPolicy(policy) + if err != nil { + return nil, fmt.Errorf("error getting identity provider to verify access token: %w", err) + } -// func (h *IDPTokenSessionHandler) getIncomingIDPAccessToken(r *http.Request) (rawAccessToken string, ok bool) { -// if h.options. + authenticateURL, transport, err := cfg.resolveAuthenticateURL() + if err != nil { + return nil, fmt.Errorf("error resolving authenticate url to verify access token: %w", err) + } -// return "", false -// } + res, err := authenticateapi.New(authenticateURL, transport).VerifyAccessToken(ctx, &authenticateapi.VerifyAccessTokenRequest{ + AccessToken: rawAccessToken, + IdentityProviderID: idp.GetId(), + }) + if err != nil { + return nil, fmt.Errorf("error verifying access token: %w", err) + } else if !res.Valid { + return nil, fmt.Errorf("invalid access token") + } -// func (h *IDPTokenSessionHandler) getIncomingIDPIdentityToken(r *http.Request) (rawIdentityToken string, ok bool) { -// return "", false -// } + s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims) + s.OauthToken = &session.OAuthToken{ + TokenType: "Bearer", + AccessToken: rawAccessToken, + ExpiresAt: s.ExpiresAt, + } + u := c.newUserFromIDPClaims(res.Claims) + err = c.putSessionAndUser(ctx, s, u) + if err != nil { + return nil, fmt.Errorf("error saving session and user: %w", err) + } -// func CreateSessionForIncomingIDPToken( -// r *http.Request, -// options *Options, -// policy *Policy, -// getSession func(ctx context.Context, id string) (*session.Session, error), -// putSession func(ctx context.Context, s *session.Session) error)(*session.Session, error) { -// } + return s, nil +} + +func (c *incomingIDPTokenSessionCreator) createSessionForIdentityToken( + ctx context.Context, + cfg *Config, + policy *Policy, + rawIdentityToken string, +) (*session.Session, error) { + sessionID := uuid.NewSHA1(identityTokenUUIDNamespace, []byte(rawIdentityToken)).String() + s, err := c.getSession(ctx, sessionID) + if err == nil { + return s, nil + } else if !storage.IsNotFound(err) { + return nil, err + } + + idp, err := cfg.Options.GetIdentityProviderForPolicy(policy) + if err != nil { + return nil, fmt.Errorf("error getting identity provider to verify identity token: %w", err) + } + + authenticateURL, transport, err := cfg.resolveAuthenticateURL() + if err != nil { + return nil, fmt.Errorf("error resolving authenticate url to verify identity token: %w", err) + } + + res, err := authenticateapi.New(authenticateURL, transport).VerifyIdentityToken(ctx, &authenticateapi.VerifyIdentityTokenRequest{ + IdentityToken: rawIdentityToken, + IdentityProviderID: idp.GetId(), + }) + if err != nil { + return nil, fmt.Errorf("error verifying identity token: %w", err) + } else if !res.Valid { + return nil, fmt.Errorf("invalid identity token") + } + + s = c.newSessionFromIDPClaims(cfg, sessionID, res.Claims) + s.SetRawIDToken(rawIdentityToken) + u := c.newUserFromIDPClaims(res.Claims) + err = c.putSessionAndUser(ctx, s, u) + if err != nil { + return nil, fmt.Errorf("error saving session and user: %w", err) + } + + return s, nil +} + +func (c *incomingIDPTokenSessionCreator) newSessionFromIDPClaims( + cfg *Config, + sessionID string, + claims jwtutil.Claims, +) *session.Session { + now := time.Now() + s := new(session.Session) + s.Id = sessionID + if userID, ok := claims.GetUserID(); ok { + s.UserId = userID + } + if issuedAt, ok := claims.GetIssuedAt(); ok { + s.IssuedAt = timestamppb.New(issuedAt) + } else { + s.IssuedAt = timestamppb.New(now) + } + if expiresAt, ok := claims.GetExpirationTime(); ok { + s.ExpiresAt = timestamppb.New(expiresAt) + } else { + s.ExpiresAt = timestamppb.New(now.Add(cfg.Options.CookieExpire)) + } + s.AccessedAt = timestamppb.New(now) + s.AddClaims(identity.Claims(claims).Flatten()) + if aud, ok := claims.GetAudience(); ok { + s.Audience = aud + } + return s +} + +func (c *incomingIDPTokenSessionCreator) newUserFromIDPClaims( + claims jwtutil.Claims, +) *user.User { + u := new(user.User) + if userID, ok := claims.GetUserID(); ok { + u.Id = userID + } + if name, ok := claims.GetString("name"); ok { + u.Name = name + } + if email, ok := claims.GetString("email"); ok { + u.Email = email + } + u.Claims = identity.Claims(claims).Flatten().ToPB() + return u +} + +func (c *incomingIDPTokenSessionCreator) getSession(ctx context.Context, sessionID string) (*session.Session, error) { + record, err := c.getRecord(ctx, grpcutil.GetTypeURL(new(session.Session)), sessionID) + if err != nil { + return nil, err + } + + msg, err := record.GetData().UnmarshalNew() + if err != nil { + return nil, storage.ErrNotFound + } + + s, ok := msg.(*session.Session) + if !ok { + return nil, storage.ErrNotFound + } + + return s, nil +} + +func (c *incomingIDPTokenSessionCreator) putSessionAndUser(ctx context.Context, s *session.Session, u *user.User) error { + var records []*databroker.Record + if id := s.GetId(); id != "" { + records = append(records, &databroker.Record{ + Type: grpcutil.GetTypeURL(s), + Id: id, + Data: protoutil.NewAny(s), + }) + } + if id := u.GetId(); id != "" { + records = append(records, &databroker.Record{ + Type: grpcutil.GetTypeURL(u), + Id: id, + Data: protoutil.NewAny(u), + }) + } + return c.putRecords(ctx, records) +} // GetIncomingIDPAccessTokenForPolicy returns the raw idp access token from a request if there is one. -func (options *Options) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *http.Request) (rawAccessToken string, ok bool) { +func (cfg *Config) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *http.Request) (rawAccessToken string, ok bool) { bearerTokenFormat := BearerTokenFormatDefault - if options != nil && options.BearerTokenFormat != nil { - bearerTokenFormat = *options.BearerTokenFormat + if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil { + bearerTokenFormat = *cfg.Options.BearerTokenFormat } if policy != nil && policy.BearerTokenFormat != nil { bearerTokenFormat = *policy.BearerTokenFormat } - if token := r.Header.Get("X-Pomerium-IDP-Access-Token"); token != "" { + if token := r.Header.Get(httputil.HeaderPomeriumIDPAccessToken); token != "" { return token, true } - if auth := r.Header.Get("Authorization"); auth != "" { - prefix := "Pomerium-IDP-Access-Token " + if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" { + prefix := httputil.AuthorizationTypePomeriumIDPAccessToken + " " if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) { return strings.TrimPrefix(auth, prefix), true } - prefix = "Bearer Pomerium-IDP-Access-Token-" + prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPAccessToken + "-" if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) { return strings.TrimPrefix(auth, prefix), true } prefix = "Bearer " - if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) && bearerTokenFormat == BearerTokenFormatIDPAccessToken { + if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) && + bearerTokenFormat == BearerTokenFormatIDPAccessToken { return strings.TrimPrefix(auth, prefix), true } } @@ -200,32 +384,33 @@ func (options *Options) GetIncomingIDPAccessTokenForPolicy(policy *Policy, r *ht } // GetIncomingIDPAccessTokenForPolicy returns the raw idp identity token from a request if there is one. -func (options *Options) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http.Request) (rawIdentityToken string, ok bool) { +func (cfg *Config) GetIncomingIDPIdentityTokenForPolicy(policy *Policy, r *http.Request) (rawIdentityToken string, ok bool) { bearerTokenFormat := BearerTokenFormatDefault - if options != nil && options.BearerTokenFormat != nil { - bearerTokenFormat = *options.BearerTokenFormat + if cfg.Options != nil && cfg.Options.BearerTokenFormat != nil { + bearerTokenFormat = *cfg.Options.BearerTokenFormat } if policy != nil && policy.BearerTokenFormat != nil { bearerTokenFormat = *policy.BearerTokenFormat } - if token := r.Header.Get("X-Pomerium-IDP-Identity-Token"); token != "" { + if token := r.Header.Get(httputil.HeaderPomeriumIDPIdentityToken); token != "" { return token, true } - if auth := r.Header.Get("Authorization"); auth != "" { - prefix := "Pomerium-IDP-Identity-Token " + if auth := r.Header.Get(httputil.HeaderAuthorization); auth != "" { + prefix := httputil.AuthorizationTypePomeriumIDPIdentityToken + " " if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) { return strings.TrimPrefix(auth, prefix), true } - prefix = "Bearer Pomerium-IDP-Identity-Token-" + prefix = "Bearer " + httputil.AuthorizationTypePomeriumIDPIdentityToken + "-" if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) { return strings.TrimPrefix(auth, prefix), true } prefix = "Bearer " - if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) && bearerTokenFormat == BearerTokenFormatIDPIdentityToken { + if strings.HasPrefix(strings.ToLower(auth), strings.ToLower(prefix)) && + bearerTokenFormat == BearerTokenFormatIDPIdentityToken { return strings.TrimPrefix(auth, prefix), true } } diff --git a/internal/httputil/headers.go b/internal/httputil/headers.go index 355699c7f..0cf60b726 100644 --- a/internal/httputil/headers.go +++ b/internal/httputil/headers.go @@ -1,7 +1,12 @@ package httputil -// AuthorizationTypePomerium is for Authorization: Pomerium JWT... headers -const AuthorizationTypePomerium = "Pomerium" +// Pomerium authorization types +const ( + // AuthorizationTypePomerium is for Authorization: Pomerium JWT... headers + AuthorizationTypePomerium = "Pomerium" + AuthorizationTypePomeriumIDPAccessToken = "Pomerium-IDP-Access-Token" //nolint: gosec + AuthorizationTypePomeriumIDPIdentityToken = "Pomerium-IDP-Identity-Token" //nolint: gosec +) // Standard headers const ( @@ -16,7 +21,9 @@ const ( // HeaderPomeriumAuthorization is the header key for a pomerium authorization JWT. It // can be used in place of the standard authorization header if that header is being // used by upstream applications. - HeaderPomeriumAuthorization = "x-pomerium-authorization" + HeaderPomeriumAuthorization = "x-pomerium-authorization" + HeaderPomeriumIDPAccessToken = "x-pomerium-idp-access-token" //nolint: gosec + HeaderPomeriumIDPIdentityToken = "x-pomerium-idp-identity-token" //nolint: gosec // HeaderPomeriumResponse is set when pomerium itself creates a response, // as opposed to the upstream application and can be used to distinguish // between an application error, and a pomerium related error when debugging. diff --git a/internal/jwtutil/jwtutil.go b/internal/jwtutil/jwtutil.go new file mode 100644 index 000000000..7903dec2f --- /dev/null +++ b/internal/jwtutil/jwtutil.go @@ -0,0 +1,160 @@ +// Package jwtutil contains functions for working with JWTs. +package jwtutil + +import ( + "bytes" + "encoding/json" + "fmt" + "reflect" + "time" +) + +// Claims represent claims in a JWT. +type Claims map[string]any + +// UnmarshalJSON implements a custom unmarshaller for claims data. +func (claims *Claims) UnmarshalJSON(raw []byte) error { + dst := map[string]any{} + dec := json.NewDecoder(bytes.NewReader(raw)) + dec.UseNumber() + err := dec.Decode(&dst) + if err != nil { + return err + } + *claims = Claims(dst) + return nil +} + +// registered claims + +// GetIssuer gets the iss claim. +func (claims Claims) GetIssuer() (issuer string, ok bool) { + return claims.GetString("iss") +} + +// GetSubject gets the sub claim. +func (claims Claims) GetSubject() (subject string, ok bool) { + return claims.GetString("sub") +} + +// GetAudience gets the aud claim. +func (claims Claims) GetAudience() (audiences []string, ok bool) { + return claims.GetStringSlice("aud") +} + +// GetExpirationTime gets the exp claim. +func (claims Claims) GetExpirationTime() (expirationTime time.Time, ok bool) { + return claims.GetNumericDate("exp") +} + +// GetNotBefore gets the nbf claim. +func (claims Claims) GetNotBefore() (notBefore time.Time, ok bool) { + return claims.GetNumericDate("nbf") +} + +// GetIssuedAt gets the iat claim. +func (claims Claims) GetIssuedAt() (issuedAt time.Time, ok bool) { + return claims.GetNumericDate("iat") +} + +// GetJWTID gets the jti claim. +func (claims Claims) GetJWTID() (jwtID string, ok bool) { + return claims.GetString("jti") +} + +// custom claims + +// GetUserID returns the oid or sub claim. +func (claims Claims) GetUserID() (userID string, ok bool) { + if oid, ok := claims.GetString("oid"); ok { + return oid, true + } + + if sub, ok := claims.GetSubject(); ok { + return sub, true + } + + return "", false +} + +// GetNumericDate returns the claim as a numeric date. +func (claims Claims) GetNumericDate(name string) (tm time.Time, ok bool) { + if claims == nil { + return tm, false + } + + raw, ok := claims[name] + if !ok { + return tm, false + } + + switch v := raw.(type) { + case float64: + return time.Unix(int64(v), 0), true + case int64: + return time.Unix(v, 0), true + case json.Number: + i, err := v.Int64() + if err != nil { + if f, err := v.Float64(); err == nil { + i = int64(f) + } + } + if err != nil { + return tm, false + } + return time.Unix(i, 0), true + } + + return tm, false +} + +// GetString returns the claim as a string. +func (claims Claims) GetString(name string) (value string, ok bool) { + if claims == nil { + return value, false + } + + raw, ok := claims[name] + if !ok { + return value, false + } + + return toString(raw), true +} + +// GetStringSlice returns the claim as a slice of strings. +func (claims Claims) GetStringSlice(name string) (values []string, ok bool) { + if claims == nil { + return nil, false + } + + raw, ok := claims[name] + if !ok { + return nil, false + } + + return toStringSlice(raw), true +} + +func toString(data any) string { + switch v := data.(type) { + case string: + return v + } + return fmt.Sprint(data) +} + +func toStringSlice(obj any) []string { + v := reflect.ValueOf(obj) + switch v.Kind() { + case reflect.Slice: + vs := make([]string, v.Len()) + for i := 0; i < v.Len(); i++ { + vs[i] = toString(v.Index(i).Interface()) + } + return vs + } + + return []string{toString(obj)} +} diff --git a/pkg/authenticateapi/authenticateapi.go b/pkg/authenticateapi/authenticateapi.go new file mode 100644 index 000000000..deef9f822 --- /dev/null +++ b/pkg/authenticateapi/authenticateapi.go @@ -0,0 +1,109 @@ +// Package authenticateapi has the types and methods for the authenticate api. +package authenticateapi + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + + "github.com/pomerium/pomerium/internal/jwtutil" +) + +// VerifyAccessTokenRequest is used to verify access tokens. +type VerifyAccessTokenRequest struct { + AccessToken string `json:"accessToken"` + IdentityProviderID string `json:"identityProviderId,omitempty"` +} + +// VerifyIdentityTokenRequest is used to verify identity tokens. +type VerifyIdentityTokenRequest struct { + IdentityToken string `json:"identityToken"` + IdentityProviderID string `json:"identityProviderId,omitempty"` +} + +// VerifyTokenResponse is the result of verifying an access or identity token. +type VerifyTokenResponse struct { + Valid bool `json:"valid"` + Claims jwtutil.Claims `json:"claims,omitempty"` +} + +// An API is an api client for the authenticate service. +type API struct { + authenticateURL *url.URL + transport http.RoundTripper +} + +// New creates a new API client. +func New( + authenticateURL *url.URL, + transport http.RoundTripper, +) *API { + return &API{ + authenticateURL: authenticateURL, + transport: transport, + } +} + +// VerifyAccessToken verifies an access token. +func (api *API) VerifyAccessToken(ctx context.Context, request *VerifyAccessTokenRequest) (*VerifyTokenResponse, error) { + var response VerifyTokenResponse + err := api.call(ctx, "verify-access-token", request, &response) + if err != nil { + return nil, err + } + return &response, nil +} + +// VerifyIdentityToken verifies an identity token. +func (api *API) VerifyIdentityToken(ctx context.Context, request *VerifyIdentityTokenRequest) (*VerifyTokenResponse, error) { + var response VerifyTokenResponse + err := api.call(ctx, "verify-identity-token", request, &response) + if err != nil { + return nil, err + } + return &response, nil +} + +func (api *API) call( + ctx context.Context, + endpoint string, + request, response any, +) error { + u := api.authenticateURL.ResolveReference(&url.URL{ + Path: "/.pomerium/" + endpoint, + }) + + body, err := json.Marshal(request) + if err != nil { + return fmt.Errorf("error marshaling %s http request: %w", endpoint, err) + } + + req, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), bytes.NewReader(body)) + if err != nil { + return fmt.Errorf("error creating %s http request: %w", endpoint, err) + } + + res, err := (&http.Client{ + Transport: api.transport, + }).Do(req) + if err != nil { + return fmt.Errorf("error executing %s http request: %w", endpoint, err) + } + defer res.Body.Close() + + body, err = io.ReadAll(res.Body) + if err != nil { + return fmt.Errorf("error reading %s http response: %w", endpoint, err) + } + + err = json.Unmarshal(body, &response) + if err != nil { + return fmt.Errorf("error reading %s http response (body=%s): %w", endpoint, body, err) + } + + return nil +} diff --git a/pkg/identity/oidc/azure/microsoft.go b/pkg/identity/oidc/azure/microsoft.go index 4b98fee56..f5339a895 100644 --- a/pkg/identity/oidc/azure/microsoft.go +++ b/pkg/identity/oidc/azure/microsoft.go @@ -17,6 +17,7 @@ import ( "github.com/google/uuid" "golang.org/x/oauth2" + "github.com/pomerium/pomerium/internal/jwtutil" "github.com/pomerium/pomerium/pkg/identity/oauth" pom_oidc "github.com/pomerium/pomerium/pkg/identity/oidc" ) @@ -95,7 +96,7 @@ func (p *Provider) VerifyAccessToken(ctx context.Context, rawAccessToken string) return nil, fmt.Errorf("error verifying access token: %w", err) } - claims = map[string]any{} + claims = jwtutil.Claims(map[string]any{}) err = token.Claims(&claims) if err != nil { return nil, fmt.Errorf("error unmarshaling access token claims: %w", err)