core/authenticate: refactor idp sign out (#4582)

This commit is contained in:
Caleb Doxsey 2023-09-28 08:41:19 -07:00 committed by GitHub
parent 7211a8d819
commit a0c92896ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 318 additions and 93 deletions

View file

@ -89,6 +89,7 @@ func (a *Authenticate) mountDashboard(r *mux.Router) {
// routes that don't need a session:
sr.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))
sr.Path("/signed_out").Handler(handlers.SignedOut(handlers.SignedOutData{})).Methods(http.MethodGet)
// routes that need a session:
sr = sr.NewRoute().Subrouter()
@ -266,33 +267,35 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
rawIDToken := a.revokeSession(ctx, w, r)
redirectString := ""
signOutURL, err := options.GetSignOutRedirectURL()
authenticateURL, err := options.GetAuthenticateURL()
if err != nil {
return fmt.Errorf("error getting authenticate url: %w", err)
}
signOutRedirectURL, err := options.GetSignOutRedirectURL()
if err != nil {
return err
}
if signOutURL != nil {
redirectString = signOutURL.String()
}
var signOutURL string
if uri := r.FormValue(urlutil.QueryRedirectURI); uri != "" {
redirectString = uri
signOutURL = uri
} else if signOutRedirectURL != nil {
signOutURL = signOutRedirectURL.String()
} else {
signOutURL = authenticateURL.ResolveReference(&url.URL{
Path: "/.pomerium/signed_out",
}).String()
}
endSessionURL, err := authenticator.LogOut()
if err == nil && redirectString != "" {
params := endSessionURL.Query()
params.Add("id_token_hint", rawIDToken)
params.Add("post_logout_redirect_uri", redirectString)
endSessionURL.RawQuery = params.Encode()
redirectString = endSessionURL.String()
} else if err != nil && !errors.Is(err, oidc.ErrSignoutNotImplemented) {
log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session")
if idpSignOutURL, err := authenticator.GetSignOutURL(rawIDToken, signOutURL); err == nil {
signOutURL = idpSignOutURL
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
log.Warn(r.Context()).Err(err).Msg("authenticate: failed to get sign out url for authenticator")
}
if redirectString != "" {
httputil.Redirect(w, r, redirectString, http.StatusFound)
return nil
}
return httputil.NewError(http.StatusOK, errors.New("user logged out"))
httputil.Redirect(w, r, signOutURL, http.StatusFound)
return nil
}
// reauthenticateOrFail starts the authenticate process by redirecting the

View file

@ -135,7 +135,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"",
"sig",
"ts",
identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))},
identity.MockProvider{GetSignOutURLResponse: "https://microsoft.com"},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
@ -148,7 +148,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"https://signout-redirect-url.example.com",
"sig",
"ts",
identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))},
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
@ -161,7 +161,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"",
"sig",
"ts",
identity.MockProvider{RevokeError: errors.New("OH NO")},
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
@ -174,7 +174,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"",
"sig",
"ts",
identity.MockProvider{RevokeError: errors.New("OH NO")},
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
@ -187,24 +187,11 @@ func TestAuthenticate_SignOut(t *testing.T) {
"",
"sig",
"ts",
identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented},
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
},
{
"no redirect uri",
http.MethodPost,
nil,
"",
"",
"sig",
"ts",
identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusOK,
"{\"Status\":200}\n",
},
}
for _, tt := range tests {
tt := tt
@ -253,7 +240,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
}
if tt.signoutRedirectURL != "" {
loc := w.Header().Get("Location")
assert.Contains(t, loc, url.QueryEscape(tt.signoutRedirectURL))
assert.Contains(t, loc, tt.signoutRedirectURL)
}
})
}

View file

@ -0,0 +1,23 @@
package handlers
import (
"net/http"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/ui"
)
// SignedOutData is the data for the SignedOut page.
type SignedOutData struct{}
// ToJSON converts the data into a JSON map.
func (data SignedOutData) ToJSON() map[string]interface{} {
return map[string]interface{}{}
}
// SignedOut returns a handler that renders the signed out page.
func SignedOut(data SignedOutData) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
return ui.ServePage(w, r, "SignedOut", data.ToJSON())
})
}

