mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-05 05:16:04 +02:00
core/authenticate: refactor idp sign out (#4582)
This commit is contained in:
parent
7211a8d819
commit
a0c92896ef
16 changed files with 318 additions and 93 deletions
|
@ -89,6 +89,7 @@ func (a *Authenticate) mountDashboard(r *mux.Router) {
|
||||||
|
|
||||||
// routes that don't need a session:
|
// routes that don't need a session:
|
||||||
sr.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))
|
sr.Path("/sign_out").Handler(httputil.HandlerFunc(a.SignOut))
|
||||||
|
sr.Path("/signed_out").Handler(handlers.SignedOut(handlers.SignedOutData{})).Methods(http.MethodGet)
|
||||||
|
|
||||||
// routes that need a session:
|
// routes that need a session:
|
||||||
sr = sr.NewRoute().Subrouter()
|
sr = sr.NewRoute().Subrouter()
|
||||||
|
@ -266,33 +267,35 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
|
||||||
|
|
||||||
rawIDToken := a.revokeSession(ctx, w, r)
|
rawIDToken := a.revokeSession(ctx, w, r)
|
||||||
|
|
||||||
redirectString := ""
|
authenticateURL, err := options.GetAuthenticateURL()
|
||||||
signOutURL, err := options.GetSignOutRedirectURL()
|
if err != nil {
|
||||||
|
return fmt.Errorf("error getting authenticate url: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signOutRedirectURL, err := options.GetSignOutRedirectURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if signOutURL != nil {
|
|
||||||
redirectString = signOutURL.String()
|
var signOutURL string
|
||||||
}
|
|
||||||
if uri := r.FormValue(urlutil.QueryRedirectURI); uri != "" {
|
if uri := r.FormValue(urlutil.QueryRedirectURI); uri != "" {
|
||||||
redirectString = uri
|
signOutURL = uri
|
||||||
|
} else if signOutRedirectURL != nil {
|
||||||
|
signOutURL = signOutRedirectURL.String()
|
||||||
|
} else {
|
||||||
|
signOutURL = authenticateURL.ResolveReference(&url.URL{
|
||||||
|
Path: "/.pomerium/signed_out",
|
||||||
|
}).String()
|
||||||
}
|
}
|
||||||
|
|
||||||
endSessionURL, err := authenticator.LogOut()
|
if idpSignOutURL, err := authenticator.GetSignOutURL(rawIDToken, signOutURL); err == nil {
|
||||||
if err == nil && redirectString != "" {
|
signOutURL = idpSignOutURL
|
||||||
params := endSessionURL.Query()
|
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
|
||||||
params.Add("id_token_hint", rawIDToken)
|
log.Warn(r.Context()).Err(err).Msg("authenticate: failed to get sign out url for authenticator")
|
||||||
params.Add("post_logout_redirect_uri", redirectString)
|
|
||||||
endSessionURL.RawQuery = params.Encode()
|
|
||||||
redirectString = endSessionURL.String()
|
|
||||||
} else if err != nil && !errors.Is(err, oidc.ErrSignoutNotImplemented) {
|
|
||||||
log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session")
|
|
||||||
}
|
}
|
||||||
if redirectString != "" {
|
|
||||||
httputil.Redirect(w, r, redirectString, http.StatusFound)
|
httputil.Redirect(w, r, signOutURL, http.StatusFound)
|
||||||
return nil
|
return nil
|
||||||
}
|
|
||||||
return httputil.NewError(http.StatusOK, errors.New("user logged out"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// reauthenticateOrFail starts the authenticate process by redirecting the
|
// reauthenticateOrFail starts the authenticate process by redirecting the
|
||||||
|
|
|
@ -135,7 +135,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
"",
|
"",
|
||||||
"sig",
|
"sig",
|
||||||
"ts",
|
"ts",
|
||||||
identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))},
|
identity.MockProvider{GetSignOutURLResponse: "https://microsoft.com"},
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
"",
|
"",
|
||||||
|
@ -148,7 +148,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
"https://signout-redirect-url.example.com",
|
"https://signout-redirect-url.example.com",
|
||||||
"sig",
|
"sig",
|
||||||
"ts",
|
"ts",
|
||||||
identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))},
|
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented},
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
"",
|
"",
|
||||||
|
@ -161,7 +161,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
"",
|
"",
|
||||||
"sig",
|
"sig",
|
||||||
"ts",
|
"ts",
|
||||||
identity.MockProvider{RevokeError: errors.New("OH NO")},
|
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
"",
|
"",
|
||||||
|
@ -174,7 +174,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
"",
|
"",
|
||||||
"sig",
|
"sig",
|
||||||
"ts",
|
"ts",
|
||||||
identity.MockProvider{RevokeError: errors.New("OH NO")},
|
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
"",
|
"",
|
||||||
|
@ -187,24 +187,11 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
"",
|
"",
|
||||||
"sig",
|
"sig",
|
||||||
"ts",
|
"ts",
|
||||||
identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented},
|
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented},
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
"",
|
"",
|
||||||
},
|
},
|
||||||
{
|
|
||||||
"no redirect uri",
|
|
||||||
http.MethodPost,
|
|
||||||
nil,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
"sig",
|
|
||||||
"ts",
|
|
||||||
identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))},
|
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
|
||||||
http.StatusOK,
|
|
||||||
"{\"Status\":200}\n",
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
tt := tt
|
tt := tt
|
||||||
|
@ -253,7 +240,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
}
|
}
|
||||||
if tt.signoutRedirectURL != "" {
|
if tt.signoutRedirectURL != "" {
|
||||||
loc := w.Header().Get("Location")
|
loc := w.Header().Get("Location")
|
||||||
assert.Contains(t, loc, url.QueryEscape(tt.signoutRedirectURL))
|
assert.Contains(t, loc, tt.signoutRedirectURL)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
23
internal/handlers/signedout.go
Normal file
23
internal/handlers/signedout.go
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/ui"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SignedOutData is the data for the SignedOut page.
|
||||||
|
type SignedOutData struct{}
|
||||||
|
|
||||||
|
// ToJSON converts the data into a JSON map.
|
||||||
|
func (data SignedOutData) ToJSON() map[string]interface{} {
|
||||||
|
return map[string]interface{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignedOut returns a handler that renders the signed out page.
|
||||||
|
func SignedOut(data SignedOutData) http.Handler {
|
||||||
|
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
return ui.ServePage(w, r, "SignedOut", data.ToJSON())
|
||||||
|
})
|
||||||
|
}
|
|
@ -2,7 +2,6 @@ package identity
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
@ -11,15 +10,15 @@ import (
|
||||||
|
|
||||||
// MockProvider provides a mocked implementation of the providers interface.
|
// MockProvider provides a mocked implementation of the providers interface.
|
||||||
type MockProvider struct {
|
type MockProvider struct {
|
||||||
AuthenticateResponse oauth2.Token
|
AuthenticateResponse oauth2.Token
|
||||||
AuthenticateError error
|
AuthenticateError error
|
||||||
RefreshResponse oauth2.Token
|
RefreshResponse oauth2.Token
|
||||||
RefreshError error
|
RefreshError error
|
||||||
RevokeError error
|
RevokeError error
|
||||||
GetSignInURLResponse string
|
GetSignInURLResponse string
|
||||||
LogOutResponse url.URL
|
GetSignOutURLResponse string
|
||||||
LogOutError error
|
GetSignOutURLError error
|
||||||
UpdateUserInfoError error
|
UpdateUserInfoError error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate is a mocked providers function.
|
// Authenticate is a mocked providers function.
|
||||||
|
@ -40,8 +39,10 @@ func (mp MockProvider) Revoke(_ context.Context, _ *oauth2.Token) error {
|
||||||
// GetSignInURL is a mocked providers function.
|
// GetSignInURL is a mocked providers function.
|
||||||
func (mp MockProvider) GetSignInURL(_ string) (string, error) { return mp.GetSignInURLResponse, nil }
|
func (mp MockProvider) GetSignInURL(_ string) (string, error) { return mp.GetSignInURLResponse, nil }
|
||||||
|
|
||||||
// LogOut is a mocked providers function.
|
// GetSignOutURL is a mocked providers function.
|
||||||
func (mp MockProvider) LogOut() (*url.URL, error) { return &mp.LogOutResponse, mp.LogOutError }
|
func (mp MockProvider) GetSignOutURL(_, _ string) (string, error) {
|
||||||
|
return mp.GetSignOutURLResponse, mp.GetSignOutURLError
|
||||||
|
}
|
||||||
|
|
||||||
// UpdateUserInfo is a mocked providers function.
|
// UpdateUserInfo is a mocked providers function.
|
||||||
func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error {
|
func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error {
|
||||||
|
|
|
@ -103,6 +103,11 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
|
||||||
return authURL, nil
|
return authURL, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSignOutURL is not implemented.
|
||||||
|
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
|
||||||
|
return "", oidc.ErrSignoutNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
// Authenticate converts an authorization code returned from the identity
|
// Authenticate converts an authorization code returned from the identity
|
||||||
// provider into a token which is then converted into a user session.
|
// provider into a token which is then converted into a user session.
|
||||||
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
|
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
|
||||||
|
@ -123,11 +128,6 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
|
||||||
return oauth2Token, nil
|
return oauth2Token, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogOut is not implemented by Apple.
|
|
||||||
func (p *Provider) LogOut() (*url.URL, error) {
|
|
||||||
return nil, oidc.ErrSignoutNotImplemented
|
|
||||||
}
|
|
||||||
|
|
||||||
// Refresh renews a user's session.
|
// Refresh renews a user's session.
|
||||||
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) {
|
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) {
|
||||||
if t == nil {
|
if t == nil {
|
||||||
|
|
|
@ -245,9 +245,9 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
|
||||||
return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil
|
return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogOut is not implemented by github.
|
// GetSignOutURL is not implemented.
|
||||||
func (p *Provider) LogOut() (*url.URL, error) {
|
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
|
||||||
return nil, oidc.ErrSignoutNotImplemented
|
return "", oidc.ErrSignoutNotImplemented
|
||||||
}
|
}
|
||||||
|
|
||||||
// Name returns the provider name.
|
// Name returns the provider name.
|
||||||
|
|
|
@ -6,10 +6,12 @@ package auth0
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
@ -47,3 +49,28 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
||||||
func (p *Provider) Name() string {
|
func (p *Provider) Name() string {
|
||||||
return Name
|
return Name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSignOutURL implements logout as described in https://auth0.com/docs/api/authentication#logout.
|
||||||
|
func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) {
|
||||||
|
oa, err := p.GetOauthConfig()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error getting auth0 oauth config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error parsing auth0 endpoint auth url: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logoutQuery := url.Values{
|
||||||
|
"client_id": {oa.ClientID},
|
||||||
|
}
|
||||||
|
if redirectToURL != "" {
|
||||||
|
logoutQuery.Set("returnTo", redirectToURL)
|
||||||
|
}
|
||||||
|
logoutURL := authURL.ResolveReference(&url.URL{
|
||||||
|
Path: "/v2/logout",
|
||||||
|
RawQuery: logoutQuery.Encode(),
|
||||||
|
})
|
||||||
|
return logoutURL.String(), nil
|
||||||
|
}
|
||||||
|
|
61
internal/identity/oidc/auth0/auth0_test.go
Normal file
61
internal/identity/oidc/auth0/auth0_test.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package auth0
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProvider(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
t.Cleanup(clearTimeout)
|
||||||
|
|
||||||
|
var srv *httptest.Server
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
baseURL, err := url.Parse(srv.URL + "/")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/.well-known/openid-configuration":
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"issuer": baseURL.String(),
|
||||||
|
"authorization_endpoint": srv.URL + "/authorize",
|
||||||
|
})
|
||||||
|
|
||||||
|
default:
|
||||||
|
assert.Failf(t, "unexpected http request", "url: %s", r.URL.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
srv = httptest.NewServer(handler)
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
redirectURL, err := url.Parse(srv.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
p, err := New(ctx, &oauth.Options{
|
||||||
|
ProviderURL: srv.URL,
|
||||||
|
RedirectURL: redirectURL,
|
||||||
|
ClientID: "CLIENT_ID",
|
||||||
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, p)
|
||||||
|
|
||||||
|
t.Run("GetSignOutURL", func(t *testing.T) {
|
||||||
|
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
|
||||||
|
})
|
||||||
|
}
|
|
@ -48,14 +48,30 @@ func New(ctx context.Context, opts *oauth.Options) (*Provider, error) {
|
||||||
// https://docs.aws.amazon.com/cognito/latest/developerguide/revocation-endpoint.html
|
// https://docs.aws.amazon.com/cognito/latest/developerguide/revocation-endpoint.html
|
||||||
p.RevocationURL = cognitoURL.ResolveReference(&url.URL{Path: "/oauth2/revoke"}).String()
|
p.RevocationURL = cognitoURL.ResolveReference(&url.URL{Path: "/oauth2/revoke"}).String()
|
||||||
|
|
||||||
// https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html
|
|
||||||
p.EndSessionURL = cognitoURL.ResolveReference(&url.URL{
|
|
||||||
Path: "/logout",
|
|
||||||
RawQuery: url.Values{
|
|
||||||
"client_id": []string{opts.ClientID},
|
|
||||||
"logout_uri": []string{opts.RedirectURL.ResolveReference(&url.URL{Path: "/"}).String()},
|
|
||||||
}.Encode(),
|
|
||||||
}).String()
|
|
||||||
|
|
||||||
return &p, nil
|
return &p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSignOutURL gets the sign out URL according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html.
|
||||||
|
func (p *Provider) GetSignOutURL(_, returnToURL string) (string, error) {
|
||||||
|
oa, err := p.GetOauthConfig()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error getting cognito oauth config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("error getting cognito endpoint auth url: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logOutQuery := url.Values{
|
||||||
|
"client_id": []string{oa.ClientID},
|
||||||
|
}
|
||||||
|
if returnToURL != "" {
|
||||||
|
logOutQuery.Set("logout_uri", returnToURL)
|
||||||
|
}
|
||||||
|
logOutURL := authURL.ResolveReference(&url.URL{
|
||||||
|
Path: "/logout",
|
||||||
|
RawQuery: logOutQuery.Encode(),
|
||||||
|
})
|
||||||
|
return logOutURL.String(), nil
|
||||||
|
}
|
||||||
|
|
61
internal/identity/oidc/cognito/cognito_test.go
Normal file
61
internal/identity/oidc/cognito/cognito_test.go
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
package cognito
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProvider(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
t.Cleanup(clearTimeout)
|
||||||
|
|
||||||
|
var srv *httptest.Server
|
||||||
|
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
baseURL, err := url.Parse(srv.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
switch r.URL.Path {
|
||||||
|
case "/.well-known/openid-configuration":
|
||||||
|
json.NewEncoder(w).Encode(map[string]any{
|
||||||
|
"issuer": baseURL.String(),
|
||||||
|
"authorization_endpoint": srv.URL + "/authorize",
|
||||||
|
})
|
||||||
|
|
||||||
|
default:
|
||||||
|
assert.Failf(t, "unexpected http request", "url: %s", r.URL.String())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
srv = httptest.NewServer(handler)
|
||||||
|
t.Cleanup(srv.Close)
|
||||||
|
|
||||||
|
redirectURL, err := url.Parse(srv.URL)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
p, err := New(ctx, &oauth.Options{
|
||||||
|
ProviderURL: srv.URL,
|
||||||
|
RedirectURL: redirectURL,
|
||||||
|
ClientID: "CLIENT_ID",
|
||||||
|
ClientSecret: "CLIENT_SECRET",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, p)
|
||||||
|
|
||||||
|
t.Run("GetSignOutURL", func(t *testing.T) {
|
||||||
|
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
|
||||||
|
})
|
||||||
|
}
|
|
@ -116,6 +116,38 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
|
||||||
return oa.AuthCodeURL(state, opts...), nil
|
return oa.AuthCodeURL(state, opts...), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetSignOutURL returns the EndSessionURL endpoint to allow a logout
|
||||||
|
// session to be initiated.
|
||||||
|
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
|
||||||
|
func (p *Provider) GetSignOutURL(idTokenHint, redirectToURL string) (string, error) {
|
||||||
|
_, err := p.GetProvider()
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.EndSessionURL == "" {
|
||||||
|
return "", ErrSignoutNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
params := endSessionURL.Query()
|
||||||
|
if idTokenHint != "" {
|
||||||
|
params.Add("id_token_hint", idTokenHint)
|
||||||
|
}
|
||||||
|
if oa, err := p.GetOauthConfig(); err == nil {
|
||||||
|
params.Add("client_id", oa.ClientID)
|
||||||
|
}
|
||||||
|
if redirectToURL != "" {
|
||||||
|
params.Add("post_logout_redirect_uri", redirectToURL)
|
||||||
|
}
|
||||||
|
endSessionURL.RawQuery = params.Encode()
|
||||||
|
return endSessionURL.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// Authenticate converts an authorization code returned from the identity
|
// Authenticate converts an authorization code returned from the identity
|
||||||
// provider into a token which is then converted into a user session.
|
// provider into a token which is then converted into a user session.
|
||||||
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
|
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
|
||||||
|
@ -259,20 +291,6 @@ func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// LogOut returns the EndSessionURL endpoint to allow a logout
|
|
||||||
// session to be initiated.
|
|
||||||
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
|
|
||||||
func (p *Provider) LogOut() (*url.URL, error) {
|
|
||||||
_, err := p.GetProvider()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if p.EndSessionURL == "" {
|
|
||||||
return nil, ErrSignoutNotImplemented
|
|
||||||
}
|
|
||||||
return urlutil.ParseAndValidateURL(p.EndSessionURL)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSubject gets the RFC 7519 Subject claim (`sub`) from a
|
// GetSubject gets the RFC 7519 Subject claim (`sub`) from a
|
||||||
func (p *Provider) GetSubject(v interface{}) (string, error) {
|
func (p *Provider) GetSubject(v interface{}) (string, error) {
|
||||||
b, err := json.Marshal(v)
|
b, err := json.Marshal(v)
|
||||||
|
|
|
@ -5,7 +5,6 @@ package identity
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/url"
|
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
@ -30,8 +29,8 @@ type Authenticator interface {
|
||||||
Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error)
|
Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error)
|
||||||
Revoke(context.Context, *oauth2.Token) error
|
Revoke(context.Context, *oauth2.Token) error
|
||||||
GetSignInURL(state string) (string, error)
|
GetSignInURL(state string) (string, error)
|
||||||
|
GetSignOutURL(idTokenHint, redirectToURL string) (string, error)
|
||||||
Name() string
|
Name() string
|
||||||
LogOut() (*url.URL, error)
|
|
||||||
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error
|
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,12 +1,13 @@
|
||||||
import Box from "@mui/material/Box";
|
import Box from "@mui/material/Box";
|
||||||
import CssBaseline from "@mui/material/CssBaseline";
|
import CssBaseline from "@mui/material/CssBaseline";
|
||||||
import { ThemeProvider } from "@mui/material/styles";
|
import { ThemeProvider } from "@mui/material/styles";
|
||||||
import React, {FC, useLayoutEffect} from "react";
|
import React, { FC, useLayoutEffect } from "react";
|
||||||
|
|
||||||
import ErrorPage from "./components/ErrorPage";
|
import ErrorPage from "./components/ErrorPage";
|
||||||
import Footer from "./components/Footer";
|
import Footer from "./components/Footer";
|
||||||
import Header from "./components/Header";
|
import Header from "./components/Header";
|
||||||
import SignOutConfirmPage from "./components/SignOutConfirmPage";
|
import SignOutConfirmPage from "./components/SignOutConfirmPage";
|
||||||
|
import SignedOutPage from "./components/SignedOutPage";
|
||||||
import { ToolbarOffset } from "./components/ToolbarOffset";
|
import { ToolbarOffset } from "./components/ToolbarOffset";
|
||||||
import UserInfoPage from "./components/UserInfoPage";
|
import UserInfoPage from "./components/UserInfoPage";
|
||||||
import WebAuthnRegistrationPage from "./components/WebAuthnRegistrationPage";
|
import WebAuthnRegistrationPage from "./components/WebAuthnRegistrationPage";
|
||||||
|
@ -27,6 +28,9 @@ const App: FC = () => {
|
||||||
case "SignOutConfirm":
|
case "SignOutConfirm":
|
||||||
body = <SignOutConfirmPage data={data} />;
|
body = <SignOutConfirmPage data={data} />;
|
||||||
break;
|
break;
|
||||||
|
case "SignedOut":
|
||||||
|
body = <SignedOutPage data={data} />;
|
||||||
|
break;
|
||||||
case "DeviceEnrolled":
|
case "DeviceEnrolled":
|
||||||
case "UserInfo":
|
case "UserInfo":
|
||||||
body = <UserInfoPage data={data} />;
|
body = <UserInfoPage data={data} />;
|
||||||
|
@ -38,18 +42,18 @@ const App: FC = () => {
|
||||||
|
|
||||||
useLayoutEffect(() => {
|
useLayoutEffect(() => {
|
||||||
const favicon = document.getElementById(
|
const favicon = document.getElementById(
|
||||||
'favicon'
|
"favicon"
|
||||||
) as HTMLAnchorElement | null;
|
) as HTMLAnchorElement | null;
|
||||||
if (favicon) {
|
if (favicon) {
|
||||||
favicon.href = data?.faviconUrl || '/.pomerium/favicon.ico';
|
favicon.href = data?.faviconUrl || "/.pomerium/favicon.ico";
|
||||||
}
|
}
|
||||||
const extraFaviconLinks = document.getElementsByClassName(
|
const extraFaviconLinks = document.getElementsByClassName(
|
||||||
'pomerium_favicon'
|
"pomerium_favicon"
|
||||||
) as HTMLCollectionOf<HTMLAnchorElement> | null;
|
) as HTMLCollectionOf<HTMLAnchorElement> | null;
|
||||||
for (const link of extraFaviconLinks) {
|
for (const link of extraFaviconLinks) {
|
||||||
link.style.display = data?.faviconUrl ? 'none' : '';
|
link.style.display = data?.faviconUrl ? "none" : "";
|
||||||
}
|
}
|
||||||
}, [])
|
}, []);
|
||||||
|
|
||||||
return (
|
return (
|
||||||
<ThemeProvider theme={theme}>
|
<ThemeProvider theme={theme}>
|
||||||
|
|
|
@ -57,6 +57,8 @@ const Header: FC<HeaderProps> = ({ includeSidebar, data }) => {
|
||||||
get(data, "user.claims.picture") ||
|
get(data, "user.claims.picture") ||
|
||||||
get(data, "profile.claims.picture") ||
|
get(data, "profile.claims.picture") ||
|
||||||
null;
|
null;
|
||||||
|
const showAvatar =
|
||||||
|
data?.page !== "SignOutConfirm" && data?.page !== "SignedOut";
|
||||||
|
|
||||||
const handleDrawerOpen = () => {
|
const handleDrawerOpen = () => {
|
||||||
setDrawerOpen(true);
|
setDrawerOpen(true);
|
||||||
|
@ -122,9 +124,11 @@ const Header: FC<HeaderProps> = ({ includeSidebar, data }) => {
|
||||||
</a>
|
</a>
|
||||||
)}
|
)}
|
||||||
<Box flexGrow={1} />
|
<Box flexGrow={1} />
|
||||||
<IconButton color="inherit" onClick={handleMenuOpen}>
|
{showAvatar && (
|
||||||
<Avatar name={userName} url={userPictureUrl} />
|
<IconButton color="inherit" onClick={handleMenuOpen}>
|
||||||
</IconButton>
|
<Avatar name={userName} url={userPictureUrl} />
|
||||||
|
</IconButton>
|
||||||
|
)}
|
||||||
<Menu
|
<Menu
|
||||||
onClose={handleMenuClose}
|
onClose={handleMenuClose}
|
||||||
anchorOrigin={{
|
anchorOrigin={{
|
||||||
|
|
16
ui/src/components/SignedOutPage.tsx
Normal file
16
ui/src/components/SignedOutPage.tsx
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
import { Alert } from "@mui/material";
|
||||||
|
import Container from "@mui/material/Container";
|
||||||
|
import React, { FC } from "react";
|
||||||
|
import { SignedOutPageData } from "src/types";
|
||||||
|
|
||||||
|
type SignedOutPageProps = {
|
||||||
|
data: SignedOutPageData;
|
||||||
|
};
|
||||||
|
const SignedOutPage: FC<SignedOutPageProps> = ({ data }) => {
|
||||||
|
return (
|
||||||
|
<Container>
|
||||||
|
<Alert color="info">User has been logged out.</Alert>
|
||||||
|
</Container>
|
||||||
|
);
|
||||||
|
};
|
||||||
|
export default SignedOutPage;
|
|
@ -132,6 +132,10 @@ export type SignOutConfirmPageData = BasePageData & {
|
||||||
url: string;
|
url: string;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type SignedOutPageData = BasePageData & {
|
||||||
|
page: "SignedOut";
|
||||||
|
};
|
||||||
|
|
||||||
export type UserInfoPageData = BasePageData &
|
export type UserInfoPageData = BasePageData &
|
||||||
UserInfoData & {
|
UserInfoData & {
|
||||||
page: "UserInfo";
|
page: "UserInfo";
|
||||||
|
@ -150,6 +154,7 @@ export type PageData =
|
||||||
| ErrorPageData
|
| ErrorPageData
|
||||||
| DeviceEnrolledPageData
|
| DeviceEnrolledPageData
|
||||||
| SignOutConfirmPageData
|
| SignOutConfirmPageData
|
||||||
|
| SignedOutPageData
|
||||||
| UserInfoPageData
|
| UserInfoPageData
|
||||||
| WebAuthnRegistrationPageData;
|
| WebAuthnRegistrationPageData;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue