core/authenticate: refactor identity authenticators to initiate redirect (#4858)

* core/authenticate: refactor identity authenticators to initiate redirect, use cookie for redirect url for cognito

* set secure and http only, update test
This commit is contained in:
Caleb Doxsey 2023-12-19 12:04:23 -07:00 committed by GitHub
parent 4c15b202d1
commit 3adbc65d37
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
14 changed files with 237 additions and 125 deletions

View file

@ -223,14 +223,14 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
signOutURL = uri signOutURL = uri
} else if signOutRedirectURL != nil { } else if signOutRedirectURL != nil {
signOutURL = signOutRedirectURL.String() signOutURL = signOutRedirectURL.String()
} else {
signOutURL = authenticateURL.ResolveReference(&url.URL{
Path: "/.pomerium/signed_out",
}).String()
} }
if idpSignOutURL, err := authenticator.GetSignOutURL(rawIDToken, signOutURL); err == nil { authenticateSignedOutURL := authenticateURL.ResolveReference(&url.URL{
signOutURL = idpSignOutURL Path: "/.pomerium/signed_out",
}).String()
if err := authenticator.SignOut(w, r, rawIDToken, authenticateSignedOutURL, signOutURL); err == nil {
return nil
} else if !errors.Is(err, oidc.ErrSignoutNotImplemented) { } else if !errors.Is(err, oidc.ErrSignoutNotImplemented) {
log.Warn(r.Context()).Err(err).Msg("authenticate: failed to get sign out url for authenticator") log.Warn(r.Context()).Err(err).Msg("authenticate: failed to get sign out url for authenticator")
} }
@ -275,12 +275,12 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b) enc := cryptutil.Encrypt(state.cookieCipher, []byte(redirectURL.String()), b)
b = append(b, enc...) b = append(b, enc...)
encodedState := base64.URLEncoding.EncodeToString(b) encodedState := base64.URLEncoding.EncodeToString(b)
signinURL, err := authenticator.GetSignInURL(encodedState)
err = authenticator.SignIn(w, r, encodedState)
if err != nil { if err != nil {
return httputil.NewError(http.StatusInternalServerError, return httputil.NewError(http.StatusInternalServerError,
fmt.Errorf("failed to get sign in url: %w", err)) fmt.Errorf("failed to sign in: %w", err))
} }
httputil.Redirect(w, r, signinURL, http.StatusFound)
return nil return nil
} }

View file

@ -137,7 +137,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"", "",
"sig", "sig",
"ts", "ts",
identity.MockProvider{GetSignOutURLResponse: "https://microsoft.com"}, identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented},
&mstore.Store{Encrypted: true, Session: &sessions.State{}}, &mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound, http.StatusFound,
"", "",
@ -150,7 +150,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"https://signout-redirect-url.example.com", "https://signout-redirect-url.example.com",
"sig", "sig",
"ts", "ts",
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented}, identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented},
&mstore.Store{Encrypted: true, Session: &sessions.State{}}, &mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound, http.StatusFound,
"", "",
@ -163,7 +163,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"", "",
"sig", "sig",
"ts", "ts",
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")}, identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
&mstore.Store{Encrypted: true, Session: &sessions.State{}}, &mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound, http.StatusFound,
"", "",
@ -176,7 +176,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"", "",
"sig", "sig",
"ts", "ts",
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")}, identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented, RevokeError: errors.New("OH NO")},
&mstore.Store{Encrypted: true, Session: &sessions.State{}}, &mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound, http.StatusFound,
"", "",
@ -189,7 +189,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
"", "",
"sig", "sig",
"ts", "ts",
identity.MockProvider{GetSignOutURLError: oidc.ErrSignoutNotImplemented}, identity.MockProvider{SignOutError: oidc.ErrSignoutNotImplemented},
&mstore.Store{Encrypted: true, Session: &sessions.State{}}, &mstore.Store{Encrypted: true, Session: &sessions.State{}},
http.StatusFound, http.StatusFound,
"", "",
@ -401,7 +401,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
errors.New("hi"), errors.New("hi"),
identity.MockProvider{}, identity.MockProvider{},
http.StatusFound, http.StatusOK,
}, },
{ {
"expired,refresh error", "expired,refresh error",
@ -409,7 +409,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
&mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, &mstore.Store{Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
sessions.ErrExpired, sessions.ErrExpired,
identity.MockProvider{RefreshError: errors.New("error")}, identity.MockProvider{RefreshError: errors.New("error")},
http.StatusFound, http.StatusOK,
}, },
{ {
"expired,save error", "expired,save error",
@ -417,7 +417,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
&mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}}, &mstore.Store{SaveError: errors.New("error"), Session: &sessions.State{IdentityProviderID: idp.GetId(), ID: "xyz"}},
sessions.ErrExpired, sessions.ErrExpired,
identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}, identity.MockProvider{RefreshResponse: oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}},
http.StatusFound, http.StatusOK,
}, },
{ {
"expired XHR,refresh error", "expired XHR,refresh error",

View file

@ -18,6 +18,12 @@ func (data SignedOutData) ToJSON() map[string]interface{} {
// SignedOut returns a handler that renders the signed out page. // SignedOut returns a handler that renders the signed out page.
func SignedOut(data SignedOutData) http.Handler { func SignedOut(data SignedOutData) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if redirectURI, ok := httputil.GetSignedOutRedirectURICookie(w, r); ok {
httputil.Redirect(w, r, redirectURI, http.StatusFound)
return nil
}
// otherwise show the signed-out page
return ui.ServePage(w, r, "SignedOut", data.ToJSON()) return ui.ServePage(w, r, "SignedOut", data.ToJSON())
}) })
} }

View file

@ -0,0 +1,42 @@
package handlers_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/handlers"
)
func TestSignedOut(t *testing.T) {
t.Parallel()
t.Run("ok", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/.pomerium/signed_out", nil)
handlers.SignedOut(handlers.SignedOutData{}).ServeHTTP(w, r)
assert.Equal(t, http.StatusOK, w.Code)
})
t.Run("redirect", func(t *testing.T) {
t.Parallel()
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "/.pomerium/signed_out", nil)
r.AddCookie(&http.Cookie{
Name: "_pomerium_signed_out_redirect_uri",
Value: "https://www.google.com",
})
handlers.SignedOut(handlers.SignedOutData{}).ServeHTTP(w, r)
assert.Equal(t, http.StatusFound, w.Code)
assert.Equal(t, "https://www.google.com", w.Header().Get("Location"))
})
}

View file

@ -0,0 +1,28 @@
package httputil
import "net/http"
const signedOutRedirectURICookieName = "_pomerium_signed_out_redirect_uri"
// GetSignedOutRedirectURICookie gets the redirect uri cookie for the signed-out page.
func GetSignedOutRedirectURICookie(w http.ResponseWriter, r *http.Request) (string, bool) {
cookie, err := r.Cookie(signedOutRedirectURICookieName)
if err != nil {
return "", false
}
cookie.MaxAge = -1
http.SetCookie(w, cookie)
return cookie.Value, true
}
// SetSignedOutRedirectURICookie sets the redirect uri cookie for the signed-out page.
func SetSignedOutRedirectURICookie(w http.ResponseWriter, redirectURI string) {
http.SetCookie(w, &http.Cookie{
Name: signedOutRedirectURICookieName,
Value: redirectURI,
MaxAge: 5 * 60,
HttpOnly: true,
Secure: true,
})
}

View file

