core/authenticate: refactor identity authenticators to initiate redirect (#4858)

* core/authenticate: refactor identity authenticators to initiate redirect, use cookie for redirect url for cognito

* set secure and http only, update test
This commit is contained in:
Caleb Doxsey 2023-12-19 12:04:23 -07:00 committed by GitHub
parent 4c15b202d1
commit 3adbc65d37
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 237 additions and 125 deletions

View file

@ -223,14 +223,14 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
signOutURL = uri
} else if signOutRedirectURL != nil {
signOutURL = signOutRedirectURL.String()
} else {
signOutURL = authenticateURL.ResolveReference(&url.URL{
Path: "/.pomerium/signed_out",
}).String()
}
if idpSignOutURL, err := authenticator.GetSignOutURL(rawIDToken, signOutURL); err == nil {
signOutURL = idpSignOutURL
authenticateSignedOutURL := authenticateURL.ResolveReference(&url.URL{
Path: "/.pomerium/signed_out",
}).String()
if err := authenticator.SignOut(w, r, rawIDToken, authenticateSignedOutURL, signOutURL); err == nil {
return nil
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
log.Warn(r.Context()).Err(err).Msg("authenticate: failed to get sign out url for authenticator")
}
@ -275,12 +275,12 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b)
b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b)
signinURL, err := authenticator.GetSignInURL(encodedState)
err = authenticator.SignIn(w, r, encodedState)
if err != nil {
return httputil.NewError(http.StatusInternalServerError,
fmt.Errorf("failed to get sign in url: %w", err))
fmt.Errorf("failed to sign in: %w", err))
}
httputil.Redirect(w, r, signinURL, http.StatusFound)
return nil
}

View file

@ -137,7 +137,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"",
"sig",
"ts",
identity.MockProvider{GetSignOutURLResponse: "https://microsoft.com"},
identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
@ -150,7 +150,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"https://signout-redirect-url.example.com",
"sig",
"ts",
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented},
identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
@ -163,7 +163,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"",
"sig",
"ts",
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
@ -176,7 +176,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"",
"sig",
"ts",
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
@ -189,7 +189,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"",
"sig",
"ts",
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented},
identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented},
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound,
"",
@ -401,7 +401,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
errors.New("hi"),
identity.MockProvider{},
http.StatusFound,
http.StatusOK,
},
{
"expired,refresh error",
@ -409,7 +409,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
sessions.ErrExpired,
identity.MockProvider{RefreshError: errors.New("error")},
http.StatusFound,
http.StatusOK,
},
{
"expired,save error",
@ -417,7 +417,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
sessions.ErrExpired,
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
http.StatusFound,
http.StatusOK,
},
{
"expired XHR,refresh error",

View file

@ -18,6 +18,12 @@ func (data SignedOutData) ToJSON() map[string]interface{} {
// SignedOut returns a handler that renders the signed out page.
func SignedOut(data SignedOutData) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if redirectURI, ok := httputil.GetSignedOutRedirectURICookie(w, r); ok {
httputil.Redirect(w, r, redirectURI, http.StatusFound)
return nil
}
// otherwise show the signed-out page
return ui.ServePage(w, r, "SignedOut", data.ToJSON())
})
}

View file

@ -0,0 +1,42 @@
package handlers_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/handlers"
)
func TestSignedOut(t *testing.T) {
t.Parallel()
t.Run("ok", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/.pomerium/signed_out", nil)
handlers.SignedOut(handlers.SignedOutData{}).ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("redirect", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/.pomerium/signed_out", nil)
r.AddCookie(&http.Cookie{
Name: "_pomerium_signed_out_redirect_uri",
Value: "https://www.google.com",
})
handlers.SignedOut(handlers.SignedOutData{}).ServeHTTP(w, r)
assert.Equal(t, http.StatusFound, w.Code)
assert.Equal(t, "https://www.google.com", w.Header().Get("Location"))
})
}

View file

@ -0,0 +1,28 @@
package httputil
import "net/http"
const signedOutRedirectURICookieName = "_pomerium_signed_out_redirect_uri"
// GetSignedOutRedirectURICookie gets the redirect uri cookie for the signed-out page.
func GetSignedOutRedirectURICookie(w http.ResponseWriter, r *http.Request) (string, bool) {
cookie, err := r.Cookie(signedOutRedirectURICookieName)
if err != nil {
return "", false
}
cookie.MaxAge = -1
http.SetCookie(w, cookie)
return cookie.Value, true
}
// SetSignedOutRedirectURICookie sets the redirect uri cookie for the signed-out page.
func SetSignedOutRedirectURICookie(w http.ResponseWriter, redirectURI string) {
http.SetCookie(w, &http.Cookie{
Name: signedOutRedirectURICookieName,
Value: redirectURI,
MaxAge: 5 * 60,
HttpOnly: true,
Secure: true,
})
}

View file

@ -2,6 +2,7 @@ package identity
import (
"context"
"net/http"
"golang.org/x/oauth2"
@ -10,15 +11,14 @@ import (
// MockProvider provides a mocked implementation of the providers interface.
type MockProvider struct {
AuthenticateResponse oauth2.Token
AuthenticateError error
RefreshResponse oauth2.Token
RefreshError error
RevokeError error
GetSignInURLResponse string
GetSignOutURLResponse string
GetSignOutURLError error
UpdateUserInfoError error
AuthenticateResponse oauth2.Token
AuthenticateError error
RefreshResponse oauth2.Token
RefreshError error
RevokeError error
UpdateUserInfoError error
SignInError error
SignOutError error
}
// Authenticate is a mocked providers function.
@ -36,14 +36,6 @@ func (mp MockProvider) Revoke(_ context.Context, _ *oauth2.Token) error {
return mp.RevokeError
}
// GetSignInURL is a mocked providers function.
func (mp MockProvider) GetSignInURL(_ string) (string, error) { return mp.GetSignInURLResponse, nil }
// GetSignOutURL is a mocked providers function.
func (mp MockProvider) GetSignOutURL(_, _ string) (string, error) {
return mp.GetSignOutURLResponse, mp.GetSignOutURLError
}
// UpdateUserInfo is a mocked providers function.
func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error {
return mp.UpdateUserInfoError
@ -53,3 +45,13 @@ func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ inte
func (mp MockProvider) Name() string {
return "mock"
}
// SignOut is a mocked providers function.
func (mp MockProvider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
return mp.SignOutError
}
// SignIn is a mocked providers function.
func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string) error {
return mp.SignInError
}

View file

@ -83,31 +83,6 @@ func (p *Provider) Name() string {
return Name
}
// GetSignInURL returns the url of the provider's OAuth 2.0 consent page
// that asks for permissions for the required scopes explicitly.
//
// State is a token to protect the user from CSRF attacks. You must
// always provide a non-empty string and validate that it matches the
// the state query parameter on your redirect callback.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
func (p *Provider) GetSignInURL(state string) (string, error) {
opts := []oauth2.AuthCodeOption{}
for k, v := range p.authCodeOptions {
opts = append(opts, oauth2.SetAuthURLParam(k, v))
}
authURL := p.oauth.AuthCodeURL(state, opts...)
// Apple is very picky here and we need to use %20 instead of +
authURL = strings.ReplaceAll(authURL, "+", "%20")
return authURL, nil
}
// GetSignOutURL is not implemented.
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
return "", oidc.ErrSignoutNotImplemented
}
// Authenticate converts an authorization code returned from the identity
// provider into a token which is then converted into a user session.
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
@ -181,3 +156,29 @@ func (p *Provider) UpdateUserInfo(_ context.Context, t *oauth2.Token, v interfac
return idToken.UnsafeClaimsWithoutVerification(v)
}
// SignIn redirects to the url of the provider's OAuth 2.0 consent page
// that asks for permissions for the required scopes explicitly.
//
// State is a token to protect the user from CSRF attacks. You must
// always provide a non-empty string and validate that it matches the
// the state query parameter on your redirect callback.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) error {
opts := []oauth2.AuthCodeOption{}
for k, v := range p.authCodeOptions {
opts = append(opts, oauth2.SetAuthURLParam(k, v))
}
authURL := p.oauth.AuthCodeURL(state, opts...)
// Apple is very picky here and we need to use %20 instead of +
authURL = strings.ReplaceAll(authURL, "+", "%20")
httputil.Redirect(w, r, authURL, http.StatusFound)
return nil
}
// SignOut is not implemented.
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
return oidc.ErrSignoutNotImplemented
}

View file

