core/authenticate: refactor idp sign out (#4589)

core/authenticate: refactor idp sign out (#4582)

Co-authored-by: Caleb Doxsey <cdoxsey@pomerium.com>
This commit is contained in:
backport-actions-token[bot] 2023-09-28 08:52:22 -07:00 committed by GitHub
parent 57aead4eda
commit e6ef8b68cc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 318 additions and 93 deletions

View file

@ -2,7 +2,6 @@ package identity
import (
"context"
"net/url"
"golang.org/x/oauth2"
@ -11,15 +10,15 @@ import (
// MockProvider provides a mocked implementation of the providers interface.
type MockProvider struct {
AuthenticateResponse oauth2.Token
AuthenticateError error
RefreshResponse oauth2.Token
RefreshError error
RevokeError error
GetSignInURLResponse string
LogOutResponse url.URL
LogOutError error
UpdateUserInfoError error
AuthenticateResponse oauth2.Token
AuthenticateError error
RefreshResponse oauth2.Token
RefreshError error
RevokeError error
GetSignInURLResponse string
GetSignOutURLResponse string
GetSignOutURLError error
UpdateUserInfoError error
}
// Authenticate is a mocked providers function.
@ -40,8 +39,10 @@ func (mp MockProvider) Revoke(_ context.Context, _ *oauth2.Token) error {
// GetSignInURL is a mocked providers function.
func (mp MockProvider) GetSignInURL(_ string) (string, error) { return mp.GetSignInURLResponse, nil }
// LogOut is a mocked providers function.
func (mp MockProvider) LogOut() (*url.URL, error) { return &mp.LogOutResponse, mp.LogOutError }
// GetSignOutURL is a mocked providers function.
func (mp MockProvider) GetSignOutURL(_, _ string) (string, error) {
return mp.GetSignOutURLResponse, mp.GetSignOutURLError
}
// UpdateUserInfo is a mocked providers function.
func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error {

View file

@ -103,6 +103,11 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
return authURL, nil
}
// GetSignOutURL is not implemented.
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
return "", oidc.ErrSignoutNotImplemented
}
// Authenticate converts an authorization code returned from the identity
// provider into a token which is then converted into a user session.
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
@ -123,11 +128,6 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
return oauth2Token, nil
}
// LogOut is not implemented by Apple.
func (p *Provider) LogOut() (*url.URL, error) {
return nil, oidc.ErrSignoutNotImplemented
}
// Refresh renews a user's session.
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) {
if t == nil {

View file

@ -245,9 +245,9 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil
}
// LogOut is not implemented by github.
func (p *Provider) LogOut() (*url.URL, error) {
return nil, oidc.ErrSignoutNotImplemented
// GetSignOutURL is not implemented.
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
return "", oidc.ErrSignoutNotImplemented
}
// Name returns the provider name.

View file

@ -6,10 +6,12 @@ package auth0
import (
"context"
"fmt"
"net/url"
"strings"
"github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/urlutil"
)
const (
@ -47,3 +49,28 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
func (p *Provider) Name() string {
return Name
}
// GetSignOutURL implements logout as described in https://auth0.com/docs/api/authentication#logout.
func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) {
oa, err := p.GetOauthConfig()
if err != nil {
return "", fmt.Errorf("error getting auth0 oauth config: %w", err)
}
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
if err != nil {
return "", fmt.Errorf("error parsing auth0 endpoint auth url: %w", err)
}
logoutQuery := url.Values{
"client_id": {oa.ClientID},
}
if redirectToURL != "" {
logoutQuery.Set("returnTo", redirectToURL)
}
logoutURL := authURL.ResolveReference(&url.URL{
Path: "/v2/logout",
RawQuery: logoutQuery.Encode(),
})
return logoutURL.String(), nil
}

View file

@ -0,0 +1,61 @@
package auth0
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/identity/oauth"
)
func TestProvider(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
var srv *httptest.Server
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
baseURL, err := url.Parse(srv.URL + "/")
require.NoError(t, err)
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/.well-known/openid-configuration":
json.NewEncoder(w).Encode(map[string]any{
"issuer": baseURL.String(),
"authorization_endpoint": srv.URL + "/authorize",
})
default:
assert.Failf(t, "unexpected http request", "url: %s", r.URL.String())
}
})
srv = httptest.NewServer(handler)
t.Cleanup(srv.Close)
redirectURL, err := url.Parse(srv.URL)
require.NoError(t, err)
p, err := New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
})
require.NoError(t, err)
require.NotNil(t, p)
t.Run("GetSignOutURL", func(t *testing.T) {
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
assert.NoError(t, err)
assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
})
}

