mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-04 01:09:36 +02:00
authenticateflow: move stateless flow logic (#4820)
Consolidate all logic specific to the stateless authenticate flow into a a new Stateless type in a new package internal/authenticateflow. This is in preparation for adding a new Stateful type implementing the older stateful authenticate flow (from Pomerium v0.20 and previous). This change is intended as a pure refactoring of existing logic, with no changes in functionality.
This commit is contained in:
parent
3b2bdd059a
commit
b7896b3153
18 changed files with 823 additions and 461 deletions
|
@ -45,15 +45,16 @@ type Authenticate struct {
|
|||
|
||||
// New validates and creates a new authenticate service from a set of Options.
|
||||
func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
|
||||
authenticateConfig := getAuthenticateConfig(options...)
|
||||
a := &Authenticate{
|
||||
cfg: getAuthenticateConfig(options...),
|
||||
cfg: authenticateConfig,
|
||||
options: config.NewAtomicOptions(),
|
||||
state: atomicutil.NewValue(newAuthenticateState()),
|
||||
}
|
||||
|
||||
a.options.Store(cfg.Options)
|
||||
|
||||
state, err := newAuthenticateStateFromConfig(cfg)
|
||||
state, err := newAuthenticateStateFromConfig(cfg, authenticateConfig)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -69,7 +70,7 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
|||
}
|
||||
|
||||
a.options.Store(cfg.Options)
|
||||
if state, err := newAuthenticateStateFromConfig(cfg); err != nil {
|
||||
if state, err := newAuthenticateStateFromConfig(cfg, a.cfg); err != nil {
|
||||
log.Error(ctx).Err(err).Msg("authenticate: failed to update state")
|
||||
} else {
|
||||
a.state.Store(state)
|
||||
|
|
|
@ -2,6 +2,7 @@ package authenticate
|
|||
|
||||
import (
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
)
|
||||
|
@ -9,7 +10,7 @@ import (
|
|||
type authenticateConfig struct {
|
||||
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)
|
||||
profileTrimFn func(*identitypb.Profile)
|
||||
authEventFn AuthEventFn
|
||||
authEventFn authenticateflow.AuthEventFn
|
||||
}
|
||||
|
||||
// An Option customizes the Authenticate config.
|
||||
|
@ -39,7 +40,7 @@ func WithProfileTrimFn(profileTrimFn func(*identitypb.Profile)) Option {
|
|||
}
|
||||
|
||||
// WithOnAuthenticationEventHook sets the authEventFn function in the config
|
||||
func WithOnAuthenticationEventHook(fn AuthEventFn) Option {
|
||||
func WithOnAuthenticationEventHook(fn authenticateflow.AuthEventFn) Option {
|
||||
return func(cfg *authenticateConfig) {
|
||||
cfg.authEventFn = fn
|
||||
}
|
||||
|
|
|
@ -1,101 +0,0 @@
|
|||
package authenticate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
"github.com/pomerium/pomerium/pkg/hpke"
|
||||
)
|
||||
|
||||
// AuthEventKind is the type of an authentication event
|
||||
type AuthEventKind string
|
||||
|
||||
const (
|
||||
// AuthEventSignInRequest is an authentication event for a sign in request before IdP redirect
|
||||
AuthEventSignInRequest AuthEventKind = "sign_in_request"
|
||||
// AuthEventSignInComplete is an authentication event for a sign in request after IdP redirect
|
||||
AuthEventSignInComplete AuthEventKind = "sign_in_complete"
|
||||
)
|
||||
|
||||
// AuthEvent is a log event for an authentication event
|
||||
type AuthEvent struct {
|
||||
// Event is the type of authentication event
|
||||
Event AuthEventKind
|
||||
// IP is the IP address of the client
|
||||
IP string
|
||||
// Version is the version of the Pomerium client
|
||||
Version string
|
||||
// RequestUUID is the UUID of the request
|
||||
RequestUUID string
|
||||
// PubKey is the public key of the client
|
||||
PubKey string
|
||||
// UID is the IdP user ID of the user
|
||||
UID *string
|
||||
// Email is the email of the user
|
||||
Email *string
|
||||
// Domain is the domain of the request (for sign in complete events)
|
||||
Domain *string
|
||||
}
|
||||
|
||||
// AuthEventFn is a function that handles an authentication event
|
||||
type AuthEventFn func(context.Context, AuthEvent)
|
||||
|
||||
func (a *Authenticate) logAuthenticateEvent(r *http.Request, profile *identity.Profile) {
|
||||
if a.cfg.authEventFn == nil {
|
||||
return
|
||||
}
|
||||
|
||||
state := a.state.Load()
|
||||
ctx := r.Context()
|
||||
pub, params, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Err(err).Msg("log authenticate event: failed to decrypt request params")
|
||||
}
|
||||
|
||||
evt := AuthEvent{
|
||||
IP: httputil.GetClientIP(r),
|
||||
Version: params.Get(urlutil.QueryVersion),
|
||||
RequestUUID: params.Get(urlutil.QueryRequestUUID),
|
||||
PubKey: pub.String(),
|
||||
}
|
||||
|
||||
if uid := getUserClaim(profile, "sub"); uid != nil {
|
||||
evt.UID = uid
|
||||
}
|
||||
if email := getUserClaim(profile, "email"); email != nil {
|
||||
evt.Email = email
|
||||
}
|
||||
|
||||
if evt.UID != nil {
|
||||
evt.Event = AuthEventSignInComplete
|
||||
} else {
|
||||
evt.Event = AuthEventSignInRequest
|
||||
}
|
||||
|
||||
if redirectURL, err := url.Parse(params.Get(urlutil.QueryRedirectURI)); err == nil {
|
||||
domain := redirectURL.Hostname()
|
||||
evt.Domain = &domain
|
||||
}
|
||||
|
||||
a.cfg.authEventFn(ctx, evt)
|
||||
}
|
||||
|
||||
func getUserClaim(profile *identity.Profile, field string) *string {
|
||||
if profile == nil {
|
||||
return nil
|
||||
}
|
||||
if profile.Claims == nil {
|
||||
return nil
|
||||
}
|
||||
val, ok := profile.Claims.Fields[field]
|
||||
if !ok || val == nil {
|
||||
return nil
|
||||
}
|
||||
txt := val.GetStringValue()
|
||||
return &txt
|
||||
}
|
|
@ -3,7 +3,6 @@ package authenticate
|
|||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
@ -14,9 +13,9 @@ import (
|
|||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/rs/cors"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/csrf"
|
||||
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
|
@ -27,7 +26,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/hpke"
|
||||
)
|
||||
|
||||
// Handler returns the authenticate service's handler chain.
|
||||
|
@ -76,7 +74,7 @@ func (a *Authenticate) mountDashboard(r *mux.Router) {
|
|||
c := cors.New(cors.Options{
|
||||
AllowOriginRequestFunc: func(r *http.Request, _ string) bool {
|
||||
state := a.state.Load()
|
||||
err := middleware.ValidateRequestURL(a.getExternalRequest(r), state.sharedKey)
|
||||
err := state.flow.VerifyAuthenticateSignature(r)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: origin blocked")
|
||||
}
|
||||
|
@ -139,21 +137,11 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
|||
return a.reauthenticateOrFail(w, r, err)
|
||||
}
|
||||
|
||||
profile, err := a.loadIdentityProfile(r, state.cookieCipher)
|
||||
if err != nil {
|
||||
if err := state.flow.VerifySession(ctx, r, sessionState); err != nil {
|
||||
log.FromRequest(r).Info().
|
||||
Err(err).
|
||||
Str("idp_id", idpID).
|
||||
Msg("authenticate: identity profile load error")
|
||||
return a.reauthenticateOrFail(w, r, err)
|
||||
}
|
||||
|
||||
err = a.validateIdentityProfile(ctx, profile)
|
||||
if err != nil {
|
||||
log.FromRequest(r).Info().
|
||||
Err(err).
|
||||
Str("idp_id", idpID).
|
||||
Msg("authenticate: invalid identity profile")
|
||||
Msg("authenticate: couldn't verify session")
|
||||
return a.reauthenticateOrFail(w, r, err)
|
||||
}
|
||||
|
||||
|
@ -176,62 +164,20 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
|
||||
state := a.state.Load()
|
||||
|
||||
if err := r.ParseForm(); err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
proxyPublicKey, requestParams, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
idpID := requestParams.Get(urlutil.QueryIdentityProviderID)
|
||||
|
||||
s, err := a.getSessionFromCtx(ctx)
|
||||
if err != nil {
|
||||
state.sessionStore.ClearSession(w, r)
|
||||
return err
|
||||
}
|
||||
|
||||
// start over if this is a different identity provider
|
||||
if s == nil || s.IdentityProviderID != idpID {
|
||||
s = sessions.NewState(idpID)
|
||||
}
|
||||
|
||||
// re-persist the session, useful when session was evicted from session
|
||||
if err := state.sessionStore.SaveSession(w, r, s); err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
profile, err := a.loadIdentityProfile(r, state.cookieCipher)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
if a.cfg.profileTrimFn != nil {
|
||||
a.cfg.profileTrimFn(profile)
|
||||
}
|
||||
|
||||
a.logAuthenticateEvent(r, profile)
|
||||
|
||||
encryptURLValues := hpke.EncryptURLValuesV1
|
||||
if hpke.IsEncryptedURLV2(r.Form) {
|
||||
encryptURLValues = hpke.EncryptURLValuesV2
|
||||
}
|
||||
|
||||
redirectTo, err := urlutil.CallbackURL(state.hpkePrivateKey, proxyPublicKey, requestParams, profile, encryptURLValues)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
||||
httputil.Redirect(w, r, redirectTo, http.StatusFound)
|
||||
return nil
|
||||
return state.flow.SignIn(w, r, s)
|
||||
}
|
||||
|
||||
// SignOut signs the user out and attempts to revoke the user's identity session
|
||||
// Handles both GET and POST.
|
||||
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
||||
// check for an HMAC'd URL. If none is found, show a confirmation page.
|
||||
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
|
||||
err := a.state.Load().flow.VerifyAuthenticateSignature(r)
|
||||
if err != nil {
|
||||
authenticateURL, err := a.options.Load().GetAuthenticateURL()
|
||||
if err != nil {
|
||||
|
@ -319,7 +265,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
|
|||
return err
|
||||
}
|
||||
|
||||
a.logAuthenticateEvent(r, nil)
|
||||
state.flow.LogAuthenticateEvent(r)
|
||||
|
||||
state.sessionStore.ClearSession(w, r)
|
||||
redirectURL := state.redirectURL.ResolveReference(r.URL)
|
||||
|
@ -418,7 +364,7 @@ Or contact your administrator.
|
|||
`, redirectURL.String(), redirectURL.String()))
|
||||
}
|
||||
|
||||
idpID := a.getIdentityProviderIDForURLValues(redirectURL.Query())
|
||||
idpID := state.flow.GetIdentityProviderIDForURLValues(redirectURL.Query())
|
||||
|
||||
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
|
||||
if err != nil {
|
||||
|
@ -445,13 +391,9 @@ Or contact your administrator.
|
|||
newState.Audience = append(newState.Audience, nextRedirectURL.Hostname())
|
||||
}
|
||||
|
||||
// save the session and access token to the databroker
|
||||
profile, err := a.buildIdentityProfile(r, claims, accessToken)
|
||||
if err != nil {
|
||||
return nil, httputil.NewError(http.StatusInternalServerError, err)
|
||||
}
|
||||
if err := a.storeIdentityProfile(w, state.cookieCipher, profile); err != nil {
|
||||
log.Error(r.Context()).Err(err).Msg("failed to store identity profile")
|
||||
// save the session and access token to the databroker/cookie store
|
||||
if err := state.flow.PersistSession(ctx, w, &newState, claims, accessToken); err != nil {
|
||||
return nil, fmt.Errorf("failed saving new session: %w", err)
|
||||
}
|
||||
|
||||
// ... and the user state to local storage.
|
||||
|
@ -478,11 +420,12 @@ func (a *Authenticate) getSessionFromCtx(ctx context.Context) (*sessions.State,
|
|||
func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error {
|
||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.userInfo")
|
||||
defer span.End()
|
||||
r = r.WithContext(ctx)
|
||||
r = a.getExternalRequest(r)
|
||||
|
||||
options := a.options.Load()
|
||||
|
||||
r = r.WithContext(ctx)
|
||||
r = authenticateflow.GetExternalAuthenticateRequest(r, options)
|
||||
|
||||
// if we came in with a redirect URI, save it to a cookie so it doesn't expire with the HMAC
|
||||
if redirectURI := r.FormValue(urlutil.QueryRedirectURI); redirectURI != "" {
|
||||
u := urlutil.GetAbsoluteURL(r)
|
||||
|
@ -509,14 +452,9 @@ func (a *Authenticate) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
|||
s.ID = uuid.New().String()
|
||||
}
|
||||
|
||||
profile, _ := a.loadIdentityProfile(r, state.cookieCipher)
|
||||
|
||||
data := handlers.UserInfoData{
|
||||
CSRFToken: csrf.Token(r),
|
||||
Profile: profile,
|
||||
|
||||
BrandingOptions: a.options.Load().BrandingOptions,
|
||||
}
|
||||
data := state.flow.GetUserInfoData(r, s)
|
||||
data.CSRFToken = csrf.Token(r)
|
||||
data.BrandingOptions = a.options.Load().BrandingOptions
|
||||
return data
|
||||
}
|
||||
|
||||
|
@ -537,18 +475,7 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,
|
|||
return ""
|
||||
}
|
||||
|
||||
profile, err := a.loadIdentityProfile(r, a.state.Load().cookieCipher)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
oauthToken := new(oauth2.Token)
|
||||
_ = json.Unmarshal(profile.GetOauthToken(), oauthToken)
|
||||
if err := authenticator.Revoke(ctx, oauthToken); err != nil {
|
||||
log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token")
|
||||
}
|
||||
|
||||
return string(profile.GetIdToken())
|
||||
return state.flow.RevokeSession(ctx, r, authenticator, nil)
|
||||
}
|
||||
|
||||
// Callback handles the result of a successful call to the authenticate service
|
||||
|
@ -603,19 +530,5 @@ func (a *Authenticate) getIdentityProviderIDForRequest(r *http.Request) string {
|
|||
if err := r.ParseForm(); err != nil {
|
||||
return ""
|
||||
}
|
||||
return a.getIdentityProviderIDForURLValues(r.Form)
|
||||
}
|
||||
|
||||
func (a *Authenticate) getIdentityProviderIDForURLValues(vs url.Values) string {
|
||||
state := a.state.Load()
|
||||
idpID := ""
|
||||
if _, requestParams, err := hpke.DecryptURLValues(state.hpkePrivateKey, vs); err == nil {
|
||||
if idpID == "" {
|
||||
idpID = requestParams.Get(urlutil.QueryIdentityProviderID)
|
||||
}
|
||||
}
|
||||
if idpID == "" {
|
||||
idpID = vs.Get(urlutil.QueryIdentityProviderID)
|
||||
}
|
||||
return idpID
|
||||
return a.state.Load().flow.GetIdentityProviderIDForURLValues(r.Form)
|
||||
}
|
||||
|
|
|
@ -8,7 +8,6 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -24,6 +23,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/identity/oidc"
|
||||
|
@ -40,6 +40,7 @@ func testAuthenticate() *Authenticate {
|
|||
auth.state = atomicutil.NewValue(&authenticateState{
|
||||
redirectURL: redirectURL,
|
||||
cookieSecret: cryptutil.NewKey(),
|
||||
flow: new(stubFlow),
|
||||
})
|
||||
auth.options = config.NewAtomicOptions()
|
||||
auth.options.Store(&config.Options{
|
||||
|
@ -205,6 +206,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
state: atomicutil.NewValue(&authenticateState{
|
||||
sessionStore: tt.sessionStore,
|
||||
sharedEncoder: mock.Encoder{},
|
||||
flow: new(stubFlow),
|
||||
}),
|
||||
options: config.NewAtomicOptions(),
|
||||
}
|
||||
|
@ -301,6 +303,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
redirectURL: authURL,
|
||||
sessionStore: tt.session,
|
||||
cookieCipher: aead,
|
||||
flow: new(stubFlow),
|
||||
}),
|
||||
options: config.NewAtomicOptions(),
|
||||
}
|
||||
|
@ -414,6 +417,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
sessionStore: tt.session,
|
||||
cookieCipher: aead,
|
||||
sharedEncoder: signer,
|
||||
flow: new(stubFlow),
|
||||
}),
|
||||
options: config.NewAtomicOptions(),
|
||||
}
|
||||
|
@ -452,6 +456,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
var a Authenticate
|
||||
a.state = atomicutil.NewValue(&authenticateState{
|
||||
cookieSecret: cryptutil.NewKey(),
|
||||
flow: new(stubFlow),
|
||||
})
|
||||
a.options = config.NewAtomicOptions()
|
||||
a.options.Store(&config.Options{
|
||||
|
@ -467,36 +472,32 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
|
||||
now := time.Now()
|
||||
tests := []struct {
|
||||
name string
|
||||
url *url.URL
|
||||
method string
|
||||
sessionStore sessions.SessionStore
|
||||
wantCode int
|
||||
wantBody string
|
||||
name string
|
||||
url string
|
||||
validSignature bool
|
||||
sessionStore sessions.SessionStore
|
||||
wantCode int
|
||||
}{
|
||||
{
|
||||
"good",
|
||||
mustParseURL("/"),
|
||||
http.MethodGet,
|
||||
"not a redirect",
|
||||
"/",
|
||||
true,
|
||||
&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
|
||||
http.StatusOK,
|
||||
"",
|
||||
},
|
||||
{
|
||||
"missing signature",
|
||||
mustParseURL("/?pomerium_redirect_uri=http://example.com"),
|
||||
http.MethodGet,
|
||||
"signed redirect",
|
||||
"/?pomerium_redirect_uri=http://example.com",
|
||||
true,
|
||||
&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
|
||||
http.StatusBadRequest,
|
||||
"",
|
||||
http.StatusFound,
|
||||
},
|
||||
{
|
||||
"bad signature",
|
||||
urlutil.NewSignedURL([]byte("BAD KEY"), mustParseURL("/?pomerium_redirect_uri=http://example.com")).Sign(),
|
||||
http.MethodGet,
|
||||
"invalid redirect",
|
||||
"/?pomerium_redirect_uri=http://example.com",
|
||||
false,
|
||||
&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
|
||||
http.StatusBadRequest,
|
||||
"",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
|
@ -513,14 +514,19 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
AuthenticateURLString: "https://authenticate.localhost.pomerium.io",
|
||||
SharedKey: "SHARED KEY",
|
||||
})
|
||||
f := new(stubFlow)
|
||||
if !tt.validSignature {
|
||||
f.verifySignatureErr = errors.New("bad signature")
|
||||
}
|
||||
a := &Authenticate{
|
||||
options: o,
|
||||
state: atomicutil.NewValue(&authenticateState{
|
||||
sessionStore: tt.sessionStore,
|
||||
sharedEncoder: signer,
|
||||
flow: f,
|
||||
}),
|
||||
}
|
||||
r := httptest.NewRequest(tt.method, tt.url.String(), nil)
|
||||
r := httptest.NewRequest(http.MethodGet, tt.url, nil)
|
||||
state, err := tt.sessionStore.LoadSession(r)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -535,10 +541,6 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
if status := w.Code; status != tt.wantCode {
|
||||
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
|
||||
}
|
||||
body := w.Body.String()
|
||||
if !strings.Contains(body, tt.wantBody) {
|
||||
t.Errorf("Unexpected body, contains: %s, got: %s", tt.wantBody, body)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -565,3 +567,42 @@ func mustParseURL(rawurl string) *url.URL {
|
|||
}
|
||||
return u
|
||||
}
|
||||
|
||||
// stubFlow is a stub implementation of the flow interface.
|
||||
type stubFlow struct {
|
||||
verifySignatureErr error
|
||||
}
|
||||
|
||||
func (f *stubFlow) VerifyAuthenticateSignature(*http.Request) error {
|
||||
return f.verifySignatureErr
|
||||
}
|
||||
|
||||
func (*stubFlow) SignIn(http.ResponseWriter, *http.Request, *sessions.State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*stubFlow) PersistSession(
|
||||
context.Context, http.ResponseWriter, *sessions.State, identity.SessionClaims, *oauth2.Token,
|
||||
) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*stubFlow) VerifySession(context.Context, *http.Request, *sessions.State) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*stubFlow) RevokeSession(
|
||||
context.Context, *http.Request, identity.Authenticator, *sessions.State,
|
||||
) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (*stubFlow) GetUserInfoData(*http.Request, *sessions.State) handlers.UserInfoData {
|
||||
return handlers.UserInfoData{}
|
||||
}
|
||||
|
||||
func (*stubFlow) LogAuthenticateEvent(*http.Request) {}
|
||||
|
||||
func (*stubFlow) GetIdentityProviderIDForURLValues(url.Values) string {
|
||||
return ""
|
||||
}
|
||||
|
|
|
@ -1,112 +0,0 @@
|
|||
package authenticate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/oauth2"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||
)
|
||||
|
||||
var cookieChunker = httputil.NewCookieChunker()
|
||||
|
||||
func (a *Authenticate) buildIdentityProfile(
|
||||
r *http.Request,
|
||||
claims identity.SessionClaims,
|
||||
oauthToken *oauth2.Token,
|
||||
) (*identitypb.Profile, error) {
|
||||
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
|
||||
|
||||
rawIDToken := []byte(claims.RawIDToken)
|
||||
rawOAuthToken, err := json.Marshal(oauthToken)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error marshaling oauth token: %w", err)
|
||||
}
|
||||
rawClaims, err := structpb.NewStruct(claims.Claims)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error creating claims struct: %w", err)
|
||||
}
|
||||
|
||||
return &identitypb.Profile{
|
||||
ProviderId: idpID,
|
||||
IdToken: rawIDToken,
|
||||
OauthToken: rawOAuthToken,
|
||||
Claims: rawClaims,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (a *Authenticate) loadIdentityProfile(r *http.Request, aead cipher.AEAD) (*identitypb.Profile, error) {
|
||||
cookie, err := cookieChunker.LoadCookie(r, urlutil.QueryIdentityProfile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error loading identity profile cookie: %w", err)
|
||||
}
|
||||
|
||||
encrypted, err := base64.RawURLEncoding.DecodeString(cookie.Value)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error decoding identity profile cookie: %w", err)
|
||||
}
|
||||
|
||||
decrypted, err := cryptutil.Decrypt(aead, encrypted, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error decrypting identity profile cookie: %w", err)
|
||||
}
|
||||
|
||||
var profile identitypb.Profile
|
||||
err = protojson.Unmarshal(decrypted, &profile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authenticate: error unmarshaling identity profile cookie: %w", err)
|
||||
}
|
||||
return &profile, nil
|
||||
}
|
||||
|
||||
func (a *Authenticate) storeIdentityProfile(w http.ResponseWriter, aead cipher.AEAD, profile *identitypb.Profile) error {
|
||||
options := a.options.Load()
|
||||
|
||||
decrypted, err := protojson.Marshal(profile)
|
||||
if err != nil {
|
||||
// this shouldn't happen
|
||||
panic(fmt.Errorf("error marshaling message: %w", err))
|
||||
}
|
||||
encrypted := cryptutil.Encrypt(aead, decrypted, nil)
|
||||
cookie := options.NewCookie()
|
||||
cookie.Name = urlutil.QueryIdentityProfile
|
||||
cookie.Value = base64.RawURLEncoding.EncodeToString(encrypted)
|
||||
cookie.Path = "/"
|
||||
return cookieChunker.SetCookie(w, cookie)
|
||||
}
|
||||
|
||||
func (a *Authenticate) validateIdentityProfile(ctx context.Context, profile *identitypb.Profile) error {
|
||||
authenticator, err := a.cfg.getIdentityProvider(a.options.Load(), profile.GetProviderId())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
oauthToken := new(oauth2.Token)
|
||||
err = json.Unmarshal(profile.GetOauthToken(), oauthToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid oauth token in profile: %w", err)
|
||||
}
|
||||
|
||||
if !oauthToken.Valid() {
|
||||
return fmt.Errorf("invalid oauth token in profile")
|
||||
}
|
||||
|
||||
var claims identity.SessionClaims
|
||||
err = authenticator.UpdateUserInfo(ctx, oauthToken, &claims)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error updating user info from oauth token: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -4,7 +4,6 @@ import (
|
|||
"net/http"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
|
@ -13,7 +12,7 @@ import (
|
|||
func (a *Authenticate) requireValidSignatureOnRedirect(next httputil.HandlerFunc) http.Handler {
|
||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
if r.FormValue(urlutil.QueryRedirectURI) != "" || r.FormValue(urlutil.QueryHmacSignature) != "" {
|
||||
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
|
||||
err := a.state.Load().flow.VerifyAuthenticateSignature(r)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
@ -25,26 +24,10 @@ func (a *Authenticate) requireValidSignatureOnRedirect(next httputil.HandlerFunc
|
|||
// requireValidSignature validates the pomerium_signature.
|
||||
func (a *Authenticate) requireValidSignature(next httputil.HandlerFunc) http.Handler {
|
||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
|
||||
err := a.state.Load().flow.VerifyAuthenticateSignature(r)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return next(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (a *Authenticate) getExternalRequest(r *http.Request) *http.Request {
|
||||
options := a.options.Load()
|
||||
|
||||
externalURL, err := options.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
internalURL, err := options.GetInternalAuthenticateURL()
|
||||
if err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
return urlutil.GetExternalRequest(internalURL, externalURL, r)
|
||||
}
|
||||
|
|
|
@ -1,23 +1,41 @@
|
|||
package authenticate
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/go-jose/go-jose/v3"
|
||||
"golang.org/x/oauth2"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||
"github.com/pomerium/pomerium/internal/encoding"
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/handlers"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/hpke"
|
||||
)
|
||||
|
||||
type flow interface {
|
||||
VerifyAuthenticateSignature(r *http.Request) error
|
||||
SignIn(w http.ResponseWriter, r *http.Request, sessionState *sessions.State) error
|
||||
PersistSession(ctx context.Context, w http.ResponseWriter, sessionState *sessions.State, claims identity.SessionClaims, accessToken *oauth2.Token) error
|
||||
VerifySession(ctx context.Context, r *http.Request, sessionState *sessions.State) error
|
||||
RevokeSession(ctx context.Context, r *http.Request, authenticator identity.Authenticator, sessionState *sessions.State) string
|
||||
GetUserInfoData(r *http.Request, sessionState *sessions.State) handlers.UserInfoData
|
||||
LogAuthenticateEvent(r *http.Request)
|
||||
GetIdentityProviderIDForURLValues(url.Values) string
|
||||
}
|
||||
|
||||
type authenticateState struct {
|
||||
flow flow
|
||||
|
||||
redirectURL *url.URL
|
||||
// sharedEncoder is the encoder to use to serialize data to be consumed
|
||||
// by other services
|
||||
|
@ -34,8 +52,7 @@ type authenticateState struct {
|
|||
sessionStore sessions.SessionStore
|
||||
// sessionLoaders are a collection of session loaders to attempt to pull
|
||||
// a user's session state from
|
||||
sessionLoader sessions.SessionLoader
|
||||
hpkePrivateKey *hpke.PrivateKey
|
||||
sessionLoader sessions.SessionLoader
|
||||
|
||||
jwk *jose.JSONWebKeySet
|
||||
}
|
||||
|
@ -46,7 +63,9 @@ func newAuthenticateState() *authenticateState {
|
|||
}
|
||||
}
|
||||
|
||||
func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, error) {
|
||||
func newAuthenticateStateFromConfig(
|
||||
cfg *config.Config, authenticateConfig *authenticateConfig,
|
||||
) (*authenticateState, error) {
|
||||
err := ValidateOptions(cfg.Options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -125,12 +144,16 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
|
|||
}
|
||||
}
|
||||
|
||||
sharedKey, err := cfg.Options.GetSharedKey()
|
||||
state.flow, err = authenticateflow.NewStateless(
|
||||
cfg,
|
||||
cookieStore,
|
||||
authenticateConfig.getIdentityProvider,
|
||||
authenticateConfig.profileTrimFn,
|
||||
authenticateConfig.authEventFn,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
state.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey)
|
||||
|
||||
return state, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue