authenticate: get/set identity provider id for all sessions (#3597)

This commit is contained in:
Caleb Doxsey 2022-09-07 10:06:59 -06:00 committed by GitHub
parent 8d7db85737
commit bdd6145e91
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 64 additions and 31 deletions

View file

@ -138,20 +138,25 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
defer span.End() defer span.End()
state := a.state.Load() state := a.state.Load()
idpID := r.FormValue(urlutil.QueryIdentityProviderID) options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
sessionState, err := a.getSessionFromCtx(ctx) sessionState, err := a.getSessionFromCtx(ctx)
if err != nil { if err != nil {
log.FromRequest(r).Info(). log.FromRequest(r).Info().
Err(err). Err(err).
Str("idp_id", idpID). Str("idp_id", idp.GetId()).
Msg("authenticate: session load error") Msg("authenticate: session load error")
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
} }
if sessionState.IdentityProviderID != idpID { if sessionState.IdentityProviderID != idp.GetId() {
log.FromRequest(r).Info(). log.FromRequest(r).Info().
Str("idp_id", idpID). Str("idp_id", idp.GetId()).
Str("session_idp_id", sessionState.IdentityProviderID).
Str("id", sessionState.ID). Str("id", sessionState.ID).
Msg("authenticate: session not associated with identity provider") Msg("authenticate: session not associated with identity provider")
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
@ -163,7 +168,7 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil { if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil {
log.FromRequest(r).Info(). log.FromRequest(r).Info().
Err(err). Err(err).
Str("idp_id", idpID). Str("idp_id", idp.GetId()).
Str("id", sessionState.ID). Str("id", sessionState.ID).
Msg("authenticate: session not found in databroker") Msg("authenticate: session not found in databroker")
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
@ -187,6 +192,11 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
defer span.End() defer span.End()
state := a.state.Load() state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil { if err != nil {
@ -216,8 +226,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
} }
// start over if this is a different identity provider // start over if this is a different identity provider
if s == nil || s.IdentityProviderID != r.FormValue(urlutil.QueryIdentityProviderID) { if s == nil || s.IdentityProviderID != idp.GetId() {
s = sessions.NewState(urlutil.QueryIdentityProviderID) s = sessions.NewState(idp.GetId())
} }
newSession := s.WithNewIssuer(state.redirectURL.Host, jwtAudience) newSession := s.WithNewIssuer(state.redirectURL.Host, jwtAudience)
@ -276,8 +286,12 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
defer span.End() defer span.End()
options := a.options.Load() options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil { if err != nil {
return err return err
} }
@ -285,7 +299,7 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
rawIDToken := a.revokeSession(ctx, w, r) rawIDToken := a.revokeSession(ctx, w, r)
redirectString := "" redirectString := ""
signOutURL, err := a.options.Load().GetSignOutRedirectURL() signOutURL, err := options.GetSignOutRedirectURL()
if err != nil { if err != nil {
return err return err
} }
@ -296,14 +310,14 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
redirectString = uri redirectString = uri
} }
endSessionURL, err := idp.LogOut() endSessionURL, err := authenticator.LogOut()
if err == nil && redirectString != "" { if err == nil && redirectString != "" {
params := url.Values{} params := url.Values{}
params.Add("id_token_hint", rawIDToken) params.Add("id_token_hint", rawIDToken)
params.Add("post_logout_redirect_uri", redirectString) params.Add("post_logout_redirect_uri", redirectString)
endSessionURL.RawQuery = params.Encode() endSessionURL.RawQuery = params.Encode()
redirectString = endSessionURL.String() redirectString = endSessionURL.String()
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) { } else if err != nil && !errors.Is(err, oidc.ErrSignoutNotImplemented) {
log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session") log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session")
} }
if redirectString != "" { if redirectString != "" {
@ -330,10 +344,14 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
return httputil.NewError(http.StatusUnauthorized, err) return httputil.NewError(http.StatusUnauthorized, err)
} }
options := a.options.Load()
state := a.state.Load() state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil { if err != nil {
return err return err
} }
@ -346,7 +364,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b) enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b)
b = append(b, enc...) b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b) encodedState := base64.URLEncoding.EncodeToString(b)
signinURL, err := idp.GetSignInURL(encodedState) signinURL, err := authenticator.GetSignInURL(encodedState)
if err != nil { if err != nil {
return httputil.NewError(http.StatusInternalServerError, return httputil.NewError(http.StatusInternalServerError,
fmt.Errorf("failed to get sign in url: %w", err)) fmt.Errorf("failed to get sign in url: %w", err))
@ -381,8 +399,8 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback") ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
defer span.End() defer span.End()
options := a.options.Load()
state := a.state.Load() state := a.state.Load()
options := a.options.Load()
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6 // Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
// //
@ -428,9 +446,13 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
if err != nil { if err != nil {
return nil, httputil.NewError(http.StatusBadRequest, err) return nil, httputil.NewError(http.StatusBadRequest, err)
} }
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)
idp, err := a.cfg.getIdentityProvider(options, idpID) idp, err := options.GetIdentityProviderForID(redirectURL.Query().Get(urlutil.QueryIdentityProviderID))
if err != nil {
return nil, err
}
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -439,12 +461,12 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
// //
// Exchange the supplied Authorization Code for a valid user session. // Exchange the supplied Authorization Code for a valid user session.
var claims identity.SessionClaims var claims identity.SessionClaims
accessToken, err := idp.Authenticate(ctx, code, &claims) accessToken, err := authenticator.Authenticate(ctx, code, &claims)
if err != nil { if err != nil {
return nil, fmt.Errorf("error redeeming authenticate code: %w", err) return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
} }
s := sessions.NewState(idpID) s := sessions.NewState(idp.GetId())
err = claims.Claims.Claims(&s) err = claims.Claims.Claims(&s)
if err != nil { if err != nil {
return nil, fmt.Errorf("error unmarshaling session state: %w", err) return nil, fmt.Errorf("error unmarshaling session state: %w", err)
@ -582,8 +604,12 @@ func (a *Authenticate) saveSessionToDataBroker(
) error { ) error {
state := a.state.Load() state := a.state.Load()
options := a.options.Load() options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil { if err != nil {
return err return err
} }
@ -593,7 +619,7 @@ func (a *Authenticate) saveSessionToDataBroker(
s := &session.Session{ s := &session.Session{
Id: sessionState.ID, Id: sessionState.ID,
UserId: sessionState.UserID(idp.Name()), UserId: sessionState.UserID(authenticator.Name()),
IssuedAt: timestamppb.Now(), IssuedAt: timestamppb.Now(),
AccessedAt: timestamppb.Now(), AccessedAt: timestamppb.Now(),
ExpiresAt: sessionExpiry, ExpiresAt: sessionExpiry,
@ -617,7 +643,7 @@ func (a *Authenticate) saveSessionToDataBroker(
Id: s.GetUserId(), Id: s.GetUserId(),
} }
} }
err = idp.UpdateUserInfo(ctx, accessToken, &managerUser) err = authenticator.UpdateUserInfo(ctx, accessToken, &managerUser)
if err != nil { if err != nil {
return fmt.Errorf("authenticate: error retrieving user info: %w", err) return fmt.Errorf("authenticate: error retrieving user info: %w", err)
} }
@ -648,13 +674,18 @@ func (a *Authenticate) saveSessionToDataBroker(
// databroker. If successful, it returns the original `id_token` of the session, if failed, returns // databroker. If successful, it returns the original `id_token` of the session, if failed, returns
// and empty string. // and empty string.
func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string { func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string {
options := a.options.Load()
state := a.state.Load() state := a.state.Load()
options := a.options.Load()
// clear the user's local session no matter what // clear the user's local session no matter what
defer state.sessionStore.ClearSession(w, r) defer state.sessionStore.ClearSession(w, r)
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return ""
}
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
if err != nil { if err != nil {
return "" return ""
} }
@ -667,7 +698,7 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,
if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil { if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil {
rawIDToken = s.GetIdToken().GetRaw() rawIDToken = s.GetIdToken().GetRaw()
if err := idp.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil { if err := authenticator.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil {
log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token") log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token")
} }
} }

View file

@ -478,6 +478,8 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
idp, _ := new(config.Options).GetIdentityProviderForID("")
tests := []struct { tests := []struct {
name string name string
headers map[string]string headers map[string]string
@ -491,7 +493,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
{ {
"good", "good",
nil, nil,
&mstore.Store{Session: &sessions.State{ID: "xyz"}}, &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
nil, nil,
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
http.StatusOK, http.StatusOK,
@ -499,7 +501,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
{ {
"invalid session", "invalid session",
nil, nil,
&mstore.Store{Session: &sessions.State{ID: "xyz"}}, &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
errors.New("hi"), errors.New("hi"),
identity.MockProvider{}, identity.MockProvider{},
http.StatusFound, http.StatusFound,
@ -507,7 +509,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
{ {
"good refresh expired", "good refresh expired",
nil, nil,
&mstore.Store{Session: &sessions.State{ID: "xyz"}}, &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
nil, nil,
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
http.StatusOK, http.StatusOK,
@ -515,7 +517,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
{ {
"expired,refresh error", "expired,refresh error",
nil, nil,
&mstore.Store{Session: &sessions.State{ID: "xyz"}}, &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
sessions.ErrExpired, sessions.ErrExpired,
identity.MockProvider{RefreshError: errors.New("error")}, identity.MockProvider{RefreshError: errors.New("error")},
http.StatusFound, http.StatusFound,
@ -523,7 +525,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
{ {
"expired,save error", "expired,save error",
nil, nil,
&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{ID: "xyz"}}, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
sessions.ErrExpired, sessions.ErrExpired,
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
http.StatusFound, http.StatusFound,
@ -531,7 +533,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
{ {
"expired XHR,refresh error", "expired XHR,refresh error",
map[string]string{"X-Requested-With": "XmlHttpRequest"}, map[string]string{"X-Requested-With": "XmlHttpRequest"},
&mstore.Store{Session: &sessions.State{ID: "xyz"}}, &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
sessions.ErrExpired, sessions.ErrExpired,
identity.MockProvider{RefreshError: errors.New("error")}, identity.MockProvider{RefreshError: errors.New("error")},
http.StatusUnauthorized, http.StatusUnauthorized,