@ -2,6 +2,7 @@ package identity
import ( import (
"context" "context"
"net/http"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -10,15 +11,14 @@ import (
// MockProvider provides a mocked implementation of the providers interface. // MockProvider provides a mocked implementation of the providers interface.
type MockProvider struct { type MockProvider struct {
AuthenticateResponse oauth2.Token AuthenticateResponse oauth2.Token
AuthenticateError error AuthenticateError error
RefreshResponse oauth2.Token RefreshResponse oauth2.Token
RefreshError error RefreshError error
RevokeError error RevokeError error
GetSignInURLResponse string UpdateUserInfoError error
GetSignOutURLResponse string SignInError error
GetSignOutURLError error SignOutError error
UpdateUserInfoError error
} }
// Authenticate is a mocked providers function. // Authenticate is a mocked providers function.
@ -36,14 +36,6 @@ func (mp MockProvider) Revoke(_ context.Context, _ *oauth2.Token) error {
return mp.RevokeError return mp.RevokeError
} }
// GetSignInURL is a mocked providers function.
func (mp MockProvider) GetSignInURL(_ string) (string, error) { return mp.GetSignInURLResponse, nil }
// GetSignOutURL is a mocked providers function.
func (mp MockProvider) GetSignOutURL(_, _ string) (string, error) {
return mp.GetSignOutURLResponse, mp.GetSignOutURLError
}
// UpdateUserInfo is a mocked providers function. // UpdateUserInfo is a mocked providers function.
func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error { func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ interface{}) error {
return mp.UpdateUserInfoError return mp.UpdateUserInfoError
@ -53,3 +45,13 @@ func (mp MockProvider) UpdateUserInfo(_ context.Context, _ *oauth2.Token, _ inte
func (mp MockProvider) Name() string { func (mp MockProvider) Name() string {
return "mock" return "mock"
} }
// SignOut is a mocked providers function.
func (mp MockProvider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
return mp.SignOutError
}
// SignIn is a mocked providers function.
func (mp MockProvider) SignIn(_ http.ResponseWriter, _ *http.Request, _ string) error {
return mp.SignInError
}

View file

@ -83,31 +83,6 @@ func (p *Provider) Name() string {
return Name return Name
} }
// GetSignInURL returns the url of the provider's OAuth 2.0 consent page
// that asks for permissions for the required scopes explicitly.
//
// State is a token to protect the user from CSRF attacks. You must
// always provide a non-empty string and validate that it matches the
// the state query parameter on your redirect callback.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
func (p *Provider) GetSignInURL(state string) (string, error) {
opts := []oauth2.AuthCodeOption{}
for k, v := range p.authCodeOptions {
opts = append(opts, oauth2.SetAuthURLParam(k, v))
}
authURL := p.oauth.AuthCodeURL(state, opts...)
// Apple is very picky here and we need to use %20 instead of +
authURL = strings.ReplaceAll(authURL, "+", "%20")
return authURL, nil
}
// GetSignOutURL is not implemented.
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
return "", oidc.ErrSignoutNotImplemented
}
// Authenticate converts an authorization code returned from the identity // Authenticate converts an authorization code returned from the identity
// provider into a token which is then converted into a user session. // provider into a token which is then converted into a user session.
func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) { func (p *Provider) Authenticate(ctx context.Context, code string, v identity.State) (*oauth2.Token, error) {
@ -181,3 +156,29 @@ func (p *Provider) UpdateUserInfo(_ context.Context, t *oauth2.Token, v interfac
return idToken.UnsafeClaimsWithoutVerification(v) return idToken.UnsafeClaimsWithoutVerification(v)
} }
// SignIn redirects to the url of the provider's OAuth 2.0 consent page
// that asks for permissions for the required scopes explicitly.
//
// State is a token to protect the user from CSRF attacks. You must
// always provide a non-empty string and validate that it matches the
// the state query parameter on your redirect callback.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) error {
opts := []oauth2.AuthCodeOption{}
for k, v := range p.authCodeOptions {
opts = append(opts, oauth2.SetAuthURLParam(k, v))
}
authURL := p.oauth.AuthCodeURL(state, opts...)
// Apple is very picky here and we need to use %20 instead of +
authURL = strings.ReplaceAll(authURL, "+", "%20")
httputil.Redirect(w, r, authURL, http.StatusFound)
return nil
}
// SignOut is not implemented.
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
return oidc.ErrSignoutNotImplemented
}

View file

@ -239,18 +239,20 @@ func (p *Provider) Revoke(ctx context.Context, token *oauth2.Token) error {
return nil return nil
} }
// GetSignInURL returns a URL to OAuth 2.0 provider's consent page
// that asks for permissions for the required scopes explicitly.
func (p *Provider) GetSignInURL(state string) (string, error) {
return p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline), nil
}
// GetSignOutURL is not implemented.
func (p *Provider) GetSignOutURL(_, _ string) (string, error) {
return "", oidc.ErrSignoutNotImplemented
}
// Name returns the provider name. // Name returns the provider name.
func (p *Provider) Name() string { func (p *Provider) Name() string {
return Name return Name
} }
// SignIn redirects to the OAuth 2.0 provider's consent page
// that asks for permissions for the required scopes explicitly.
func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) error {
signInURL := p.Oauth.AuthCodeURL(state, oauth2.AccessTypeOffline)
httputil.Redirect(w, r, signInURL, http.StatusFound)
return nil
}
// SignOut is not implemented.
func (p *Provider) SignOut(_ http.ResponseWriter, _ *http.Request, _, _, _ string) error {
return oidc.ErrSignoutNotImplemented
}

View file

@ -6,9 +6,11 @@ package auth0
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"strings" "strings"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
@ -50,16 +52,16 @@ func (p *Provider) Name() string {
return Name return Name
} }
// GetSignOutURL implements logout as described in https://auth0.com/docs/api/authentication#logout. // SignOut implements logout as described in https://auth0.com/docs/api/authentication#logout.
func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) { func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, _, authenticateSignedOutURL, redirectToURL string) error {
oa, err := p.GetOauthConfig() oa, err := p.GetOauthConfig()
if err != nil { if err != nil {
return "", fmt.Errorf("error getting auth0 oauth config: %w", err) return fmt.Errorf("error getting auth0 oauth config: %w", err)
} }
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL) authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
if err != nil { if err != nil {
return "", fmt.Errorf("error parsing auth0 endpoint auth url: %w", err) return fmt.Errorf("error parsing auth0 endpoint auth url: %w", err)
} }
logoutQuery := url.Values{ logoutQuery := url.Values{
@ -67,10 +69,14 @@ func (p *Provider) GetSignOutURL(_, redirectToURL string) (string, error) {
} }
if redirectToURL != "" { if redirectToURL != "" {
logoutQuery.Set("returnTo", redirectToURL) logoutQuery.Set("returnTo", redirectToURL)
} else if authenticateSignedOutURL != "" {
logoutQuery.Set("returnTo", authenticateSignedOutURL)
} }
logoutURL := authURL.ResolveReference(&url.URL{ logoutURL := authURL.ResolveReference(&url.URL{
Path: "/v2/logout", Path: "/v2/logout",
RawQuery: logoutQuery.Encode(), RawQuery: logoutQuery.Encode(),
}) })
return logoutURL.String(), nil
httputil.Redirect(w, r, logoutURL.String(), http.StatusFound)
return nil
} }

View file

@ -53,9 +53,11 @@ func TestProvider(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, p) require.NotNil(t, p)
t.Run("GetSignOutURL", func(t *testing.T) { t.Run("SignOut", func(t *testing.T) {
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b") w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://authenticate.example.com/.pomerium/sign_out", nil)
err := p.SignOut(w, r, "", "https://authenticate.example.com/.pomerium/signed_out", "https://www.example.com?a=b")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL) assert.Equal(t, srv.URL+"/v2/logout?client_id=CLIENT_ID&returnTo=https%3A%2F%2Fwww.example.com%3Fa%3Db", w.Header().Get("Location"))
}) })
} }

View file

