mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
core/authenticate: refactor identity authenticators to initiate redirect (#4858)
* core/authenticate: refactor identity authenticators to initiate redirect, use cookie for redirect url for cognito * set secure and http only, update test
This commit is contained in:
parent
4c15b202d1
commit
3adbc65d37
14 changed files with 237 additions and 125 deletions
|
@ -223,14 +223,14 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
|
||||||
signOutURL = uri
|
signOutURL = uri
|
||||||
} else if signOutRedirectURL != nil {
|
} else if signOutRedirectURL != nil {
|
||||||
signOutURL = signOutRedirectURL.String()
|
signOutURL = signOutRedirectURL.String()
|
||||||
} else {
|
|
||||||
signOutURL = authenticateURL.ResolveReference(&url.URL{
|
|
||||||
Path: "/.pomerium/signed_out",
|
|
||||||
}).String()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if idpSignOutURL, err := authenticator.GetSignOutURL(rawIDToken, signOutURL); err == nil {
|
authenticateSignedOutURL := authenticateURL.ResolveReference(&url.URL{
|
||||||
signOutURL = idpSignOutURL
|
Path: "/.pomerium/signed_out",
|
||||||
|
}).String()
|
||||||
|
|
||||||
|
if err := authenticator.SignOut(w, r, rawIDToken, authenticateSignedOutURL, signOutURL); err == nil {
|
||||||
|
return nil
|
||||||
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
|
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
|
||||||
log.Warn(r.Context()).Err(err).Msg("authenticate: failed to get sign out url for authenticator")
|
log.Warn(r.Context()).Err(err).Msg("authenticate: failed to get sign out url for authenticator")
|
||||||
}
|
}
|
||||||
|
@ -275,12 +275,12 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
|
||||||
enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b)
|
enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b)
|
||||||
b = append(b, enc...)
|
b = append(b, enc...)
|
||||||
encodedState := base64.URLEncoding.EncodeToString(b)
|
encodedState := base64.URLEncoding.EncodeToString(b)
|
||||||
signinURL, err := authenticator.GetSignInURL(encodedState)
|
|
||||||
|
err = authenticator.SignIn(w, r, encodedState)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusInternalServerError,
|
return httputil.NewError(http.StatusInternalServerError,
|
||||||
fmt.Errorf("failed to get sign in url: %w", err))
|
fmt.Errorf("failed to sign in: %w", err))
|
||||||
}
|
}
|
||||||
httputil.Redirect(w, r, signinURL, http.StatusFound)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -137,7 +137,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
"",
|
"",
|
||||||
"sig",
|
"sig",
|
||||||
"ts",
|
"ts",
|
||||||
identity.MockProvider{GetSignOutURLResponse: "https://microsoft.com"},
|
identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented},
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
"",
|
"",
|
||||||
|
@ -150,7 +150,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
"https://signout-redirect-url.example.com",
|
"https://signout-redirect-url.example.com",
|
||||||
"sig",
|
"sig",
|
||||||
"ts",
|
"ts",
|
||||||
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented},
|
identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented},
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
"",
|
"",
|
||||||
|
@ -163,7 +163,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
"",
|
"",
|
||||||
"sig",
|
"sig",
|
||||||
"ts",
|
"ts",
|
||||||
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
|
identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
"",
|
"",
|
||||||
|
@ -176,7 +176,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
"",
|
"",
|
||||||
"sig",
|
"sig",
|
||||||
"ts",
|
"ts",
|
||||||
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
|
identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
"",
|
"",
|
||||||
|
@ -189,7 +189,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
"",
|
"",
|
||||||
"sig",
|
"sig",
|
||||||
"ts",
|
"ts",
|
||||||
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented},
|
identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented},
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
&mstore.Store{Encrypted: true, Session: &sessions.State{}},
|
||||||
http.StatusFound,
|
http.StatusFound,
|
||||||
"",
|
"",
|
||||||
|
@ -401,7 +401,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||||
errors.New("hi"),
|
errors.New("hi"),
|
||||||
identity.MockProvider{},
|
identity.MockProvider{},
|
||||||
http.StatusFound,
|
http.StatusOK,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"expired,refresh error",
|
"expired,refresh error",
|
||||||
|
@ -409,7 +409,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||||
sessions.ErrExpired,
|
sessions.ErrExpired,
|
||||||
identity.MockProvider{RefreshError: errors.New("error")},
|
identity.MockProvider{RefreshError: errors.New("error")},
|
||||||
http.StatusFound,
|
http.StatusOK,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"expired,save error",
|
"expired,save error",
|
||||||
|
@ -417,7 +417,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
|
||||||
sessions.ErrExpired,
|
sessions.ErrExpired,
|
||||||
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
|
||||||
http.StatusFound,
|
http.StatusOK,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"expired XHR,refresh error",
|
"expired XHR,refresh error",
|
||||||
|
|
|
@ -18,6 +18,12 @@ func (data SignedOutData) ToJSON() map[string]interface{} {
|
||||||
// SignedOut returns a handler that renders the signed out page.
|
// SignedOut returns a handler that renders the signed out page.
|
||||||
func SignedOut(data SignedOutData) http.Handler {
|
func SignedOut(data SignedOutData) http.Handler {
|
||||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
if redirectURI, ok := httputil.GetSignedOutRedirectURICookie(w, r); ok {
|
||||||
|
httputil.Redirect(w, r, redirectURI, http.StatusFound)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// otherwise show the signed-out page
|
||||||
return ui.ServePage(w, r, "SignedOut", data.ToJSON())
|
return ui.ServePage(w, r, "SignedOut", data.ToJSON())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
42
internal/handlers/signedout_test.go
Normal file
42
internal/handlers/signedout_test.go
Normal file
|
@ -0,0 +1,42 @@
|
||||||
|
package handlers_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/handlers"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSignedOut(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
t.Run("ok", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/.pomerium/signed_out", nil)
|
||||||
|
|
||||||
|
handlers.SignedOut(handlers.SignedOutData{}).ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusOK, w.Code)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("redirect", func(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/.pomerium/signed_out", nil)
|
||||||
|
r.AddCookie(&http.Cookie{
|
||||||
|
Name: "_pomerium_signed_out_redirect_uri",
|
||||||
|
Value: "https://www.google.com",
|
||||||
|
})
|
||||||
|
|
||||||
|
handlers.SignedOut(handlers.SignedOutData{}).ServeHTTP(w, r)
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusFound, w.Code)
|
||||||
|
assert.Equal(t, "https://www.google.com", w.Header().Get("Location"))
|
||||||
|
})
|
||||||
|
}
|
28
internal/httputil/signedout.go
Normal file
28
internal/httputil/signedout.go
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
package httputil
|
||||||
|
|
||||||
|
import "net/http"
|
||||||
|
|
||||||
|
const signedOutRedirectURICookieName = "_pomerium_signed_out_redirect_uri"
|
||||||
|
|
||||||
|
// GetSignedOutRedirectURICookie gets the redirect uri cookie for the signed-out page.
|
||||||
|
func GetSignedOutRedirectURICookie(w http.ResponseWriter, r *http.Request) (string, bool) {
|
||||||
|
cookie, err := r.Cookie(signedOutRedirectURICookieName)
|
||||||
|
if err != nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
cookie.MaxAge = -1
|
||||||
|
http.SetCookie(w, cookie)
|
||||||
|
return cookie.Value, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSignedOutRedirectURICookie sets the redirect uri cookie for the signed-out page.
|
||||||
|
func SetSignedOutRedirectURICookie(w http.ResponseWriter, redirectURI string) {
|
||||||
|
http.SetCookie(w, &http.Cookie{
|
||||||
|
Name: signedOutRedirectURICookieName,
|
||||||
|
Value: redirectURI,
|
||||||
|
MaxAge: 5 * 60,
|
||||||
|
HttpOnly: true,
|
||||||
|
Secure: true,
|
||||||
|
})
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ package identity
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
@ -15,10 +16,9 @@ type MockProvider struct {
|
||||||
RefreshResponse oauth2.Token
|
RefreshResponse oauth2.Token
|
||||||
RefreshError error
|
RefreshError error
|
||||||
RevokeError error
|
RevokeError error
|
||||||
GetSignInURLResponse string
|
|
||||||
GetSignOutURLResponse string
|
|
||||||
GetSignOutURLError error
|
|
||||||
UpdateUserInfoError error
|
UpdateUserInfoError error
|
||||||
|
SignInError error
|
||||||
|
SignOutError error
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate is a mocked providers function.
|
// Authenticate is a mocked providers function.
|
||||||
|
@ -36,14 +36,6 @@ func (mp MockProvider) Revoke(_ context.Context, _ *oauth2.Token) error {
|
||||||
return mp.RevokeError
|
return mp.RevokeError
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSignInURL is a mocked providers function.
|
|
||||||
func (mp MockProvider) GetSignInURL(_ string) (string, error) { return mp.GetSignInURLResponse, nil }
|
|
||||||
|
|
||||||
// GetSignOutURL is a mocked providers function.
|
|
||||||
func (mp MockProvider) GetSignOutURL(_, _ string) (string, error) {
|
|
||||||
return mp.GetSignOutURLResponse, mp.GetSignOutURLError
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateUserInfo is a mocked providers function.
|
// UpdateUserInfo is a mocked providers function.
|
||||||
func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error {
|
func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error {
|
||||||
return mp.UpdateUserInfoError
|
return mp.UpdateUserInfoError
|
||||||
|
@ -53,3 +45,13 @@ func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ inte
|
||||||
func (mp MockProvider) Name() string {
|
func (mp MockProvider) Name() string {
|
||||||
return "mock"
|
return "mock"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SignOut is a mocked providers function.
|
||||||
|
func (mp MockProvider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
|
||||||
|
return mp.SignOutError
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignIn is a mocked providers function.
|
||||||
|
func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string) error {
|
||||||
|
return mp.SignInError
|
||||||
|
}
|
||||||
|
|
|
@ -83,31 +83,6 @@ func (p *Provider) Name() string {
|
||||||
return Name
|
return Name
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSignInURL returns the url of the provider's OAuth 2.0 consent page
|
|
||||||
// that asks for permissions for the required scopes explicitly.
|
|
||||||
//
|
|
||||||
// State is a token to protect the user from CSRF attacks. You must
|
|
||||||
// always provide a non-empty string and validate that it matches the
|
|
||||||
// the state query parameter on your redirect callback.
|
|
||||||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
|
||||||
func (p *Provider) GetSignInURL(state string) (string, error) {
|
|
||||||
opts := []oauth2.AuthCodeOption{}
|
|
||||||
for k, v := range p.authCodeOptions {
|
|
||||||
opts = append(opts, oauth2.SetAuthURLParam(k, v))
|
|
||||||
}
|
|
||||||
authURL := p.oauth.AuthCodeURL(state, opts...)
|
|
||||||
|
|
||||||
// Apple is very picky here and we need to use %20 instead of +
|
|
||||||
authURL = strings.ReplaceAll(authURL, "+", "%20")
|
|
||||||
|
|
||||||
return authURL, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSignOutURL is not implemented.
|
|
||||||
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
|
|
||||||
return "", oidc.ErrSignoutNotImplemented
|
|
||||||
}
|
|
||||||
|
|
||||||
// Authenticate converts an authorization code returned from the identity
|
// Authenticate converts an authorization code returned from the identity
|
||||||
// provider into a token which is then converted into a user session.
|
// provider into a token which is then converted into a user session.
|
||||||
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
|
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
|
||||||
|
@ -181,3 +156,29 @@ func (p *Provider) UpdateUserInfo(_ context.Context, t *oauth2.Token, v interfac
|
||||||
|
|
||||||
return idToken.UnsafeClaimsWithoutVerification(v)
|
return idToken.UnsafeClaimsWithoutVerification(v)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SignIn redirects to the url of the provider's OAuth 2.0 consent page
|
||||||
|
// that asks for permissions for the required scopes explicitly.
|
||||||
|
//
|
||||||
|
// State is a token to protect the user from CSRF attacks. You must
|
||||||
|
// always provide a non-empty string and validate that it matches the
|
||||||
|
// the state query parameter on your redirect callback.
|
||||||
|
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
||||||
|
func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) error {
|
||||||
|
opts := []oauth2.AuthCodeOption{}
|
||||||
|
for k, v := range p.authCodeOptions {
|
||||||
|
opts = append(opts, oauth2.SetAuthURLParam(k, v))
|
||||||
|
}
|
||||||
|
authURL := p.oauth.AuthCodeURL(state, opts...)
|
||||||
|
|
||||||
|
// Apple is very picky here and we need to use %20 instead of +
|
||||||
|
authURL = strings.ReplaceAll(authURL, "+", "%20")
|
||||||
|
|
||||||
|
httputil.Redirect(w, r, authURL, http.StatusFound)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignOut is not implemented.
|
||||||
|
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
|
||||||
|
return oidc.ErrSignoutNotImplemented
|
||||||
|
}
|
||||||
|
|
|
@ -239,18 +239,20 @@ func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSignInURL returns a URL to OAuth 2.0 provider's consent page
|
|
||||||
// that asks for permissions for the required scopes explicitly.
|
|
||||||
func (p *Provider) GetSignInURL(state string) (string, error) {
|
|
||||||
return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetSignOutURL is not implemented.
|
|
||||||
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
|
|
||||||
return "", oidc.ErrSignoutNotImplemented
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name returns the provider name.
|
// Name returns the provider name.
|
||||||
func (p *Provider) Name() string {
|
func (p *Provider) Name() string {
|
||||||
return Name
|
return Name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SignIn redirects to the OAuth 2.0 provider's consent page
|
||||||
|
// that asks for permissions for the required scopes explicitly.
|
||||||
|
func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) error {
|
||||||
|
signInURL := p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline)
|
||||||
|
httputil.Redirect(w, r, signInURL, http.StatusFound)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignOut is not implemented.
|
||||||
|
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
|
||||||
|
return oidc.ErrSignoutNotImplemented
|
||||||
|
}
|
||||||
|
|
|
@ -6,9 +6,11 @@ package auth0
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
@ -50,16 +52,16 @@ func (p *Provider) Name() string {
|
||||||
return Name
|
return Name
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSignOutURL implements logout as described in https://auth0.com/docs/api/authentication#logout.
|
// SignOut implements logout as described in https://auth0.com/docs/api/authentication#logout.
|
||||||
func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) {
|
func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, _, authenticateSignedOutURL, redirectToURL string) error {
|
||||||
oa, err := p.GetOauthConfig()
|
oa, err := p.GetOauthConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error getting auth0 oauth config: %w", err)
|
return fmt.Errorf("error getting auth0 oauth config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
|
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error parsing auth0 endpoint auth url: %w", err)
|
return fmt.Errorf("error parsing auth0 endpoint auth url: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logoutQuery := url.Values{
|
logoutQuery := url.Values{
|
||||||
|
@ -67,10 +69,14 @@ func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) {
|
||||||
}
|
}
|
||||||
if redirectToURL != "" {
|
if redirectToURL != "" {
|
||||||
logoutQuery.Set("returnTo", redirectToURL)
|
logoutQuery.Set("returnTo", redirectToURL)
|
||||||
|
} else if authenticateSignedOutURL != "" {
|
||||||
|
logoutQuery.Set("returnTo", authenticateSignedOutURL)
|
||||||
}
|
}
|
||||||
logoutURL := authURL.ResolveReference(&url.URL{
|
logoutURL := authURL.ResolveReference(&url.URL{
|
||||||
Path: "/v2/logout",
|
Path: "/v2/logout",
|
||||||
RawQuery: logoutQuery.Encode(),
|
RawQuery: logoutQuery.Encode(),
|
||||||
})
|
})
|
||||||
return logoutURL.String(), nil
|
|
||||||
|
httputil.Redirect(w, r, logoutURL.String(), http.StatusFound)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,9 +53,11 @@ func TestProvider(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, p)
|
require.NotNil(t, p)
|
||||||
|
|
||||||
t.Run("GetSignOutURL", func(t *testing.T) {
|
t.Run("SignOut", func(t *testing.T) {
|
||||||
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "https://authenticate.example.com/.pomerium/sign_out", nil)
|
||||||
|
err := p.SignOut(w, r, "", "https://authenticate.example.com/.pomerium/signed_out", "https://www.example.com?a=b")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
|
assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", w.Header().Get("Location"))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,8 +4,10 @@ package cognito
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/identity/oauth"
|
"github.com/pomerium/pomerium/internal/identity/oauth"
|
||||||
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
@ -51,27 +53,31 @@ func New(ctx context.Context, opts *oauth.Options) (*Provider, error) {
|
||||||
return &p, nil
|
return &p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSignOutURL gets the sign out URL according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html.
|
// SignOut implements sign out according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html.
|
||||||
func (p *Provider) GetSignOutURL(_, returnToURL string) (string, error) {
|
func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, _, authenticateSignedOutURL, returnToURL string) error {
|
||||||
oa, err := p.GetOauthConfig()
|
oa, err := p.GetOauthConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error getting cognito oauth config: %w", err)
|
return fmt.Errorf("error getting cognito oauth config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
|
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("error getting cognito endpoint auth url: %w", err)
|
return fmt.Errorf("error getting cognito endpoint auth url: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logOutQuery := url.Values{
|
logOutQuery := url.Values{
|
||||||
"client_id": []string{oa.ClientID},
|
"client_id": []string{oa.ClientID},
|
||||||
}
|
}
|
||||||
|
if authenticateSignedOutURL != "" {
|
||||||
|
logOutQuery.Set("logout_uri", authenticateSignedOutURL)
|
||||||
|
}
|
||||||
if returnToURL != "" {
|
if returnToURL != "" {
|
||||||
logOutQuery.Set("logout_uri", returnToURL)
|
httputil.SetSignedOutRedirectURICookie(w, returnToURL)
|
||||||
}
|
}
|
||||||
logOutURL := authURL.ResolveReference(&url.URL{
|
logOutURL := authURL.ResolveReference(&url.URL{
|
||||||
Path: "/logout",
|
Path: "/logout",
|
||||||
RawQuery: logOutQuery.Encode(),
|
RawQuery: logOutQuery.Encode(),
|
||||||
})
|
})
|
||||||
return logOutURL.String(), nil
|
httputil.Redirect(w, r, logOutURL.String(), http.StatusFound)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,9 +53,19 @@ func TestProvider(t *testing.T) {
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, p)
|
require.NotNil(t, p)
|
||||||
|
|
||||||
t.Run("GetSignOutURL", func(t *testing.T) {
|
t.Run("SignOut", func(t *testing.T) {
|
||||||
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b")
|
w := httptest.NewRecorder()
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "https://authenticate.example.com/.pomerium/sign_out", nil)
|
||||||
|
err := p.SignOut(w, r, "", "https://authenticate.example.com/.pomerium/signed_out", "https://www.example.com?a=b")
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL)
|
assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fauthenticate.example.com%2F.pomerium%2Fsigned_out", w.Header().Get("Location"))
|
||||||
|
assert.Equal(t, []*http.Cookie{{
|
||||||
|
Name: "_pomerium_signed_out_redirect_uri",
|
||||||
|
Value: "https://www.example.com?a=b",
|
||||||
|
MaxAge: 300,
|
||||||
|
Secure: true,
|
||||||
|
HttpOnly: true,
|
||||||
|
Raw: "_pomerium_signed_out_redirect_uri=https://www.example.com?a=b; Max-Age=300; HttpOnly; Secure",
|
||||||
|
}}, w.Result().Cookies())
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -96,56 +96,26 @@ func New(ctx context.Context, o *oauth.Options, options ...Option) (*Provider, e
|
||||||
return p, nil
|
return p, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSignInURL returns the url of the provider's OAuth 2.0 consent page
|
// SignIn redirects to the url of the provider's OAuth 2.0 consent page
|
||||||
// that asks for permissions for the required scopes explicitly.
|
// that asks for permissions for the required scopes explicitly.
|
||||||
//
|
//
|
||||||
// State is a token to protect the user from CSRF attacks. You must
|
// State is a token to protect the user from CSRF attacks. You must
|
||||||
// always provide a non-empty string and validate that it matches the
|
// always provide a non-empty string and validate that it matches the
|
||||||
// the state query parameter on your redirect callback.
|
// the state query parameter on your redirect callback.
|
||||||
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
|
||||||
func (p *Provider) GetSignInURL(state string) (string, error) {
|
func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) error {
|
||||||
oa, err := p.GetOauthConfig()
|
oa, err := p.GetOauthConfig()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
opts := defaultAuthCodeOptions
|
opts := defaultAuthCodeOptions
|
||||||
for k, v := range p.AuthCodeOptions {
|
for k, v := range p.AuthCodeOptions {
|
||||||
opts = append(opts, oauth2.SetAuthURLParam(k, v))
|
opts = append(opts, oauth2.SetAuthURLParam(k, v))
|
||||||
}
|
}
|
||||||
return oa.AuthCodeURL(state, opts...), nil
|
signInURL := oa.AuthCodeURL(state, opts...)
|
||||||
}
|
httputil.Redirect(w, r, signInURL, http.StatusFound)
|
||||||
|
return nil
|
||||||
// GetSignOutURL returns the EndSessionURL endpoint to allow a logout
|
|
||||||
// session to be initiated.
|
|
||||||
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
|
|
||||||
func (p *Provider) GetSignOutURL(idTokenHint, redirectToURL string) (string, error) {
|
|
||||||
_, err := p.GetProvider()
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if p.EndSessionURL == "" {
|
|
||||||
return "", ErrSignoutNotImplemented
|
|
||||||
}
|
|
||||||
|
|
||||||
endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
params := endSessionURL.Query()
|
|
||||||
if idTokenHint != "" {
|
|
||||||
params.Add("id_token_hint", idTokenHint)
|
|
||||||
}
|
|
||||||
if oa, err := p.GetOauthConfig(); err == nil {
|
|
||||||
params.Add("client_id", oa.ClientID)
|
|
||||||
}
|
|
||||||
if redirectToURL != "" {
|
|
||||||
params.Add("post_logout_redirect_uri", redirectToURL)
|
|
||||||
}
|
|
||||||
endSessionURL.RawQuery = params.Encode()
|
|
||||||
return endSessionURL.String(), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Authenticate converts an authorization code returned from the identity
|
// Authenticate converts an authorization code returned from the identity
|
||||||
|
@ -340,3 +310,38 @@ func (p *Provider) GetOauthConfig() (*oauth2.Config, error) {
|
||||||
}
|
}
|
||||||
return p.cfg.getOauthConfig(pp), nil
|
return p.cfg.getOauthConfig(pp), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SignOut uses the EndSessionURL endpoint to allow a logout session to be initiated.
|
||||||
|
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
|
||||||
|
func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error {
|
||||||
|
_, err := p.GetProvider()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.EndSessionURL == "" {
|
||||||
|
return ErrSignoutNotImplemented
|
||||||
|
}
|
||||||
|
|
||||||
|
endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
params := endSessionURL.Query()
|
||||||
|
if idTokenHint != "" {
|
||||||
|
params.Add("id_token_hint", idTokenHint)
|
||||||
|
}
|
||||||
|
if oa, err := p.GetOauthConfig(); err == nil {
|
||||||
|
params.Add("client_id", oa.ClientID)
|
||||||
|
}
|
||||||
|
if redirectToURL != "" {
|
||||||
|
params.Add("post_logout_redirect_uri", redirectToURL)
|
||||||
|
} else if authenticateSignedOutURL != "" {
|
||||||
|
params.Add("post_logout_redirect_uri", authenticateSignedOutURL)
|
||||||
|
}
|
||||||
|
endSessionURL.RawQuery = params.Encode()
|
||||||
|
|
||||||
|
httputil.Redirect(w, r, endSessionURL.String(), http.StatusFound)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -5,6 +5,7 @@ package identity
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
|
@ -28,10 +29,11 @@ type Authenticator interface {
|
||||||
Authenticate(context.Context, string, identity.State) (*oauth2.Token, error)
|
Authenticate(context.Context, string, identity.State) (*oauth2.Token, error)
|
||||||
Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error)
|
Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error)
|
||||||
Revoke(context.Context, *oauth2.Token) error
|
Revoke(context.Context, *oauth2.Token) error
|
||||||
GetSignInURL(state string) (string, error)
|
|
||||||
GetSignOutURL(idTokenHint, redirectToURL string) (string, error)
|
|
||||||
Name() string
|
Name() string
|
||||||
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error
|
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error
|
||||||
|
|
||||||
|
SignIn(w http.ResponseWriter, r *http.Request, state string) error
|
||||||
|
SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewAuthenticator returns a new identity provider based on its name.
|
// NewAuthenticator returns a new identity provider based on its name.
|
||||||
|
|
Loading…
Add table
Reference in a new issue