View file

@ -2,7 +2,6 @@ package identity
import (
"context"
"net/url"
"golang.org/x/oauth2"
@ -11,15 +10,15 @@ import (
// MockProvider provides a mocked implementation of the providers interface.
type MockProvider struct {
AuthenticateResponse oauth2.Token
AuthenticateError error
RefreshResponse oauth2.Token
RefreshError error
RevokeError error
GetSignInURLResponse string
LogOutResponse url.URL
LogOutError error
UpdateUserInfoError error
AuthenticateResponse oauth2.Token
AuthenticateError error
RefreshResponse oauth2.Token
RefreshError error
RevokeError error
GetSignInURLResponse string
GetSignOutURLResponse string
GetSignOutURLError error
UpdateUserInfoError error
}
// Authenticate is a mocked providers function.
@ -40,8 +39,10 @@ func (mp MockProvider) Revoke(_ context.Context, _ *oauth2.Token) error {
// GetSignInURL is a mocked providers function.
func (mp MockProvider) GetSignInURL(_ string) (string, error) { return mp.GetSignInURLResponse, nil }
// LogOut is a mocked providers function.
func (mp MockProvider) LogOut() (*url.URL, error) { return &mp.LogOutResponse, mp.LogOutError }
// GetSignOutURL is a mocked providers function.
func (mp MockProvider) GetSignOutURL(_, _ string) (string, error) {
return mp.GetSignOutURLResponse, mp.GetSignOutURLError
}
// UpdateUserInfo is a mocked providers function.
func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error {

View file

@ -103,6 +103,11 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
return authURL, nil
}
// GetSignOutURL is not implemented.
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
return "", oidc.ErrSignoutNotImplemented
}
// Authenticate converts an authorization code returned from the identity
// provider into a token which is then converted into a user session.
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
@ -123,11 +128,6 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
return oauth2Token, nil
}
// LogOut is not implemented by Apple.
func (p *Provider) LogOut() (*url.URL, error) {
return nil, oidc.ErrSignoutNotImplemented
}
// Refresh renews a user's session.
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) {
if t == nil {

View file

@ -245,9 +245,9 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil
}
// LogOut is not implemented by github.
func (p *Provider) LogOut() (*url.URL, error) {
return nil, oidc.ErrSignoutNotImplemented
// GetSignOutURL is not implemented.
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
return "", oidc.ErrSignoutNotImplemented
}
// Name returns the provider name.

View file

@ -6,10 +6,12 @@ package auth0
import (
"context"
"fmt"
"net/url"
"strings"
"github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/urlutil"
)
const (
@ -47,3 +49,28 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
func (p *Provider) Name() string {
return Name
}
// GetSignOutURL implements logout as described in https://auth0.com/docs/api/authentication#logout.
func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) {
oa, err := p.GetOauthConfig()
if err != nil {
return "", fmt.Errorf("error getting auth0 oauth config: %w", err)
}
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
if err != nil {
return "", fmt.Errorf("error parsing auth0 endpoint auth url: %w", err)
}
logoutQuery := url.Values{
"client_id": {oa.ClientID},
}
if redirectToURL != "" {
logoutQuery.Set("returnTo", redirectToURL)
}
logoutURL := authURL.ResolveReference(&url.URL{
Path: "/v2/logout",
RawQuery: logoutQuery.Encode(),
})
return logoutURL.String(), nil
}

View file

@ -0,0 +1,61 @@
package auth0
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/identity/oauth"
)
func TestProvider(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
var srv *httptest.Server
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
baseURL, err := url.Parse(srv.URL + "/")
require.NoError(t, err)
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/.well-known/openid-configuration":
json.NewEncoder(w).Encode(map[string]any{
"issuer": baseURL.String(),
"authorization_endpoint": srv.URL + "/authorize",
})
default:
assert.Failf(t, "unexpected http request", "url: %s", r.URL.String())
}
})
srv = httptest.NewServer(handler)
t.Cleanup(srv.Close)
redirectURL, err := url.Parse(srv.URL)
require.NoError(t, err)
p, err := New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
})
require.NoError(t, err)
require.NotNil(t, p)
t.Run("GetSignOutURL", func(t *testing.T) {
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
assert.NoError(t, err)
assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
})
}

View file

@ -48,14 +48,30 @@ func New(ctx context.Context, opts *oauth.Options) (*Provider, error) {
// https://docs.aws.amazon.com/cognito/latest/developerguide/revocation-endpoint.html
p.RevocationURL = cognitoURL.ResolveReference(&url.URL{Path: "/oauth2/revoke"}).String()
// https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html
p.EndSessionURL = cognitoURL.ResolveReference(&url.URL{
Path: "/logout",
RawQuery: url.Values{
"client_id": []string{opts.ClientID},
"logout_uri": []string{opts.RedirectURL.ResolveReference(&url.URL{Path: "/"}).String()},
}.Encode(),
}).String()
return &p, nil
}
// GetSignOutURL gets the sign out URL according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html.
func (p *Provider) GetSignOutURL(_, returnToURL string) (string, error) {
oa, err := p.GetOauthConfig()
if err != nil {
return "", fmt.Errorf("error getting cognito oauth config: %w", err)
}
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
if err != nil {
return "", fmt.Errorf("error getting cognito endpoint auth url: %w", err)
}
logOutQuery := url.Values{
"client_id": []string{oa.ClientID},
}
if returnToURL != "" {
logOutQuery.Set("logout_uri", returnToURL)
}
logOutURL := authURL.ResolveReference(&url.URL{
Path: "/logout",
RawQuery: logOutQuery.Encode(),
})
return logOutURL.String(), nil
}

View file

@ -0,0 +1,61 @@
package cognito
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/identity/oauth"
)
func TestProvider(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
var srv *httptest.Server
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
baseURL, err := url.Parse(srv.URL)
require.NoError(t, err)
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/.well-known/openid-configuration":
json.NewEncoder(w).Encode(map[string]any{
"issuer": baseURL.String(),
"authorization_endpoint": srv.URL + "/authorize",
})
default:
assert.Failf(t, "unexpected http request", "url: %s", r.URL.String())
}
})
srv = httptest.NewServer(handler)
t.Cleanup(srv.Close)
redirectURL, err := url.Parse(srv.URL)
require.NoError(t, err)
p, err := New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
})
require.NoError(t, err)
require.NotNil(t, p)
t.Run("GetSignOutURL", func(t *testing.T) {
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
assert.NoError(t, err)
assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
})
}

View file

