authenticateflow: move stateless flow logic (#4820)

Consolidate all logic specific to the stateless authenticate flow into a
a new Stateless type in a new package internal/authenticateflow. This is
in preparation for adding a new Stateful type implementing the older
stateful authenticate flow (from Pomerium v0.20 and previous).

This change is intended as a pure refactoring of existing logic, with no
changes in functionality.
This commit is contained in:
Kenneth Jenkins 2023-12-06 16:55:57 -08:00 committed by GitHub
parent 3b2bdd059a
commit b7896b3153
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 823 additions and 461 deletions

View file

@ -45,15 +45,16 @@ type Authenticate struct {
// New validates and creates a new authenticate service from a set of Options.
func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
authenticateConfig := getAuthenticateConfig(options...)
a := &Authenticate{
cfg: getAuthenticateConfig(options...),
cfg: authenticateConfig,
options: config.NewAtomicOptions(),
state: atomicutil.NewValue(newAuthenticateState()),
}
a.options.Store(cfg.Options)
state, err := newAuthenticateStateFromConfig(cfg)
state, err := newAuthenticateStateFromConfig(cfg, authenticateConfig)
if err != nil {
return nil, err
}
@ -69,7 +70,7 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
}
a.options.Store(cfg.Options)
if state, err := newAuthenticateStateFromConfig(cfg); err != nil {
if state, err := newAuthenticateStateFromConfig(cfg, a.cfg); err != nil {
log.Error(ctx).Err(err).Msg("authenticate: failed to update state")
} else {
a.state.Store(state)

View file

@ -2,6 +2,7 @@ package authenticate
import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/authenticateflow"
"github.com/pomerium/pomerium/internal/identity"
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
)
@ -9,7 +10,7 @@ import (
type authenticateConfig struct {
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)
profileTrimFn func(*identitypb.Profile)
authEventFn AuthEventFn
authEventFn authenticateflow.AuthEventFn
}
// An Option customizes the Authenticate config.
@ -39,7 +40,7 @@ func WithProfileTrimFn(profileTrimFn func(*identitypb.Profile)) Option {
}
// WithOnAuthenticationEventHook sets the authEventFn function in the config
func WithOnAuthenticationEventHook(fn AuthEventFn) Option {
func WithOnAuthenticationEventHook(fn authenticateflow.AuthEventFn) Option {
return func(cfg *authenticateConfig) {
cfg.authEventFn = fn
}

View file

@ -1,101 +0,0 @@
package authenticate
import (
"context"
"net/http"
"net/url"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/hpke"
)
// AuthEventKind is the type of an authentication event
type AuthEventKind string
const (
// AuthEventSignInRequest is an authentication event for a sign in request before IdP redirect
AuthEventSignInRequest AuthEventKind = "sign_in_request"
// AuthEventSignInComplete is an authentication event for a sign in request after IdP redirect
AuthEventSignInComplete AuthEventKind = "sign_in_complete"
)
// AuthEvent is a log event for an authentication event
type AuthEvent struct {
// Event is the type of authentication event
Event AuthEventKind
// IP is the IP address of the client
IP string
// Version is the version of the Pomerium client
Version string
// RequestUUID is the UUID of the request
RequestUUID string
// PubKey is the public key of the client
PubKey string
// UID is the IdP user ID of the user
UID *string
// Email is the email of the user
Email *string
// Domain is the domain of the request (for sign in complete events)
Domain *string
}
// AuthEventFn is a function that handles an authentication event
type AuthEventFn func(context.Context, AuthEvent)
func (a *Authenticate) logAuthenticateEvent(r *http.Request, profile *identity.Profile) {
if a.cfg.authEventFn == nil {
return
}
state := a.state.Load()
ctx := r.Context()
pub, params, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form)
if err != nil {
log.Warn(ctx).Err(err).Msg("log authenticate event: failed to decrypt request params")
}
evt := AuthEvent{
IP: httputil.GetClientIP(r),
Version: params.Get(urlutil.QueryVersion),
RequestUUID: params.Get(urlutil.QueryRequestUUID),
PubKey: pub.String(),
}
if uid := getUserClaim(profile, "sub"); uid != nil {
evt.UID = uid
}
if email := getUserClaim(profile, "email"); email != nil {
evt.Email = email
}
if evt.UID != nil {
evt.Event = AuthEventSignInComplete
} else {
evt.Event = AuthEventSignInRequest
}
if redirectURL, err := url.Parse(params.Get(urlutil.QueryRedirectURI)); err == nil {
domain := redirectURL.Hostname()
evt.Domain = &domain
}
a.cfg.authEventFn(ctx, evt)
}
func getUserClaim(profile *identity.Profile, field string) *string {
if profile == nil {
return nil
}
if profile.Claims == nil {
return nil
}
val, ok := profile.Claims.Fields[field]
if !ok || val == nil {
return nil
}
txt := val.GetStringValue()
return &txt
}

View file

@ -3,7 +3,6 @@ package authenticate
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net/http"
@ -14,9 +13,9 @@ import (
"github.com/google/uuid"
"github.com/gorilla/mux"
"github.com/rs/cors"
"golang.org/x/oauth2"
"github.com/pomerium/csrf"
"github.com/pomerium/pomerium/internal/authenticateflow"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity"
@ -27,7 +26,6 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/hpke"
)
// Handler returns the authenticate service's handler chain.
@ -76,7 +74,7 @@ func (a *Authenticate) mountDashboard(r *mux.Router) {
c := cors.New(cors.Options{
AllowOriginRequestFunc: func(r *http.Request, _ string) bool {
state := a.state.Load()
err := middleware.ValidateRequestURL(a.getExternalRequest(r), state.sharedKey)
err := state.flow.VerifyAuthenticateSignature(r)
if err != nil {
log.FromRequest(r).Info().Err(err).Msg("authenticate: origin blocked")
}
@ -139,21 +137,11 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
return a.reauthenticateOrFail(w, r, err)
}
profile, err := a.loadIdentityProfile(r, state.cookieCipher)
if err != nil {
if err := state.flow.VerifySession(ctx, r, sessionState); err != nil {
log.FromRequest(r).Info().
Err(err).
Str("idp_id", idpID).
Msg("authenticate: identity profile load error")
return a.reauthenticateOrFail(w, r, err)
}
err = a.validateIdentityProfile(ctx, profile)
if err != nil {
log.FromRequest(r).Info().
Err(err).
Str("idp_id", idpID).
Msg("authenticate: invalid identity profile")
Msg("authenticate: couldn't verify session")
return a.reauthenticateOrFail(w, r, err)
}
@ -176,62 +164,20 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
state := a.state.Load()
if err := r.ParseForm(); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
proxyPublicKey, requestParams, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form)
if err != nil {
return err
}
idpID := requestParams.Get(urlutil.QueryIdentityProviderID)
s, err := a.getSessionFromCtx(ctx)
if err != nil {
state.sessionStore.ClearSession(w, r)
return err
}
// start over if this is a different identity provider
if s == nil || s.IdentityProviderID != idpID {
s = sessions.NewState(idpID)
}
// re-persist the session, useful when session was evicted from session
if err := state.sessionStore.SaveSession(w, r, s); err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
profile, err := a.loadIdentityProfile(r, state.cookieCipher)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
if a.cfg.profileTrimFn != nil {
a.cfg.profileTrimFn(profile)
}
a.logAuthenticateEvent(r, profile)
encryptURLValues := hpke.EncryptURLValuesV1
if hpke.IsEncryptedURLV2(r.Form) {
encryptURLValues = hpke.EncryptURLValuesV2
}
redirectTo, err := urlutil.CallbackURL(state.hpkePrivateKey, proxyPublicKey, requestParams, profile, encryptURLValues)
if err != nil {
return httputil.NewError(http.StatusInternalServerError, err)
}
httputil.Redirect(w, r, redirectTo, http.StatusFound)
return nil
return state.flow.SignIn(w, r, s)
}
// SignOut signs the user out and attempts to revoke the user's identity session
// Handles both GET and POST.
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
// check for an HMAC'd URL. If none is found, show a confirmation page.
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
err := a.state.Load().flow.VerifyAuthenticateSignature(r)
if err != nil {
authenticateURL, err := a.options.Load().GetAuthenticateURL()
if err != nil {
@ -319,7 +265,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
return err
}
a.logAuthenticateEvent(r, nil)
state.flow.LogAuthenticateEvent(r)
state.sessionStore.ClearSession(w, r)
redirectURL := state.redirectURL.ResolveReference(r.URL)
@ -418,7 +364,7 @@ Or contact your administrator.
`, redirectURL.String(), redirectURL.String()))
}
idpID := a.getIdentityProviderIDForURLValues(redirectURL.Query())
idpID := state.flow.GetIdentityProviderIDForURLValues(redirectURL.Query())
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
if err != nil {
@ -445,13 +391,9 @@ Or contact your administrator.
newState.Audience = append(newState.Audience, nextRedirectURL.Hostname())
}
// save the session and access token to the databroker
profile, err := a.buildIdentityProfile(r, claims, accessToken)
if err != nil {
return nil, httputil.NewError(http.StatusInternalServerError, err)
}
if err := a.storeIdentityProfile(w, state.cookieCipher, profile); err != nil {
log.Error(r.Context()).Err(err).Msg("failed to store identity profile")
// save the session and access token to the databroker/cookie store
if err := state.flow.PersistSession(ctx, w, &newState, claims, accessToken); err != nil {
return nil, fmt.Errorf("failed saving new session: %w", err)
}
// ... and the user state to local storage.
@ -478,11 +420,12 @@ func (a *Authenticate) getSessionFromCtx(ctx context.Context) (*sessions.State,
func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error {
ctx, span := trace.StartSpan(r.Context(), "authenticate.userInfo")
defer span.End()
r = r.WithContext(ctx)
r = a.getExternalRequest(r)
options := a.options.Load()
r = r.WithContext(ctx)
r = authenticateflow.GetExternalAuthenticateRequest(r, options)
// if we came in with a redirect URI, save it to a cookie so it doesn't expire with the HMAC
if redirectURI := r.FormValue(urlutil.QueryRedirectURI); redirectURI != "" {
u := urlutil.GetAbsoluteURL(r)
@ -509,14 +452,9 @@ func (a *Authenticate) getUserInfoData(r *http.Request) handlers.UserInfoData {
s.ID = uuid.New().String()
}
profile, _ := a.loadIdentityProfile(r, state.cookieCipher)
data := handlers.UserInfoData{
CSRFToken: csrf.Token(r),
Profile: profile,
BrandingOptions: a.options.Load().BrandingOptions,
}
data := state.flow.GetUserInfoData(r, s)
data.CSRFToken = csrf.Token(r)
data.BrandingOptions = a.options.Load().BrandingOptions
return data
}
@ -537,18 +475,7 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,
return ""
}
profile, err := a.loadIdentityProfile(r, a.state.Load().cookieCipher)
if err != nil {
return ""
}
oauthToken := new(oauth2.Token)
_ = json.Unmarshal(profile.GetOauthToken(), oauthToken)
if err := authenticator.Revoke(ctx, oauthToken); err != nil {
log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token")
}
return string(profile.GetIdToken())
return state.flow.RevokeSession(ctx, r, authenticator, nil)
}
// Callback handles the result of a successful call to the authenticate service
@ -603,19 +530,5 @@ func (a *Authenticate) getIdentityProviderIDForRequest(r *http.Request) string {
if err := r.ParseForm(); err != nil {
return ""
}
return a.getIdentityProviderIDForURLValues(r.Form)
}
func (a *Authenticate) getIdentityProviderIDForURLValues(vs url.Values) string {
state := a.state.Load()
idpID := ""
if _, requestParams, err := hpke.DecryptURLValues(state.hpkePrivateKey, vs); err == nil {
if idpID == "" {
idpID = requestParams.Get(urlutil.QueryIdentityProviderID)
}
}
if idpID == "" {
idpID = vs.Get(urlutil.QueryIdentityProviderID)
}
return idpID
return a.state.Load().flow.GetIdentityProviderIDForURLValues(r.Form)
}

View file

@ -8,7 +8,6 @@ import (
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
@ -24,6 +23,7 @@ import (
"github.com/pomerium/pomerium/internal/atomicutil"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/identity/oidc"
@ -40,6 +40,7 @@ func testAuthenticate() *Authenticate {
auth.state = atomicutil.NewValue(&authenticateState{
redirectURL: redirectURL,
cookieSecret: cryptutil.NewKey(),
flow: new(stubFlow),
})
auth.options = config.NewAtomicOptions()
auth.options.Store(&config.Options{
@ -205,6 +206,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
state: atomicutil.NewValue(&authenticateState{
sessionStore: tt.sessionStore,
sharedEncoder: mock.Encoder{},
flow: new(stubFlow),
}),
options: config.NewAtomicOptions(),
}
@ -301,6 +303,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
redirectURL: authURL,
sessionStore: tt.session,
cookieCipher: aead,
flow: new(stubFlow),
}),
options: config.NewAtomicOptions(),
}
@ -414,6 +417,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
sessionStore: tt.session,
cookieCipher: aead,
sharedEncoder: signer,
flow: new(stubFlow),
}),
options: config.NewAtomicOptions(),
}
@ -452,6 +456,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
var a Authenticate
a.state = atomicutil.NewValue(&authenticateState{
cookieSecret: cryptutil.NewKey(),
flow: new(stubFlow),
})
a.options = config.NewAtomicOptions()
a.options.Store(&config.Options{
@ -467,36 +472,32 @@ func TestAuthenticate_userInfo(t *testing.T) {
now := time.Now()
tests := []struct {
name string
url *url.URL
method string
sessionStore sessions.SessionStore
wantCode int
wantBody string
name string
url string
validSignature bool
sessionStore sessions.SessionStore
wantCode int
}{
{
"good",
mustParseURL("/"),
http.MethodGet,
"not a redirect",
"/",
true,
&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
http.StatusOK,
"",
},
{
"missing signature",
mustParseURL("/?pomerium_redirect_uri=http://example.com"),
http.MethodGet,
"signed redirect",
"/?pomerium_redirect_uri=http://example.com",
true,
&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
http.StatusBadRequest,
"",
http.StatusFound,
},
{
"bad signature",
urlutil.NewSignedURL([]byte("BAD KEY"), mustParseURL("/?pomerium_redirect_uri=http://example.com")).Sign(),
http.MethodGet,
"invalid redirect",
"/?pomerium_redirect_uri=http://example.com",
false,
&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
http.StatusBadRequest,
"",
},
}
for _, tt := range tests {
@ -513,14 +514,19 @@ func TestAuthenticate_userInfo(t *testing.T) {
AuthenticateURLString: "https://authenticate.localhost.pomerium.io",
SharedKey: "SHARED KEY",
})
f := new(stubFlow)
if !tt.validSignature {
f.verifySignatureErr = errors.New("bad signature")
}
a := &Authenticate{
options: o,
state: atomicutil.NewValue(&authenticateState{
sessionStore: tt.sessionStore,
sharedEncoder: signer,
flow: f,
}),
}
r := httptest.NewRequest(tt.method, tt.url.String(), nil)
r := httptest.NewRequest(http.MethodGet, tt.url, nil)
state, err := tt.sessionStore.LoadSession(r)
if err != nil {
t.Fatal(err)
@ -535,10 +541,6 @@ func TestAuthenticate_userInfo(t *testing.T) {
if status := w.Code; status != tt.wantCode {
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
}
body := w.Body.String()
if !strings.Contains(body, tt.wantBody) {
t.Errorf("Unexpected body, contains: %s, got: %s", tt.wantBody, body)
}
})
}
}
@ -565,3 +567,42 @@ func mustParseURL(rawurl string) *url.URL {
}
return u
}
// stubFlow is a stub implementation of the flow interface.
type stubFlow struct {
verifySignatureErr error
}
func (f *stubFlow) VerifyAuthenticateSignature(*http.Request) error {
return f.verifySignatureErr
}
func (*stubFlow) SignIn(http.ResponseWriter, *http.Request, *sessions.State) error {
return nil
}
func (*stubFlow) PersistSession(
context.Context, http.ResponseWriter, *sessions.State, identity.SessionClaims, *oauth2.Token,
) error {
return nil
}
func (*stubFlow) VerifySession(context.Context, *http.Request, *sessions.State) error {
return nil
}
func (*stubFlow) RevokeSession(
context.Context, *http.Request, identity.Authenticator, *sessions.State,
) string {
return ""
}
func (*stubFlow) GetUserInfoData(*http.Request, *sessions.State) handlers.UserInfoData {
return handlers.UserInfoData{}
}
func (*stubFlow) LogAuthenticateEvent(*http.Request) {}
func (*stubFlow) GetIdentityProviderIDForURLValues(url.Values) string {
return ""
}

View file

@ -1,112 +0,0 @@
package authenticate
import (
"context"
"crypto/cipher"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"golang.org/x/oauth2"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
)
var cookieChunker = httputil.NewCookieChunker()
func (a *Authenticate) buildIdentityProfile(
r *http.Request,
claims identity.SessionClaims,
oauthToken *oauth2.Token,
) (*identitypb.Profile, error) {
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
rawIDToken := []byte(claims.RawIDToken)
rawOAuthToken, err := json.Marshal(oauthToken)
if err != nil {
return nil, fmt.Errorf("authenticate: error marshaling oauth token: %w", err)
}
rawClaims, err := structpb.NewStruct(claims.Claims)
if err != nil {
return nil, fmt.Errorf("authenticate: error creating claims struct: %w", err)
}
return &identitypb.Profile{
ProviderId: idpID,
IdToken: rawIDToken,
OauthToken: rawOAuthToken,
Claims: rawClaims,
}, nil
}
func (a *Authenticate) loadIdentityProfile(r *http.Request, aead cipher.AEAD) (*identitypb.Profile, error) {
cookie, err := cookieChunker.LoadCookie(r, urlutil.QueryIdentityProfile)
if err != nil {
return nil, fmt.Errorf("authenticate: error loading identity profile cookie: %w", err)
}
encrypted, err := base64.RawURLEncoding.DecodeString(cookie.Value)
if err != nil {
return nil, fmt.Errorf("authenticate: error decoding identity profile cookie: %w", err)
}
decrypted, err := cryptutil.Decrypt(aead, encrypted, nil)
if err != nil {
return nil, fmt.Errorf("authenticate: error decrypting identity profile cookie: %w", err)
}
var profile identitypb.Profile
err = protojson.Unmarshal(decrypted, &profile)
if err != nil {
return nil, fmt.Errorf("authenticate: error unmarshaling identity profile cookie: %w", err)
}
return &profile, nil
}
func (a *Authenticate) storeIdentityProfile(w http.ResponseWriter, aead cipher.AEAD, profile *identitypb.Profile) error {
options := a.options.Load()
decrypted, err := protojson.Marshal(profile)
if err != nil {
// this shouldn't happen
panic(fmt.Errorf("error marshaling message: %w", err))
}
encrypted := cryptutil.Encrypt(aead, decrypted, nil)
cookie := options.NewCookie()
cookie.Name = urlutil.QueryIdentityProfile
cookie.Value = base64.RawURLEncoding.EncodeToString(encrypted)
cookie.Path = "/"
return cookieChunker.SetCookie(w, cookie)
}
func (a *Authenticate) validateIdentityProfile(ctx context.Context, profile *identitypb.Profile) error {
authenticator, err := a.cfg.getIdentityProvider(a.options.Load(), profile.GetProviderId())
if err != nil {
return err
}
oauthToken := new(oauth2.Token)
err = json.Unmarshal(profile.GetOauthToken(), oauthToken)
if err != nil {
return fmt.Errorf("invalid oauth token in profile: %w", err)
}
if !oauthToken.Valid() {
return fmt.Errorf("invalid oauth token in profile")
}
var claims identity.SessionClaims
err = authenticator.UpdateUserInfo(ctx, oauthToken, &claims)
if err != nil {
return fmt.Errorf("error updating user info from oauth token: %w", err)
}
return nil
}

View file

@ -4,7 +4,6 @@ import (
"net/http"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/middleware"
"github.com/pomerium/pomerium/internal/urlutil"
)
@ -13,7 +12,7 @@ import (
func (a *Authenticate) requireValidSignatureOnRedirect(next httputil.HandlerFunc) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
if r.FormValue(urlutil.QueryRedirectURI) != "" || r.FormValue(urlutil.QueryHmacSignature) != "" {
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
err := a.state.Load().flow.VerifyAuthenticateSignature(r)
if err != nil {
return httputil.NewError(http.StatusBadRequest, err)
}
@ -25,26 +24,10 @@ func (a *Authenticate) requireValidSignatureOnRedirect(next httputil.HandlerFunc
// requireValidSignature validates the pomerium_signature.
func (a *Authenticate) requireValidSignature(next httputil.HandlerFunc) http.Handler {
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
err := a.state.Load().flow.VerifyAuthenticateSignature(r)
if err != nil {
return err
}
return next(w, r)
})
}
func (a *Authenticate) getExternalRequest(r *http.Request) *http.Request {
options := a.options.Load()
externalURL, err := options.GetAuthenticateURL()
if err != nil {
return r
}
internalURL, err := options.GetInternalAuthenticateURL()
if err != nil {
return r
}
return urlutil.GetExternalRequest(internalURL, externalURL, r)
}

View file

@ -1,23 +1,41 @@
package authenticate
import (
"context"
"crypto/cipher"
"fmt"
"net/http"
"net/url"
"github.com/go-jose/go-jose/v3"
"golang.org/x/oauth2"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/authenticateflow"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/jws"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/sessions/cookie"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/hpke"
)
type flow interface {
VerifyAuthenticateSignature(r *http.Request) error
SignIn(w http.ResponseWriter, r *http.Request, sessionState *sessions.State) error
PersistSession(ctx context.Context, w http.ResponseWriter, sessionState *sessions.State, claims identity.SessionClaims, accessToken *oauth2.Token) error
VerifySession(ctx context.Context, r *http.Request, sessionState *sessions.State) error
RevokeSession(ctx context.Context, r *http.Request, authenticator identity.Authenticator, sessionState *sessions.State) string
GetUserInfoData(r *http.Request, sessionState *sessions.State) handlers.UserInfoData
LogAuthenticateEvent(r *http.Request)
GetIdentityProviderIDForURLValues(url.Values) string
}
type authenticateState struct {
flow flow
redirectURL *url.URL
// sharedEncoder is the encoder to use to serialize data to be consumed
// by other services
@ -34,8 +52,7 @@ type authenticateState struct {
sessionStore sessions.SessionStore
// sessionLoaders are a collection of session loaders to attempt to pull
// a user's session state from
sessionLoader sessions.SessionLoader
hpkePrivateKey *hpke.PrivateKey
sessionLoader sessions.SessionLoader
jwk *jose.JSONWebKeySet
}
@ -46,7 +63,9 @@ func newAuthenticateState() *authenticateState {
}
}
func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, error) {
func newAuthenticateStateFromConfig(
cfg *config.Config, authenticateConfig *authenticateConfig,
) (*authenticateState, error) {
err := ValidateOptions(cfg.Options)
if err != nil {
return nil, err
@ -125,12 +144,16 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
}
}
sharedKey, err := cfg.Options.GetSharedKey()
state.flow, err = authenticateflow.NewStateless(
cfg,
cookieStore,
authenticateConfig.getIdentityProvider,
authenticateConfig.profileTrimFn,
authenticateConfig.authEventFn,
)
if err != nil {
return nil, err
}
state.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey)
return state, nil
}