From d390e80b30e451f9e7966c2ffddba4785c0a2719 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Tue, 26 Oct 2021 14:45:53 -0600 Subject: [PATCH] authenticate: add databroker versions to session cookie (#2709) * authenticate: add databroker versions to session cookie authorize: wait for databroker synchronization on updated sessions * fix test --- authenticate/handlers.go | 26 ++++---- authenticate/handlers/webauthn/webauthn.go | 74 ++++++++++++++++------ authorize/sync.go | 29 +++++++++ authorize/sync_test.go | 37 +++++++++++ internal/sessions/state.go | 7 ++ internal/urlutil/proxy.go | 51 +++++++++++++++ 6 files changed, 192 insertions(+), 32 deletions(-) create mode 100644 internal/urlutil/proxy.go diff --git a/authenticate/handlers.go b/authenticate/handlers.go index e449a0982..5b6c52586 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -186,19 +186,13 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { jwtAudience := []string{state.redirectURL.Host, redirectURL.Host} - var callbackURL *url.URL // if the callback is explicitly set, set it and add an additional audience if callbackStr := r.FormValue(urlutil.QueryCallbackURI); callbackStr != "" { - callbackURL, err = urlutil.ParseAndValidateURL(callbackStr) + callbackURL, err := urlutil.ParseAndValidateURL(callbackStr) if err != nil { return httputil.NewError(http.StatusBadRequest, err) } jwtAudience = append(jwtAudience, callbackURL.Host) - } else { - // otherwise, assume callback is the same host as redirect - callbackURL, _ = urlutil.DeepCopy(redirectURL) - callbackURL.Path = "/.pomerium/callback/" - callbackURL.RawQuery = "" } // add an additional claim for the forward-auth host, if set @@ -219,11 +213,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { return httputil.NewError(http.StatusBadRequest, err) } - callbackParams := callbackURL.Query() - if r.FormValue(urlutil.QueryIsProgrammatic) == "true" { newSession.Programmatic = true - callbackParams.Set(urlutil.QueryIsProgrammatic, "true") } // sign the route session, as a JWT @@ -237,10 +228,10 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { // base64 our encrypted payload for URL-friendlyness encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT) - // add our encoded and encrypted route-session JWT to a query param - callbackParams.Set(urlutil.QuerySessionEncrypted, encodedJWT) - callbackParams.Set(urlutil.QueryRedirectURI, redirectURL.String()) - callbackURL.RawQuery = callbackParams.Encode() + callbackURL, err := urlutil.GetCallbackURL(r, encodedJWT) + if err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } // build our hmac-d redirect URL with our session, pointing back to the // proxy's callback URL which is responsible for setting our new route-session @@ -665,10 +656,17 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e return nil, err } + ss, err := a.getSessionFromCtx(ctx) + if err != nil { + return nil, err + } + return &webauthn.State{ SharedKey: state.sharedKey, Client: state.dataBrokerClient, Session: s, + SessionState: ss, + SessionStore: state.sessionStore, RelyingParty: state.webauthnRelyingParty, }, nil } diff --git a/authenticate/handlers/webauthn/webauthn.go b/authenticate/handlers/webauthn/webauthn.go index 75c0d30e2..9d453bb12 100644 --- a/authenticate/handlers/webauthn/webauthn.go +++ b/authenticate/handlers/webauthn/webauthn.go @@ -4,6 +4,7 @@ package webauthn import ( "bytes" "context" + "encoding/base64" "encoding/json" "errors" "fmt" @@ -19,10 +20,13 @@ import ( "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/pomerium/pomerium/internal/encoding/jws" "github.com/pomerium/pomerium/internal/frontend" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/middleware" + "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/device" "github.com/pomerium/pomerium/pkg/grpc/session" @@ -42,6 +46,8 @@ type State struct { SharedKey []byte Client databroker.DataBrokerServiceClient Session *session.Session + SessionState *sessions.State + SessionStore sessions.SessionStore RelyingParty *webauthn.RelyingParty } @@ -177,21 +183,14 @@ func (h *Handler) handleAuthenticate(w http.ResponseWriter, r *http.Request, sta } } - // save the session + // update the session state.Session.DeviceCredentials = append(state.Session.DeviceCredentials, &session.Session_DeviceCredential{ TypeId: deviceType.GetId(), Credential: &session.Session_DeviceCredential_Id{ Id: webauthnutil.GetDeviceCredentialID(serverCredential.ID), }, }) - _, err = session.Put(ctx, state.Client, state.Session) - if err != nil { - return err - } - - // redirect - httputil.Redirect(w, r, redirectURIParam, http.StatusFound) - return nil + return h.saveSessionAndRedirect(w, r, state, redirectURIParam) } func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state *State) error { @@ -286,21 +285,14 @@ func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state * return err } - // save the session + // update the session state.Session.DeviceCredentials = append(state.Session.DeviceCredentials, &session.Session_DeviceCredential{ TypeId: deviceType.GetId(), Credential: &session.Session_DeviceCredential_Id{ Id: webauthnutil.GetDeviceCredentialID(serverCredential.ID), }, }) - _, err = session.Put(ctx, state.Client, state.Session) - if err != nil { - return err - } - - // redirect - httputil.Redirect(w, r, redirectURIParam, http.StatusFound) - return nil + return h.saveSessionAndRedirect(w, r, state, redirectURIParam) } func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *State) error { @@ -351,6 +343,52 @@ func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *Stat return err } +func (h *Handler) saveSessionAndRedirect(w http.ResponseWriter, r *http.Request, state *State, rawRedirectURI string) error { + // save the session to the databroker + res, err := session.Put(r.Context(), state.Client, state.Session) + if err != nil { + return err + } + + // add databroker versions to the session cookie and save + state.SessionState.Version = sessions.Version(fmt.Sprint(res.GetServerVersion())) + state.SessionState.DatabrokerServerVersion = res.GetServerVersion() + state.SessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion() + err = state.SessionStore.SaveSession(w, r, state.SessionState) + if err != nil { + return err + } + + // sign+encrypt the session JWT + encoder, err := jws.NewHS256Signer(state.SharedKey) + if err != nil { + return err + } + + signedJWT, err := encoder.Marshal(state.SessionState) + if err != nil { + return err + } + + cipher, err := cryptutil.NewAEADCipher(state.SharedKey) + if err != nil { + return err + } + + encryptedJWT := cryptutil.Encrypt(cipher, signedJWT, nil) + encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT) + + // redirect to the proxy callback URL with the session + callbackURL, err := urlutil.GetCallbackURL(r, encodedJWT) + if err != nil { + return err + } + + signedCallbackURL := urlutil.NewSignedURL(state.SharedKey, callbackURL) + httputil.Redirect(w, r, signedCallbackURL.String(), http.StatusFound) + return nil +} + func getKnownDeviceCredentials( ctx context.Context, client databroker.DataBrokerServiceClient, diff --git a/authorize/sync.go b/authorize/sync.go index 49edebf8a..61d95f013 100644 --- a/authorize/sync.go +++ b/authorize/sync.go @@ -72,6 +72,12 @@ func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) (sessionO if ss == nil { return nil, nil, nil } + + // if the session state has databroker versions, wait for those to finish syncing + if ss.DatabrokerServerVersion != 0 && ss.DatabrokerRecordVersion != 0 { + a.forceSyncToVersion(ctx, ss.DatabrokerServerVersion, ss.DatabrokerRecordVersion) + } + s := a.forceSyncSession(ctx, ss.ID) if s == nil { return nil, nil, errors.New("session not found") @@ -80,6 +86,29 @@ func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) (sessionO return s, u, nil } +func (a *Authorize) forceSyncToVersion(ctx context.Context, serverVersion, recordVersion uint64) (ready bool) { + ctx, span := trace.StartSpan(ctx, "authorize.forceSyncToVersion") + defer span.End() + + ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait) + defer clearTimeout() + + ticker := time.NewTicker(time.Millisecond * 50) + for { + currentServerVersion, currentRecordVersion := a.store.GetDataBrokerVersions() + // check if the local record version is up to date with the expected record version + if currentServerVersion == serverVersion && currentRecordVersion >= recordVersion { + return true + } + + select { + case <-ctx.Done(): + return false + case <-ticker.C: + } + } +} + func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) sessionOrServiceAccount { ctx, span := trace.StartSpan(ctx, "authorize.forceSyncSession") defer span.End() diff --git a/authorize/sync_test.go b/authorize/sync_test.go index f59481e51..50843e07e 100644 --- a/authorize/sync_test.go +++ b/authorize/sync_test.go @@ -19,6 +19,43 @@ import ( "github.com/pomerium/pomerium/pkg/protoutil" ) +func TestAuthorize_forceSyncToVersion(t *testing.T) { + o := &config.Options{ + AuthenticateURLString: "https://authN.example.com", + DataBrokerURLString: "https://databroker.example.com", + SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=", + Policies: testPolicies(t), + } + a, err := New(&config.Config{Options: o}) + require.NoError(t, err) + + a.store.UpdateRecord(1, &databroker.Record{ + Version: 1, + }) + t.Run("ready", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + assert.True(t, a.forceSyncToVersion(ctx, 1, 1)) + }) + t.Run("not ready", func(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + assert.False(t, a.forceSyncToVersion(ctx, 1, 2)) + }) + t.Run("becomes ready", func(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + go func() { + <-time.After(time.Millisecond * 100) + a.store.UpdateRecord(1, &databroker.Record{ + Version: 2, + }) + }() + assert.True(t, a.forceSyncToVersion(ctx, 1, 2)) + }) +} + func TestAuthorize_waitForRecordSync(t *testing.T) { ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30) defer clearTimeout() diff --git a/internal/sessions/state.go b/internal/sessions/state.go index 950761499..abec60faf 100644 --- a/internal/sessions/state.go +++ b/internal/sessions/state.go @@ -63,6 +63,13 @@ type State struct { // Programmatic whether this state is used for machine-to-machine // programmatic access. Programmatic bool `json:"programmatic"` + + // DatabrokerServerVersion tracks the last referenced databroker server version + // for the saved session. + DatabrokerServerVersion uint64 `json:"databroker_server_version,omitempty"` + // DatabrokerRecordVersion tracks the last referenced databroker record version + // for the saved session. + DatabrokerRecordVersion uint64 `json:"databroker_record_version,omitempty"` } // NewSession updates issuer, audience, and issuance timestamps but keeps diff --git a/internal/urlutil/proxy.go b/internal/urlutil/proxy.go new file mode 100644 index 000000000..c4e64f65f --- /dev/null +++ b/internal/urlutil/proxy.go @@ -0,0 +1,51 @@ +package urlutil + +import ( + "errors" + "net/http" + "net/url" +) + +// ErrMissingRedirectURI indicates the pomerium_redirect_uri was missing from the query string. +var ErrMissingRedirectURI = errors.New("missing " + QueryRedirectURI) + +// GetCallbackURL gets the proxy's callback URL from a request and a base64url encoded + encrypted session state JWT. +func GetCallbackURL(r *http.Request, encodedSessionJWT string) (*url.URL, error) { + rawRedirectURI := r.FormValue(QueryRedirectURI) + if rawRedirectURI == "" { + return nil, ErrMissingRedirectURI + } + + redirectURI, err := ParseAndValidateURL(rawRedirectURI) + if err != nil { + return nil, err + } + + var callbackURI *url.URL + if callbackStr := r.FormValue(QueryCallbackURI); callbackStr != "" { + callbackURI, err = ParseAndValidateURL(callbackStr) + if err != nil { + return nil, err + } + } else { + // otherwise, assume callback is the same host as redirect + callbackURI, err = DeepCopy(redirectURI) + if err != nil { + return nil, err + } + callbackURI.Path = "/.pomerium/callback/" + callbackURI.RawQuery = "" + } + + callbackParams := callbackURI.Query() + + if r.FormValue(QueryIsProgrammatic) == "true" { + callbackParams.Set(QueryIsProgrammatic, "true") + } + // add our encoded and encrypted route-session JWT to a query param + callbackParams.Set(QuerySessionEncrypted, encodedSessionJWT) + callbackParams.Set(QueryRedirectURI, redirectURI.String()) + callbackURI.RawQuery = callbackParams.Encode() + + return callbackURI, nil +}