diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 30e90d17f..c854f042f 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -89,6 +89,7 @@ func (a *Authenticate) mountDashboard(r *mux.Router) { // routes that don't need a session: sr.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut)) + sr.Path("/signed_out").Handler(handlers.SignedOut(handlers.SignedOutData{})).Methods(http.MethodGet) // routes that need a session: sr = sr.NewRoute().Subrouter() @@ -266,33 +267,35 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e rawIDToken := a.revokeSession(ctx, w, r) - redirectString := "" - signOutURL, err := options.GetSignOutRedirectURL() + authenticateURL, err := options.GetAuthenticateURL() + if err != nil { + return fmt.Errorf("error getting authenticate url: %w", err) + } + + signOutRedirectURL, err := options.GetSignOutRedirectURL() if err != nil { return err } - if signOutURL != nil { - redirectString = signOutURL.String() - } + + var signOutURL string if uri := r.FormValue(urlutil.QueryRedirectURI); uri != "" { - redirectString = uri + signOutURL = uri + } else if signOutRedirectURL != nil { + signOutURL = signOutRedirectURL.String() + } else { + signOutURL = authenticateURL.ResolveReference(&url.URL{ + Path: "/.pomerium/signed_out", + }).String() } - endSessionURL, err := authenticator.LogOut() - if err == nil && redirectString != "" { - params := endSessionURL.Query() - params.Add("id_token_hint", rawIDToken) - params.Add("post_logout_redirect_uri", redirectString) - endSessionURL.RawQuery = params.Encode() - redirectString = endSessionURL.String() - } else if err != nil && !errors.Is(err, oidc.ErrSignoutNotImplemented) { - log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session") + if idpSignOutURL, err := authenticator.GetSignOutURL(rawIDToken, signOutURL); err == nil { + signOutURL = idpSignOutURL + } else if !errors.Is(err, oidc.ErrSignoutNotImplemented) { + log.Warn(r.Context()).Err(err).Msg("authenticate: failed to get sign out url for authenticator") } - if redirectString != "" { - httputil.Redirect(w, r, redirectString, http.StatusFound) - return nil - } - return httputil.NewError(http.StatusOK, errors.New("user logged out")) + + httputil.Redirect(w, r, signOutURL, http.StatusFound) + return nil } // reauthenticateOrFail starts the authenticate process by redirecting the diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 1a33a8203..943977314 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -135,7 +135,7 @@ func TestAuthenticate_SignOut(t *testing.T) { "", "sig", "ts", - identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, + identity.MockProvider{GetSignOutURLResponse: "https://microsoft.com"}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusFound, "", @@ -148,7 +148,7 @@ func TestAuthenticate_SignOut(t *testing.T) { "https://signout-redirect-url.example.com", "sig", "ts", - identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, + identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusFound, "", @@ -161,7 +161,7 @@ func TestAuthenticate_SignOut(t *testing.T) { "", "sig", "ts", - identity.MockProvider{RevokeError: errors.New("OH NO")}, + identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusFound, "", @@ -174,7 +174,7 @@ func TestAuthenticate_SignOut(t *testing.T) { "", "sig", "ts", - identity.MockProvider{RevokeError: errors.New("OH NO")}, + identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusFound, "", @@ -187,24 +187,11 @@ func TestAuthenticate_SignOut(t *testing.T) { "", "sig", "ts", - identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented}, + identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented}, &mstore.Store{Encrypted: true, Session: &sessions.State{}}, http.StatusFound, "", }, - { - "no redirect uri", - http.MethodPost, - nil, - "", - "", - "sig", - "ts", - identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, - &mstore.Store{Encrypted: true, Session: &sessions.State{}}, - http.StatusOK, - "{\"Status\":200}\n", - }, } for _, tt := range tests { tt := tt @@ -253,7 +240,7 @@ func TestAuthenticate_SignOut(t *testing.T) { } if tt.signoutRedirectURL != "" { loc := w.Header().Get("Location") - assert.Contains(t, loc, url.QueryEscape(tt.signoutRedirectURL)) + assert.Contains(t, loc, tt.signoutRedirectURL) } }) } diff --git a/internal/handlers/signedout.go b/internal/handlers/signedout.go new file mode 100644 index 000000000..099a333db --- /dev/null +++ b/internal/handlers/signedout.go @@ -0,0 +1,23 @@ +package handlers + +import ( + "net/http" + + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/ui" +) + +// SignedOutData is the data for the SignedOut page. +type SignedOutData struct{} + +// ToJSON converts the data into a JSON map. +func (data SignedOutData) ToJSON() map[string]interface{} { + return map[string]interface{}{} +} + +// SignedOut returns a handler that renders the signed out page. +func SignedOut(data SignedOutData) http.Handler { + return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { + return ui.ServePage(w, r, "SignedOut", data.ToJSON()) + }) +} diff --git a/internal/identity/mock_provider.go b/internal/identity/mock_provider.go index 221e4d24f..bdd0b2a96 100644 --- a/internal/identity/mock_provider.go +++ b/internal/identity/mock_provider.go @@ -2,7 +2,6 @@ package identity import ( "context" - "net/url" "golang.org/x/oauth2" @@ -11,15 +10,15 @@ import ( // MockProvider provides a mocked implementation of the providers interface. type MockProvider struct { - AuthenticateResponse oauth2.Token - AuthenticateError error - RefreshResponse oauth2.Token - RefreshError error - RevokeError error - GetSignInURLResponse string - LogOutResponse url.URL - LogOutError error - UpdateUserInfoError error + AuthenticateResponse oauth2.Token + AuthenticateError error + RefreshResponse oauth2.Token + RefreshError error + RevokeError error + GetSignInURLResponse string + GetSignOutURLResponse string + GetSignOutURLError error + UpdateUserInfoError error } // Authenticate is a mocked providers function. @@ -40,8 +39,10 @@ func (mp MockProvider) Revoke(_ context.Context, _ *oauth2.Token) error { // GetSignInURL is a mocked providers function. func (mp MockProvider) GetSignInURL(_ string) (string, error) { return mp.GetSignInURLResponse, nil } -// LogOut is a mocked providers function. -func (mp MockProvider) LogOut() (*url.URL, error) { return &mp.LogOutResponse, mp.LogOutError } +// GetSignOutURL is a mocked providers function. +func (mp MockProvider) GetSignOutURL(_, _ string) (string, error) { + return mp.GetSignOutURLResponse, mp.GetSignOutURLError +} // UpdateUserInfo is a mocked providers function. func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error { diff --git a/internal/identity/oauth/apple/apple.go b/internal/identity/oauth/apple/apple.go index d734bfef8..53878606f 100644 --- a/internal/identity/oauth/apple/apple.go +++ b/internal/identity/oauth/apple/apple.go @@ -103,6 +103,11 @@ func (p *Provider) GetSignInURL(state string) (string, error) { return authURL, nil } +// GetSignOutURL is not implemented. +func (p *Provider) GetSignOutURL(_, _ string) (string, error) { + return "", oidc.ErrSignoutNotImplemented +} + // Authenticate converts an authorization code returned from the identity // provider into a token which is then converted into a user session. func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) { @@ -123,11 +128,6 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta return oauth2Token, nil } -// LogOut is not implemented by Apple. -func (p *Provider) LogOut() (*url.URL, error) { - return nil, oidc.ErrSignoutNotImplemented -} - // Refresh renews a user's session. func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) { if t == nil { diff --git a/internal/identity/oauth/github/github.go b/internal/identity/oauth/github/github.go index e5e3d98da..2e3ee1c88 100644 --- a/internal/identity/oauth/github/github.go +++ b/internal/identity/oauth/github/github.go @@ -245,9 +245,9 @@ func (p *Provider) GetSignInURL(state string) (string, error) { return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil } -// LogOut is not implemented by github. -func (p *Provider) LogOut() (*url.URL, error) { - return nil, oidc.ErrSignoutNotImplemented +// GetSignOutURL is not implemented. +func (p *Provider) GetSignOutURL(_, _ string) (string, error) { + return "", oidc.ErrSignoutNotImplemented } // Name returns the provider name. diff --git a/internal/identity/oidc/auth0/auth0.go b/internal/identity/oidc/auth0/auth0.go index abb12ece6..b417dacbb 100644 --- a/internal/identity/oidc/auth0/auth0.go +++ b/internal/identity/oidc/auth0/auth0.go @@ -6,10 +6,12 @@ package auth0 import ( "context" "fmt" + "net/url" "strings" "github.com/pomerium/pomerium/internal/identity/oauth" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" + "github.com/pomerium/pomerium/internal/urlutil" ) const ( @@ -47,3 +49,28 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) { func (p *Provider) Name() string { return Name } + +// GetSignOutURL implements logout as described in https://auth0.com/docs/api/authentication#logout. +func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) { + oa, err := p.GetOauthConfig() + if err != nil { + return "", fmt.Errorf("error getting auth0 oauth config: %w", err) + } + + authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL) + if err != nil { + return "", fmt.Errorf("error parsing auth0 endpoint auth url: %w", err) + } + + logoutQuery := url.Values{ + "client_id": {oa.ClientID}, + } + if redirectToURL != "" { + logoutQuery.Set("returnTo", redirectToURL) + } + logoutURL := authURL.ResolveReference(&url.URL{ + Path: "/v2/logout", + RawQuery: logoutQuery.Encode(), + }) + return logoutURL.String(), nil +} diff --git a/internal/identity/oidc/auth0/auth0_test.go b/internal/identity/oidc/auth0/auth0_test.go new file mode 100644 index 000000000..d19a17960 --- /dev/null +++ b/internal/identity/oidc/auth0/auth0_test.go @@ -0,0 +1,61 @@ +package auth0 + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/identity/oauth" +) + +func TestProvider(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + t.Cleanup(clearTimeout) + + var srv *httptest.Server + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + baseURL, err := url.Parse(srv.URL + "/") + require.NoError(t, err) + + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/.well-known/openid-configuration": + json.NewEncoder(w).Encode(map[string]any{ + "issuer": baseURL.String(), + "authorization_endpoint": srv.URL + "/authorize", + }) + + default: + assert.Failf(t, "unexpected http request", "url: %s", r.URL.String()) + } + }) + srv = httptest.NewServer(handler) + t.Cleanup(srv.Close) + + redirectURL, err := url.Parse(srv.URL) + require.NoError(t, err) + + p, err := New(ctx, &oauth.Options{ + ProviderURL: srv.URL, + RedirectURL: redirectURL, + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + }) + require.NoError(t, err) + require.NotNil(t, p) + + t.Run("GetSignOutURL", func(t *testing.T) { + signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b") + assert.NoError(t, err) + assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL) + }) +} diff --git a/internal/identity/oidc/cognito/cognito.go b/internal/identity/oidc/cognito/cognito.go index 74a88e9a0..305d93694 100644 --- a/internal/identity/oidc/cognito/cognito.go +++ b/internal/identity/oidc/cognito/cognito.go @@ -48,14 +48,30 @@ func New(ctx context.Context, opts *oauth.Options) (*Provider, error) { // https://docs.aws.amazon.com/cognito/latest/developerguide/revocation-endpoint.html p.RevocationURL = cognitoURL.ResolveReference(&url.URL{Path: "/oauth2/revoke"}).String() - // https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html - p.EndSessionURL = cognitoURL.ResolveReference(&url.URL{ - Path: "/logout", - RawQuery: url.Values{ - "client_id": []string{opts.ClientID}, - "logout_uri": []string{opts.RedirectURL.ResolveReference(&url.URL{Path: "/"}).String()}, - }.Encode(), - }).String() - return &p, nil } + +// GetSignOutURL gets the sign out URL according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html. +func (p *Provider) GetSignOutURL(_, returnToURL string) (string, error) { + oa, err := p.GetOauthConfig() + if err != nil { + return "", fmt.Errorf("error getting cognito oauth config: %w", err) + } + + authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL) + if err != nil { + return "", fmt.Errorf("error getting cognito endpoint auth url: %w", err) + } + + logOutQuery := url.Values{ + "client_id": []string{oa.ClientID}, + } + if returnToURL != "" { + logOutQuery.Set("logout_uri", returnToURL) + } + logOutURL := authURL.ResolveReference(&url.URL{ + Path: "/logout", + RawQuery: logOutQuery.Encode(), + }) + return logOutURL.String(), nil +} diff --git a/internal/identity/oidc/cognito/cognito_test.go b/internal/identity/oidc/cognito/cognito_test.go new file mode 100644 index 000000000..bbd28c6c1 --- /dev/null +++ b/internal/identity/oidc/cognito/cognito_test.go @@ -0,0 +1,61 @@ +package cognito + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/identity/oauth" +) + +func TestProvider(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + t.Cleanup(clearTimeout) + + var srv *httptest.Server + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + baseURL, err := url.Parse(srv.URL) + require.NoError(t, err) + + w.Header().Set("Content-Type", "application/json") + switch r.URL.Path { + case "/.well-known/openid-configuration": + json.NewEncoder(w).Encode(map[string]any{ + "issuer": baseURL.String(), + "authorization_endpoint": srv.URL + "/authorize", + }) + + default: + assert.Failf(t, "unexpected http request", "url: %s", r.URL.String()) + } + }) + srv = httptest.NewServer(handler) + t.Cleanup(srv.Close) + + redirectURL, err := url.Parse(srv.URL) + require.NoError(t, err) + + p, err := New(ctx, &oauth.Options{ + ProviderURL: srv.URL, + RedirectURL: redirectURL, + ClientID: "CLIENT_ID", + ClientSecret: "CLIENT_SECRET", + }) + require.NoError(t, err) + require.NotNil(t, p) + + t.Run("GetSignOutURL", func(t *testing.T) { + signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b") + assert.NoError(t, err) + assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL) + }) +} diff --git a/internal/identity/oidc/oidc.go b/internal/identity/oidc/oidc.go index 13fbce41f..832a62a85 100644 --- a/internal/identity/oidc/oidc.go +++ b/internal/identity/oidc/oidc.go @@ -116,6 +116,38 @@ func (p *Provider) GetSignInURL(state string) (string, error) { return oa.AuthCodeURL(state, opts...), nil } +// GetSignOutURL returns the EndSessionURL endpoint to allow a logout +// session to be initiated. +// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated +func (p *Provider) GetSignOutURL(idTokenHint, redirectToURL string) (string, error) { + _, err := p.GetProvider() + if err != nil { + return "", err + } + + if p.EndSessionURL == "" { + return "", ErrSignoutNotImplemented + } + + endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL) + if err != nil { + return "", err + } + + params := endSessionURL.Query() + if idTokenHint != "" { + params.Add("id_token_hint", idTokenHint) + } + if oa, err := p.GetOauthConfig(); err == nil { + params.Add("client_id", oa.ClientID) + } + if redirectToURL != "" { + params.Add("post_logout_redirect_uri", redirectToURL) + } + endSessionURL.RawQuery = params.Encode() + return endSessionURL.String(), nil +} + // Authenticate converts an authorization code returned from the identity // provider into a token which is then converted into a user session. func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) { @@ -259,20 +291,6 @@ func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error { return nil } -// LogOut returns the EndSessionURL endpoint to allow a logout -// session to be initiated. -// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated -func (p *Provider) LogOut() (*url.URL, error) { - _, err := p.GetProvider() - if err != nil { - return nil, err - } - if p.EndSessionURL == "" { - return nil, ErrSignoutNotImplemented - } - return urlutil.ParseAndValidateURL(p.EndSessionURL) -} - // GetSubject gets the RFC 7519 Subject claim (`sub`) from a func (p *Provider) GetSubject(v interface{}) (string, error) { b, err := json.Marshal(v) diff --git a/internal/identity/providers.go b/internal/identity/providers.go index f6fe06c57..27cc518ea 100644 --- a/internal/identity/providers.go +++ b/internal/identity/providers.go @@ -5,7 +5,6 @@ package identity import ( "context" "fmt" - "net/url" "golang.org/x/oauth2" @@ -30,8 +29,8 @@ type Authenticator interface { Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error) Revoke(context.Context, *oauth2.Token) error GetSignInURL(state string) (string, error) + GetSignOutURL(idTokenHint, redirectToURL string) (string, error) Name() string - LogOut() (*url.URL, error) UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error } diff --git a/ui/src/App.tsx b/ui/src/App.tsx index 56851e5b2..48d645910 100644 --- a/ui/src/App.tsx +++ b/ui/src/App.tsx @@ -1,12 +1,13 @@ import Box from "@mui/material/Box"; import CssBaseline from "@mui/material/CssBaseline"; import { ThemeProvider } from "@mui/material/styles"; -import React, {FC, useLayoutEffect} from "react"; +import React, { FC, useLayoutEffect } from "react"; import ErrorPage from "./components/ErrorPage"; import Footer from "./components/Footer"; import Header from "./components/Header"; import SignOutConfirmPage from "./components/SignOutConfirmPage"; +import SignedOutPage from "./components/SignedOutPage"; import { ToolbarOffset } from "./components/ToolbarOffset"; import UserInfoPage from "./components/UserInfoPage"; import WebAuthnRegistrationPage from "./components/WebAuthnRegistrationPage"; @@ -27,6 +28,9 @@ const App: FC = () => { case "SignOutConfirm": body = ; break; + case "SignedOut": + body = ; + break; case "DeviceEnrolled": case "UserInfo": body = ; @@ -38,18 +42,18 @@ const App: FC = () => { useLayoutEffect(() => { const favicon = document.getElementById( - 'favicon' + "favicon" ) as HTMLAnchorElement | null; if (favicon) { - favicon.href = data?.faviconUrl || '/.pomerium/favicon.ico'; + favicon.href = data?.faviconUrl || "/.pomerium/favicon.ico"; } const extraFaviconLinks = document.getElementsByClassName( - 'pomerium_favicon' + "pomerium_favicon" ) as HTMLCollectionOf | null; for (const link of extraFaviconLinks) { - link.style.display = data?.faviconUrl ? 'none' : ''; + link.style.display = data?.faviconUrl ? "none" : ""; } - }, []) + }, []); return ( diff --git a/ui/src/components/Header.tsx b/ui/src/components/Header.tsx index dfa11e3d9..e562fb85a 100644 --- a/ui/src/components/Header.tsx +++ b/ui/src/components/Header.tsx @@ -57,6 +57,8 @@ const Header: FC = ({ includeSidebar, data }) => { get(data, "user.claims.picture") || get(data, "profile.claims.picture") || null; + const showAvatar = + data?.page !== "SignOutConfirm" && data?.page !== "SignedOut"; const handleDrawerOpen = () => { setDrawerOpen(true); @@ -122,9 +124,11 @@ const Header: FC = ({ includeSidebar, data }) => { )} - - - + {showAvatar && ( + + + + )} = ({ data }) => { + return ( + + User has been logged out. + + ); +}; +export default SignedOutPage; diff --git a/ui/src/types/index.ts b/ui/src/types/index.ts index dbf62c207..d8158402f 100644 --- a/ui/src/types/index.ts +++ b/ui/src/types/index.ts @@ -132,6 +132,10 @@ export type SignOutConfirmPageData = BasePageData & { url: string; }; +export type SignedOutPageData = BasePageData & { + page: "SignedOut"; +}; + export type UserInfoPageData = BasePageData & UserInfoData & { page: "UserInfo"; @@ -150,6 +154,7 @@ export type PageData = | ErrorPageData | DeviceEnrolledPageData | SignOutConfirmPageData + | SignedOutPageData | UserInfoPageData | WebAuthnRegistrationPageData;