@ -116,6 +116,38 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
return oa.AuthCodeURL(state, opts...), nil
}
// GetSignOutURL returns the EndSessionURL endpoint to allow a logout
// session to be initiated.
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
func (p *Provider) GetSignOutURL(idTokenHint, redirectToURL string) (string, error) {
_, err := p.GetProvider()
if err != nil {
return "", err
}
if p.EndSessionURL == "" {
return "", ErrSignoutNotImplemented
}
endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL)
if err != nil {
return "", err
}
params := endSessionURL.Query()
if idTokenHint != "" {
params.Add("id_token_hint", idTokenHint)
}
if oa, err := p.GetOauthConfig(); err == nil {
params.Add("client_id", oa.ClientID)
}
if redirectToURL != "" {
params.Add("post_logout_redirect_uri", redirectToURL)
}
endSessionURL.RawQuery = params.Encode()
return endSessionURL.String(), nil
}
// Authenticate converts an authorization code returned from the identity
// provider into a token which is then converted into a user session.
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
@ -259,20 +291,6 @@ func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error {
return nil
}
// LogOut returns the EndSessionURL endpoint to allow a logout
// session to be initiated.
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
func (p *Provider) LogOut() (*url.URL, error) {
_, err := p.GetProvider()
if err != nil {
return nil, err
}
if p.EndSessionURL == "" {
return nil, ErrSignoutNotImplemented
}
return urlutil.ParseAndValidateURL(p.EndSessionURL)
}
// GetSubject gets the RFC 7519 Subject claim (`sub`) from a
func (p *Provider) GetSubject(v interface{}) (string, error) {
b, err := json.Marshal(v)

View file

@ -5,7 +5,6 @@ package identity
import (
"context"
"fmt"
"net/url"
"golang.org/x/oauth2"
@ -30,8 +29,8 @@ type Authenticator interface {
Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error)
Revoke(context.Context, *oauth2.Token) error
GetSignInURL(state string) (string, error)
GetSignOutURL(idTokenHint, redirectToURL string) (string, error)
Name() string
LogOut() (*url.URL, error)
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error
}

View file

@ -1,12 +1,13 @@
import Box from "@mui/material/Box";
import CssBaseline from "@mui/material/CssBaseline";
import { ThemeProvider } from "@mui/material/styles";
import React, {FC, useLayoutEffect} from "react";
import React, { FC, useLayoutEffect } from "react";
import ErrorPage from "./components/ErrorPage";
import Footer from "./components/Footer";
import Header from "./components/Header";
import SignOutConfirmPage from "./components/SignOutConfirmPage";
import SignedOutPage from "./components/SignedOutPage";
import { ToolbarOffset } from "./components/ToolbarOffset";
import UserInfoPage from "./components/UserInfoPage";
import WebAuthnRegistrationPage from "./components/WebAuthnRegistrationPage";
@ -27,6 +28,9 @@ const App: FC = () => {
case "SignOutConfirm":
body = <SignOutConfirmPage data={data} />;
break;
case "SignedOut":
body = <SignedOutPage data={data} />;
break;
case "DeviceEnrolled":
case "UserInfo":
body = <UserInfoPage data={data} />;
@ -38,18 +42,18 @@ const App: FC = () => {
useLayoutEffect(() => {
const favicon = document.getElementById(
'favicon'
"favicon"
) as HTMLAnchorElement | null;
if (favicon) {
favicon.href = data?.faviconUrl || '/.pomerium/favicon.ico';
favicon.href = data?.faviconUrl || "/.pomerium/favicon.ico";
}
const extraFaviconLinks = document.getElementsByClassName(
'pomerium_favicon'
"pomerium_favicon"
) as HTMLCollectionOf<HTMLAnchorElement> | null;
for (const link of extraFaviconLinks) {
link.style.display = data?.faviconUrl ? 'none' : '';
link.style.display = data?.faviconUrl ? "none" : "";
}
}, [])
}, []);
return (
<ThemeProvider theme={theme}>

View file

@ -57,6 +57,8 @@ const Header: FC<HeaderProps> = ({ includeSidebar, data }) => {
get(data, "user.claims.picture") ||
get(data, "profile.claims.picture") ||
null;
const showAvatar =
data?.page !== "SignOutConfirm" && data?.page !== "SignedOut";
const handleDrawerOpen = () => {
setDrawerOpen(true);
@ -122,9 +124,11 @@ const Header: FC<HeaderProps> = ({ includeSidebar, data }) => {
</a>
)}
<Box flexGrow={1} />
<IconButton color="inherit" onClick={handleMenuOpen}>
<Avatar name={userName} url={userPictureUrl} />
</IconButton>
{showAvatar && (
<IconButton color="inherit" onClick={handleMenuOpen}>
<Avatar name={userName} url={userPictureUrl} />
</IconButton>
)}
<Menu
onClose={handleMenuClose}
anchorOrigin={{

View file

@ -0,0 +1,16 @@
import { Alert } from "@mui/material";
import Container from "@mui/material/Container";
import React, { FC } from "react";
import { SignedOutPageData } from "src/types";
type SignedOutPageProps = {
data: SignedOutPageData;
};
const SignedOutPage: FC<SignedOutPageProps> = ({ data }) => {
return (
<Container>
<Alert color="info">User has been logged out.</Alert>
</Container>
);
};
export default SignedOutPage;

View file

@ -132,6 +132,10 @@ export type SignOutConfirmPageData = BasePageData & {
url: string;
};
export type SignedOutPageData = BasePageData & {
page: "SignedOut";
};
export type UserInfoPageData = BasePageData &
UserInfoData & {
page: "UserInfo";
@ -150,6 +154,7 @@ export type PageData =
| ErrorPageData
| DeviceEnrolledPageData
| SignOutConfirmPageData
| SignedOutPageData
| UserInfoPageData
| WebAuthnRegistrationPageData;