@ -4,8 +4,10 @@ package cognito
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"net/url" "net/url"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity/oauth" "github.com/pomerium/pomerium/internal/identity/oauth"
pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc" pom_oidc "github.com/pomerium/pomerium/internal/identity/oidc"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
@ -51,27 +53,31 @@ func New(ctx context.Context, opts *oauth.Options) (*Provider, error) {
return &p, nil return &p, nil
} }
// GetSignOutURL gets the sign out URL according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html. // SignOut implements sign out according to https://docs.aws.amazon.com/cognito/latest/developerguide/logout-endpoint.html.
func (p *Provider) GetSignOutURL(_, returnToURL string) (string, error) { func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, _, authenticateSignedOutURL, returnToURL string) error {
oa, err := p.GetOauthConfig() oa, err := p.GetOauthConfig()
if err != nil { if err != nil {
return "", fmt.Errorf("error getting cognito oauth config: %w", err) return fmt.Errorf("error getting cognito oauth config: %w", err)
} }
authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL) authURL, err := urlutil.ParseAndValidateURL(oa.Endpoint.AuthURL)
if err != nil { if err != nil {
return "", fmt.Errorf("error getting cognito endpoint auth url: %w", err) return fmt.Errorf("error getting cognito endpoint auth url: %w", err)
} }
logOutQuery := url.Values{ logOutQuery := url.Values{
"client_id": []string{oa.ClientID}, "client_id": []string{oa.ClientID},
} }
if authenticateSignedOutURL != "" {
logOutQuery.Set("logout_uri", authenticateSignedOutURL)
}
if returnToURL != "" { if returnToURL != "" {
logOutQuery.Set("logout_uri", returnToURL) httputil.SetSignedOutRedirectURICookie(w, returnToURL)
} }
logOutURL := authURL.ResolveReference(&url.URL{ logOutURL := authURL.ResolveReference(&url.URL{
Path: "/logout", Path: "/logout",
RawQuery: logOutQuery.Encode(), RawQuery: logOutQuery.Encode(),
}) })
return logOutURL.String(), nil httputil.Redirect(w, r, logOutURL.String(), http.StatusFound)
return nil
} }

View file

@ -53,9 +53,19 @@ func TestProvider(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, p) require.NotNil(t, p)
t.Run("GetSignOutURL", func(t *testing.T) { t.Run("SignOut", func(t *testing.T) {
signOutURL, err := p.GetSignOutURL("", "https://www.example.com?a=b") w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodGet, "https://authenticate.example.com/.pomerium/sign_out", nil)
err := p.SignOut(w, r, "", "https://authenticate.example.com/.pomerium/signed_out", "https://www.example.com?a=b")
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fwww.example.com%3Fa%3Db", signOutURL) assert.Equal(t, srv.URL+"/logout?client_id=CLIENT_ID&logout_uri=https%3A%2F%2Fauthenticate.example.com%2F.pomerium%2Fsigned_out", w.Header().Get("Location"))
assert.Equal(t, []*http.Cookie{{
Name: "_pomerium_signed_out_redirect_uri",
Value: "https://www.example.com?a=b",
MaxAge: 300,
Secure: true,
HttpOnly: true,
Raw: "_pomerium_signed_out_redirect_uri=https://www.example.com?a=b; Max-Age=300; HttpOnly; Secure",
}}, w.Result().Cookies())
}) })
} }

View file

