From c0a8870717ffc02f3083cdb191faa43c74077288 Mon Sep 17 00:00:00 2001 From: "backport-actions-token[bot]" <87506591+backport-actions-token[bot]@users.noreply.github.com> Date: Wed, 7 Sep 2022 10:16:40 -0600 Subject: [PATCH] authenticate: get/set identity provider id for all sessions (#3608) authenticate: get/set identity provider id for all sessions (#3597) Co-authored-by: Caleb Doxsey --- authenticate/handlers.go | 81 ++++++++++++++++++++++++----------- authenticate/handlers_test.go | 14 +++--- 2 files changed, 64 insertions(+), 31 deletions(-) diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 0102107b5..6e9420826 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -138,20 +138,25 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { defer span.End() state := a.state.Load() - idpID := r.FormValue(urlutil.QueryIdentityProviderID) + options := a.options.Load() + idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) + if err != nil { + return err + } sessionState, err := a.getSessionFromCtx(ctx) if err != nil { log.FromRequest(r).Info(). Err(err). - Str("idp_id", idpID). + Str("idp_id", idp.GetId()). Msg("authenticate: session load error") return a.reauthenticateOrFail(w, r, err) } - if sessionState.IdentityProviderID != idpID { + if sessionState.IdentityProviderID != idp.GetId() { log.FromRequest(r).Info(). - Str("idp_id", idpID). + Str("idp_id", idp.GetId()). + Str("session_idp_id", sessionState.IdentityProviderID). Str("id", sessionState.ID). Msg("authenticate: session not associated with identity provider") return a.reauthenticateOrFail(w, r, err) @@ -163,7 +168,7 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler { if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil { log.FromRequest(r).Info(). Err(err). - Str("idp_id", idpID). + Str("idp_id", idp.GetId()). Str("id", sessionState.ID). Msg("authenticate: session not found in databroker") return a.reauthenticateOrFail(w, r, err) @@ -187,6 +192,11 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { defer span.End() state := a.state.Load() + options := a.options.Load() + idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) + if err != nil { + return err + } redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) if err != nil { @@ -216,8 +226,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { } // start over if this is a different identity provider - if s == nil || s.IdentityProviderID != r.FormValue(urlutil.QueryIdentityProviderID) { - s = sessions.NewState(urlutil.QueryIdentityProviderID) + if s == nil || s.IdentityProviderID != idp.GetId() { + s = sessions.NewState(idp.GetId()) } newSession := s.WithNewIssuer(state.redirectURL.Host, jwtAudience) @@ -276,8 +286,12 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e defer span.End() options := a.options.Load() + idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) + if err != nil { + return err + } - idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) + authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) if err != nil { return err } @@ -285,7 +299,7 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e rawIDToken := a.revokeSession(ctx, w, r) redirectString := "" - signOutURL, err := a.options.Load().GetSignOutRedirectURL() + signOutURL, err := options.GetSignOutRedirectURL() if err != nil { return err } @@ -296,14 +310,14 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e redirectString = uri } - endSessionURL, err := idp.LogOut() + endSessionURL, err := authenticator.LogOut() if err == nil && redirectString != "" { params := url.Values{} params.Add("id_token_hint", rawIDToken) params.Add("post_logout_redirect_uri", redirectString) endSessionURL.RawQuery = params.Encode() redirectString = endSessionURL.String() - } else if !errors.Is(err, oidc.ErrSignoutNotImplemented) { + } else if err != nil && !errors.Is(err, oidc.ErrSignoutNotImplemented) { log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session") } if redirectString != "" { @@ -330,10 +344,14 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque return httputil.NewError(http.StatusUnauthorized, err) } - options := a.options.Load() state := a.state.Load() + options := a.options.Load() + idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) + if err != nil { + return err + } - idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) + authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) if err != nil { return err } @@ -346,7 +364,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b) b = append(b, enc...) encodedState := base64.URLEncoding.EncodeToString(b) - signinURL, err := idp.GetSignInURL(encodedState) + signinURL, err := authenticator.GetSignInURL(encodedState) if err != nil { return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("failed to get sign in url: %w", err)) @@ -381,8 +399,8 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback") defer span.End() - options := a.options.Load() state := a.state.Load() + options := a.options.Load() // Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6 // @@ -428,9 +446,13 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) if err != nil { return nil, httputil.NewError(http.StatusBadRequest, err) } - idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID) - idp, err := a.cfg.getIdentityProvider(options, idpID) + idp, err := options.GetIdentityProviderForID(redirectURL.Query().Get(urlutil.QueryIdentityProviderID)) + if err != nil { + return nil, err + } + + authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) if err != nil { return nil, err } @@ -439,12 +461,12 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) // // Exchange the supplied Authorization Code for a valid user session. var claims identity.SessionClaims - accessToken, err := idp.Authenticate(ctx, code, &claims) + accessToken, err := authenticator.Authenticate(ctx, code, &claims) if err != nil { return nil, fmt.Errorf("error redeeming authenticate code: %w", err) } - s := sessions.NewState(idpID) + s := sessions.NewState(idp.GetId()) err = claims.Claims.Claims(&s) if err != nil { return nil, fmt.Errorf("error unmarshaling session state: %w", err) @@ -582,8 +604,12 @@ func (a *Authenticate) saveSessionToDataBroker( ) error { state := a.state.Load() options := a.options.Load() + idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) + if err != nil { + return err + } - idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) + authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) if err != nil { return err } @@ -593,7 +619,7 @@ func (a *Authenticate) saveSessionToDataBroker( s := &session.Session{ Id: sessionState.ID, - UserId: sessionState.UserID(idp.Name()), + UserId: sessionState.UserID(authenticator.Name()), IssuedAt: timestamppb.Now(), AccessedAt: timestamppb.Now(), ExpiresAt: sessionExpiry, @@ -617,7 +643,7 @@ func (a *Authenticate) saveSessionToDataBroker( Id: s.GetUserId(), } } - err = idp.UpdateUserInfo(ctx, accessToken, &managerUser) + err = authenticator.UpdateUserInfo(ctx, accessToken, &managerUser) if err != nil { return fmt.Errorf("authenticate: error retrieving user info: %w", err) } @@ -648,13 +674,18 @@ func (a *Authenticate) saveSessionToDataBroker( // databroker. If successful, it returns the original `id_token` of the session, if failed, returns // and empty string. func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string { - options := a.options.Load() state := a.state.Load() + options := a.options.Load() // clear the user's local session no matter what defer state.sessionStore.ClearSession(w, r) - idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) + idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) + if err != nil { + return "" + } + + authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) if err != nil { return "" } @@ -667,7 +698,7 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil { rawIDToken = s.GetIdToken().GetRaw() - if err := idp.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil { + if err := authenticator.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil { log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token") } } diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 54a8c7795..978e0a5c0 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -478,6 +478,8 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { w.WriteHeader(http.StatusOK) }) + idp, _ := new(config.Options).GetIdentityProviderForID("") + tests := []struct { name string headers map[string]string @@ -491,7 +493,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { { "good", nil, - &mstore.Store{Session: &sessions.State{ID: "xyz"}}, + &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusOK, @@ -499,7 +501,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { { "invalid session", nil, - &mstore.Store{Session: &sessions.State{ID: "xyz"}}, + &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound, @@ -507,7 +509,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { { "good refresh expired", nil, - &mstore.Store{Session: &sessions.State{ID: "xyz"}}, + &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, nil, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusOK, @@ -515,7 +517,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { { "expired,refresh error", nil, - &mstore.Store{Session: &sessions.State{ID: "xyz"}}, + &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound, @@ -523,7 +525,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { { "expired,save error", nil, - &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{ID: "xyz"}}, + &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, http.StatusFound, @@ -531,7 +533,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { { "expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, - &mstore.Store{Session: &sessions.State{ID: "xyz"}}, + &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized,