authenticate: always trust the passed in idp (#3931)

authenticate: always trust the passed in idp (#3917)

Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com>
This commit is contained in:
backport-actions-token[bot] 2023-01-30 19:06:35 -07:00 committed by GitHub
parent cc475a3985
commit 3ba74b38ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 43 deletions

View file

@ -118,24 +118,20 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
defer span.End()
state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
sessionState, err := a.getSessionFromCtx(ctx)
if err != nil {
log.FromRequest(r).Info().
Err(err).
Str("idp_id", idp.GetId()).
Str("idp_id", idpID).
Msg("authenticate: session load error")
return a.reauthenticateOrFail(w, r, err)
}
if sessionState.IdentityProviderID != idp.GetId() {
if sessionState.IdentityProviderID != idpID {
log.FromRequest(r).Info().
Str("idp_id", idp.GetId()).
Str("idp_id", idpID).
Str("session_idp_id", sessionState.IdentityProviderID).
Str("id", sessionState.ID).
Msg("authenticate: session not associated with identity provider")
@ -146,7 +142,7 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
if err != nil {
log.FromRequest(r).Info().
Err(err).
Str("idp_id", idp.GetId()).
Str("idp_id", idpID).
Msg("authenticate: identity profile load error")
return a.reauthenticateOrFail(w, r, err)
}
@ -169,7 +165,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
defer span.End()
state := a.state.Load()
options := a.options.Load()
if err := r.ParseForm(); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
@ -179,10 +174,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
return err
}
idp, err := options.GetIdentityProviderForID(requestParams.Get(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idpID := requestParams.Get(urlutil.QueryIdentityProviderID)
s, err := a.getSessionFromCtx(ctx)
if err != nil {
@ -191,8 +183,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
}
// start over if this is a different identity provider
if s == nil || s.IdentityProviderID != idp.GetId() {
s = sessions.NewState(idp.GetId())
if s == nil || s.IdentityProviderID != idpID {
s = sessions.NewState(idpID)
}
// re-persist the session, useful when session was evicted from session
@ -240,12 +232,9 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
defer span.End()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return err
}
@ -300,12 +289,9 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
state := a.state.Load()
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return err
}
@ -407,12 +393,9 @@ Or contact your administrator.
`, redirectURL.String(), redirectURL.String()))
}
idp, err := options.GetIdentityProviderForID(redirectURL.Query().Get(urlutil.QueryIdentityProviderID))
if err != nil {
return nil, err
}
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return nil, err
}
@ -426,7 +409,7 @@ Or contact your administrator.
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
}
s := sessions.NewState(idp.GetId())
s := sessions.NewState(idpID)
err = claims.Claims.Claims(&s)
if err != nil {
return nil, fmt.Errorf("error unmarshaling session state: %w", err)
@ -522,12 +505,9 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,
// clear the user's local session no matter what
defer state.sessionStore.ClearSession(w, r)
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return ""
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return ""
}

View file

@ -31,12 +31,9 @@ func (a *Authenticate) buildIdentityProfile(
oauthToken *oauth2.Token,
) (*identitypb.Profile, error) {
options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return nil, fmt.Errorf("authenticate: error getting identity provider for id: %w", err)
}
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
return nil, fmt.Errorf("authenticate: error getting identity provider authenticator: %w", err)
}
@ -57,7 +54,7 @@ func (a *Authenticate) buildIdentityProfile(
}
return &identitypb.Profile{
ProviderId: idp.GetId(),
ProviderId: idpID,
IdToken: rawIDToken,
OauthToken: rawOAuthToken,
Claims: rawClaims,