authenticate: always trust the passed in idp (#3931)

authenticate: always trust the passed in idp (#3917)

Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com>
This commit is contained in:
backport-actions-token[bot] 2023-01-30 19:06:35 -07:00 committed by GitHub
parent cc475a3985
commit 3ba74b38ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 43 deletions

View file

@ -118,24 +118,20 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
defer span.End() defer span.End()
state := a.state.Load() state := a.state.Load()
options := a.options.Load() idpID := r.FormValue(urlutil.QueryIdentityProviderID)
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
sessionState, err := a.getSessionFromCtx(ctx) sessionState, err := a.getSessionFromCtx(ctx)
if err != nil { if err != nil {
log.FromRequest(r).Info(). log.FromRequest(r).Info().
Err(err). Err(err).
Str("idp_id", idp.GetId()). Str("idp_id", idpID).
Msg("authenticate: session load error") Msg("authenticate: session load error")
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
} }
if sessionState.IdentityProviderID != idp.GetId() { if sessionState.IdentityProviderID != idpID {
log.FromRequest(r).Info(). log.FromRequest(r).Info().
Str("idp_id", idp.GetId()). Str("idp_id", idpID).
Str("session_idp_id", sessionState.IdentityProviderID). Str("session_idp_id", sessionState.IdentityProviderID).
Str("id", sessionState.ID). Str("id", sessionState.ID).
Msg("authenticate: session not associated with identity provider") Msg("authenticate: session not associated with identity provider")
@ -146,7 +142,7 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
if err != nil { if err != nil {
log.FromRequest(r).Info(). log.FromRequest(r).Info().
Err(err). Err(err).
Str("idp_id", idp.GetId()). Str("idp_id", idpID).
Msg("authenticate: identity profile load error") Msg("authenticate: identity profile load error")
return a.reauthenticateOrFail(w, r, err) return a.reauthenticateOrFail(w, r, err)
} }
@ -169,7 +165,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
defer span.End() defer span.End()
state := a.state.Load() state := a.state.Load()
options := a.options.Load()
if err := r.ParseForm(); err != nil { if err := r.ParseForm(); err != nil {
return httputil.NewError(http.StatusBadRequest, err) return httputil.NewError(http.StatusBadRequest, err)
@ -179,10 +174,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
return err return err
} }
idp, err := options.GetIdentityProviderForID(requestParams.Get(urlutil.QueryIdentityProviderID)) idpID := requestParams.Get(urlutil.QueryIdentityProviderID)
if err != nil {
return err
}
s, err := a.getSessionFromCtx(ctx) s, err := a.getSessionFromCtx(ctx)
if err != nil { if err != nil {
@ -191,8 +183,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
} }
// start over if this is a different identity provider // start over if this is a different identity provider
if s == nil || s.IdentityProviderID != idp.GetId() { if s == nil || s.IdentityProviderID != idpID {
s = sessions.NewState(idp.GetId()) s = sessions.NewState(idpID)
} }
// re-persist the session, useful when session was evicted from session // re-persist the session, useful when session was evicted from session
@ -240,12 +232,9 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
defer span.End() defer span.End()
options := a.options.Load() options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) idpID := r.FormValue(urlutil.QueryIdentityProviderID)
if err != nil {
return err
}
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil { if err != nil {
return err return err
} }
@ -300,12 +289,9 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
state := a.state.Load() state := a.state.Load()
options := a.options.Load() options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) idpID := r.FormValue(urlutil.QueryIdentityProviderID)
if err != nil {
return err
}
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil { if err != nil {
return err return err
} }
@ -407,12 +393,9 @@ Or contact your administrator.
`, redirectURL.String(), redirectURL.String())) `, redirectURL.String(), redirectURL.String()))
} }
idp, err := options.GetIdentityProviderForID(redirectURL.Query().Get(urlutil.QueryIdentityProviderID)) idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)
if err != nil {
return nil, err
}
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -426,7 +409,7 @@ Or contact your administrator.
return nil, fmt.Errorf("error redeeming authenticate code: %w", err) return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
} }
s := sessions.NewState(idp.GetId()) s := sessions.NewState(idpID)
err = claims.Claims.Claims(&s) err = claims.Claims.Claims(&s)
if err != nil { if err != nil {
return nil, fmt.Errorf("error unmarshaling session state: %w", err) return nil, fmt.Errorf("error unmarshaling session state: %w", err)
@ -522,12 +505,9 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,
// clear the user's local session no matter what // clear the user's local session no matter what
defer state.sessionStore.ClearSession(w, r) defer state.sessionStore.ClearSession(w, r)
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) idpID := r.FormValue(urlutil.QueryIdentityProviderID)
if err != nil {
return ""
}
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil { if err != nil {
return "" return ""
} }

View file

@ -31,12 +31,9 @@ func (a *Authenticate) buildIdentityProfile(
oauthToken *oauth2.Token, oauthToken *oauth2.Token,
) (*identitypb.Profile, error) { ) (*identitypb.Profile, error) {
options := a.options.Load() options := a.options.Load()
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID)) idpID := r.FormValue(urlutil.QueryIdentityProviderID)
if err != nil {
return nil, fmt.Errorf("authenticate: error getting identity provider for id: %w", err)
}
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId()) authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil { if err != nil {
return nil, fmt.Errorf("authenticate: error getting identity provider authenticator: %w", err) return nil, fmt.Errorf("authenticate: error getting identity provider authenticator: %w", err)
} }
@ -57,7 +54,7 @@ func (a *Authenticate) buildIdentityProfile(
} }
return &identitypb.Profile{ return &identitypb.Profile{
ProviderId: idp.GetId(), ProviderId: idpID,
IdToken: rawIDToken, IdToken: rawIDToken,
OauthToken: rawOAuthToken, OauthToken: rawOAuthToken,
Claims: rawClaims, Claims: rawClaims,