authenticate: get/set identity provider id for all sessions (#3597)

This commit is contained in:
Caleb Doxsey 2022-09-07 10:06:59 -06:00 committed by GitHub
parent 8d7db85737
commit bdd6145e91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 31 deletions

View file

@ -138,20 +138,25 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
defer span.End()
state := a.state.Load()
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
sessionState, err := a.getSessionFromCtx(ctx)
if err != nil {
log.FromRequest(r).Info().
Err(err).
Str("idp_id", idpID).
Str("idp_id", idp.GetId()).
Msg("authenticate: session load error")
return a.reauthenticateOrFail(w, r, err)
}
if sessionState.IdentityProviderID != idpID {
if sessionState.IdentityProviderID != idp.GetId() {
log.FromRequest(r).Info().
Str("idp_id", idpID).
Str("idp_id", idp.GetId()).
Str("session_idp_id", sessionState.IdentityProviderID).
Str("id", sessionState.ID).
Msg("authenticate: session not associated with identity provider")
return a.reauthenticateOrFail(w, r, err)
@ -163,7 +168,7 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil {
log.FromRequest(r).Info().
Err(err).
Str("idp_id", idpID).
Str("idp_id", idp.GetId()).
Str("id", sessionState.ID).
Msg("authenticate: session not found in databroker")
return a.reauthenticateOrFail(w, r, err)
@ -187,6 +192,11 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
defer span.End()
state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil {
@ -216,8 +226,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
}
// start over if this is a different identity provider
if s == nil || s.IdentityProviderID != r.FormValue(urlutil.QueryIdentityProviderID) {
s = sessions.NewState(urlutil.QueryIdentityProviderID)
if s == nil || s.IdentityProviderID != idp.GetId() {
s = sessions.NewState(idp.GetId())
}
newSession := s.WithNewIssuer(state.redirectURL.Host, jwtAudience)
@ -276,8 +286,12 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
defer span.End()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil {
return err
}
@ -285,7 +299,7 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
rawIDToken := a.revokeSession(ctx, w, r)
redirectString := ""
signOutURL, err := a.options.Load().GetSignOutRedirectURL()
signOutURL, err := options.GetSignOutRedirectURL()
if err != nil {
return err
}
@ -296,14 +310,14 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
redirectString = uri
}
endSessionURL, err := idp.LogOut()
endSessionURL, err := authenticator.LogOut()
if err == nil && redirectString != "" {
params := url.Values{}
params.Add("id_token_hint", rawIDToken)
params.Add("post_logout_redirect_uri", redirectString)
endSessionURL.RawQuery = params.Encode()
redirectString = endSessionURL.String()
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
} else if err != nil && !errors.Is(err, oidc.ErrSignoutNotImplemented) {
log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session")
}
if redirectString != "" {
@ -330,10 +344,14 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
return httputil.NewError(http.StatusUnauthorized, err)
}
options := a.options.Load()
state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil {
return err
}
@ -346,7 +364,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b)
b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b)
signinURL, err := idp.GetSignInURL(encodedState)
signinURL, err := authenticator.GetSignInURL(encodedState)
if err != nil {
return httputil.NewError(http.StatusInternalServerError,
fmt.Errorf("failed to get sign in url: %w", err))
@ -381,8 +399,8 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
defer span.End()
options := a.options.Load()
state := a.state.Load()
options := a.options.Load()
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
//
@ -428,9 +446,13 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, err)
}
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)
idp, err := a.cfg.getIdentityProvider(options, idpID)
idp, err := options.GetIdentityProviderForID(redirectURL.Query().Get(urlutil.QueryIdentityProviderID))
if err != nil {
return nil, err
}
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil {
return nil, err
}
@ -439,12 +461,12 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
//
// Exchange the supplied Authorization Code for a valid user session.
var claims identity.SessionClaims
accessToken, err := idp.Authenticate(ctx, code, &claims)
accessToken, err := authenticator.Authenticate(ctx, code, &claims)
if err != nil {
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
}
s := sessions.NewState(idpID)
s := sessions.NewState(idp.GetId())
err = claims.Claims.Claims(&s)
if err != nil {
return nil, fmt.Errorf("error unmarshaling session state: %w", err)
@ -582,8 +604,12 @@ func (a *Authenticate) saveSessionToDataBroker(
) error {
state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil {
return err
}
@ -593,7 +619,7 @@ func (a *Authenticate) saveSessionToDataBroker(
s := &session.Session{
Id: sessionState.ID,
UserId: sessionState.UserID(idp.Name()),
UserId: sessionState.UserID(authenticator.Name()),
IssuedAt: timestamppb.Now(),
AccessedAt: timestamppb.Now(),
ExpiresAt: sessionExpiry,
@ -617,7 +643,7 @@ func (a *Authenticate) saveSessionToDataBroker(
Id: s.GetUserId(),
}
}
err = idp.UpdateUserInfo(ctx, accessToken, &managerUser)
err = authenticator.UpdateUserInfo(ctx, accessToken, &managerUser)
if err != nil {
return fmt.Errorf("authenticate: error retrieving user info: %w", err)
}
@ -648,13 +674,18 @@ func (a *Authenticate) saveSessionToDataBroker(
// databroker. If successful, it returns the original `id_token` of the session, if failed, returns
// and empty string.
func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string {
options := a.options.Load()
state := a.state.Load()
options := a.options.Load()
// clear the user's local session no matter what
defer state.sessionStore.ClearSession(w, r)
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return ""
}
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil {
return ""
}
@ -667,7 +698,7 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,
if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil {
rawIDToken = s.GetIdToken().GetRaw()
if err := idp.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil {
if err := authenticator.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil {
log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token")
}
}

View file

@ -478,6 +478,8 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
w.WriteHeader(http.StatusOK)
})
idp, _ := new(config.Options).GetIdentityProviderForID("")
tests := []struct {
name string
headers map[string]string
@ -491,7 +493,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
{
"good",
nil,
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
nil,
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
http.StatusOK,
@ -499,7 +501,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
{
"invalid session",
nil,
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
errors.New("hi"),
identity.MockProvider{},
http.StatusFound,
@ -507,7 +509,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
{
"good refresh expired",
nil,
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
nil,
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
http.StatusOK,
@ -515,7 +517,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
{
"expired,refresh error",
nil,
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
sessions.ErrExpired,
identity.MockProvider{RefreshError: errors.New("error")},
http.StatusFound,
@ -523,7 +525,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
{
"expired,save error",
nil,
&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{ID: "xyz"}},
&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
sessions.ErrExpired,
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
http.StatusFound,
@ -531,7 +533,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
{
"expired XHR,refresh error",
map[string]string{"X-Requested-With": "XmlHttpRequest"},
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
sessions.ErrExpired,
identity.MockProvider{RefreshError: errors.New("error")},
http.StatusUnauthorized,