core/authenticate: refactor idp sign out

This commit is contained in:
Caleb Doxsey 2023-09-19 16:27:02 -06:00
parent 9088f07cc9
commit f649d9b1bc
11 changed files with 245 additions and 79 deletions

View file

@ -266,30 +266,26 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
rawIDToken := a.revokeSession(ctx, w, r) rawIDToken := a.revokeSession(ctx, w, r)
redirectString := "" signOutURL := ""
signOutURL, err := options.GetSignOutRedirectURL() signOutRedirectURL, err := options.GetSignOutRedirectURL()
if err != nil { if err != nil {
return err return err
} }
if signOutURL != nil { if signOutRedirectURL != nil {
redirectString = signOutURL.String() signOutURL = signOutRedirectURL.String()
} }
if uri := r.FormValue(urlutil.QueryRedirectURI); uri != "" { if uri := r.FormValue(urlutil.QueryRedirectURI); uri != "" {
redirectString = uri signOutURL = uri
} }
endSessionURL, err := authenticator.LogOut() if idpSignOutURL, err := authenticator.GetSignOutURL(rawIDToken, signOutURL); err == nil {
if err == nil && redirectString != "" { signOutURL = idpSignOutURL
params := endSessionURL.Query() } else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
params.Add("id_token_hint", rawIDToken) log.Warn(r.Context()).Err(err).Msg("authenticate: failed to get sign out url for authenticator")
params.Add("post_logout_redirect_uri", redirectString)
endSessionURL.RawQuery = params.Encode()
redirectString = endSessionURL.String()
} else if err != nil && !errors.Is(err, oidc.ErrSignoutNotImplemented) {
log.Warn(r.Context()).Err(err).Msg("authenticate.SignOut: failed getting session")
} }
if redirectString != "" {
httputil.Redirect(w, r, redirectString, http.StatusFound) if signOutURL != "" {
httputil.Redirect(w, r, signOutURL, http.StatusFound)
return nil return nil
} }
return httputil.NewError(http.StatusOK, errors.New("user logged out")) return httputil.NewError(http.StatusOK, errors.New("user logged out"))

View file

@ -135,7 +135,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"", "",
"sig", "sig",
"ts", "ts",
identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, identity.MockProvider{GetSignOutURLResponse: "https://microsoft.com"},
&mstore.Store{Encrypted: true, Session: &sessions.State{}}, &mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound, http.StatusFound,
"", "",
@ -148,7 +148,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"https://signout-redirect-url.example.com", "https://signout-redirect-url.example.com",
"sig", "sig",
"ts", "ts",
identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))}, identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented},
&mstore.Store{Encrypted: true, Session: &sessions.State{}}, &mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound, http.StatusFound,
"", "",
@ -161,7 +161,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"", "",
"sig", "sig",
"ts", "ts",
identity.MockProvider{RevokeError: errors.New("OH NO")}, identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
&mstore.Store{Encrypted: true, Session: &sessions.State{}}, &mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound, http.StatusFound,
"", "",
@ -174,7 +174,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"", "",
"sig", "sig",
"ts", "ts",
identity.MockProvider{RevokeError: errors.New("OH NO")}, identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
&mstore.Store{Encrypted: true, Session: &sessions.State{}}, &mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound, http.StatusFound,
"", "",
@ -187,24 +187,11 @@ func TestAuthenticate_SignOut(t *testing.T) {
"", "",
"sig", "sig",
"ts", "ts",
identity.MockProvider{LogOutError: oidc.ErrSignoutNotImplemented}, identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented},
&mstore.Store{Encrypted: true, Session: &sessions.State{}}, &mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound, http.StatusFound,
"", "",
}, },
{
"no redirect uri",
http.MethodPost,
nil,
"",
"",
"sig",
"ts",
identity.MockProvider{LogOutResponse: (*uriParseHelper("https://microsoft.com"))},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusOK,
"{\"Status\":200}\n",
},
} }
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
@ -253,7 +240,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
} }
if tt.signoutRedirectURL != "" { if tt.signoutRedirectURL != "" {
loc := w.Header().Get("Location") loc := w.Header().Get("Location")
assert.Contains(t, loc, url.QueryEscape(tt.signoutRedirectURL)) assert.Contains(t, loc, tt.signoutRedirectURL)
} }
}) })
} }

View file