@ -239,18 +239,20 @@ func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error {
return nil
}
// GetSignInURL returns a URL to OAuth 2.0 provider's consent page
// that asks for permissions for the required scopes explicitly.
func (p *Provider) GetSignInURL(state string) (string, error) {
return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil
}
// GetSignOutURL is not implemented.
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
return "", oidc.ErrSignoutNotImplemented
}
// Name returns the provider name.
func (p *Provider) Name() string {
return Name
}
// SignIn redirects to the OAuth 2.0 provider's consent page
// that asks for permissions for the required scopes explicitly.
func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) error {
signInURL := p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline)
httputil.Redirect(w, r, signInURL, http.StatusFound)
return nil
}
// SignOut is not implemented.
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
return oidc.ErrSignoutNotImplemented
}

View file

@ -6,9 +6,11 @@ package auth0
import (
"context"
"fmt"
"net/http"
"net/url"
"strings"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/urlutil"
@ -50,16 +52,16 @@ func (p *Provider) Name() string {
return Name
}
// GetSignOutURL implements logout as described in https://auth0.com/docs/api/authentication#logout.
func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) {
// SignOut implements logout as described in https://auth0.com/docs/api/authentication#logout.
func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, _, authenticateSignedOutURL, redirectToURL string) error {
oa, err := p.GetOauthConfig()
if err != nil {
return "", fmt.Errorf("error getting auth0 oauth config: %w", err)
return fmt.Errorf("error getting auth0 oauth config: %w", err)
}
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
if err != nil {
return "", fmt.Errorf("error parsing auth0 endpoint auth url: %w", err)
return fmt.Errorf("error parsing auth0 endpoint auth url: %w", err)
}
logoutQuery := url.Values{
@ -67,10 +69,14 @@ func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) {
}
if redirectToURL != "" {
logoutQuery.Set("returnTo", redirectToURL)
} else if authenticateSignedOutURL != "" {
logoutQuery.Set("returnTo", authenticateSignedOutURL)
}
logoutURL := authURL.ResolveReference(&url.URL{
Path: "/v2/logout",
RawQuery: logoutQuery.Encode(),
})
return logoutURL.String(), nil
httputil.Redirect(w, r, logoutURL.String(), http.StatusFound)
return nil
}

View file

@ -53,9 +53,11 @@ func TestProvider(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, p)
t.Run("GetSignOutURL", func(t *testing.T) {
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
t.Run("SignOut", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://authenticate.example.com/.pomerium/sign_out", nil)
err := p.SignOut(w, r, "", "https://authenticate.example.com/.pomerium/signed_out", "https://www.example.com?a=b")
assert.NoError(t, err)
assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", w.Header().Get("Location"))
})
}

View file

@ -4,8 +4,10 @@ package cognito
import (
"context"
"fmt"
"net/http"
"net/url"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/urlutil"
@ -51,27 +53,31 @@ func New(ctx context.Context, opts *oauth.Options) (*Provider, error) {
return &p, nil
}
// GetSignOutURL gets the sign out URL according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html.
func (p *Provider) GetSignOutURL(_, returnToURL string) (string, error) {
// SignOut implements sign out according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html.
func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, _, authenticateSignedOutURL, returnToURL string) error {
oa, err := p.GetOauthConfig()
if err != nil {
return "", fmt.Errorf("error getting cognito oauth config: %w", err)
return fmt.Errorf("error getting cognito oauth config: %w", err)
}
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
if err != nil {
return "", fmt.Errorf("error getting cognito endpoint auth url: %w", err)
return fmt.Errorf("error getting cognito endpoint auth url: %w", err)
}
logOutQuery := url.Values{
"client_id": []string{oa.ClientID},
}
if authenticateSignedOutURL != "" {
logOutQuery.Set("logout_uri", authenticateSignedOutURL)
}
if returnToURL != "" {
logOutQuery.Set("logout_uri", returnToURL)
httputil.SetSignedOutRedirectURICookie(w, returnToURL)
}
logOutURL := authURL.ResolveReference(&url.URL{
Path: "/logout",
RawQuery: logOutQuery.Encode(),
})
return logOutURL.String(), nil
httputil.Redirect(w, r, logOutURL.String(), http.StatusFound)
return nil
}

View file

@ -53,9 +53,19 @@ func TestProvider(t *testing.T) {
require.NoError(t, err)
require.NotNil(t, p)
t.Run("GetSignOutURL", func(t *testing.T) {
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
t.Run("SignOut", func(t *testing.T) {
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://authenticate.example.com/.pomerium/sign_out", nil)
err := p.SignOut(w, r, "", "https://authenticate.example.com/.pomerium/signed_out", "https://www.example.com?a=b")
assert.NoError(t, err)
assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fauthenticate.example.com%2F.pomerium%2Fsigned_out", w.Header().Get("Location"))
assert.Equal(t, []*http.Cookie{{
Name: "_pomerium_signed_out_redirect_uri",
Value: "https://www.example.com?a=b",
MaxAge: 300,
Secure: true,
HttpOnly: true,
Raw: "_pomerium_signed_out_redirect_uri=https://www.example.com?a=b; Max-Age=300; HttpOnly; Secure",
}}, w.Result().Cookies())
})
}

View file

@ -96,56 +96,26 @@ func New(ctx context.Context, o *oauth.Options, options ...Option) (*Provider, e
return p, nil
}
// GetSignInURL returns the url of the provider's OAuth 2.0 consent page
// SignIn redirects to the url of the provider's OAuth 2.0 consent page
// that asks for permissions for the required scopes explicitly.
//
// State is a token to protect the user from CSRF attacks. You must
// always provide a non-empty string and validate that it matches the
// the state query parameter on your redirect callback.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
func (p *Provider) GetSignInURL(state string) (string, error) {
func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) error {
oa, err := p.GetOauthConfig()
if err != nil {
return "", err
return err
}
opts := defaultAuthCodeOptions
for k, v := range p.AuthCodeOptions {
opts = append(opts, oauth2.SetAuthURLParam(k, v))
}
return oa.AuthCodeURL(state, opts...), nil
}
// GetSignOutURL returns the EndSessionURL endpoint to allow a logout
// session to be initiated.
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
func (p *Provider) GetSignOutURL(idTokenHint, redirectToURL string) (string, error) {
_, err := p.GetProvider()
if err != nil {
return "", err
}
if p.EndSessionURL == "" {
return "", ErrSignoutNotImplemented
}
endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL)
if err != nil {
return "", err
}
params := endSessionURL.Query()
if idTokenHint != "" {
params.Add("id_token_hint", idTokenHint)
}
if oa, err := p.GetOauthConfig(); err == nil {
params.Add("client_id", oa.ClientID)
}
if redirectToURL != "" {
params.Add("post_logout_redirect_uri", redirectToURL)
}
endSessionURL.RawQuery = params.Encode()
return endSessionURL.String(), nil
signInURL := oa.AuthCodeURL(state, opts...)
httputil.Redirect(w, r, signInURL, http.StatusFound)
return nil
}
// Authenticate converts an authorization code returned from the identity
@ -340,3 +310,38 @@ func (p *Provider) GetOauthConfig() (*oauth2.Config, error) {
}
return p.cfg.getOauthConfig(pp), nil
}
// SignOut uses the EndSessionURL endpoint to allow a logout session to be initiated.
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error {
_, err := p.GetProvider()
if err != nil {
return err
}
if p.EndSessionURL == "" {
return ErrSignoutNotImplemented
}
endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL)
if err != nil {
return err
}
params := endSessionURL.Query()
if idTokenHint != "" {
params.Add("id_token_hint", idTokenHint)
}
if oa, err := p.GetOauthConfig(); err == nil {
params.Add("client_id", oa.ClientID)
}
if redirectToURL != "" {
params.Add("post_logout_redirect_uri", redirectToURL)
} else if authenticateSignedOutURL != "" {
params.Add("post_logout_redirect_uri", authenticateSignedOutURL)
}
endSessionURL.RawQuery = params.Encode()
httputil.Redirect(w, r, endSessionURL.String(), http.StatusFound)
return nil
}

View file

@ -5,6 +5,7 @@ package identity
import (
"context"
"fmt"
"net/http"
"golang.org/x/oauth2"
@ -28,10 +29,11 @@ type Authenticator interface {
Authenticate(context.Context, string, identity.State) (*oauth2.Token, error)
Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error)
Revoke(context.Context, *oauth2.Token) error
GetSignInURL(state string) (string, error)
GetSignOutURL(idTokenHint, redirectToURL string) (string, error)
Name() string
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error
SignIn(w http.ResponseWriter, r *http.Request, state string) error
SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error
}
// NewAuthenticator returns a new identity provider based on its name.