From 3adbc65d371d896c07844d3d641d7d25f10634bf Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Tue, 19 Dec 2023 12:04:23 -0700 Subject: [PATCH] core/authenticate: refactor identity authenticators to initiate redirect (#4858) * core/authenticate: refactor identity authenticators to initiate redirect, use cookie for redirect url for cognito * set secure and http only, update test --- authenticate/handlers.go | 18 ++--- authenticate/handlers_test.go | 16 ++-- internal/handlers/signedout.go | 6 ++ internal/handlers/signedout_test.go | 42 ++++++++++ internal/httputil/signedout.go | 28 +++++++ internal/identity/mock_provider.go | 36 +++++---- internal/identity/oauth/apple/apple.go | 51 ++++++------ internal/identity/oauth/github/github.go | 24 +++--- internal/identity/oidc/auth0/auth0.go | 16 ++-- internal/identity/oidc/auth0/auth0_test.go | 8 +- internal/identity/oidc/cognito/cognito.go | 18 +++-- .../identity/oidc/cognito/cognito_test.go | 16 +++- internal/identity/oidc/oidc.go | 77 ++++++++++--------- internal/identity/providers.go | 6 +- 14 files changed, 237 insertions(+), 125 deletions(-) create mode 100644 internal/handlers/signedout_test.go create mode 100644 internal/httputil/signedout.go diff --git a/authenticate/handlers.go b/authenticate/handlers.go index ca5afe92b..6fe0767fc 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -223,14 +223,14 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e signOutURL = uri } else if signOutRedirectURL != nil { signOutURL = signOutRedirectURL.String() - } else { - signOutURL = authenticateURL.ResolveReference(&url.URL{ - Path: "/.pomerium/signed_out", - }).String() } - if idpSignOutURL, err := authenticator.GetSignOutURL(rawIDToken, signOutURL); err == nil { - signOutURL = idpSignOutURL + authenticateSignedOutURL := authenticateURL.ResolveReference(&url.URL{ + Path: "/.pomerium/signed_out", + }).String() + + if err := authenticator.SignOut(w, r, rawIDToken, authenticateSignedOutURL, signOutURL); err == nil { + return nil } else if !errors.Is(err, oidc.ErrSignoutNotImplemented) { log.Warn(r.Context()).Err(err).Msg("authenticate: failed to get sign out url for authenticator") } @@ -275,12 +275,12 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b) b = append(b, enc...) encodedState := base64.URLEncoding.EncodeToString(b) - signinURL, err := authenticator.GetSignInURL(encodedState) + + err = authenticator.SignIn(w, r, encodedState) if err != nil { return httputil.NewError(http.StatusInternalServerError, - fmt.Errorf("failed to get sign in url: %w", err)) + fmt.Errorf("failed to sign in: %w", err)) } - httputil.Redirect(w, r, signinURL, http.StatusFound) return nil } diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 7a3424098..ef902ec3a 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -137,7 +137,7 @@ func TestAuthenticate_SignOut(t *testing.T) { "", "sig", "ts", - identity.MockProvider{GetSignOutURLResponse: "https://microsoft.com"}, + identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusFound, "", @@ -150,7 +150,7 @@ func TestAuthenticate_SignOut(t *testing.T) { "https://signout-redirect-url.example.com", "sig", "ts", - identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented}, + identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusFound, "", @@ -163,7 +163,7 @@ func TestAuthenticate_SignOut(t *testing.T) { "", "sig", "ts", - identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")}, + identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusFound, "", @@ -176,7 +176,7 @@ func TestAuthenticate_SignOut(t *testing.T) { "", "sig", "ts", - identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")}, + identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusFound, "", @@ -189,7 +189,7 @@ func TestAuthenticate_SignOut(t *testing.T) { "", "sig", "ts", - identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented}, + identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusFound, "", @@ -401,7 +401,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, errors.New("hi"), identity.MockProvider{}, - http.StatusFound, + http.StatusOK, }, { "expired,refresh error", @@ -409,7 +409,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, - http.StatusFound, + http.StatusOK, }, { "expired,save error", @@ -417,7 +417,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, - http.StatusFound, + http.StatusOK, }, { "expired XHR,refresh error", diff --git a/internal/handlers/signedout.go b/internal/handlers/signedout.go index 099a333db..5fd0440ca 100644 --- a/internal/handlers/signedout.go +++ b/internal/handlers/signedout.go @@ -18,6 +18,12 @@ func (data SignedOutData) ToJSON() map[string]interface{} { // SignedOut returns a handler that renders the signed out page. func SignedOut(data SignedOutData) http.Handler { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + if redirectURI, ok := httputil.GetSignedOutRedirectURICookie(w, r); ok { + httputil.Redirect(w, r, redirectURI, http.StatusFound) + return nil + } + + // otherwise show the signed-out page return ui.ServePage(w, r, "SignedOut", data.ToJSON()) }) } diff --git a/internal/handlers/signedout_test.go b/internal/handlers/signedout_test.go new file mode 100644 index 000000000..3f1c0d18b --- /dev/null +++ b/internal/handlers/signedout_test.go @@ -0,0 +1,42 @@ +package handlers_test + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/handlers" +) + +func TestSignedOut(t *testing.T) { + t.Parallel() + + t.Run("ok", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/.pomerium/signed_out", nil) + + handlers.SignedOut(handlers.SignedOutData{}).ServeHTTP(w, r) + + assert.Equal(t, http.StatusOK, w.Code) + }) + + t.Run("redirect", func(t *testing.T) { + t.Parallel() + + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "/.pomerium/signed_out", nil) + r.AddCookie(&http.Cookie{ + Name: "_pomerium_signed_out_redirect_uri", + Value: "https://www.google.com", + }) + + handlers.SignedOut(handlers.SignedOutData{}).ServeHTTP(w, r) + + assert.Equal(t, http.StatusFound, w.Code) + assert.Equal(t, "https://www.google.com", w.Header().Get("Location")) + }) +} diff --git a/internal/httputil/signedout.go b/internal/httputil/signedout.go new file mode 100644 index 000000000..631e05048 --- /dev/null +++ b/internal/httputil/signedout.go @@ -0,0 +1,28 @@ +package httputil + +import "net/http" + +const signedOutRedirectURICookieName = "_pomerium_signed_out_redirect_uri" + +// GetSignedOutRedirectURICookie gets the redirect uri cookie for the signed-out page. +func GetSignedOutRedirectURICookie(w http.ResponseWriter, r *http.Request) (string, bool) { + cookie, err := r.Cookie(signedOutRedirectURICookieName) + if err != nil { + return "", false + } + + cookie.MaxAge = -1 + http.SetCookie(w, cookie) + return cookie.Value, true +} + +// SetSignedOutRedirectURICookie sets the redirect uri cookie for the signed-out page. +func SetSignedOutRedirectURICookie(w http.ResponseWriter, redirectURI string) { + http.SetCookie(w, &http.Cookie{ + Name: signedOutRedirectURICookieName, + Value: redirectURI, + MaxAge: 5 * 60, + HttpOnly: true, + Secure: true, + }) +} diff --git a/internal/identity/mock_provider.go b/internal/identity/mock_provider.go index bdd0b2a96..49075c96f 100644 --- a/internal/identity/mock_provider.go +++ b/internal/identity/mock_provider.go @@ -2,6 +2,7 @@ package identity import ( "context" + "net/http" "golang.org/x/oauth2" @@ -10,15 +11,14 @@ import ( // MockProvider provides a mocked implementation of the providers interface. type MockProvider struct { - AuthenticateResponse oauth2.Token - AuthenticateError error - RefreshResponse oauth2.Token - RefreshError error - RevokeError error - GetSignInURLResponse string - GetSignOutURLResponse string - GetSignOutURLError error - UpdateUserInfoError error + AuthenticateResponse oauth2.Token + AuthenticateError error + RefreshResponse oauth2.Token + RefreshError error + RevokeError error + UpdateUserInfoError error + SignInError error + SignOutError error } // Authenticate is a mocked providers function. @@ -36,14 +36,6 @@ func (mp MockProvider) Revoke(_ context.Context, _ *oauth2.Token) error { return mp.RevokeError } -// GetSignInURL is a mocked providers function. -func (mp MockProvider) GetSignInURL(_ string) (string, error) { return mp.GetSignInURLResponse, nil } - -// GetSignOutURL is a mocked providers function. -func (mp MockProvider) GetSignOutURL(_, _ string) (string, error) { - return mp.GetSignOutURLResponse, mp.GetSignOutURLError -} - // UpdateUserInfo is a mocked providers function. func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error { return mp.UpdateUserInfoError @@ -53,3 +45,13 @@ func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ inte func (mp MockProvider) Name() string { return "mock" } + +// SignOut is a mocked providers function. +func (mp MockProvider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error { + return mp.SignOutError +} + +// SignIn is a mocked providers function. +func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string) error { + return mp.SignInError +} diff --git a/internal/identity/oauth/apple/apple.go b/internal/identity/oauth/apple/apple.go index bed70fa94..0ba5c0133 100644 --- a/internal/identity/oauth/apple/apple.go +++ b/internal/identity/oauth/apple/apple.go @@ -83,31 +83,6 @@ func (p *Provider) Name() string { return Name } -// GetSignInURL returns the url of the provider's OAuth 2.0 consent page -// that asks for permissions for the required scopes explicitly. -// -// State is a token to protect the user from CSRF attacks. You must -// always provide a non-empty string and validate that it matches the -// the state query parameter on your redirect callback. -// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. -func (p *Provider) GetSignInURL(state string) (string, error) { - opts := []oauth2.AuthCodeOption{} - for k, v := range p.authCodeOptions { - opts = append(opts, oauth2.SetAuthURLParam(k, v)) - } - authURL := p.oauth.AuthCodeURL(state, opts...) - - // Apple is very picky here and we need to use %20 instead of + - authURL = strings.ReplaceAll(authURL, "+", "%20") - - return authURL, nil -} - -// GetSignOutURL is not implemented. -func (p *Provider) GetSignOutURL(_, _ string) (string, error) { - return "", oidc.ErrSignoutNotImplemented -} - // Authenticate converts an authorization code returned from the identity // provider into a token which is then converted into a user session. func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) { @@ -181,3 +156,29 @@ func (p *Provider) UpdateUserInfo(_ context.Context, t *oauth2.Token, v interfac return idToken.UnsafeClaimsWithoutVerification(v) } + +// SignIn redirects to the url of the provider's OAuth 2.0 consent page +// that asks for permissions for the required scopes explicitly. +// +// State is a token to protect the user from CSRF attacks. You must +// always provide a non-empty string and validate that it matches the +// the state query parameter on your redirect callback. +// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. +func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) error { + opts := []oauth2.AuthCodeOption{} + for k, v := range p.authCodeOptions { + opts = append(opts, oauth2.SetAuthURLParam(k, v)) + } + authURL := p.oauth.AuthCodeURL(state, opts...) + + // Apple is very picky here and we need to use %20 instead of + + authURL = strings.ReplaceAll(authURL, "+", "%20") + + httputil.Redirect(w, r, authURL, http.StatusFound) + return nil +} + +// SignOut is not implemented. +func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error { + return oidc.ErrSignoutNotImplemented +} diff --git a/internal/identity/oauth/github/github.go b/internal/identity/oauth/github/github.go index 2e3ee1c88..cfd9adf9d 100644 --- a/internal/identity/oauth/github/github.go +++ b/internal/identity/oauth/github/github.go @@ -239,18 +239,20 @@ func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error { return nil } -// GetSignInURL returns a URL to OAuth 2.0 provider's consent page -// that asks for permissions for the required scopes explicitly. -func (p *Provider) GetSignInURL(state string) (string, error) { - return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil -} - -// GetSignOutURL is not implemented. -func (p *Provider) GetSignOutURL(_, _ string) (string, error) { - return "", oidc.ErrSignoutNotImplemented -} - // Name returns the provider name. func (p *Provider) Name() string { return Name } + +// SignIn redirects to the OAuth 2.0 provider's consent page +// that asks for permissions for the required scopes explicitly. +func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) error { + signInURL := p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline) + httputil.Redirect(w, r, signInURL, http.StatusFound) + return nil +} + +// SignOut is not implemented. +func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error { + return oidc.ErrSignoutNotImplemented +} diff --git a/internal/identity/oidc/auth0/auth0.go b/internal/identity/oidc/auth0/auth0.go index b417dacbb..3b50ae153 100644 --- a/internal/identity/oidc/auth0/auth0.go +++ b/internal/identity/oidc/auth0/auth0.go @@ -6,9 +6,11 @@ package auth0 import ( "context" "fmt" + "net/http" "net/url" "strings" + "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity/oauth" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/urlutil" @@ -50,16 +52,16 @@ func (p *Provider) Name() string { return Name } -// GetSignOutURL implements logout as described in https://auth0.com/docs/api/authentication#logout. -func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) { +// SignOut implements logout as described in https://auth0.com/docs/api/authentication#logout. +func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, _, authenticateSignedOutURL, redirectToURL string) error { oa, err := p.GetOauthConfig() if err != nil { - return "", fmt.Errorf("error getting auth0 oauth config: %w", err) + return fmt.Errorf("error getting auth0 oauth config: %w", err) } authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL) if err != nil { - return "", fmt.Errorf("error parsing auth0 endpoint auth url: %w", err) + return fmt.Errorf("error parsing auth0 endpoint auth url: %w", err) } logoutQuery := url.Values{ @@ -67,10 +69,14 @@ func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) { } if redirectToURL != "" { logoutQuery.Set("returnTo", redirectToURL) + } else if authenticateSignedOutURL != "" { + logoutQuery.Set("returnTo", authenticateSignedOutURL) } logoutURL := authURL.ResolveReference(&url.URL{ Path: "/v2/logout", RawQuery: logoutQuery.Encode(), }) - return logoutURL.String(), nil + + httputil.Redirect(w, r, logoutURL.String(), http.StatusFound) + return nil } diff --git a/internal/identity/oidc/auth0/auth0_test.go b/internal/identity/oidc/auth0/auth0_test.go index d19a17960..4cff74415 100644 --- a/internal/identity/oidc/auth0/auth0_test.go +++ b/internal/identity/oidc/auth0/auth0_test.go @@ -53,9 +53,11 @@ func TestProvider(t *testing.T) { require.NoError(t, err) require.NotNil(t, p) - t.Run("GetSignOutURL", func(t *testing.T) { - signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b") + t.Run("SignOut", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://authenticate.example.com/.pomerium/sign_out", nil) + err := p.SignOut(w, r, "", "https://authenticate.example.com/.pomerium/signed_out", "https://www.example.com?a=b") assert.NoError(t, err) - assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL) + assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", w.Header().Get("Location")) }) } diff --git a/internal/identity/oidc/cognito/cognito.go b/internal/identity/oidc/cognito/cognito.go index 305d93694..e19e87722 100644 --- a/internal/identity/oidc/cognito/cognito.go +++ b/internal/identity/oidc/cognito/cognito.go @@ -4,8 +4,10 @@ package cognito import ( "context" "fmt" + "net/http" "net/url" + "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/identity/oauth" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" "github.com/pomerium/pomerium/internal/urlutil" @@ -51,27 +53,31 @@ func New(ctx context.Context, opts *oauth.Options) (*Provider, error) { return &p, nil } -// GetSignOutURL gets the sign out URL according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html. -func (p *Provider) GetSignOutURL(_, returnToURL string) (string, error) { +// SignOut implements sign out according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html. +func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, _, authenticateSignedOutURL, returnToURL string) error { oa, err := p.GetOauthConfig() if err != nil { - return "", fmt.Errorf("error getting cognito oauth config: %w", err) + return fmt.Errorf("error getting cognito oauth config: %w", err) } authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL) if err != nil { - return "", fmt.Errorf("error getting cognito endpoint auth url: %w", err) + return fmt.Errorf("error getting cognito endpoint auth url: %w", err) } logOutQuery := url.Values{ "client_id": []string{oa.ClientID}, } + if authenticateSignedOutURL != "" { + logOutQuery.Set("logout_uri", authenticateSignedOutURL) + } if returnToURL != "" { - logOutQuery.Set("logout_uri", returnToURL) + httputil.SetSignedOutRedirectURICookie(w, returnToURL) } logOutURL := authURL.ResolveReference(&url.URL{ Path: "/logout", RawQuery: logOutQuery.Encode(), }) - return logOutURL.String(), nil + httputil.Redirect(w, r, logOutURL.String(), http.StatusFound) + return nil } diff --git a/internal/identity/oidc/cognito/cognito_test.go b/internal/identity/oidc/cognito/cognito_test.go index bbd28c6c1..20ca014c5 100644 --- a/internal/identity/oidc/cognito/cognito_test.go +++ b/internal/identity/oidc/cognito/cognito_test.go @@ -53,9 +53,19 @@ func TestProvider(t *testing.T) { require.NoError(t, err) require.NotNil(t, p) - t.Run("GetSignOutURL", func(t *testing.T) { - signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b") + t.Run("SignOut", func(t *testing.T) { + w := httptest.NewRecorder() + r := httptest.NewRequest(http.MethodGet, "https://authenticate.example.com/.pomerium/sign_out", nil) + err := p.SignOut(w, r, "", "https://authenticate.example.com/.pomerium/signed_out", "https://www.example.com?a=b") assert.NoError(t, err) - assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL) + assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fauthenticate.example.com%2F.pomerium%2Fsigned_out", w.Header().Get("Location")) + assert.Equal(t, []*http.Cookie{{ + Name: "_pomerium_signed_out_redirect_uri", + Value: "https://www.example.com?a=b", + MaxAge: 300, + Secure: true, + HttpOnly: true, + Raw: "_pomerium_signed_out_redirect_uri=https://www.example.com?a=b; Max-Age=300; HttpOnly; Secure", + }}, w.Result().Cookies()) }) } diff --git a/internal/identity/oidc/oidc.go b/internal/identity/oidc/oidc.go index d13e3634f..d8d89c4d2 100644 --- a/internal/identity/oidc/oidc.go +++ b/internal/identity/oidc/oidc.go @@ -96,56 +96,26 @@ func New(ctx context.Context, o *oauth.Options, options ...Option) (*Provider, e return p, nil } -// GetSignInURL returns the url of the provider's OAuth 2.0 consent page +// SignIn redirects to the url of the provider's OAuth 2.0 consent page // that asks for permissions for the required scopes explicitly. // // State is a token to protect the user from CSRF attacks. You must // always provide a non-empty string and validate that it matches the // the state query parameter on your redirect callback. // See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. -func (p *Provider) GetSignInURL(state string) (string, error) { +func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) error { oa, err := p.GetOauthConfig() if err != nil { - return "", err + return err } opts := defaultAuthCodeOptions for k, v := range p.AuthCodeOptions { opts = append(opts, oauth2.SetAuthURLParam(k, v)) } - return oa.AuthCodeURL(state, opts...), nil -} - -// GetSignOutURL returns the EndSessionURL endpoint to allow a logout -// session to be initiated. -// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated -func (p *Provider) GetSignOutURL(idTokenHint, redirectToURL string) (string, error) { - _, err := p.GetProvider() - if err != nil { - return "", err - } - - if p.EndSessionURL == "" { - return "", ErrSignoutNotImplemented - } - - endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL) - if err != nil { - return "", err - } - - params := endSessionURL.Query() - if idTokenHint != "" { - params.Add("id_token_hint", idTokenHint) - } - if oa, err := p.GetOauthConfig(); err == nil { - params.Add("client_id", oa.ClientID) - } - if redirectToURL != "" { - params.Add("post_logout_redirect_uri", redirectToURL) - } - endSessionURL.RawQuery = params.Encode() - return endSessionURL.String(), nil + signInURL := oa.AuthCodeURL(state, opts...) + httputil.Redirect(w, r, signInURL, http.StatusFound) + return nil } // Authenticate converts an authorization code returned from the identity @@ -340,3 +310,38 @@ func (p *Provider) GetOauthConfig() (*oauth2.Config, error) { } return p.cfg.getOauthConfig(pp), nil } + +// SignOut uses the EndSessionURL endpoint to allow a logout session to be initiated. +// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated +func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error { + _, err := p.GetProvider() + if err != nil { + return err + } + + if p.EndSessionURL == "" { + return ErrSignoutNotImplemented + } + + endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL) + if err != nil { + return err + } + + params := endSessionURL.Query() + if idTokenHint != "" { + params.Add("id_token_hint", idTokenHint) + } + if oa, err := p.GetOauthConfig(); err == nil { + params.Add("client_id", oa.ClientID) + } + if redirectToURL != "" { + params.Add("post_logout_redirect_uri", redirectToURL) + } else if authenticateSignedOutURL != "" { + params.Add("post_logout_redirect_uri", authenticateSignedOutURL) + } + endSessionURL.RawQuery = params.Encode() + + httputil.Redirect(w, r, endSessionURL.String(), http.StatusFound) + return nil +} diff --git a/internal/identity/providers.go b/internal/identity/providers.go index 27cc518ea..4e8fd51f6 100644 --- a/internal/identity/providers.go +++ b/internal/identity/providers.go @@ -5,6 +5,7 @@ package identity import ( "context" "fmt" + "net/http" "golang.org/x/oauth2" @@ -28,10 +29,11 @@ type Authenticator interface { Authenticate(context.Context, string, identity.State) (*oauth2.Token, error) Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error) Revoke(context.Context, *oauth2.Token) error - GetSignInURL(state string) (string, error) - GetSignOutURL(idTokenHint, redirectToURL string) (string, error) Name() string UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error + + SignIn(w http.ResponseWriter, r *http.Request, state string) error + SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error } // NewAuthenticator returns a new identity provider based on its name.