@ -2,7 +2,6 @@ package identity
import ( import (
"context" "context"
"net/url"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -11,15 +10,15 @@ import (
// MockProvider provides a mocked implementation of the providers interface. // MockProvider provides a mocked implementation of the providers interface.
type MockProvider struct { type MockProvider struct {
AuthenticateResponse oauth2.Token AuthenticateResponse oauth2.Token
AuthenticateError error AuthenticateError error
RefreshResponse oauth2.Token RefreshResponse oauth2.Token
RefreshError error RefreshError error
RevokeError error RevokeError error
GetSignInURLResponse string GetSignInURLResponse string
LogOutResponse url.URL GetSignOutURLResponse string
LogOutError error GetSignOutURLError error
UpdateUserInfoError error UpdateUserInfoError error
} }
// Authenticate is a mocked providers function. // Authenticate is a mocked providers function.
@ -41,7 +40,9 @@ func (mp MockProvider) Revoke(_ context.Context, _ *oauth2.Token) error {
func (mp MockProvider) GetSignInURL(_ string) (string, error) { return mp.GetSignInURLResponse, nil } func (mp MockProvider) GetSignInURL(_ string) (string, error) { return mp.GetSignInURLResponse, nil }
// LogOut is a mocked providers function. // LogOut is a mocked providers function.
func (mp MockProvider) LogOut() (*url.URL, error) { return &mp.LogOutResponse, mp.LogOutError } func (mp MockProvider) GetSignOutURL(_, _ string) (string, error) {
return mp.GetSignOutURLResponse, mp.GetSignOutURLError
}
// UpdateUserInfo is a mocked providers function. // UpdateUserInfo is a mocked providers function.
func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error { func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error {

View file

@ -103,6 +103,11 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
return authURL, nil return authURL, nil
} }
// GetSignOutURL is not implemented.
func (p *Provider) GetSignOutURL(idTokenHint, redirectToURL string) (string, error) {
return "", oidc.ErrSignoutNotImplemented
}
// Authenticate converts an authorization code returned from the identity // Authenticate converts an authorization code returned from the identity
// provider into a token which is then converted into a user session. // provider into a token which is then converted into a user session.
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) { func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
@ -123,11 +128,6 @@ func (p *Provider) Authenticate(ctx context.Context, code string, v identity.Sta
return oauth2Token, nil return oauth2Token, nil
} }
// LogOut is not implemented by Apple.
func (p *Provider) LogOut() (*url.URL, error) {
return nil, oidc.ErrSignoutNotImplemented
}
// Refresh renews a user's session. // Refresh renews a user's session.
func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) { func (p *Provider) Refresh(ctx context.Context, t *oauth2.Token, v identity.State) (*oauth2.Token, error) {
if t == nil { if t == nil {

View file

@ -245,9 +245,9 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil
} }
// LogOut is not implemented by github. // GetSignOutURL is not implemented.
func (p *Provider) LogOut() (*url.URL, error) { func (p *Provider) GetSignOutURL(idTokenHint, redirectToURL string) (string, error) {
return nil, oidc.ErrSignoutNotImplemented return "", oidc.ErrSignoutNotImplemented
} }
// Name returns the provider name. // Name returns the provider name.

View file

@ -6,10 +6,12 @@ package auth0
import ( import (
"context" "context"
"fmt" "fmt"
"net/url"
"strings" "strings"
"github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/urlutil"
) )
const ( const (
@ -47,3 +49,28 @@ func New(ctx context.Context, o *oauth.Options) (*Provider, error) {
func (p *Provider) Name() string { func (p *Provider) Name() string {
return Name return Name
} }
// GetSignOutURL implements logout as described in https://auth0.com/docs/api/authentication#logout.
func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) {
oa, err := p.GetOauthConfig()
if err != nil {
return "", fmt.Errorf("error getting auth0 oauth config: %w", err)
}
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
if err != nil {
return "", fmt.Errorf("error parsing auth0 endpoint auth url: %w", err)
}
logoutQuery := url.Values{
"client_id": {oa.ClientID},
}
if redirectToURL != "" {
logoutQuery.Set("returnTo", redirectToURL)
}
logoutURL := authURL.ResolveReference(&url.URL{
Path: "/v2/logout",
RawQuery: logoutQuery.Encode(),
})
return logoutURL.String(), nil
}

View file

@ -0,0 +1,61 @@
package auth0
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/identity/oauth"
)
func TestProvider(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
var srv *httptest.Server
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
baseURL, err := url.Parse(srv.URL + "/")
require.NoError(t, err)
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/.well-known/openid-configuration":
json.NewEncoder(w).Encode(map[string]any{
"issuer": baseURL.String(),
"authorization_endpoint": srv.URL + "/authorize",
})
default:
assert.Failf(t, "unexpected http request", "url: %s", r.URL.String())
}
})
srv = httptest.NewServer(handler)
t.Cleanup(srv.Close)
redirectURL, err := url.Parse(srv.URL)
require.NoError(t, err)
p, err := New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
})
require.NoError(t, err)
require.NotNil(t, p)
t.Run("GetSignOutURL", func(t *testing.T) {
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
assert.NoError(t, err)
assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
})
}

View file

@ -48,14 +48,30 @@ func New(ctx context.Context, opts *oauth.Options) (*Provider, error) {
// https://docs.aws.amazon.com/cognito/latest/developerguide/revocation-endpoint.html // https://docs.aws.amazon.com/cognito/latest/developerguide/revocation-endpoint.html
p.RevocationURL = cognitoURL.ResolveReference(&url.URL{Path: "/oauth2/revoke"}).String() p.RevocationURL = cognitoURL.ResolveReference(&url.URL{Path: "/oauth2/revoke"}).String()
// https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html
p.EndSessionURL = cognitoURL.ResolveReference(&url.URL{
Path: "/logout",
RawQuery: url.Values{
"client_id": []string{opts.ClientID},
"logout_uri": []string{opts.RedirectURL.ResolveReference(&url.URL{Path: "/"}).String()},
}.Encode(),
}).String()
return &p, nil return &p, nil
} }
// GetSignOutURL gets the sign out URL according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html.
func (p *Provider) GetSignOutURL(idTokenHint, returnToURL string) (string, error) {
oa, err := p.GetOauthConfig()
if err != nil {
return "", fmt.Errorf("error getting cognito oauth config: %w", err)
}
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
if err != nil {
return "", fmt.Errorf("error getting cognito endpoint auth url: %w", err)
}
logOutQuery := url.Values{
"client_id": []string{oa.ClientID},
}
if returnToURL != "" {
logOutQuery.Set("logout_uri", returnToURL)
}
logOutURL := authURL.ResolveReference(&url.URL{
Path: "/logout",
RawQuery: logOutQuery.Encode(),
})
return logOutURL.String(), nil
}

View file

@ -0,0 +1,61 @@
package cognito
import (
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/identity/oauth"
)
func TestProvider(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
var srv *httptest.Server
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
baseURL, err := url.Parse(srv.URL)
require.NoError(t, err)
w.Header().Set("Content-Type", "application/json")
switch r.URL.Path {
case "/.well-known/openid-configuration":
json.NewEncoder(w).Encode(map[string]any{
"issuer": baseURL.String(),
"authorization_endpoint": srv.URL + "/authorize",
})
default:
assert.Failf(t, "unexpected http request", "url: %s", r.URL.String())
}
})
srv = httptest.NewServer(handler)
t.Cleanup(srv.Close)
redirectURL, err := url.Parse(srv.URL)
require.NoError(t, err)
p, err := New(ctx, &oauth.Options{
ProviderURL: srv.URL,
RedirectURL: redirectURL,
ClientID: "CLIENT_ID",
ClientSecret: "CLIENT_SECRET",
})
require.NoError(t, err)
require.NotNil(t, p)
t.Run("GetSignOutURL", func(t *testing.T) {
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
assert.NoError(t, err)
assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
})
}

View file

@ -116,6 +116,38 @@ func (p *Provider) GetSignInURL(state string) (string, error) {
return oa.AuthCodeURL(state, opts...), nil return oa.AuthCodeURL(state, opts...), nil
} }
// GetSignOutURL returns the EndSessionURL endpoint to allow a logout
// session to be initiated.
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
func (p *Provider) GetSignOutURL(idTokenHint, redirectToURL string) (string, error) {
_, err := p.GetProvider()
if err != nil {
return "", err
}
if p.EndSessionURL == "" {
return "", ErrSignoutNotImplemented
}
endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL)
if err != nil {
return "", err
}
params := endSessionURL.Query()
if idTokenHint != "" {
params.Add("id_token_hint", idTokenHint)
}
if oa, err := p.GetOauthConfig(); err == nil {
params.Add("client_id", oa.ClientID)
}
if redirectToURL != "" {
params.Add("post_logout_redirect_uri", redirectToURL)
}
endSessionURL.RawQuery = params.Encode()
return endSessionURL.String(), nil
}
// Authenticate converts an authorization code returned from the identity // Authenticate converts an authorization code returned from the identity
// provider into a token which is then converted into a user session. // provider into a token which is then converted into a user session.
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) { func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
@ -259,20 +291,6 @@ func (p *Provider) Revoke(ctx context.Context, t *oauth2.Token) error {
return nil return nil
} }
// LogOut returns the EndSessionURL endpoint to allow a logout
// session to be initiated.
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
func (p *Provider) LogOut() (*url.URL, error) {
_, err := p.GetProvider()
if err != nil {
return nil, err
}
if p.EndSessionURL == "" {
return nil, ErrSignoutNotImplemented
}
return urlutil.ParseAndValidateURL(p.EndSessionURL)
}
// GetSubject gets the RFC 7519 Subject claim (`sub`) from a // GetSubject gets the RFC 7519 Subject claim (`sub`) from a
func (p *Provider) GetSubject(v interface{}) (string, error) { func (p *Provider) GetSubject(v interface{}) (string, error) {
b, err := json.Marshal(v) b, err := json.Marshal(v)

View file

@ -5,7 +5,6 @@ package identity
import ( import (
"context" "context"
"fmt" "fmt"
"net/url"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -30,8 +29,8 @@ type Authenticator interface {
Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error) Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error)
Revoke(context.Context, *oauth2.Token) error Revoke(context.Context, *oauth2.Token) error
GetSignInURL(state string) (string, error) GetSignInURL(state string) (string, error)
GetSignOutURL(idTokenHint, redirectToURL string) (string, error)
Name() string Name() string
LogOut() (*url.URL, error)
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error
} }