mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-02 20:06:03 +02:00
authenticate: get/set identity provider id for all sessions (#3597)
This commit is contained in:
parent
8d7db85737
commit
bdd6145e91
2 changed files with 64 additions and 31 deletions
|
@ -138,20 +138,25 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
|
options := a.options.Load()
|
||||||
|
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
sessionState, err := a.getSessionFromCtx(ctx)
|
sessionState, err := a.getSessionFromCtx(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Info().
|
log.FromRequest(r).Info().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("idp_id", idpID).
|
Str("idp_id", idp.GetId()).
|
||||||
Msg("authenticate: session load error")
|
Msg("authenticate: session load error")
|
||||||
return a.reauthenticateOrFail(w, r, err)
|
return a.reauthenticateOrFail(w, r, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if sessionState.IdentityProviderID != idpID {
|
if sessionState.IdentityProviderID != idp.GetId() {
|
||||||
log.FromRequest(r).Info().
|
log.FromRequest(r).Info().
|
||||||
Str("idp_id", idpID).
|
Str("idp_id", idp.GetId()).
|
||||||
|
Str("session_idp_id", sessionState.IdentityProviderID).
|
||||||
Str("id", sessionState.ID).
|
Str("id", sessionState.ID).
|
||||||
Msg("authenticate: session not associated with identity provider")
|
Msg("authenticate: session not associated with identity provider")
|
||||||
return a.reauthenticateOrFail(w, r, err)
|
return a.reauthenticateOrFail(w, r, err)
|
||||||
|
@ -163,7 +168,7 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
||||||
if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil {
|
if _, err = session.Get(ctx, state.dataBrokerClient, sessionState.ID); err != nil {
|
||||||
log.FromRequest(r).Info().
|
log.FromRequest(r).Info().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("idp_id", idpID).
|
Str("idp_id", idp.GetId()).
|
||||||
Str("id", sessionState.ID).
|
Str("id", sessionState.ID).
|
||||||
Msg("authenticate: session not found in databroker")
|
Msg("authenticate: session not found in databroker")
|
||||||
return a.reauthenticateOrFail(w, r, err)
|
return a.reauthenticateOrFail(w, r, err)
|
||||||
|
@ -187,6 +192,11 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
options := a.options.Load()
|
||||||
|
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -216,8 +226,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// start over if this is a different identity provider
|
// start over if this is a different identity provider
|
||||||
if s == nil || s.IdentityProviderID != r.FormValue(urlutil.QueryIdentityProviderID) {
|
if s == nil || s.IdentityProviderID != idp.GetId() {
|
||||||
s = sessions.NewState(urlutil.QueryIdentityProviderID)
|
s = sessions.NewState(idp.GetId())
|
||||||
}
|
}
|
||||||
|
|
||||||
newSession := s.WithNewIssuer(state.redirectURL.Host, jwtAudience)
|
newSession := s.WithNewIssuer(state.redirectURL.Host, jwtAudience)
|
||||||
|
@ -276,8 +286,12 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
options := a.options.Load()
|
options := a.options.Load()
|
||||||
|
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -285,7 +299,7 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
|
||||||
rawIDToken := a.revokeSession(ctx, w, r)
|
rawIDToken := a.revokeSession(ctx, w, r)
|
||||||
|
|
||||||
redirectString := ""
|
redirectString := ""
|
||||||
signOutURL, err := a.options.Load().GetSignOutRedirectURL()
|
signOutURL, err := options.GetSignOutRedirectURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -296,14 +310,14 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
|
||||||
redirectString = uri
|
redirectString = uri
|
||||||
}
|
}
|
||||||
|
|
||||||
endSessionURL, err := idp.LogOut()
|
endSessionURL, err := authenticator.LogOut()
|
||||||
if err == nil && redirectString != "" {
|
if err == nil && redirectString != "" {
|
||||||
params := url.Values{}
|
params := url.Values{}
|
||||||
params.Add("id_token_hint", rawIDToken)
|
params.Add("id_token_hint", rawIDToken)
|
||||||
params.Add("post_logout_redirect_uri", redirectString)
|
params.Add("post_logout_redirect_uri", redirectString)
|
||||||
endSessionURL.RawQuery = params.Encode()
|
endSessionURL.RawQuery = params.Encode()
|
||||||
redirectString = endSessionURL.String()
|
redirectString = endSessionURL.String()
|
||||||
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
|
} else if err != nil && !errors.Is(err, oidc.ErrSignoutNotImplemented) {
|
||||||
log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session")
|
log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session")
|
||||||
}
|
}
|
||||||
if redirectString != "" {
|
if redirectString != "" {
|
||||||
|
@ -330,10 +344,14 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
|
||||||
return httputil.NewError(http.StatusUnauthorized, err)
|
return httputil.NewError(http.StatusUnauthorized, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
options := a.options.Load()
|
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
options := a.options.Load()
|
||||||
|
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -346,7 +364,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
|
||||||
enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b)
|
enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b)
|
||||||
b = append(b, enc...)
|
b = append(b, enc...)
|
||||||
encodedState := base64.URLEncoding.EncodeToString(b)
|
encodedState := base64.URLEncoding.EncodeToString(b)
|
||||||
signinURL, err := idp.GetSignInURL(encodedState)
|
signinURL, err := authenticator.GetSignInURL(encodedState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusInternalServerError,
|
return httputil.NewError(http.StatusInternalServerError,
|
||||||
fmt.Errorf("failed to get sign in url: %w", err))
|
fmt.Errorf("failed to get sign in url: %w", err))
|
||||||
|
@ -381,8 +399,8 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
|
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
options := a.options.Load()
|
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
options := a.options.Load()
|
||||||
|
|
||||||
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
|
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
|
||||||
//
|
//
|
||||||
|
@ -428,9 +446,13 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, httputil.NewError(http.StatusBadRequest, err)
|
return nil, httputil.NewError(http.StatusBadRequest, err)
|
||||||
}
|
}
|
||||||
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)
|
|
||||||
|
|
||||||
idp, err := a.cfg.getIdentityProvider(options, idpID)
|
idp, err := options.GetIdentityProviderForID(redirectURL.Query().Get(urlutil.QueryIdentityProviderID))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -439,12 +461,12 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
//
|
//
|
||||||
// Exchange the supplied Authorization Code for a valid user session.
|
// Exchange the supplied Authorization Code for a valid user session.
|
||||||
var claims identity.SessionClaims
|
var claims identity.SessionClaims
|
||||||
accessToken, err := idp.Authenticate(ctx, code, &claims)
|
accessToken, err := authenticator.Authenticate(ctx, code, &claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
|
return nil, fmt.Errorf("error redeeming authenticate code: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := sessions.NewState(idpID)
|
s := sessions.NewState(idp.GetId())
|
||||||
err = claims.Claims.Claims(&s)
|
err = claims.Claims.Claims(&s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error unmarshaling session state: %w", err)
|
return nil, fmt.Errorf("error unmarshaling session state: %w", err)
|
||||||
|
@ -582,8 +604,12 @@ func (a *Authenticate) saveSessionToDataBroker(
|
||||||
) error {
|
) error {
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
options := a.options.Load()
|
options := a.options.Load()
|
||||||
|
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -593,7 +619,7 @@ func (a *Authenticate) saveSessionToDataBroker(
|
||||||
|
|
||||||
s := &session.Session{
|
s := &session.Session{
|
||||||
Id: sessionState.ID,
|
Id: sessionState.ID,
|
||||||
UserId: sessionState.UserID(idp.Name()),
|
UserId: sessionState.UserID(authenticator.Name()),
|
||||||
IssuedAt: timestamppb.Now(),
|
IssuedAt: timestamppb.Now(),
|
||||||
AccessedAt: timestamppb.Now(),
|
AccessedAt: timestamppb.Now(),
|
||||||
ExpiresAt: sessionExpiry,
|
ExpiresAt: sessionExpiry,
|
||||||
|
@ -617,7 +643,7 @@ func (a *Authenticate) saveSessionToDataBroker(
|
||||||
Id: s.GetUserId(),
|
Id: s.GetUserId(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
err = idp.UpdateUserInfo(ctx, accessToken, &managerUser)
|
err = authenticator.UpdateUserInfo(ctx, accessToken, &managerUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("authenticate: error retrieving user info: %w", err)
|
return fmt.Errorf("authenticate: error retrieving user info: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -648,13 +674,18 @@ func (a *Authenticate) saveSessionToDataBroker(
|
||||||
// databroker. If successful, it returns the original `id_token` of the session, if failed, returns
|
// databroker. If successful, it returns the original `id_token` of the session, if failed, returns
|
||||||
// and empty string.
|
// and empty string.
|
||||||
func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string {
|
func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string {
|
||||||
options := a.options.Load()
|
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
options := a.options.Load()
|
||||||
|
|
||||||
// clear the user's local session no matter what
|
// clear the user's local session no matter what
|
||||||
defer state.sessionStore.ClearSession(w, r)
|
defer state.sessionStore.ClearSession(w, r)
|
||||||
|
|
||||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
idp, err := options.GetIdentityProviderForID(r.FormValue(urlutil.QueryIdentityProviderID))
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticator, err := a.cfg.getIdentityProvider(options, idp.GetId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
@ -667,7 +698,7 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,
|
||||||
|
|
||||||
if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil {
|
if s, _ := session.Get(ctx, state.dataBrokerClient, sessionState.ID); s != nil && s.OauthToken != nil {
|
||||||
rawIDToken = s.GetIdToken().GetRaw()
|
rawIDToken = s.GetIdToken().GetRaw()
|
||||||
if err := idp.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil {
|
if err := authenticator.Revoke(ctx, manager.FromOAuthToken(s.OauthToken)); err != nil {
|
||||||
log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token")
|
log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -478,6 +478,8 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
idp, _ := new(config.Options).GetIdentityProviderForID("")
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
headers map[string]string
|
headers map[string]string
|
||||||
|
@ -491,7 +493,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
{
|
{
|
||||||
"good",
|
"good",
|
||||||
nil,
|
nil,
|
||||||
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
|
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||||
nil,
|
nil,
|
||||||
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
||||||
http.StatusOK,
|
http.StatusOK,
|
||||||
|
@ -499,7 +501,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
{
|
{
|
||||||
"invalid session",
|
"invalid session",
|
||||||
nil,
|
nil,
|
||||||
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
|
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||||
errors.New("hi"),
|
errors.New("hi"),
|
||||||
identity.MockProvider{},
|
identity.MockProvider{},
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
|
@ -507,7 +509,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
{
|
{
|
||||||
"good refresh expired",
|
"good refresh expired",
|
||||||
nil,
|
nil,
|
||||||
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
|
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||||
nil,
|
nil,
|
||||||
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
||||||
http.StatusOK,
|
http.StatusOK,
|
||||||
|
@ -515,7 +517,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
{
|
{
|
||||||
"expired,refresh error",
|
"expired,refresh error",
|
||||||
nil,
|
nil,
|
||||||
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
|
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||||
sessions.ErrExpired,
|
sessions.ErrExpired,
|
||||||
identity.MockProvider{RefreshError: errors.New("error")},
|
identity.MockProvider{RefreshError: errors.New("error")},
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
|
@ -523,7 +525,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
{
|
{
|
||||||
"expired,save error",
|
"expired,save error",
|
||||||
nil,
|
nil,
|
||||||
&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{ID: "xyz"}},
|
&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||||
sessions.ErrExpired,
|
sessions.ErrExpired,
|
||||||
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
|
@ -531,7 +533,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
{
|
{
|
||||||
"expired XHR,refresh error",
|
"expired XHR,refresh error",
|
||||||
map[string]string{"X-Requested-With": "XmlHttpRequest"},
|
map[string]string{"X-Requested-With": "XmlHttpRequest"},
|
||||||
&mstore.Store{Session: &sessions.State{ID: "xyz"}},
|
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||||
sessions.ErrExpired,
|
sessions.ErrExpired,
|
||||||
identity.MockProvider{RefreshError: errors.New("error")},
|
identity.MockProvider{RefreshError: errors.New("error")},
|
||||||
http.StatusUnauthorized,
|
http.StatusUnauthorized,
|
||||||
|
|
Loading…
Add table
Reference in a new issue