View file

@ -48,14 +48,30 @@ func New(ctx context.Context, opts *oauth.Options) (*Provider, error) {
// https://docs.aws.amazon.com/cognito/latest/developerguide/revocation-endpoint.html
p.RevocationURL = cognitoURL.ResolveReference(&url.URL{Path: "/oauth2/revoke"}).String()
// https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html
p.EndSessionURL = cognitoURL.ResolveReference(&url.URL{
Path: "/logout",
RawQuery: url.Values{
"client_id": []string{opts.ClientID},
"logout_uri": []string{opts.RedirectURL.ResolveReference(&url.URL{Path: "/"}).String()},
}.Encode(),
}).String()
return &p, nil
}
// GetSignOutURL gets the sign out URL according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html.
func (p *Provider) GetSignOutURL(_, returnToURL string) (string, error) {
oa, err := p.GetOauthConfig()
if err != nil {
return "", fmt.Errorf("error getting cognito oauth config: %w", err)
}
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
if err != nil {
return "", fmt.Errorf("error getting cognito endpoint auth url: %w", err)
}
logOutQuery := url.Values{
"client_id": []string{oa.ClientID},
}
if returnToURL != "" {
logOutQuery.Set("logout_uri", returnToURL)
}
logOutURL := authURL.ResolveReference(&url.URL{
Path: "/logout",
RawQuery: logOutQuery.Encode(),
})
return logOutURL.String(), nil
}

View file

@ -0,0 +1,61 @@
package cognito
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/identity/oauth"
)
func TestProvider(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
var srv *httptest.Server
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
baseURL, err := url.Parse(srv.URL)
require.NoError(t, err)
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/.well-known/openid-configuration":
json.NewEncoder(w).Encode(map[string]any{
"issuer": baseURL.String(),
"authorization_endpoint": srv.URL + "/authorize",
})
default:
assert.Failf(t, "unexpected http request", "url: %s", r.URL.String())
}
})
srv = httptest.NewServer(handler)
t.Cleanup(srv.Close)
redirectURL, err := url.Parse(srv.URL)
require.NoError(t, err)
p, err := New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
})
require.NoError(t, err)
require.NotNil(t, p)
t.Run("GetSignOutURL", func(t *testing.T) {
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
assert.NoError(t, err)
assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
})
}

View file

@ -116,6 +116,38 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
return oa.AuthCodeURL(state, opts...), nil
}
// GetSignOutURL returns the EndSessionURL endpoint to allow a logout
// session to be initiated.
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
func (p *Provider) GetSignOutURL(idTokenHint, redirectToURL string) (string, error) {
_, err := p.GetProvider()
if err != nil {
return "", err
}
if p.EndSessionURL == "" {
return "", ErrSignoutNotImplemented
}
endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL)
if err != nil {
return "", err
}
params := endSessionURL.Query()
if idTokenHint != "" {
params.Add("id_token_hint", idTokenHint)
}
if oa, err := p.GetOauthConfig(); err == nil {
params.Add("client_id", oa.ClientID)
}
if redirectToURL != "" {
params.Add("post_logout_redirect_uri", redirectToURL)
}
endSessionURL.RawQuery = params.Encode()
return endSessionURL.String(), nil
}
// Authenticate converts an authorization code returned from the identity
// provider into a token which is then converted into a user session.
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
@ -259,20 +291,6 @@ func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error {
return nil
}
// LogOut returns the EndSessionURL endpoint to allow a logout
// session to be initiated.
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
func (p *Provider) LogOut() (*url.URL, error) {
_, err := p.GetProvider()
if err != nil {
return nil, err
}
if p.EndSessionURL == "" {
return nil, ErrSignoutNotImplemented
}
return urlutil.ParseAndValidateURL(p.EndSessionURL)
}
// GetSubject gets the RFC 7519 Subject claim (`sub`) from a
func (p *Provider) GetSubject(v interface{}) (string, error) {
b, err := json.Marshal(v)

View file

@ -5,7 +5,6 @@ package identity
import (
"context"
"fmt"
"net/url"
"golang.org/x/oauth2"
@ -30,8 +29,8 @@ type Authenticator interface {
Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error)
Revoke(context.Context, *oauth2.Token) error
GetSignInURL(state string) (string, error)
GetSignOutURL(idTokenHint, redirectToURL string) (string, error)
Name() string
LogOut() (*url.URL, error)
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error
}