mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-25 15:07:33 +02:00
authenticateflow: move stateless flow logic (#4820)
Consolidate all logic specific to the stateless authenticate flow into a a new Stateless type in a new package internal/authenticateflow. This is in preparation for adding a new Stateful type implementing the older stateful authenticate flow (from Pomerium v0.20 and previous). This change is intended as a pure refactoring of existing logic, with no changes in functionality.
This commit is contained in:
parent
3b2bdd059a
commit
b7896b3153
18 changed files with 823 additions and 461 deletions
|
@ -45,15 +45,16 @@ type Authenticate struct {
|
||||||
|
|
||||||
// New validates and creates a new authenticate service from a set of Options.
|
// New validates and creates a new authenticate service from a set of Options.
|
||||||
func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
|
func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
|
||||||
|
authenticateConfig := getAuthenticateConfig(options...)
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
cfg: getAuthenticateConfig(options...),
|
cfg: authenticateConfig,
|
||||||
options: config.NewAtomicOptions(),
|
options: config.NewAtomicOptions(),
|
||||||
state: atomicutil.NewValue(newAuthenticateState()),
|
state: atomicutil.NewValue(newAuthenticateState()),
|
||||||
}
|
}
|
||||||
|
|
||||||
a.options.Store(cfg.Options)
|
a.options.Store(cfg.Options)
|
||||||
|
|
||||||
state, err := newAuthenticateStateFromConfig(cfg)
|
state, err := newAuthenticateStateFromConfig(cfg, authenticateConfig)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -69,7 +70,7 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
}
|
}
|
||||||
|
|
||||||
a.options.Store(cfg.Options)
|
a.options.Store(cfg.Options)
|
||||||
if state, err := newAuthenticateStateFromConfig(cfg); err != nil {
|
if state, err := newAuthenticateStateFromConfig(cfg, a.cfg); err != nil {
|
||||||
log.Error(ctx).Err(err).Msg("authenticate: failed to update state")
|
log.Error(ctx).Err(err).Msg("authenticate: failed to update state")
|
||||||
} else {
|
} else {
|
||||||
a.state.Store(state)
|
a.state.Store(state)
|
||||||
|
|
|
@ -2,6 +2,7 @@ package authenticate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||||
)
|
)
|
||||||
|
@ -9,7 +10,7 @@ import (
|
||||||
type authenticateConfig struct {
|
type authenticateConfig struct {
|
||||||
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)
|
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)
|
||||||
profileTrimFn func(*identitypb.Profile)
|
profileTrimFn func(*identitypb.Profile)
|
||||||
authEventFn AuthEventFn
|
authEventFn authenticateflow.AuthEventFn
|
||||||
}
|
}
|
||||||
|
|
||||||
// An Option customizes the Authenticate config.
|
// An Option customizes the Authenticate config.
|
||||||
|
@ -39,7 +40,7 @@ func WithProfileTrimFn(profileTrimFn func(*identitypb.Profile)) Option {
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithOnAuthenticationEventHook sets the authEventFn function in the config
|
// WithOnAuthenticationEventHook sets the authEventFn function in the config
|
||||||
func WithOnAuthenticationEventHook(fn AuthEventFn) Option {
|
func WithOnAuthenticationEventHook(fn authenticateflow.AuthEventFn) Option {
|
||||||
return func(cfg *authenticateConfig) {
|
return func(cfg *authenticateConfig) {
|
||||||
cfg.authEventFn = fn
|
cfg.authEventFn = fn
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,6 @@ package authenticate
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
@ -14,9 +13,9 @@ import (
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/rs/cors"
|
"github.com/rs/cors"
|
||||||
"golang.org/x/oauth2"
|
|
||||||
|
|
||||||
"github.com/pomerium/csrf"
|
"github.com/pomerium/csrf"
|
||||||
|
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||||
"github.com/pomerium/pomerium/internal/handlers"
|
"github.com/pomerium/pomerium/internal/handlers"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
|
@ -27,7 +26,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/hpke"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Handler returns the authenticate service's handler chain.
|
// Handler returns the authenticate service's handler chain.
|
||||||
|
@ -76,7 +74,7 @@ func (a *Authenticate) mountDashboard(r *mux.Router) {
|
||||||
c := cors.New(cors.Options{
|
c := cors.New(cors.Options{
|
||||||
AllowOriginRequestFunc: func(r *http.Request, _ string) bool {
|
AllowOriginRequestFunc: func(r *http.Request, _ string) bool {
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
err := middleware.ValidateRequestURL(a.getExternalRequest(r), state.sharedKey)
|
err := state.flow.VerifyAuthenticateSignature(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.FromRequest(r).Info().Err(err).Msg("authenticate: origin blocked")
|
log.FromRequest(r).Info().Err(err).Msg("authenticate: origin blocked")
|
||||||
}
|
}
|
||||||
|
@ -139,21 +137,11 @@ func (a *Authenticate) VerifySession(next http.Handler) http.Handler {
|
||||||
return a.reauthenticateOrFail(w, r, err)
|
return a.reauthenticateOrFail(w, r, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
profile, err := a.loadIdentityProfile(r, state.cookieCipher)
|
if err := state.flow.VerifySession(ctx, r, sessionState); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.FromRequest(r).Info().
|
log.FromRequest(r).Info().
|
||||||
Err(err).
|
Err(err).
|
||||||
Str("idp_id", idpID).
|
Str("idp_id", idpID).
|
||||||
Msg("authenticate: identity profile load error")
|
Msg("authenticate: couldn't verify session")
|
||||||
return a.reauthenticateOrFail(w, r, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = a.validateIdentityProfile(ctx, profile)
|
|
||||||
if err != nil {
|
|
||||||
log.FromRequest(r).Info().
|
|
||||||
Err(err).
|
|
||||||
Str("idp_id", idpID).
|
|
||||||
Msg("authenticate: invalid identity profile")
|
|
||||||
return a.reauthenticateOrFail(w, r, err)
|
return a.reauthenticateOrFail(w, r, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -176,62 +164,20 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
|
||||||
if err := r.ParseForm(); err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
proxyPublicKey, requestParams, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
idpID := requestParams.Get(urlutil.QueryIdentityProviderID)
|
|
||||||
|
|
||||||
s, err := a.getSessionFromCtx(ctx)
|
s, err := a.getSessionFromCtx(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
state.sessionStore.ClearSession(w, r)
|
state.sessionStore.ClearSession(w, r)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// start over if this is a different identity provider
|
return state.flow.SignIn(w, r, s)
|
||||||
if s == nil || s.IdentityProviderID != idpID {
|
|
||||||
s = sessions.NewState(idpID)
|
|
||||||
}
|
|
||||||
|
|
||||||
// re-persist the session, useful when session was evicted from session
|
|
||||||
if err := state.sessionStore.SaveSession(w, r, s); err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
profile, err := a.loadIdentityProfile(r, state.cookieCipher)
|
|
||||||
if err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if a.cfg.profileTrimFn != nil {
|
|
||||||
a.cfg.profileTrimFn(profile)
|
|
||||||
}
|
|
||||||
|
|
||||||
a.logAuthenticateEvent(r, profile)
|
|
||||||
|
|
||||||
encryptURLValues := hpke.EncryptURLValuesV1
|
|
||||||
if hpke.IsEncryptedURLV2(r.Form) {
|
|
||||||
encryptURLValues = hpke.EncryptURLValuesV2
|
|
||||||
}
|
|
||||||
|
|
||||||
redirectTo, err := urlutil.CallbackURL(state.hpkePrivateKey, proxyPublicKey, requestParams, profile, encryptURLValues)
|
|
||||||
if err != nil {
|
|
||||||
return httputil.NewError(http.StatusInternalServerError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
httputil.Redirect(w, r, redirectTo, http.StatusFound)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignOut signs the user out and attempts to revoke the user's identity session
|
// SignOut signs the user out and attempts to revoke the user's identity session
|
||||||
// Handles both GET and POST.
|
// Handles both GET and POST.
|
||||||
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
||||||
// check for an HMAC'd URL. If none is found, show a confirmation page.
|
// check for an HMAC'd URL. If none is found, show a confirmation page.
|
||||||
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
|
err := a.state.Load().flow.VerifyAuthenticateSignature(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
authenticateURL, err := a.options.Load().GetAuthenticateURL()
|
authenticateURL, err := a.options.Load().GetAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -319,7 +265,7 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
a.logAuthenticateEvent(r, nil)
|
state.flow.LogAuthenticateEvent(r)
|
||||||
|
|
||||||
state.sessionStore.ClearSession(w, r)
|
state.sessionStore.ClearSession(w, r)
|
||||||
redirectURL := state.redirectURL.ResolveReference(r.URL)
|
redirectURL := state.redirectURL.ResolveReference(r.URL)
|
||||||
|
@ -418,7 +364,7 @@ Or contact your administrator.
|
||||||
`, redirectURL.String(), redirectURL.String()))
|
`, redirectURL.String(), redirectURL.String()))
|
||||||
}
|
}
|
||||||
|
|
||||||
idpID := a.getIdentityProviderIDForURLValues(redirectURL.Query())
|
idpID := state.flow.GetIdentityProviderIDForURLValues(redirectURL.Query())
|
||||||
|
|
||||||
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
|
authenticator, err := a.cfg.getIdentityProvider(options, idpID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -445,13 +391,9 @@ Or contact your administrator.
|
||||||
newState.Audience = append(newState.Audience, nextRedirectURL.Hostname())
|
newState.Audience = append(newState.Audience, nextRedirectURL.Hostname())
|
||||||
}
|
}
|
||||||
|
|
||||||
// save the session and access token to the databroker
|
// save the session and access token to the databroker/cookie store
|
||||||
profile, err := a.buildIdentityProfile(r, claims, accessToken)
|
if err := state.flow.PersistSession(ctx, w, &newState, claims, accessToken); err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("failed saving new session: %w", err)
|
||||||
return nil, httputil.NewError(http.StatusInternalServerError, err)
|
|
||||||
}
|
|
||||||
if err := a.storeIdentityProfile(w, state.cookieCipher, profile); err != nil {
|
|
||||||
log.Error(r.Context()).Err(err).Msg("failed to store identity profile")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ... and the user state to local storage.
|
// ... and the user state to local storage.
|
||||||
|
@ -478,11 +420,12 @@ func (a *Authenticate) getSessionFromCtx(ctx context.Context) (*sessions.State,
|
||||||
func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error {
|
func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error {
|
||||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.userInfo")
|
ctx, span := trace.StartSpan(r.Context(), "authenticate.userInfo")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
r = r.WithContext(ctx)
|
|
||||||
r = a.getExternalRequest(r)
|
|
||||||
|
|
||||||
options := a.options.Load()
|
options := a.options.Load()
|
||||||
|
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
r = authenticateflow.GetExternalAuthenticateRequest(r, options)
|
||||||
|
|
||||||
// if we came in with a redirect URI, save it to a cookie so it doesn't expire with the HMAC
|
// if we came in with a redirect URI, save it to a cookie so it doesn't expire with the HMAC
|
||||||
if redirectURI := r.FormValue(urlutil.QueryRedirectURI); redirectURI != "" {
|
if redirectURI := r.FormValue(urlutil.QueryRedirectURI); redirectURI != "" {
|
||||||
u := urlutil.GetAbsoluteURL(r)
|
u := urlutil.GetAbsoluteURL(r)
|
||||||
|
@ -509,14 +452,9 @@ func (a *Authenticate) getUserInfoData(r *http.Request) handlers.UserInfoData {
|
||||||
s.ID = uuid.New().String()
|
s.ID = uuid.New().String()
|
||||||
}
|
}
|
||||||
|
|
||||||
profile, _ := a.loadIdentityProfile(r, state.cookieCipher)
|
data := state.flow.GetUserInfoData(r, s)
|
||||||
|
data.CSRFToken = csrf.Token(r)
|
||||||
data := handlers.UserInfoData{
|
data.BrandingOptions = a.options.Load().BrandingOptions
|
||||||
CSRFToken: csrf.Token(r),
|
|
||||||
Profile: profile,
|
|
||||||
|
|
||||||
BrandingOptions: a.options.Load().BrandingOptions,
|
|
||||||
}
|
|
||||||
return data
|
return data
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -537,18 +475,7 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter,
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
profile, err := a.loadIdentityProfile(r, a.state.Load().cookieCipher)
|
return state.flow.RevokeSession(ctx, r, authenticator, nil)
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
oauthToken := new(oauth2.Token)
|
|
||||||
_ = json.Unmarshal(profile.GetOauthToken(), oauthToken)
|
|
||||||
if err := authenticator.Revoke(ctx, oauthToken); err != nil {
|
|
||||||
log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token")
|
|
||||||
}
|
|
||||||
|
|
||||||
return string(profile.GetIdToken())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Callback handles the result of a successful call to the authenticate service
|
// Callback handles the result of a successful call to the authenticate service
|
||||||
|
@ -603,19 +530,5 @@ func (a *Authenticate) getIdentityProviderIDForRequest(r *http.Request) string {
|
||||||
if err := r.ParseForm(); err != nil {
|
if err := r.ParseForm(); err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
return a.getIdentityProviderIDForURLValues(r.Form)
|
return a.state.Load().flow.GetIdentityProviderIDForURLValues(r.Form)
|
||||||
}
|
|
||||||
|
|
||||||
func (a *Authenticate) getIdentityProviderIDForURLValues(vs url.Values) string {
|
|
||||||
state := a.state.Load()
|
|
||||||
idpID := ""
|
|
||||||
if _, requestParams, err := hpke.DecryptURLValues(state.hpkePrivateKey, vs); err == nil {
|
|
||||||
if idpID == "" {
|
|
||||||
idpID = requestParams.Get(urlutil.QueryIdentityProviderID)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if idpID == "" {
|
|
||||||
idpID = vs.Get(urlutil.QueryIdentityProviderID)
|
|
||||||
}
|
|
||||||
return idpID
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -24,6 +23,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/atomicutil"
|
"github.com/pomerium/pomerium/internal/atomicutil"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||||
|
"github.com/pomerium/pomerium/internal/handlers"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
"github.com/pomerium/pomerium/internal/identity/oidc"
|
"github.com/pomerium/pomerium/internal/identity/oidc"
|
||||||
|
@ -40,6 +40,7 @@ func testAuthenticate() *Authenticate {
|
||||||
auth.state = atomicutil.NewValue(&authenticateState{
|
auth.state = atomicutil.NewValue(&authenticateState{
|
||||||
redirectURL: redirectURL,
|
redirectURL: redirectURL,
|
||||||
cookieSecret: cryptutil.NewKey(),
|
cookieSecret: cryptutil.NewKey(),
|
||||||
|
flow: new(stubFlow),
|
||||||
})
|
})
|
||||||
auth.options = config.NewAtomicOptions()
|
auth.options = config.NewAtomicOptions()
|
||||||
auth.options.Store(&config.Options{
|
auth.options.Store(&config.Options{
|
||||||
|
@ -205,6 +206,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
state: atomicutil.NewValue(&authenticateState{
|
state: atomicutil.NewValue(&authenticateState{
|
||||||
sessionStore: tt.sessionStore,
|
sessionStore: tt.sessionStore,
|
||||||
sharedEncoder: mock.Encoder{},
|
sharedEncoder: mock.Encoder{},
|
||||||
|
flow: new(stubFlow),
|
||||||
}),
|
}),
|
||||||
options: config.NewAtomicOptions(),
|
options: config.NewAtomicOptions(),
|
||||||
}
|
}
|
||||||
|
@ -301,6 +303,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
||||||
redirectURL: authURL,
|
redirectURL: authURL,
|
||||||
sessionStore: tt.session,
|
sessionStore: tt.session,
|
||||||
cookieCipher: aead,
|
cookieCipher: aead,
|
||||||
|
flow: new(stubFlow),
|
||||||
}),
|
}),
|
||||||
options: config.NewAtomicOptions(),
|
options: config.NewAtomicOptions(),
|
||||||
}
|
}
|
||||||
|
@ -414,6 +417,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
sessionStore: tt.session,
|
sessionStore: tt.session,
|
||||||
cookieCipher: aead,
|
cookieCipher: aead,
|
||||||
sharedEncoder: signer,
|
sharedEncoder: signer,
|
||||||
|
flow: new(stubFlow),
|
||||||
}),
|
}),
|
||||||
options: config.NewAtomicOptions(),
|
options: config.NewAtomicOptions(),
|
||||||
}
|
}
|
||||||
|
@ -452,6 +456,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
||||||
var a Authenticate
|
var a Authenticate
|
||||||
a.state = atomicutil.NewValue(&authenticateState{
|
a.state = atomicutil.NewValue(&authenticateState{
|
||||||
cookieSecret: cryptutil.NewKey(),
|
cookieSecret: cryptutil.NewKey(),
|
||||||
|
flow: new(stubFlow),
|
||||||
})
|
})
|
||||||
a.options = config.NewAtomicOptions()
|
a.options = config.NewAtomicOptions()
|
||||||
a.options.Store(&config.Options{
|
a.options.Store(&config.Options{
|
||||||
|
@ -468,35 +473,31 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
url *url.URL
|
url string
|
||||||
method string
|
validSignature bool
|
||||||
sessionStore sessions.SessionStore
|
sessionStore sessions.SessionStore
|
||||||
wantCode int
|
wantCode int
|
||||||
wantBody string
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"good",
|
"not a redirect",
|
||||||
mustParseURL("/"),
|
"/",
|
||||||
http.MethodGet,
|
true,
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
|
&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
|
||||||
http.StatusOK,
|
http.StatusOK,
|
||||||
"",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"missing signature",
|
"signed redirect",
|
||||||
mustParseURL("/?pomerium_redirect_uri=http://example.com"),
|
"/?pomerium_redirect_uri=http://example.com",
|
||||||
http.MethodGet,
|
true,
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
|
&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
|
||||||
http.StatusBadRequest,
|
http.StatusFound,
|
||||||
"",
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"bad signature",
|
"invalid redirect",
|
||||||
urlutil.NewSignedURL([]byte("BAD KEY"), mustParseURL("/?pomerium_redirect_uri=http://example.com")).Sign(),
|
"/?pomerium_redirect_uri=http://example.com",
|
||||||
http.MethodGet,
|
false,
|
||||||
&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
|
&mstore.Store{Encrypted: true, Session: &sessions.State{ID: "SESSION_ID", IssuedAt: jwt.NewNumericDate(now)}},
|
||||||
http.StatusBadRequest,
|
http.StatusBadRequest,
|
||||||
"",
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
@ -513,14 +514,19 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
||||||
AuthenticateURLString: "https://authenticate.localhost.pomerium.io",
|
AuthenticateURLString: "https://authenticate.localhost.pomerium.io",
|
||||||
SharedKey: "SHARED KEY",
|
SharedKey: "SHARED KEY",
|
||||||
})
|
})
|
||||||
|
f := new(stubFlow)
|
||||||
|
if !tt.validSignature {
|
||||||
|
f.verifySignatureErr = errors.New("bad signature")
|
||||||
|
}
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
options: o,
|
options: o,
|
||||||
state: atomicutil.NewValue(&authenticateState{
|
state: atomicutil.NewValue(&authenticateState{
|
||||||
sessionStore: tt.sessionStore,
|
sessionStore: tt.sessionStore,
|
||||||
sharedEncoder: signer,
|
sharedEncoder: signer,
|
||||||
|
flow: f,
|
||||||
}),
|
}),
|
||||||
}
|
}
|
||||||
r := httptest.NewRequest(tt.method, tt.url.String(), nil)
|
r := httptest.NewRequest(http.MethodGet, tt.url, nil)
|
||||||
state, err := tt.sessionStore.LoadSession(r)
|
state, err := tt.sessionStore.LoadSession(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
|
@ -535,10 +541,6 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
||||||
if status := w.Code; status != tt.wantCode {
|
if status := w.Code; status != tt.wantCode {
|
||||||
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
|
t.Errorf("handler returned wrong status code: got %v want %v", status, tt.wantCode)
|
||||||
}
|
}
|
||||||
body := w.Body.String()
|
|
||||||
if !strings.Contains(body, tt.wantBody) {
|
|
||||||
t.Errorf("Unexpected body, contains: %s, got: %s", tt.wantBody, body)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -565,3 +567,42 @@ func mustParseURL(rawurl string) *url.URL {
|
||||||
}
|
}
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// stubFlow is a stub implementation of the flow interface.
|
||||||
|
type stubFlow struct {
|
||||||
|
verifySignatureErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *stubFlow) VerifyAuthenticateSignature(*http.Request) error {
|
||||||
|
return f.verifySignatureErr
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*stubFlow) SignIn(http.ResponseWriter, *http.Request, *sessions.State) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*stubFlow) PersistSession(
|
||||||
|
context.Context, http.ResponseWriter, *sessions.State, identity.SessionClaims, *oauth2.Token,
|
||||||
|
) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*stubFlow) VerifySession(context.Context, *http.Request, *sessions.State) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*stubFlow) RevokeSession(
|
||||||
|
context.Context, *http.Request, identity.Authenticator, *sessions.State,
|
||||||
|
) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*stubFlow) GetUserInfoData(*http.Request, *sessions.State) handlers.UserInfoData {
|
||||||
|
return handlers.UserInfoData{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*stubFlow) LogAuthenticateEvent(*http.Request) {}
|
||||||
|
|
||||||
|
func (*stubFlow) GetIdentityProviderIDForURLValues(url.Values) string {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
|
@ -4,7 +4,6 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -13,7 +12,7 @@ import (
|
||||||
func (a *Authenticate) requireValidSignatureOnRedirect(next httputil.HandlerFunc) http.Handler {
|
func (a *Authenticate) requireValidSignatureOnRedirect(next httputil.HandlerFunc) http.Handler {
|
||||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
if r.FormValue(urlutil.QueryRedirectURI) != "" || r.FormValue(urlutil.QueryHmacSignature) != "" {
|
if r.FormValue(urlutil.QueryRedirectURI) != "" || r.FormValue(urlutil.QueryHmacSignature) != "" {
|
||||||
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
|
err := a.state.Load().flow.VerifyAuthenticateSignature(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
}
|
}
|
||||||
|
@ -25,26 +24,10 @@ func (a *Authenticate) requireValidSignatureOnRedirect(next httputil.HandlerFunc
|
||||||
// requireValidSignature validates the pomerium_signature.
|
// requireValidSignature validates the pomerium_signature.
|
||||||
func (a *Authenticate) requireValidSignature(next httputil.HandlerFunc) http.Handler {
|
func (a *Authenticate) requireValidSignature(next httputil.HandlerFunc) http.Handler {
|
||||||
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error {
|
||||||
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
|
err := a.state.Load().flow.VerifyAuthenticateSignature(r)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return next(w, r)
|
return next(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticate) getExternalRequest(r *http.Request) *http.Request {
|
|
||||||
options := a.options.Load()
|
|
||||||
|
|
||||||
externalURL, err := options.GetAuthenticateURL()
|
|
||||||
if err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
internalURL, err := options.GetInternalAuthenticateURL()
|
|
||||||
if err != nil {
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
return urlutil.GetExternalRequest(internalURL, externalURL, r)
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,23 +1,41 @@
|
||||||
package authenticate
|
package authenticate
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/go-jose/go-jose/v3"
|
"github.com/go-jose/go-jose/v3"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||||
|
"github.com/pomerium/pomerium/internal/handlers"
|
||||||
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
"github.com/pomerium/pomerium/internal/sessions/cookie"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/hpke"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type flow interface {
|
||||||
|
VerifyAuthenticateSignature(r *http.Request) error
|
||||||
|
SignIn(w http.ResponseWriter, r *http.Request, sessionState *sessions.State) error
|
||||||
|
PersistSession(ctx context.Context, w http.ResponseWriter, sessionState *sessions.State, claims identity.SessionClaims, accessToken *oauth2.Token) error
|
||||||
|
VerifySession(ctx context.Context, r *http.Request, sessionState *sessions.State) error
|
||||||
|
RevokeSession(ctx context.Context, r *http.Request, authenticator identity.Authenticator, sessionState *sessions.State) string
|
||||||
|
GetUserInfoData(r *http.Request, sessionState *sessions.State) handlers.UserInfoData
|
||||||
|
LogAuthenticateEvent(r *http.Request)
|
||||||
|
GetIdentityProviderIDForURLValues(url.Values) string
|
||||||
|
}
|
||||||
|
|
||||||
type authenticateState struct {
|
type authenticateState struct {
|
||||||
|
flow flow
|
||||||
|
|
||||||
redirectURL *url.URL
|
redirectURL *url.URL
|
||||||
// sharedEncoder is the encoder to use to serialize data to be consumed
|
// sharedEncoder is the encoder to use to serialize data to be consumed
|
||||||
// by other services
|
// by other services
|
||||||
|
@ -35,7 +53,6 @@ type authenticateState struct {
|
||||||
// sessionLoaders are a collection of session loaders to attempt to pull
|
// sessionLoaders are a collection of session loaders to attempt to pull
|
||||||
// a user's session state from
|
// a user's session state from
|
||||||
sessionLoader sessions.SessionLoader
|
sessionLoader sessions.SessionLoader
|
||||||
hpkePrivateKey *hpke.PrivateKey
|
|
||||||
|
|
||||||
jwk *jose.JSONWebKeySet
|
jwk *jose.JSONWebKeySet
|
||||||
}
|
}
|
||||||
|
@ -46,7 +63,9 @@ func newAuthenticateState() *authenticateState {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, error) {
|
func newAuthenticateStateFromConfig(
|
||||||
|
cfg *config.Config, authenticateConfig *authenticateConfig,
|
||||||
|
) (*authenticateState, error) {
|
||||||
err := ValidateOptions(cfg.Options)
|
err := ValidateOptions(cfg.Options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -125,12 +144,16 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
sharedKey, err := cfg.Options.GetSharedKey()
|
state.flow, err = authenticateflow.NewStateless(
|
||||||
|
cfg,
|
||||||
|
cookieStore,
|
||||||
|
authenticateConfig.getIdentityProvider,
|
||||||
|
authenticateConfig.profileTrimFn,
|
||||||
|
authenticateConfig.authEventFn,
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
state.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey)
|
|
||||||
|
|
||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,6 +26,7 @@ func TestNew(t *testing.T) {
|
||||||
config.Options{
|
config.Options{
|
||||||
AuthenticateURLString: "https://authN.example.com",
|
AuthenticateURLString: "https://authN.example.com",
|
||||||
DataBrokerURLString: "https://databroker.example.com",
|
DataBrokerURLString: "https://databroker.example.com",
|
||||||
|
CookieSecret: "15WXae6fvK9Hal0RGZ600JlCaflYHtNy9bAyOLTlvmc=",
|
||||||
SharedKey: "2p/Wi2Q6bYDfzmoSEbKqYKtg+DUoLWTEHHs7vOhvL7w=",
|
SharedKey: "2p/Wi2Q6bYDfzmoSEbKqYKtg+DUoLWTEHHs7vOhvL7w=",
|
||||||
Policies: policies,
|
Policies: policies,
|
||||||
},
|
},
|
||||||
|
@ -36,6 +37,7 @@ func TestNew(t *testing.T) {
|
||||||
config.Options{
|
config.Options{
|
||||||
AuthenticateURLString: "https://authN.example.com",
|
AuthenticateURLString: "https://authN.example.com",
|
||||||
DataBrokerURLString: "https://databroker.example.com",
|
DataBrokerURLString: "https://databroker.example.com",
|
||||||
|
CookieSecret: "15WXae6fvK9Hal0RGZ600JlCaflYHtNy9bAyOLTlvmc=",
|
||||||
SharedKey: "AZA85podM73CjLCjViDNz1EUvvejKpWp7Hysr0knXA==",
|
SharedKey: "AZA85podM73CjLCjViDNz1EUvvejKpWp7Hysr0knXA==",
|
||||||
Policies: policies,
|
Policies: policies,
|
||||||
},
|
},
|
||||||
|
@ -46,6 +48,7 @@ func TestNew(t *testing.T) {
|
||||||
config.Options{
|
config.Options{
|
||||||
AuthenticateURLString: "https://authN.example.com",
|
AuthenticateURLString: "https://authN.example.com",
|
||||||
DataBrokerURLString: "https://databroker.example.com",
|
DataBrokerURLString: "https://databroker.example.com",
|
||||||
|
CookieSecret: "15WXae6fvK9Hal0RGZ600JlCaflYHtNy9bAyOLTlvmc=",
|
||||||
SharedKey: "sup",
|
SharedKey: "sup",
|
||||||
Policies: policies,
|
Policies: policies,
|
||||||
},
|
},
|
||||||
|
@ -56,6 +59,7 @@ func TestNew(t *testing.T) {
|
||||||
config.Options{
|
config.Options{
|
||||||
AuthenticateURLString: "https://authN.example.com",
|
AuthenticateURLString: "https://authN.example.com",
|
||||||
DataBrokerURLString: "https://databroker.example.com",
|
DataBrokerURLString: "https://databroker.example.com",
|
||||||
|
CookieSecret: "15WXae6fvK9Hal0RGZ600JlCaflYHtNy9bAyOLTlvmc=",
|
||||||
SharedKey: "AZA85podM73CjLCjViDNz1EUvvejKpWp7Hysr0knXA==",
|
SharedKey: "AZA85podM73CjLCjViDNz1EUvvejKpWp7Hysr0knXA==",
|
||||||
Policies: policies,
|
Policies: policies,
|
||||||
},
|
},
|
||||||
|
@ -67,6 +71,7 @@ func TestNew(t *testing.T) {
|
||||||
config.Options{
|
config.Options{
|
||||||
AuthenticateURLString: "https://authN.example.com",
|
AuthenticateURLString: "https://authN.example.com",
|
||||||
DataBrokerURLString: "BAD",
|
DataBrokerURLString: "BAD",
|
||||||
|
CookieSecret: "15WXae6fvK9Hal0RGZ600JlCaflYHtNy9bAyOLTlvmc=",
|
||||||
SharedKey: "AZA85podM73CjLCjViDNz1EUvvejKpWp7Hysr0knXA==",
|
SharedKey: "AZA85podM73CjLCjViDNz1EUvvejKpWp7Hysr0knXA==",
|
||||||
Policies: policies,
|
Policies: policies,
|
||||||
},
|
},
|
||||||
|
@ -105,6 +110,7 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
|
||||||
o := &config.Options{
|
o := &config.Options{
|
||||||
AuthenticateURLString: "https://authN.example.com",
|
AuthenticateURLString: "https://authN.example.com",
|
||||||
DataBrokerURLString: "https://databroker.example.com",
|
DataBrokerURLString: "https://databroker.example.com",
|
||||||
|
CookieSecret: "15WXae6fvK9Hal0RGZ600JlCaflYHtNy9bAyOLTlvmc=",
|
||||||
SharedKey: tc.SharedKey,
|
SharedKey: tc.SharedKey,
|
||||||
Policies: tc.Policies,
|
Policies: tc.Policies,
|
||||||
}
|
}
|
||||||
|
|
|
@ -192,32 +192,17 @@ func (a *Authorize) requireLoginResponse(
|
||||||
return a.deniedResponse(ctx, in, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), nil)
|
return a.deniedResponse(ctx, in, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
authenticateURL, err := options.GetAuthenticateURL()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
idp, err := options.GetIdentityProviderForPolicy(request.Policy)
|
idp, err := options.GetIdentityProviderForPolicy(request.Policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
authenticateHPKEPublicKey, err := state.authenticateKeyFetcher.FetchPublicKey(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// always assume https scheme
|
// always assume https scheme
|
||||||
checkRequestURL := getCheckRequestURL(in)
|
checkRequestURL := getCheckRequestURL(in)
|
||||||
checkRequestURL.Scheme = "https"
|
checkRequestURL.Scheme = "https"
|
||||||
|
|
||||||
redirectTo, err := urlutil.SignInURL(
|
redirectTo, err := state.authenticateFlow.AuthenticateSignInURL(
|
||||||
state.hpkePrivateKey,
|
ctx, nil, &checkRequestURL, idp.GetId())
|
||||||
authenticateHPKEPublicKey,
|
|
||||||
authenticateURL,
|
|
||||||
&checkRequestURL,
|
|
||||||
idp.GetId(),
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,20 +3,25 @@ package authorize
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
googlegrpc "google.golang.org/grpc"
|
googlegrpc "google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/authorize/evaluator"
|
"github.com/pomerium/pomerium/authorize/evaluator"
|
||||||
"github.com/pomerium/pomerium/authorize/internal/store"
|
"github.com/pomerium/pomerium/authorize/internal/store"
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/hpke"
|
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||||
|
|
||||||
|
type authenticateFlow interface {
|
||||||
|
AuthenticateSignInURL(ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
type authorizeState struct {
|
type authorizeState struct {
|
||||||
sharedKey []byte
|
sharedKey []byte
|
||||||
evaluator *evaluator.Evaluator
|
evaluator *evaluator.Evaluator
|
||||||
|
@ -24,8 +29,7 @@ type authorizeState struct {
|
||||||
dataBrokerClient databroker.DataBrokerServiceClient
|
dataBrokerClient databroker.DataBrokerServiceClient
|
||||||
auditEncryptor *protoutil.Encryptor
|
auditEncryptor *protoutil.Encryptor
|
||||||
sessionStore *config.SessionStore
|
sessionStore *config.SessionStore
|
||||||
hpkePrivateKey *hpke.PrivateKey
|
authenticateFlow authenticateFlow
|
||||||
authenticateKeyFetcher hpke.KeyFetcher
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAuthorizeStateFromConfig(
|
func newAuthorizeStateFromConfig(
|
||||||
|
@ -79,10 +83,9 @@ func newAuthorizeStateFromConfig(
|
||||||
return nil, fmt.Errorf("authorize: invalid session store: %w", err)
|
return nil, fmt.Errorf("authorize: invalid session store: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
state.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey)
|
state.authenticateFlow, err = authenticateflow.NewStateless(cfg, nil, nil, nil, nil)
|
||||||
state.authenticateKeyFetcher, err = cfg.GetAuthenticateKeyFetcher()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return state, nil
|
return state, nil
|
||||||
|
|
31
internal/authenticateflow/authenticateflow.go
Normal file
31
internal/authenticateflow/authenticateflow.go
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
// Package authenticateflow implements the core authentication flow. This
|
||||||
|
// includes creating and parsing sign-in redirect URLs, storing and retrieving
|
||||||
|
// session data, and handling authentication callback URLs.
|
||||||
|
package authenticateflow
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
)
|
||||||
|
|
||||||
|
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||||
|
|
||||||
|
func populateUserFromClaims(u *user.User, claims map[string]interface{}) {
|
||||||
|
if v, ok := claims["name"]; ok {
|
||||||
|
u.Name = fmt.Sprint(v)
|
||||||
|
}
|
||||||
|
if v, ok := claims["email"]; ok {
|
||||||
|
u.Email = fmt.Sprint(v)
|
||||||
|
}
|
||||||
|
if u.Claims == nil {
|
||||||
|
u.Claims = make(map[string]*structpb.ListValue)
|
||||||
|
}
|
||||||
|
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
|
||||||
|
u.Claims[k] = vs
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package authenticate
|
package authenticateflow
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -8,7 +8,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||||
"github.com/pomerium/pomerium/pkg/hpke"
|
"github.com/pomerium/pomerium/pkg/hpke"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -45,14 +45,15 @@ type AuthEvent struct {
|
||||||
// AuthEventFn is a function that handles an authentication event
|
// AuthEventFn is a function that handles an authentication event
|
||||||
type AuthEventFn func(context.Context, AuthEvent)
|
type AuthEventFn func(context.Context, AuthEvent)
|
||||||
|
|
||||||
func (a *Authenticate) logAuthenticateEvent(r *http.Request, profile *identity.Profile) {
|
// TODO: move into stateless.go; this is here for now just so that Git will
|
||||||
if a.cfg.authEventFn == nil {
|
// track the file history as a rename from authenticate/events.go.
|
||||||
|
func (s *Stateless) logAuthenticateEvent(r *http.Request, profile *identitypb.Profile) {
|
||||||
|
if s.authEventFn == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
state := a.state.Load()
|
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
pub, params, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form)
|
pub, params, err := hpke.DecryptURLValues(s.hpkePrivateKey, r.Form)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(ctx).Err(err).Msg("log authenticate event: failed to decrypt request params")
|
log.Warn(ctx).Err(err).Msg("log authenticate event: failed to decrypt request params")
|
||||||
}
|
}
|
||||||
|
@ -82,20 +83,5 @@ func (a *Authenticate) logAuthenticateEvent(r *http.Request, profile *identity.P
|
||||||
evt.Domain = &domain
|
evt.Domain = &domain
|
||||||
}
|
}
|
||||||
|
|
||||||
a.cfg.authEventFn(ctx, evt)
|
s.authEventFn(ctx, evt)
|
||||||
}
|
|
||||||
|
|
||||||
func getUserClaim(profile *identity.Profile, field string) *string {
|
|
||||||
if profile == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
if profile.Claims == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
val, ok := profile.Claims.Fields[field]
|
|
||||||
if !ok || val == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
txt := val.GetStringValue()
|
|
||||||
return &txt
|
|
||||||
}
|
}
|
|
@ -1,4 +1,4 @@
|
||||||
package authenticate
|
package authenticateflow
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
@ -7,27 +7,36 @@ import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
"google.golang.org/protobuf/types/known/structpb"
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
|
"github.com/pomerium/pomerium/internal/identity/manager"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// An "identity profile" is an alternative to a session, used in the stateless
|
||||||
|
// authenticate flow. An identity profile contains an IdP ID (to distinguish
|
||||||
|
// between different IdP's or between different clients of the same IdP), a
|
||||||
|
// user ID token, and an OAuth2 token.
|
||||||
|
|
||||||
var cookieChunker = httputil.NewCookieChunker()
|
var cookieChunker = httputil.NewCookieChunker()
|
||||||
|
|
||||||
func (a *Authenticate) buildIdentityProfile(
|
// buildIdentityProfile populates an identity profile.
|
||||||
r *http.Request,
|
func buildIdentityProfile(
|
||||||
|
idpID string,
|
||||||
claims identity.SessionClaims,
|
claims identity.SessionClaims,
|
||||||
oauthToken *oauth2.Token,
|
oauthToken *oauth2.Token,
|
||||||
) (*identitypb.Profile, error) {
|
) (*identitypb.Profile, error) {
|
||||||
idpID := r.FormValue(urlutil.QueryIdentityProviderID)
|
|
||||||
|
|
||||||
rawIDToken := []byte(claims.RawIDToken)
|
rawIDToken := []byte(claims.RawIDToken)
|
||||||
rawOAuthToken, err := json.Marshal(oauthToken)
|
rawOAuthToken, err := json.Marshal(oauthToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -46,7 +55,8 @@ func (a *Authenticate) buildIdentityProfile(
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticate) loadIdentityProfile(r *http.Request, aead cipher.AEAD) (*identitypb.Profile, error) {
|
// loadIdentityProfile loads an identity profile from a chunked set of cookies.
|
||||||
|
func loadIdentityProfile(r *http.Request, aead cipher.AEAD) (*identitypb.Profile, error) {
|
||||||
cookie, err := cookieChunker.LoadCookie(r, urlutil.QueryIdentityProfile)
|
cookie, err := cookieChunker.LoadCookie(r, urlutil.QueryIdentityProfile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("authenticate: error loading identity profile cookie: %w", err)
|
return nil, fmt.Errorf("authenticate: error loading identity profile cookie: %w", err)
|
||||||
|
@ -70,30 +80,35 @@ func (a *Authenticate) loadIdentityProfile(r *http.Request, aead cipher.AEAD) (*
|
||||||
return &profile, nil
|
return &profile, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticate) storeIdentityProfile(w http.ResponseWriter, aead cipher.AEAD, profile *identitypb.Profile) error {
|
// storeIdentityProfile writes the identity profile to a chunked set of cookies.
|
||||||
options := a.options.Load()
|
func storeIdentityProfile(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
cookie *http.Cookie,
|
||||||
|
aead cipher.AEAD,
|
||||||
|
profile *identitypb.Profile,
|
||||||
|
) error {
|
||||||
decrypted, err := protojson.Marshal(profile)
|
decrypted, err := protojson.Marshal(profile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// this shouldn't happen
|
// this shouldn't happen
|
||||||
panic(fmt.Errorf("error marshaling message: %w", err))
|
panic(fmt.Errorf("error marshaling message: %w", err))
|
||||||
}
|
}
|
||||||
encrypted := cryptutil.Encrypt(aead, decrypted, nil)
|
encrypted := cryptutil.Encrypt(aead, decrypted, nil)
|
||||||
cookie := options.NewCookie()
|
|
||||||
cookie.Name = urlutil.QueryIdentityProfile
|
cookie.Name = urlutil.QueryIdentityProfile
|
||||||
cookie.Value = base64.RawURLEncoding.EncodeToString(encrypted)
|
cookie.Value = base64.RawURLEncoding.EncodeToString(encrypted)
|
||||||
cookie.Path = "/"
|
cookie.Path = "/"
|
||||||
return cookieChunker.SetCookie(w, cookie)
|
return cookieChunker.SetCookie(w, cookie)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticate) validateIdentityProfile(ctx context.Context, profile *identitypb.Profile) error {
|
// validateIdentityProfile checks expirations timestamps for the ID token and
|
||||||
authenticator, err := a.cfg.getIdentityProvider(a.options.Load(), profile.GetProviderId())
|
// OAuth2 token, and makes a user info request to the IdP in order to determine
|
||||||
if err != nil {
|
// whether the OAuth2 token is still valid.
|
||||||
return err
|
func validateIdentityProfile(
|
||||||
}
|
ctx context.Context,
|
||||||
|
authenticator identity.Authenticator,
|
||||||
|
profile *identitypb.Profile,
|
||||||
|
) error {
|
||||||
oauthToken := new(oauth2.Token)
|
oauthToken := new(oauth2.Token)
|
||||||
err = json.Unmarshal(profile.GetOauthToken(), oauthToken)
|
err := json.Unmarshal(profile.GetOauthToken(), oauthToken)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("invalid oauth token in profile: %w", err)
|
return fmt.Errorf("invalid oauth token in profile: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -110,3 +125,48 @@ func (a *Authenticate) validateIdentityProfile(ctx context.Context, profile *ide
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func newSessionStateFromProfile(p *identitypb.Profile) *sessions.State {
|
||||||
|
claims := p.GetClaims().AsMap()
|
||||||
|
|
||||||
|
ss := sessions.NewState(p.GetProviderId())
|
||||||
|
|
||||||
|
// set the subject
|
||||||
|
if v, ok := claims["sub"]; ok {
|
||||||
|
ss.Subject = fmt.Sprint(v)
|
||||||
|
} else if v, ok := claims["user"]; ok {
|
||||||
|
ss.Subject = fmt.Sprint(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// set the oid
|
||||||
|
if v, ok := claims["oid"]; ok {
|
||||||
|
ss.OID = fmt.Sprint(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ss
|
||||||
|
}
|
||||||
|
|
||||||
|
func populateSessionFromProfile(s *session.Session, p *identitypb.Profile, ss *sessions.State, cookieExpire time.Duration) {
|
||||||
|
claims := p.GetClaims().AsMap()
|
||||||
|
oauthToken := new(oauth2.Token)
|
||||||
|
_ = json.Unmarshal(p.GetOauthToken(), oauthToken)
|
||||||
|
|
||||||
|
s.UserId = ss.UserID()
|
||||||
|
s.IssuedAt = timestamppb.Now()
|
||||||
|
s.AccessedAt = timestamppb.Now()
|
||||||
|
s.ExpiresAt = timestamppb.New(time.Now().Add(cookieExpire))
|
||||||
|
s.IdToken = &session.IDToken{
|
||||||
|
Issuer: ss.Issuer,
|
||||||
|
Subject: ss.Subject,
|
||||||
|
ExpiresAt: timestamppb.New(time.Now().Add(cookieExpire)),
|
||||||
|
IssuedAt: timestamppb.Now(),
|
||||||
|
Raw: string(p.GetIdToken()),
|
||||||
|
}
|
||||||
|
s.OauthToken = manager.ToOAuthToken(oauthToken)
|
||||||
|
if s.Claims == nil {
|
||||||
|
s.Claims = make(map[string]*structpb.ListValue)
|
||||||
|
}
|
||||||
|
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
|
||||||
|
s.Claims[k] = vs
|
||||||
|
}
|
||||||
|
}
|
36
internal/authenticateflow/request.go
Normal file
36
internal/authenticateflow/request.go
Normal file
|
@ -0,0 +1,36 @@
|
||||||
|
package authenticateflow
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
type signatureVerifier struct {
|
||||||
|
options *config.Options
|
||||||
|
sharedKey []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyAuthenticateSignature checks that the provided request has a valid
|
||||||
|
// signature (for the authenticate service).
|
||||||
|
func (v signatureVerifier) VerifyAuthenticateSignature(r *http.Request) error {
|
||||||
|
return middleware.ValidateRequestURL(GetExternalAuthenticateRequest(r, v.options), v.sharedKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetExternalAuthenticateRequest canonicalizes an authenticate request URL
|
||||||
|
// based on the provided configuration options.
|
||||||
|
func GetExternalAuthenticateRequest(r *http.Request, options *config.Options) *http.Request {
|
||||||
|
externalURL, err := options.GetAuthenticateURL()
|
||||||
|
if err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
internalURL, err := options.GetInternalAuthenticateURL()
|
||||||
|
if err != nil {
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
return urlutil.GetExternalRequest(internalURL, externalURL, r)
|
||||||
|
}
|
58
internal/authenticateflow/request_test.go
Normal file
58
internal/authenticateflow/request_test.go
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
package authenticateflow
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestVerifyAuthenticateSignature(t *testing.T) {
|
||||||
|
options := &config.Options{
|
||||||
|
AuthenticateURLString: "https://authenticate.example.com",
|
||||||
|
AuthenticateInternalURLString: "https://authenticate.internal",
|
||||||
|
}
|
||||||
|
key := []byte("SHARED KEY--(must be 32 bytes)--")
|
||||||
|
v := signatureVerifier{options, key}
|
||||||
|
|
||||||
|
t.Run("Valid", func(t *testing.T) {
|
||||||
|
u := mustParseURL("https://example.com/")
|
||||||
|
r := &http.Request{Host: "example.com", URL: urlutil.NewSignedURL(key, u).Sign()}
|
||||||
|
err := v.VerifyAuthenticateSignature(r)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
t.Run("NoSignature", func(t *testing.T) {
|
||||||
|
r := &http.Request{Host: "example.com", URL: mustParseURL("https://example.com/")}
|
||||||
|
err := v.VerifyAuthenticateSignature(r)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
t.Run("DifferentKey", func(t *testing.T) {
|
||||||
|
zeros := make([]byte, 32)
|
||||||
|
u := mustParseURL("https://example.com/")
|
||||||
|
r := &http.Request{Host: "example.com", URL: urlutil.NewSignedURL(zeros, u).Sign()}
|
||||||
|
err := v.VerifyAuthenticateSignature(r)
|
||||||
|
assert.Error(t, err)
|
||||||
|
})
|
||||||
|
t.Run("InternalDomain", func(t *testing.T) {
|
||||||
|
// A request with the internal authenticate service URL should first be
|
||||||
|
// canonicalized to use the external authenticate service URL before
|
||||||
|
// validating the request signature.
|
||||||
|
u := urlutil.NewSignedURL(key, mustParseURL("https://authenticate.example.com/")).Sign()
|
||||||
|
u.Host = "authenticate.internal"
|
||||||
|
r := &http.Request{Host: "authenticate.internal", URL: u}
|
||||||
|
err := v.VerifyAuthenticateSignature(r)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustParseURL(rawurl string) *url.URL {
|
||||||
|
u, err := url.Parse(rawurl)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return u
|
||||||
|
}
|
449
internal/authenticateflow/stateless.go
Normal file
449
internal/authenticateflow/stateless.go
Normal file
|
@ -0,0 +1,449 @@
|
||||||
|
package authenticateflow
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/cipher"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/go-jose/go-jose/v3"
|
||||||
|
"golang.org/x/oauth2"
|
||||||
|
"google.golang.org/protobuf/encoding/protojson"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||||
|
"github.com/pomerium/pomerium/internal/handlers"
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
"github.com/pomerium/pomerium/pkg/hpke"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Stateless implements the stateless authentication flow. In this flow, the
|
||||||
|
// authenticate service has no direct access to the databroker and instead
|
||||||
|
// stores profile information in a cookie.
|
||||||
|
type Stateless struct {
|
||||||
|
signatureVerifier
|
||||||
|
|
||||||
|
// sharedEncoder is the encoder to use to serialize data to be consumed
|
||||||
|
// by other services
|
||||||
|
sharedEncoder encoding.MarshalUnmarshaler
|
||||||
|
// cookieCipher is the cipher to use to encrypt/decrypt session data
|
||||||
|
cookieCipher cipher.AEAD
|
||||||
|
|
||||||
|
sessionStore sessions.SessionStore
|
||||||
|
|
||||||
|
hpkePrivateKey *hpke.PrivateKey
|
||||||
|
authenticateKeyFetcher hpke.KeyFetcher
|
||||||
|
|
||||||
|
jwk *jose.JSONWebKeySet
|
||||||
|
|
||||||
|
authenticateURL *url.URL
|
||||||
|
|
||||||
|
options *config.Options
|
||||||
|
|
||||||
|
dataBrokerClient databroker.DataBrokerServiceClient
|
||||||
|
|
||||||
|
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)
|
||||||
|
profileTrimFn func(*identitypb.Profile)
|
||||||
|
authEventFn AuthEventFn
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewStateless initializes the authentication flow for the given
|
||||||
|
// configuration, session store, and additional options.
|
||||||
|
func NewStateless(
|
||||||
|
cfg *config.Config,
|
||||||
|
sessionStore sessions.SessionStore,
|
||||||
|
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error),
|
||||||
|
profileTrimFn func(*identitypb.Profile),
|
||||||
|
authEventFn AuthEventFn,
|
||||||
|
) (*Stateless, error) {
|
||||||
|
s := &Stateless{
|
||||||
|
options: cfg.Options,
|
||||||
|
sessionStore: sessionStore,
|
||||||
|
getIdentityProvider: getIdentityProvider,
|
||||||
|
profileTrimFn: profileTrimFn,
|
||||||
|
authEventFn: authEventFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
s.authenticateURL, err = cfg.Options.GetAuthenticateURL()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// shared cipher to encrypt data before passing data between services
|
||||||
|
sharedKey, err := cfg.Options.GetSharedKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// shared state encoder setup
|
||||||
|
s.sharedEncoder, err = jws.NewHS256Signer(sharedKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// private state encoder setup, used to encrypt oauth2 tokens
|
||||||
|
cookieSecret, err := cfg.Options.GetCookieSecret()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.cookieCipher, err = cryptutil.NewAEADCipher(cookieSecret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.jwk = new(jose.JSONWebKeySet)
|
||||||
|
signingKey, err := cfg.Options.GetSigningKey()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if len(signingKey) > 0 {
|
||||||
|
ks, err := cryptutil.PublicJWKsFromBytes(signingKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("authenticate: failed to convert jwks: %w", err)
|
||||||
|
}
|
||||||
|
for _, k := range ks {
|
||||||
|
s.jwk.Keys = append(s.jwk.Keys, *k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
s.signatureVerifier = signatureVerifier{cfg.Options, sharedKey}
|
||||||
|
|
||||||
|
s.hpkePrivateKey = hpke.DerivePrivateKey(sharedKey)
|
||||||
|
|
||||||
|
s.authenticateKeyFetcher, err = cfg.GetAuthenticateKeyFetcher()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dataBrokerConn, err := outboundGRPCConnection.Get(context.Background(), &grpc.OutboundOptions{
|
||||||
|
OutboundPort: cfg.OutboundPort,
|
||||||
|
InstallationID: cfg.Options.InstallationID,
|
||||||
|
ServiceName: cfg.Options.Services,
|
||||||
|
SignedJWTKey: sharedKey,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
s.dataBrokerClient = databroker.NewDataBrokerServiceClient(dataBrokerConn)
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifySession checks that an existing session is still valid.
|
||||||
|
func (s *Stateless) VerifySession(ctx context.Context, r *http.Request, _ *sessions.State) error {
|
||||||
|
profile, err := loadIdentityProfile(r, s.cookieCipher)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("identity profile load error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticator, err := s.getIdentityProvider(s.options, profile.GetProviderId())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("couldn't get identity provider: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := validateIdentityProfile(ctx, authenticator, profile); err != nil {
|
||||||
|
return fmt.Errorf("invalid identity profile: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// SignIn redirects to a route callback URL, if the provided request and
|
||||||
|
// session state are valid.
|
||||||
|
func (s *Stateless) SignIn(
|
||||||
|
w http.ResponseWriter,
|
||||||
|
r *http.Request,
|
||||||
|
sessionState *sessions.State,
|
||||||
|
) error {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
|
}
|
||||||
|
proxyPublicKey, requestParams, err := hpke.DecryptURLValues(s.hpkePrivateKey, r.Form)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
idpID := requestParams.Get(urlutil.QueryIdentityProviderID)
|
||||||
|
|
||||||
|
// start over if this is a different identity provider
|
||||||
|
if sessionState == nil || sessionState.IdentityProviderID != idpID {
|
||||||
|
sessionState = sessions.NewState(idpID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// re-persist the session, useful when session was evicted from session store
|
||||||
|
if err := s.sessionStore.SaveSession(w, r, sessionState); err != nil {
|
||||||
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
profile, err := loadIdentityProfile(r, s.cookieCipher)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.profileTrimFn != nil {
|
||||||
|
s.profileTrimFn(profile)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.logAuthenticateEvent(r, profile)
|
||||||
|
|
||||||
|
encryptURLValues := hpke.EncryptURLValuesV1
|
||||||
|
if hpke.IsEncryptedURLV2(r.Form) {
|
||||||
|
encryptURLValues = hpke.EncryptURLValuesV2
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectTo, err := urlutil.CallbackURL(s.hpkePrivateKey, proxyPublicKey, requestParams, profile, encryptURLValues)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.NewError(http.StatusInternalServerError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
httputil.Redirect(w, r, redirectTo, http.StatusFound)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// PersistSession stores session data in a cookie.
|
||||||
|
func (s *Stateless) PersistSession(
|
||||||
|
ctx context.Context,
|
||||||
|
w http.ResponseWriter,
|
||||||
|
sessionState *sessions.State,
|
||||||
|
claims identity.SessionClaims,
|
||||||
|
accessToken *oauth2.Token,
|
||||||
|
) error {
|
||||||
|
idpID := sessionState.IdentityProviderID
|
||||||
|
profile, err := buildIdentityProfile(idpID, claims, accessToken)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = storeIdentityProfile(w, s.options.NewCookie(), s.cookieCipher, profile)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx).Err(err).Msg("failed to store identity profile")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetUserInfoData returns user info data associated with the given request (if
|
||||||
|
// any).
|
||||||
|
func (s *Stateless) GetUserInfoData(r *http.Request, _ *sessions.State) handlers.UserInfoData {
|
||||||
|
profile, _ := loadIdentityProfile(r, s.cookieCipher)
|
||||||
|
return handlers.UserInfoData{
|
||||||
|
Profile: profile,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RevokeSession revokes the session associated with the provided request,
|
||||||
|
// returning the ID token from the revoked session.
|
||||||
|
func (s *Stateless) RevokeSession(
|
||||||
|
ctx context.Context, r *http.Request, authenticator identity.Authenticator, _ *sessions.State,
|
||||||
|
) string {
|
||||||
|
profile, err := loadIdentityProfile(r, s.cookieCipher)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
oauthToken := new(oauth2.Token)
|
||||||
|
_ = json.Unmarshal(profile.GetOauthToken(), oauthToken)
|
||||||
|
if err := authenticator.Revoke(ctx, oauthToken); err != nil {
|
||||||
|
log.Ctx(ctx).Warn().Err(err).Msg("authenticate: failed to revoke access token")
|
||||||
|
}
|
||||||
|
|
||||||
|
return string(profile.GetIdToken())
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIdentityProviderIDForURLValues returns the identity provider ID
|
||||||
|
// associated with the given URL values.
|
||||||
|
func (s *Stateless) GetIdentityProviderIDForURLValues(vs url.Values) string {
|
||||||
|
idpID := ""
|
||||||
|
if _, requestParams, err := hpke.DecryptURLValues(s.hpkePrivateKey, vs); err == nil {
|
||||||
|
if idpID == "" {
|
||||||
|
idpID = requestParams.Get(urlutil.QueryIdentityProviderID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if idpID == "" {
|
||||||
|
idpID = vs.Get(urlutil.QueryIdentityProviderID)
|
||||||
|
}
|
||||||
|
return idpID
|
||||||
|
}
|
||||||
|
|
||||||
|
// LogAuthenticateEvent logs an authenticate service event.
|
||||||
|
func (s *Stateless) LogAuthenticateEvent(r *http.Request) {
|
||||||
|
s.logAuthenticateEvent(r, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUserClaim(profile *identitypb.Profile, field string) *string {
|
||||||
|
if profile == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if profile.Claims == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
val, ok := profile.Claims.Fields[field]
|
||||||
|
if !ok || val == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
txt := val.GetStringValue()
|
||||||
|
return &txt
|
||||||
|
}
|
||||||
|
|
||||||
|
// AuthenticateSignInURL returns a URL to redirect the user to the authenticate
|
||||||
|
// domain.
|
||||||
|
func (s *Stateless) AuthenticateSignInURL(
|
||||||
|
ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string,
|
||||||
|
) (string, error) {
|
||||||
|
authenticateHPKEPublicKey, err := s.authenticateKeyFetcher.FetchPublicKey(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticateURLWithParams := *s.authenticateURL
|
||||||
|
q := authenticateURLWithParams.Query()
|
||||||
|
for k, v := range queryParams {
|
||||||
|
q[k] = v
|
||||||
|
}
|
||||||
|
authenticateURLWithParams.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
return urlutil.SignInURL(
|
||||||
|
s.hpkePrivateKey,
|
||||||
|
authenticateHPKEPublicKey,
|
||||||
|
&authenticateURLWithParams,
|
||||||
|
redirectURL,
|
||||||
|
idpID,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Callback handles a redirect to a route domain once signed in.
|
||||||
|
func (s *Stateless) Callback(w http.ResponseWriter, r *http.Request) error {
|
||||||
|
if err := r.ParseForm(); err != nil {
|
||||||
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// decrypt the URL values
|
||||||
|
senderPublicKey, values, err := hpke.DecryptURLValues(s.hpkePrivateKey, r.Form)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid encrypted query string: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// confirm this request came from the authenticate service
|
||||||
|
err = s.validateSenderPublicKey(r.Context(), senderPublicKey)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// validate that the request has not expired
|
||||||
|
err = urlutil.ValidateTimeParameters(values)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.NewError(http.StatusBadRequest, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
profile, err := getProfileFromValues(values)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
ss := newSessionStateFromProfile(profile)
|
||||||
|
sess, err := session.Get(r.Context(), s.dataBrokerClient, ss.ID)
|
||||||
|
if err != nil {
|
||||||
|
sess = &session.Session{Id: ss.ID}
|
||||||
|
}
|
||||||
|
populateSessionFromProfile(sess, profile, ss, s.options.CookieExpire)
|
||||||
|
u, err := user.Get(r.Context(), s.dataBrokerClient, ss.UserID())
|
||||||
|
if err != nil {
|
||||||
|
u = &user.User{Id: ss.UserID()}
|
||||||
|
}
|
||||||
|
populateUserFromClaims(u, profile.GetClaims().AsMap())
|
||||||
|
|
||||||
|
redirectURI, err := getRedirectURIFromValues(values)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// save the records
|
||||||
|
res, err := s.dataBrokerClient.Put(r.Context(), &databroker.PutRequest{
|
||||||
|
Records: []*databroker.Record{
|
||||||
|
databroker.NewRecord(sess),
|
||||||
|
databroker.NewRecord(u),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving databroker records: %w", err))
|
||||||
|
}
|
||||||
|
ss.DatabrokerServerVersion = res.GetServerVersion()
|
||||||
|
for _, record := range res.GetRecords() {
|
||||||
|
if record.GetVersion() > ss.DatabrokerRecordVersion {
|
||||||
|
ss.DatabrokerRecordVersion = record.GetVersion()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// save the session state
|
||||||
|
rawJWT, err := s.sharedEncoder.Marshal(ss)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error marshaling session state: %w", err))
|
||||||
|
}
|
||||||
|
if err = s.sessionStore.SaveSession(w, r, rawJWT); err != nil {
|
||||||
|
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving session state: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
// if programmatic, encode the session jwt as a query param
|
||||||
|
if isProgrammatic := values.Get(urlutil.QueryIsProgrammatic); isProgrammatic == "true" {
|
||||||
|
q := redirectURI.Query()
|
||||||
|
q.Set(urlutil.QueryPomeriumJWT, string(rawJWT))
|
||||||
|
redirectURI.RawQuery = q.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
// redirect
|
||||||
|
httputil.Redirect(w, r, redirectURI.String(), http.StatusFound)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Stateless) validateSenderPublicKey(ctx context.Context, senderPublicKey *hpke.PublicKey) error {
|
||||||
|
authenticatePublicKey, err := s.authenticateKeyFetcher.FetchPublicKey(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("hpke: error retrieving authenticate service public key: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if !authenticatePublicKey.Equals(senderPublicKey) {
|
||||||
|
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("hpke: invalid authenticate service public key"))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getProfileFromValues(values url.Values) (*identitypb.Profile, error) {
|
||||||
|
rawProfile := values.Get(urlutil.QueryIdentityProfile)
|
||||||
|
if rawProfile == "" {
|
||||||
|
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryIdentityProfile))
|
||||||
|
}
|
||||||
|
|
||||||
|
var profile identitypb.Profile
|
||||||
|
err := protojson.Unmarshal([]byte(rawProfile), &profile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryIdentityProfile, err))
|
||||||
|
}
|
||||||
|
return &profile, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getRedirectURIFromValues(values url.Values) (*url.URL, error) {
|
||||||
|
rawRedirectURI := values.Get(urlutil.QueryRedirectURI)
|
||||||
|
if rawRedirectURI == "" {
|
||||||
|
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryRedirectURI))
|
||||||
|
}
|
||||||
|
redirectURI, err := urlutil.ParseAndValidateURL(rawRedirectURI)
|
||||||
|
if err != nil {
|
||||||
|
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryRedirectURI, err))
|
||||||
|
}
|
||||||
|
return redirectURI, nil
|
||||||
|
}
|
|
@ -1,7 +1,6 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -10,17 +9,11 @@ import (
|
||||||
|
|
||||||
"github.com/go-jose/go-jose/v3/jwt"
|
"github.com/go-jose/go-jose/v3/jwt"
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"google.golang.org/protobuf/encoding/protojson"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/handlers"
|
"github.com/pomerium/pomerium/internal/handlers"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/middleware"
|
"github.com/pomerium/pomerium/internal/middleware"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/identity"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
|
||||||
"github.com/pomerium/pomerium/pkg/hpke"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// registerDashboardHandlers returns the proxy service's ServeMux
|
// registerDashboardHandlers returns the proxy service's ServeMux
|
||||||
|
@ -110,89 +103,7 @@ func (p *Proxy) deviceEnrolled(w http.ResponseWriter, r *http.Request) error {
|
||||||
// Callback handles the result of a successful call to the authenticate service
|
// Callback handles the result of a successful call to the authenticate service
|
||||||
// and is responsible setting per-route sessions.
|
// and is responsible setting per-route sessions.
|
||||||
func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error {
|
func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error {
|
||||||
state := p.state.Load()
|
return p.state.Load().authenticateFlow.Callback(w, r)
|
||||||
options := p.currentOptions.Load()
|
|
||||||
|
|
||||||
if err := r.ParseForm(); err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// decrypt the URL values
|
|
||||||
senderPublicKey, values, err := hpke.DecryptURLValues(state.hpkePrivateKey, r.Form)
|
|
||||||
if err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid encrypted query string: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// confirm this request came from the authenticate service
|
|
||||||
err = p.validateSenderPublicKey(r.Context(), senderPublicKey)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// validate that the request has not expired
|
|
||||||
err = urlutil.ValidateTimeParameters(values)
|
|
||||||
if err != nil {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
profile, err := getProfileFromValues(values)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ss := newSessionStateFromProfile(profile)
|
|
||||||
s, err := session.Get(r.Context(), state.dataBrokerClient, ss.ID)
|
|
||||||
if err != nil {
|
|
||||||
s = &session.Session{Id: ss.ID}
|
|
||||||
}
|
|
||||||
populateSessionFromProfile(s, profile, ss, options.CookieExpire)
|
|
||||||
u, err := user.Get(r.Context(), state.dataBrokerClient, ss.UserID())
|
|
||||||
if err != nil {
|
|
||||||
u = &user.User{Id: ss.UserID()}
|
|
||||||
}
|
|
||||||
populateUserFromProfile(u, profile, ss)
|
|
||||||
|
|
||||||
redirectURI, err := getRedirectURIFromValues(values)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// save the records
|
|
||||||
res, err := state.dataBrokerClient.Put(r.Context(), &databroker.PutRequest{
|
|
||||||
Records: []*databroker.Record{
|
|
||||||
databroker.NewRecord(s),
|
|
||||||
databroker.NewRecord(u),
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving databroker records: %w", err))
|
|
||||||
}
|
|
||||||
ss.DatabrokerServerVersion = res.GetServerVersion()
|
|
||||||
for _, record := range res.GetRecords() {
|
|
||||||
if record.GetVersion() > ss.DatabrokerRecordVersion {
|
|
||||||
ss.DatabrokerRecordVersion = record.GetVersion()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// save the session state
|
|
||||||
rawJWT, err := state.encoder.Marshal(ss)
|
|
||||||
if err != nil {
|
|
||||||
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error marshaling session state: %w", err))
|
|
||||||
}
|
|
||||||
if err = state.sessionStore.SaveSession(w, r, rawJWT); err != nil {
|
|
||||||
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("proxy: error saving session state: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
// if programmatic, encode the session jwt as a query param
|
|
||||||
if isProgrammatic := values.Get(urlutil.QueryIsProgrammatic); isProgrammatic == "true" {
|
|
||||||
q := redirectURI.Query()
|
|
||||||
q.Set(urlutil.QueryPomeriumJWT, string(rawJWT))
|
|
||||||
redirectURI.RawQuery = q.Encode()
|
|
||||||
}
|
|
||||||
|
|
||||||
// redirect
|
|
||||||
httputil.Redirect(w, r, redirectURI.String(), http.StatusFound)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProgrammaticLogin returns a signed url that can be used to login
|
// ProgrammaticLogin returns a signed url that can be used to login
|
||||||
|
@ -215,20 +126,14 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error
|
||||||
return httputil.NewError(http.StatusInternalServerError, err)
|
return httputil.NewError(http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
hpkeAuthenticateKey, err := state.authenticateKeyFetcher.FetchPublicKey(r.Context())
|
|
||||||
if err != nil {
|
|
||||||
return httputil.NewError(http.StatusInternalServerError, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
signinURL := *state.authenticateSigninURL
|
|
||||||
callbackURI := urlutil.GetAbsoluteURL(r)
|
callbackURI := urlutil.GetAbsoluteURL(r)
|
||||||
callbackURI.Path = dashboardPath + "/callback/"
|
callbackURI.Path = dashboardPath + "/callback/"
|
||||||
q := signinURL.Query()
|
q := url.Values{}
|
||||||
q.Set(urlutil.QueryCallbackURI, callbackURI.String())
|
q.Set(urlutil.QueryCallbackURI, callbackURI.String())
|
||||||
q.Set(urlutil.QueryIsProgrammatic, "true")
|
q.Set(urlutil.QueryIsProgrammatic, "true")
|
||||||
signinURL.RawQuery = q.Encode()
|
|
||||||
|
|
||||||
rawURL, err := urlutil.SignInURL(state.hpkePrivateKey, hpkeAuthenticateKey, &signinURL, redirectURI, idp.GetId())
|
rawURL, err := state.authenticateFlow.AuthenticateSignInURL(
|
||||||
|
r.Context(), q, redirectURI, idp.GetId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusInternalServerError, err)
|
return httputil.NewError(http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
@ -263,44 +168,3 @@ func (p *Proxy) jwtAssertion(w http.ResponseWriter, r *http.Request) error {
|
||||||
_, _ = io.WriteString(w, rawAssertionJWT)
|
_, _ = io.WriteString(w, rawAssertionJWT)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) validateSenderPublicKey(ctx context.Context, senderPublicKey *hpke.PublicKey) error {
|
|
||||||
state := p.state.Load()
|
|
||||||
|
|
||||||
authenticatePublicKey, err := state.authenticateKeyFetcher.FetchPublicKey(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return httputil.NewError(http.StatusInternalServerError, fmt.Errorf("hpke: error retrieving authenticate service public key: %w", err))
|
|
||||||
}
|
|
||||||
|
|
||||||
if !authenticatePublicKey.Equals(senderPublicKey) {
|
|
||||||
return httputil.NewError(http.StatusBadRequest, fmt.Errorf("hpke: invalid authenticate service public key"))
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getProfileFromValues(values url.Values) (*identity.Profile, error) {
|
|
||||||
rawProfile := values.Get(urlutil.QueryIdentityProfile)
|
|
||||||
if rawProfile == "" {
|
|
||||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryIdentityProfile))
|
|
||||||
}
|
|
||||||
|
|
||||||
var profile identity.Profile
|
|
||||||
err := protojson.Unmarshal([]byte(rawProfile), &profile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryIdentityProfile, err))
|
|
||||||
}
|
|
||||||
return &profile, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRedirectURIFromValues(values url.Values) (*url.URL, error) {
|
|
||||||
rawRedirectURI := values.Get(urlutil.QueryRedirectURI)
|
|
||||||
if rawRedirectURI == "" {
|
|
||||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("missing %s", urlutil.QueryRedirectURI))
|
|
||||||
}
|
|
||||||
redirectURI, err := urlutil.ParseAndValidateURL(rawRedirectURI)
|
|
||||||
if err != nil {
|
|
||||||
return nil, httputil.NewError(http.StatusBadRequest, fmt.Errorf("invalid %s: %w", urlutil.QueryRedirectURI, err))
|
|
||||||
}
|
|
||||||
return redirectURI, nil
|
|
||||||
}
|
|
||||||
|
|
|
@ -1,79 +0,0 @@
|
||||||
package proxy
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"golang.org/x/oauth2"
|
|
||||||
"google.golang.org/protobuf/types/known/structpb"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
|
||||||
"github.com/pomerium/pomerium/internal/identity/manager"
|
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
|
||||||
identitypb "github.com/pomerium/pomerium/pkg/grpc/identity"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newSessionStateFromProfile(p *identitypb.Profile) *sessions.State {
|
|
||||||
claims := p.GetClaims().AsMap()
|
|
||||||
|
|
||||||
ss := sessions.NewState(p.GetProviderId())
|
|
||||||
|
|
||||||
// set the subject
|
|
||||||
if v, ok := claims["sub"]; ok {
|
|
||||||
ss.Subject = fmt.Sprint(v)
|
|
||||||
} else if v, ok := claims["user"]; ok {
|
|
||||||
ss.Subject = fmt.Sprint(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
// set the oid
|
|
||||||
if v, ok := claims["oid"]; ok {
|
|
||||||
ss.OID = fmt.Sprint(v)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ss
|
|
||||||
}
|
|
||||||
|
|
||||||
func populateSessionFromProfile(s *session.Session, p *identitypb.Profile, ss *sessions.State, cookieExpire time.Duration) {
|
|
||||||
claims := p.GetClaims().AsMap()
|
|
||||||
oauthToken := new(oauth2.Token)
|
|
||||||
_ = json.Unmarshal(p.GetOauthToken(), oauthToken)
|
|
||||||
|
|
||||||
s.UserId = ss.UserID()
|
|
||||||
s.IssuedAt = timestamppb.Now()
|
|
||||||
s.AccessedAt = timestamppb.Now()
|
|
||||||
s.ExpiresAt = timestamppb.New(time.Now().Add(cookieExpire))
|
|
||||||
s.IdToken = &session.IDToken{
|
|
||||||
Issuer: ss.Issuer,
|
|
||||||
Subject: ss.Subject,
|
|
||||||
ExpiresAt: timestamppb.New(time.Now().Add(cookieExpire)),
|
|
||||||
IssuedAt: timestamppb.Now(),
|
|
||||||
Raw: string(p.GetIdToken()),
|
|
||||||
}
|
|
||||||
s.OauthToken = manager.ToOAuthToken(oauthToken)
|
|
||||||
if s.Claims == nil {
|
|
||||||
s.Claims = make(map[string]*structpb.ListValue)
|
|
||||||
}
|
|
||||||
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
|
|
||||||
s.Claims[k] = vs
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func populateUserFromProfile(u *user.User, p *identitypb.Profile, _ *sessions.State) {
|
|
||||||
claims := p.GetClaims().AsMap()
|
|
||||||
if v, ok := claims["name"]; ok {
|
|
||||||
u.Name = fmt.Sprint(v)
|
|
||||||
}
|
|
||||||
if v, ok := claims["email"]; ok {
|
|
||||||
u.Email = fmt.Sprint(v)
|
|
||||||
}
|
|
||||||
if u.Claims == nil {
|
|
||||||
u.Claims = make(map[string]*structpb.ListValue)
|
|
||||||
}
|
|
||||||
for k, vs := range identity.Claims(claims).Flatten().ToPB() {
|
|
||||||
u.Claims[k] = vs
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -3,10 +3,11 @@ package proxy
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/cipher"
|
"crypto/cipher"
|
||||||
"fmt"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/authenticateflow"
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
@ -14,11 +15,15 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc"
|
"github.com/pomerium/pomerium/pkg/grpc"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/hpke"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
var outboundGRPCConnection = new(grpc.CachedOutboundGRPClientConn)
|
||||||
|
|
||||||
|
type authenticateFlow interface {
|
||||||
|
AuthenticateSignInURL(ctx context.Context, queryParams url.Values, redirectURL *url.URL, idpID string) (string, error)
|
||||||
|
Callback(w http.ResponseWriter, r *http.Request) error
|
||||||
|
}
|
||||||
|
|
||||||
type proxyState struct {
|
type proxyState struct {
|
||||||
sharedKey []byte
|
sharedKey []byte
|
||||||
sharedCipher cipher.AEAD
|
sharedCipher cipher.AEAD
|
||||||
|
@ -32,12 +37,12 @@ type proxyState struct {
|
||||||
cookieSecret []byte
|
cookieSecret []byte
|
||||||
sessionStore sessions.SessionStore
|
sessionStore sessions.SessionStore
|
||||||
jwtClaimHeaders config.JWTClaimHeaders
|
jwtClaimHeaders config.JWTClaimHeaders
|
||||||
hpkePrivateKey *hpke.PrivateKey
|
|
||||||
authenticateKeyFetcher hpke.KeyFetcher
|
|
||||||
|
|
||||||
dataBrokerClient databroker.DataBrokerServiceClient
|
dataBrokerClient databroker.DataBrokerServiceClient
|
||||||
|
|
||||||
programmaticRedirectDomainWhitelist []string
|
programmaticRedirectDomainWhitelist []string
|
||||||
|
|
||||||
|
authenticateFlow authenticateFlow
|
||||||
}
|
}
|
||||||
|
|
||||||
func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
|
func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
|
||||||
|
@ -53,16 +58,6 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
state.hpkePrivateKey, err = cfg.Options.GetHPKEPrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
state.authenticateKeyFetcher, err = cfg.GetAuthenticateKeyFetcher()
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("authorize: get authenticate JWKS key fetcher: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
state.sharedCipher, err = cryptutil.NewAEADCipher(state.sharedKey)
|
state.sharedCipher, err = cryptutil.NewAEADCipher(state.sharedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -119,5 +114,11 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
|
||||||
|
|
||||||
state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist
|
state.programmaticRedirectDomainWhitelist = cfg.Options.ProgrammaticRedirectDomainWhitelist
|
||||||
|
|
||||||
|
state.authenticateFlow, err = authenticateflow.NewStateless(
|
||||||
|
cfg, state.sessionStore, nil, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
return state, nil
|
return state, nil
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue