mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
authenticate: get/set identity provider id for all sessions (#3597)
This commit is contained in:
parent
8d7db85737
commit
bdd6145e91
2 changed files with 64 additions and 31 deletions
|
@ -138,20 +138,25 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
|||
defer span.End()
|
||||
|
||||
state := a.state.Load()
|
||||
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
|
||||
options := a.options.Load()
|
||||
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessionState, err := a.getSessionFromCtx(ctx)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Info().
|
||||
Err(err).
|
||||
Str("idp_id", idpID).
|
||||
Str("idp_id", idp.GetId()).
|
||||
Msg("authenticate: session load error")
|
||||
return a.reauthenticateOrFail(w, r, err)
|
||||
}
|
||||
|
||||
if sessionState.IdentityProviderID != idpID {
|
||||
if sessionState.IdentityProviderID != idp.GetId() {
|
||||
log.FromRequest(r).Info().
|
||||
Str("idp_id", idpID).
|
||||
Str("idp_id", idp.GetId()).
|
||||
Str("session_idp_id", sessionState.IdentityProviderID).
|
||||
Str("id", sessionState.ID).
|
||||
Msg("authenticate: session not associated with identity provider")
|
||||
return a.reauthenticateOrFail(w, r, err)
|
||||
|
@ -163,7 +168,7 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
|||
if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil {
|
||||
log.FromRequest(r).Info().
|
||||
Err(err).
|
||||
Str("idp_id", idpID).
|
||||
Str("idp_id", idp.GetId()).
|
||||
Str("id", sessionState.ID).
|
||||
Msg("authenticate: session not found in databroker")
|
||||
return a.reauthenticateOrFail(w, r, err)
|
||||
|
@ -187,6 +192,11 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
defer span.End()
|
||||
|
||||
state := a.state.Load()
|
||||
options := a.options.Load()
|
||||
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
||||
if err != nil {
|
||||
|
@ -216,8 +226,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
}
|
||||
|
||||
// start over if this is a different identity provider
|
||||
if s == nil || s.IdentityProviderID != r.FormValue(urlutil.QueryIdentityProviderID) {
|
||||
s = sessions.NewState(urlutil.QueryIdentityProviderID)
|
||||
if s == nil || s.IdentityProviderID != idp.GetId() {
|
||||
s = sessions.NewState(idp.GetId())
|
||||
}
|
||||
|
||||
newSession := s.WithNewIssuer(state.redirectURL.Host, jwtAudience)
|
||||
|
@ -276,8 +286,12 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
|
|||
defer span.End()
|
||||
|
||||
options := a.options.Load()
|
||||
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -285,7 +299,7 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
|
|||
rawIDToken := a.revokeSession(ctx, w, r)
|
||||
|
||||
redirectString := ""
|
||||
signOutURL, err := a.options.Load().GetSignOutRedirectURL()
|
||||
signOutURL, err := options.GetSignOutRedirectURL()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -296,14 +310,14 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
|
|||
redirectString = uri
|
||||
}
|
||||
|
||||
endSessionURL, err := idp.LogOut()
|
||||
endSessionURL, err := authenticator.LogOut()
|
||||
if err == nil && redirectString != "" {
|
||||
params := url.Values{}
|
||||
params.Add("id_token_hint", rawIDToken)
|
||||
params.Add("post_logout_redirect_uri", redirectString)
|
||||
endSessionURL.RawQuery = params.Encode()
|
||||
redirectString = endSessionURL.String()
|
||||
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
|
||||
} else if err != nil && !errors.Is(err, oidc.ErrSignoutNotImplemented) {
|
||||
log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session")
|
||||
}
|
||||
if redirectString != "" {
|
||||
|
@ -330,10 +344,14 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
|
|||
return httputil.NewError(http.StatusUnauthorized, err)
|
||||
}
|
||||
|
||||
options := a.options.Load()
|
||||
state := a.state.Load()
|
||||
options := a.options.Load()
|
||||
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -346,7 +364,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
|
|||
enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b)
|
||||
b = append(b, enc...)
|
||||
encodedState := base64.URLEncoding.EncodeToString(b)
|
||||
signinURL, err := idp.GetSignInURL(encodedState)
|
||||
signinURL, err := authenticator.GetSignInURL(encodedState)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError,
|
||||
fmt.Errorf("failed to get sign in url: %w", err))
|
||||
|
@ -381,8 +399,8 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
|
||||
defer span.End()
|
||||
|
||||
options := a.options.Load()
|
||||
state := a.state.Load()
|
||||
options := a.options.Load()
|
||||
|
||||
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
|
||||
//
|
||||
|
@ -428,9 +446,13 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
if err != nil {
|
||||
return nil, httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)
|
||||
|
||||
idp, err := a.cfg.getIdentityProvider(options, idpID)
|
||||
idp, err := options.GetIdentityProviderForID(redirectURL.Query().Get(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -439,12 +461,12 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
//
|
||||
// Exchange the supplied Authorization Code for a valid user session.
|
||||
var claims identity.SessionClaims
|
||||
accessToken, err := idp.Authenticate(ctx, code, &claims)
|
||||
accessToken, err := authenticator.Authenticate(ctx, code, &claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
|
||||
}
|
||||
|
||||
s := sessions.NewState(idpID)
|
||||
s := sessions.NewState(idp.GetId())
|
||||
err = claims.Claims.Claims(&s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling session state: %w", err)
|
||||
|
@ -582,8 +604,12 @@ func (a *Authenticate) saveSessionToDataBroker(
|
|||
) error {
|
||||
state := a.state.Load()
|
||||
options := a.options.Load()
|
||||
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -593,7 +619,7 @@ func (a *Authenticate) saveSessionToDataBroker(
|
|||
|
||||
s := &session.Session{
|
||||
Id: sessionState.ID,
|
||||
UserId: sessionState.UserID(idp.Name()),
|
||||
UserId: sessionState.UserID(authenticator.Name()),
|
||||
IssuedAt: timestamppb.Now(),
|
||||
AccessedAt: timestamppb.Now(),
|
||||
ExpiresAt: sessionExpiry,
|
||||
|
@ -617,7 +643,7 @@ func (a *Authenticate) saveSessionToDataBroker(
|
|||
Id: s.GetUserId(),
|
||||
}
|
||||
}
|
||||
err = idp.UpdateUserInfo(ctx, accessToken, &managerUser)
|
||||
err = authenticator.UpdateUserInfo(ctx, accessToken, &managerUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("authenticate: error retrieving user info: %w", err)
|
||||
}
|
||||
|
@ -648,13 +674,18 @@ func (a *Authenticate) saveSessionToDataBroker(
|
|||
// databroker. If successful, it returns the original `id_token` of the session, if failed, returns
|
||||
// and empty string.
|
||||
func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string {
|
||||
options := a.options.Load()
|
||||
state := a.state.Load()
|
||||
options := a.options.Load()
|
||||
|
||||
// clear the user's local session no matter what
|
||||
defer state.sessionStore.ClearSession(w, r)
|
||||
|
||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
@ -667,7 +698,7 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,
|
|||
|
||||
if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil {
|
||||
rawIDToken = s.GetIdToken().GetRaw()
|
||||
if err := idp.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil {
|
||||
if err := authenticator.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil {
|
||||
log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token")
|
||||
}
|
||||
}
|
||||
|
|
|
@ -478,6 +478,8 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
w.WriteHeader(http.StatusOK)
|
||||
})
|
||||
|
||||
idp, _ := new(config.Options).GetIdentityProviderForID("")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
|
@ -491,7 +493,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
{
|
||||
"good",
|
||||
nil,
|
||||
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
|
||||
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||
nil,
|
||||
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
||||
http.StatusOK,
|
||||
|
@ -499,7 +501,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
{
|
||||
"invalid session",
|
||||
nil,
|
||||
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
|
||||
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||
errors.New("hi"),
|
||||
identity.MockProvider{},
|
||||
http.StatusFound,
|
||||
|
@ -507,7 +509,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
{
|
||||
"good refresh expired",
|
||||
nil,
|
||||
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
|
||||
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||
nil,
|
||||
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
||||
http.StatusOK,
|
||||
|
@ -515,7 +517,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
{
|
||||
"expired,refresh error",
|
||||
nil,
|
||||
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
|
||||
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||
sessions.ErrExpired,
|
||||
identity.MockProvider{RefreshError: errors.New("error")},
|
||||
http.StatusFound,
|
||||
|
@ -523,7 +525,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
{
|
||||
"expired,save error",
|
||||
nil,
|
||||
&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{ID: "xyz"}},
|
||||
&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||
sessions.ErrExpired,
|
||||
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
||||
http.StatusFound,
|
||||
|
@ -531,7 +533,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
{
|
||||
"expired XHR,refresh error",
|
||||
map[string]string{"X-Requested-With": "XmlHttpRequest"},
|
||||
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
|
||||
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||
sessions.ErrExpired,
|
||||
identity.MockProvider{RefreshError: errors.New("error")},
|
||||
http.StatusUnauthorized,
|
||||
|
|
Loading…
Add table
Reference in a new issue