@ -96,56 +96,26 @@ func New(ctx context.Context, o *oauth.Options, options ...Option) (*Provider, e
return p, nil return p, nil
} }
// GetSignInURL returns the url of the provider's OAuth 2.0 consent page // SignIn redirects to the url of the provider's OAuth 2.0 consent page
// that asks for permissions for the required scopes explicitly. // that asks for permissions for the required scopes explicitly.
// //
// State is a token to protect the user from CSRF attacks. You must // State is a token to protect the user from CSRF attacks. You must
// always provide a non-empty string and validate that it matches the // always provide a non-empty string and validate that it matches the
// the state query parameter on your redirect callback. // the state query parameter on your redirect callback.
// See http://tools.ietf.org/html/rfc6749#section-10.12 for more info. // See http://tools.ietf.org/html/rfc6749#section-10.12 for more info.
func (p *Provider) GetSignInURL(state string) (string, error) { func (p *Provider) SignIn(w http.ResponseWriter, r *http.Request, state string) error {
oa, err := p.GetOauthConfig() oa, err := p.GetOauthConfig()
if err != nil { if err != nil {
return "", err return err
} }
opts := defaultAuthCodeOptions opts := defaultAuthCodeOptions
for k, v := range p.AuthCodeOptions { for k, v := range p.AuthCodeOptions {
opts = append(opts, oauth2.SetAuthURLParam(k, v)) opts = append(opts, oauth2.SetAuthURLParam(k, v))
} }
return oa.AuthCodeURL(state, opts...), nil signInURL := oa.AuthCodeURL(state, opts...)
} httputil.Redirect(w, r, signInURL, http.StatusFound)
return nil
// GetSignOutURL returns the EndSessionURL endpoint to allow a logout
// session to be initiated.
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
func (p *Provider) GetSignOutURL(idTokenHint, redirectToURL string) (string, error) {
_, err := p.GetProvider()
if err != nil {
return "", err
}
if p.EndSessionURL == "" {
return "", ErrSignoutNotImplemented
}
endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL)
if err != nil {
return "", err
}
params := endSessionURL.Query()
if idTokenHint != "" {
params.Add("id_token_hint", idTokenHint)
}
if oa, err := p.GetOauthConfig(); err == nil {
params.Add("client_id", oa.ClientID)
}
if redirectToURL != "" {
params.Add("post_logout_redirect_uri", redirectToURL)
}
endSessionURL.RawQuery = params.Encode()
return endSessionURL.String(), nil
} }
// Authenticate converts an authorization code returned from the identity // Authenticate converts an authorization code returned from the identity
@ -340,3 +310,38 @@ func (p *Provider) GetOauthConfig() (*oauth2.Config, error) {
} }
return p.cfg.getOauthConfig(pp), nil return p.cfg.getOauthConfig(pp), nil
} }
// SignOut uses the EndSessionURL endpoint to allow a logout session to be initiated.
// https://openid.net/specs/openid-connect-frontchannel-1_0.html#RPInitiated
func (p *Provider) SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error {
_, err := p.GetProvider()
if err != nil {
return err
}
if p.EndSessionURL == "" {
return ErrSignoutNotImplemented
}
endSessionURL, err := urlutil.ParseAndValidateURL(p.EndSessionURL)
if err != nil {
return err
}
params := endSessionURL.Query()
if idTokenHint != "" {
params.Add("id_token_hint", idTokenHint)
}
if oa, err := p.GetOauthConfig(); err == nil {
params.Add("client_id", oa.ClientID)
}
if redirectToURL != "" {
params.Add("post_logout_redirect_uri", redirectToURL)
} else if authenticateSignedOutURL != "" {
params.Add("post_logout_redirect_uri", authenticateSignedOutURL)
}
endSessionURL.RawQuery = params.Encode()
httputil.Redirect(w, r, endSessionURL.String(), http.StatusFound)
return nil
}

View file

@ -5,6 +5,7 @@ package identity
import ( import (
"context" "context"
"fmt" "fmt"
"net/http"
"golang.org/x/oauth2" "golang.org/x/oauth2"
@ -28,10 +29,11 @@ type Authenticator interface {
Authenticate(context.Context, string, identity.State) (*oauth2.Token, error) Authenticate(context.Context, string, identity.State) (*oauth2.Token, error)
Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error) Refresh(context.Context, *oauth2.Token, identity.State) (*oauth2.Token, error)
Revoke(context.Context, *oauth2.Token) error Revoke(context.Context, *oauth2.Token) error
GetSignInURL(state string) (string, error)
GetSignOutURL(idTokenHint, redirectToURL string) (string, error)
Name() string Name() string
UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error UpdateUserInfo(ctx context.Context, t *oauth2.Token, v interface{}) error
SignIn(w http.ResponseWriter, r *http.Request, state string) error
SignOut(w http.ResponseWriter, r *http.Request, idTokenHint, authenticateSignedOutURL, redirectToURL string) error
} }
// NewAuthenticator returns a new identity provider based on its name. // NewAuthenticator returns a new identity provider based on its name.