mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-21 21:17:13 +02:00
authenticate: always trust the passed in idp (#3931)
authenticate: always trust the passed in idp (#3917) Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com>
This commit is contained in:
parent
cc475a3985
commit
3ba74b38ae
2 changed files with 20 additions and 43 deletions
|
@ -118,24 +118,20 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
|||
defer span.End()
|
||||
|
||||
state := a.state.Load()
|
||||
options := a.options.Load()
|
||||
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
|
||||
|
||||
sessionState, err := a.getSessionFromCtx(ctx)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Info().
|
||||
Err(err).
|
||||
Str("idp_id", idp.GetId()).
|
||||
Str("idp_id", idpID).
|
||||
Msg("authenticate: session load error")
|
||||
return a.reauthenticateOrFail(w, r, err)
|
||||
}
|
||||
|
||||
if sessionState.IdentityProviderID != idp.GetId() {
|
||||
if sessionState.IdentityProviderID != idpID {
|
||||
log.FromRequest(r).Info().
|
||||
Str("idp_id", idp.GetId()).
|
||||
Str("idp_id", idpID).
|
||||
Str("session_idp_id", sessionState.IdentityProviderID).
|
||||
Str("id", sessionState.ID).
|
||||
Msg("authenticate: session not associated with identity provider")
|
||||
|
@ -146,7 +142,7 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
|||
if err != nil {
|
||||
log.FromRequest(r).Info().
|
||||
Err(err).
|
||||
Str("idp_id", idp.GetId()).
|
||||
Str("idp_id", idpID).
|
||||
Msg("authenticate: identity profile load error")
|
||||
return a.reauthenticateOrFail(w, r, err)
|
||||
}
|
||||
|
@ -169,7 +165,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
defer span.End()
|
||||
|
||||
state := a.state.Load()
|
||||
options := a.options.Load()
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
|
@ -179,10 +174,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
return err
|
||||
}
|
||||
|
||||
idp, err := options.GetIdentityProviderForID(requestParams.Get(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
idpID := requestParams.Get(urlutil.QueryIdentityProviderID)
|
||||
|
||||
s, err := a.getSessionFromCtx(ctx)
|
||||
if err != nil {
|
||||
|
@ -191,8 +183,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
}
|
||||
|
||||
// start over if this is a different identity provider
|
||||
if s == nil || s.IdentityProviderID != idp.GetId() {
|
||||
s = sessions.NewState(idp.GetId())
|
||||
if s == nil || s.IdentityProviderID != idpID {
|
||||
s = sessions.NewState(idpID)
|
||||
}
|
||||
|
||||
// re-persist the session, useful when session was evicted from session
|
||||
|
@ -240,12 +232,9 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
|
|||
defer span.End()
|
||||
|
||||
options := a.options.Load()
|
||||
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
|
||||
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -300,12 +289,9 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
|
|||
|
||||
state := a.state.Load()
|
||||
options := a.options.Load()
|
||||
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
|
||||
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -407,12 +393,9 @@ Or contact your administrator.
|
|||
`, redirectURL.String(), redirectURL.String()))
|
||||
}
|
||||
|
||||
idp, err := options.GetIdentityProviderForID(redirectURL.Query().Get(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)
|
||||
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -426,7 +409,7 @@ Or contact your administrator.
|
|||
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
|
||||
}
|
||||
|
||||
s := sessions.NewState(idp.GetId())
|
||||
s := sessions.NewState(idpID)
|
||||
err = claims.Claims.Claims(&s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling session state: %w", err)
|
||||
|
@ -522,12 +505,9 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,
|
|||
// clear the user's local session no matter what
|
||||
defer state.sessionStore.ClearSession(w, r)
|
||||
|
||||
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
|
||||
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
|
|
@ -31,12 +31,9 @@ func (a *Authenticate) buildIdentityProfile(
|
|||
oauthToken *oauth2.Token,
|
||||
) (*identitypb.Profile, error) {
|
||||
options := a.options.Load()
|
||||
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error getting identity provider for id: %w", err)
|
||||
}
|
||||
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
|
||||
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error getting identity provider authenticator: %w", err)
|
||||
}
|
||||
|
@ -57,7 +54,7 @@ func (a *Authenticate) buildIdentityProfile(
|
|||
}
|
||||
|
||||
return &identitypb.Profile{
|
||||
ProviderId: idp.GetId(),
|
||||
ProviderId: idpID,
|
||||
IdToken: rawIDToken,
|
||||
OauthToken: rawOAuthToken,
|
||||
Claims: rawClaims,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue