mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 16:59:22 +02:00
core/authenticate: refactor idp sign out (#4589)
core/authenticate: refactor idp sign out (#4582) Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com>
This commit is contained in:
parent
57aead4eda
commit
e6ef8b68cc
16 changed files with 318 additions and 93 deletions
|
@ -2,7 +2,6 @@ package identity
|
|||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
|
@ -11,15 +10,15 @@ import (
|
|||
|
||||
// MockProvider provides a mocked implementation of the providers interface.
|
||||
type MockProvider struct {
|
||||
AuthenticateResponse oauth2.Token
|
||||
AuthenticateError error
|
||||
RefreshResponse oauth2.Token
|
||||
RefreshError error
|
||||
RevokeError error
|
||||
GetSignInURLResponse string
|
||||
LogOutResponse url.URL
|
||||
LogOutError error
|
||||
UpdateUserInfoError error
|
||||
AuthenticateResponse oauth2.Token
|
||||
AuthenticateError error
|
||||
RefreshResponse oauth2.Token
|
||||
RefreshError error
|
||||
RevokeError error
|
||||
GetSignInURLResponse string
|
||||
GetSignOutURLResponse string
|
||||
GetSignOutURLError error
|
||||
UpdateUserInfoError error
|
||||
}
|
||||
|
||||
// Authenticate is a mocked providers function.
|
||||
|
@ -40,8 +39,10 @@ func (mp MockProvider) Revoke(_ context.Context, _ *oauth2.Token) error {
|
|||
// GetSignInURL is a mocked providers function.
|
||||
func (mp MockProvider) GetSignInURL(_ string) (string, error) { return mp.GetSignInURLResponse, nil }
|
||||
|
||||
// LogOut is a mocked providers function.
|
||||
func (mp MockProvider) LogOut() (*url.URL, error) { return &mp.LogOutResponse, mp.LogOutError }
|
||||
// GetSignOutURL is a mocked providers function.
|
||||
func (mp MockProvider) GetSignOutURL(_, _ string) (string, error) {
|
||||
return mp.GetSignOutURLResponse, mp.GetSignOutURLError
|
||||
}
|
||||
|
||||
// UpdateUserInfo is a mocked providers function.
|
||||
func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error {
|
||||
|
|
|
@ -103,6 +103,11 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
|
|||
return authURL, nil
|
||||
}
|
||||
|
||||
// GetSignOutURL is not implemented.
|
||||
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
|
||||
return "", oidc.ErrSignoutNotImplemented
|
||||
}
|
||||
|
||||
// Authenticate converts an authorization code returned from the identity
|
||||
// provider into a token which is then converted into a user session.
|
||||
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
|
||||
|
@ -123,11 +128,6 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
|
|||
return oauth2Token, nil
|
||||
}
|
||||
|
||||
// LogOut is not implemented by Apple.
|
||||
func (p *Provider) LogOut() (*url.URL, error) {
|
||||
return nil, oidc.ErrSignoutNotImplemented
|
||||
}
|
||||
|
||||
// Refresh renews a user's session.
|
||||
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) {
|
||||
if t == nil {
|
||||
|
|
|
@ -245,9 +245,9 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
|
|||
return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil
|
||||
}
|
||||
|
||||
// LogOut is not implemented by github.
|
||||
func (p *Provider) LogOut() (*url.URL, error) {
|
||||
return nil, oidc.ErrSignoutNotImplemented
|
||||
// GetSignOutURL is not implemented.
|
||||
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
|
||||
return "", oidc.ErrSignoutNotImplemented
|
||||
}
|
||||
|
||||
// Name returns the provider name.
|
||||
|
|
|
@ -6,10 +6,12 @@ package auth0
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -47,3 +49,28 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
|
|||
func (p *Provider) Name() string {
|
||||
return Name
|
||||
}
|
||||
|
||||
// GetSignOutURL implements logout as described in https://auth0.com/docs/api/authentication#logout.
|
||||
func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) {
|
||||
oa, err := p.GetOauthConfig()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting auth0 oauth config: %w", err)
|
||||
}
|
||||
|
||||
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error parsing auth0 endpoint auth url: %w", err)
|
||||
}
|
||||
|
||||
logoutQuery := url.Values{
|
||||
"client_id": {oa.ClientID},
|
||||
}
|
||||
if redirectToURL != "" {
|
||||
logoutQuery.Set("returnTo", redirectToURL)
|
||||
}
|
||||
logoutURL := authURL.ResolveReference(&url.URL{
|
||||
Path: "/v2/logout",
|
||||
RawQuery: logoutQuery.Encode(),
|
||||
})
|
||||
return logoutURL.String(), nil
|
||||
}
|
||||
|
|
61
internal/identity/oidc/auth0/auth0_test.go
Normal file
61
internal/identity/oidc/auth0/auth0_test.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package auth0
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||
)
|
||||
|
||||
func TestProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
t.Cleanup(clearTimeout)
|
||||
|
||||
var srv *httptest.Server
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
baseURL, err := url.Parse(srv.URL + "/")
|
||||
require.NoError(t, err)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/openid-configuration":
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"issuer": baseURL.String(),
|
||||
"authorization_endpoint": srv.URL + "/authorize",
|
||||
})
|
||||
|
||||
default:
|
||||
assert.Failf(t, "unexpected http request", "url: %s", r.URL.String())
|
||||
}
|
||||
})
|
||||
srv = httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
redirectURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
ProviderURL: srv.URL,
|
||||
RedirectURL: redirectURL,
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, p)
|
||||
|
||||
t.Run("GetSignOutURL", func(t *testing.T) {
|
||||
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
|
||||
})
|
||||
}
|
|
@ -48,14 +48,30 @@ func New(ctx context.Context, opts *oauth.Options) (*Provider, error) {
|
|||
// https://docs.aws.amazon.com/cognito/latest/developerguide/revocation-endpoint.html
|
||||
p.RevocationURL = cognitoURL.ResolveReference(&url.URL{Path: "/oauth2/revoke"}).String()
|
||||
|
||||
// https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html
|
||||
p.EndSessionURL = cognitoURL.ResolveReference(&url.URL{
|
||||
Path: "/logout",
|
||||
RawQuery: url.Values{
|
||||
"client_id": []string{opts.ClientID},
|
||||
"logout_uri": []string{opts.RedirectURL.ResolveReference(&url.URL{Path: "/"}).String()},
|
||||
}.Encode(),
|
||||
}).String()
|
||||
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// GetSignOutURL gets the sign out URL according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html.
|
||||
func (p *Provider) GetSignOutURL(_, returnToURL string) (string, error) {
|
||||
oa, err := p.GetOauthConfig()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting cognito oauth config: %w", err)
|
||||
}
|
||||
|
||||
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error getting cognito endpoint auth url: %w", err)
|
||||
}
|
||||
|
||||
logOutQuery := url.Values{
|
||||
"client_id": []string{oa.ClientID},
|
||||
}
|
||||
if returnToURL != "" {
|
||||
logOutQuery.Set("logout_uri", returnToURL)
|
||||
}
|
||||
logOutURL := authURL.ResolveReference(&url.URL{
|
||||
Path: "/logout",
|
||||
RawQuery: logOutQuery.Encode(),
|
||||
})
|
||||
return logOutURL.String(), nil
|
||||
}
|
||||
|
|
61
internal/identity/oidc/cognito/cognito_test.go
Normal file
61
internal/identity/oidc/cognito/cognito_test.go
Normal file
|
@ -0,0 +1,61 @@
|
|||
package cognito
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||
)
|
||||
|
||||
func TestProvider(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
t.Cleanup(clearTimeout)
|
||||
|
||||
var srv *httptest.Server
|
||||
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
baseURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
switch r.URL.Path {
|
||||
case "/.well-known/openid-configuration":
|
||||
json.NewEncoder(w).Encode(map[string]any{
|
||||
"issuer": baseURL.String(),
|
||||
"authorization_endpoint": srv.URL + "/authorize",
|
||||
})
|
||||
|
||||
default:
|
||||
assert.Failf(t, "unexpected http request", "url: %s", r.URL.String())
|
||||
}
|
||||
})
|
||||
srv = httptest.NewServer(handler)
|
||||
t.Cleanup(srv.Close)
|
||||
|
||||
redirectURL, err := url.Parse(srv.URL)
|
||||
require.NoError(t, err)
|
||||
|
||||
p, err := New(ctx, &oauth.Options{
|
||||
ProviderURL: srv.URL,
|
||||
RedirectURL: redirectURL,
|
||||
ClientID: "CLIENT_ID",
|
||||
ClientSecret: "CLIENT_SECRET",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, p)
|
||||
|
||||
t.Run("GetSignOutURL", func(t *testing.T) {
|
||||
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
|
||||
})
|
||||
}
|
|
@ -116,6 +116,38 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
|
|||
return oa.AuthCodeURL(state, opts...), nil
|
||||
}
|
||||
|
||||
// GetSignOutURL returns the EndSessionURL endpoint to allow a logout
|
||||
// session to be initiated.
|
||||
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
|
||||
func (p *Provider) GetSignOutURL(idTokenHint, redirectToURL string) (string, error) {
|
||||
_, err := p.GetProvider()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if p.EndSessionURL == "" {
|
||||
return "", ErrSignoutNotImplemented
|
||||
}
|
||||
|
||||
endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
params := endSessionURL.Query()
|
||||
if idTokenHint != "" {
|
||||
params.Add("id_token_hint", idTokenHint)
|
||||
}
|
||||
if oa, err := p.GetOauthConfig(); err == nil {
|
||||
params.Add("client_id", oa.ClientID)
|
||||
}
|
||||
if redirectToURL != "" {
|
||||
params.Add("post_logout_redirect_uri", redirectToURL)
|
||||
}
|
||||
endSessionURL.RawQuery = params.Encode()
|
||||
return endSessionURL.String(), nil
|
||||
}
|
||||
|
||||
// Authenticate converts an authorization code returned from the identity
|
||||
// provider into a token which is then converted into a user session.
|
||||
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
|
||||
|
@ -259,20 +291,6 @@ func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// LogOut returns the EndSessionURL endpoint to allow a logout
|
||||
// session to be initiated.
|
||||
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
|
||||
func (p *Provider) LogOut() (*url.URL, error) {
|
||||
_, err := p.GetProvider()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if p.EndSessionURL == "" {
|
||||
return nil, ErrSignoutNotImplemented
|
||||
}
|
||||
return urlutil.ParseAndValidateURL(p.EndSessionURL)
|
||||
}
|
||||
|
||||
// GetSubject gets the RFC 7519 Subject claim (`sub`) from a
|
||||
func (p *Provider) GetSubject(v interface{}) (string, error) {
|
||||
b, err := json.Marshal(v)
|
||||
|
|
|
@ -5,7 +5,6 @@ package identity
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/url"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
|
@ -30,8 +29,8 @@ type Authenticator interface {
|
|||
Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error)
|
||||
Revoke(context.Context, *oauth2.Token) error
|
||||
GetSignInURL(state string) (string, error)
|
||||
GetSignOutURL(idTokenHint, redirectToURL string) (string, error)
|
||||
Name() string
|
||||
LogOut() (*url.URL, error)
|
||||
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue