mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-10 15:47:36 +02:00
authenticate: add databroker versions to session cookie (#2709)
* authenticate: add databroker versions to session cookie authorize: wait for databroker synchronization on updated sessions * fix test
This commit is contained in:
parent
b2c76c3816
commit
d390e80b30
6 changed files with 192 additions and 32 deletions
|
@ -186,19 +186,13 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
|
||||
jwtAudience := []string{state.redirectURL.Host, redirectURL.Host}
|
||||
|
||||
var callbackURL *url.URL
|
||||
// if the callback is explicitly set, set it and add an additional audience
|
||||
if callbackStr := r.FormValue(urlutil.QueryCallbackURI); callbackStr != "" {
|
||||
callbackURL, err = urlutil.ParseAndValidateURL(callbackStr)
|
||||
callbackURL, err := urlutil.ParseAndValidateURL(callbackStr)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
jwtAudience = append(jwtAudience, callbackURL.Host)
|
||||
} else {
|
||||
// otherwise, assume callback is the same host as redirect
|
||||
callbackURL, _ = urlutil.DeepCopy(redirectURL)
|
||||
callbackURL.Path = "/.pomerium/callback/"
|
||||
callbackURL.RawQuery = ""
|
||||
}
|
||||
|
||||
// add an additional claim for the forward-auth host, if set
|
||||
|
@ -219,11 +213,8 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
callbackParams := callbackURL.Query()
|
||||
|
||||
if r.FormValue(urlutil.QueryIsProgrammatic) == "true" {
|
||||
newSession.Programmatic = true
|
||||
callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
|
||||
}
|
||||
|
||||
// sign the route session, as a JWT
|
||||
|
@ -237,10 +228,10 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
|
|||
// base64 our encrypted payload for URL-friendlyness
|
||||
encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT)
|
||||
|
||||
// add our encoded and encrypted route-session JWT to a query param
|
||||
callbackParams.Set(urlutil.QuerySessionEncrypted, encodedJWT)
|
||||
callbackParams.Set(urlutil.QueryRedirectURI, redirectURL.String())
|
||||
callbackURL.RawQuery = callbackParams.Encode()
|
||||
callbackURL, err := urlutil.GetCallbackURL(r, encodedJWT)
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusBadRequest, err)
|
||||
}
|
||||
|
||||
// build our hmac-d redirect URL with our session, pointing back to the
|
||||
// proxy's callback URL which is responsible for setting our new route-session
|
||||
|
@ -665,10 +656,17 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
|
|||
return nil, err
|
||||
}
|
||||
|
||||
ss, err := a.getSessionFromCtx(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &webauthn.State{
|
||||
SharedKey: state.sharedKey,
|
||||
Client: state.dataBrokerClient,
|
||||
Session: s,
|
||||
SessionState: ss,
|
||||
SessionStore: state.sessionStore,
|
||||
RelyingParty: state.webauthnRelyingParty,
|
||||
}, nil
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ package webauthn
|
|||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -19,10 +20,13 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/encoding/jws"
|
||||
"github.com/pomerium/pomerium/internal/frontend"
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/middleware"
|
||||
"github.com/pomerium/pomerium/internal/sessions"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/device"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
|
@ -42,6 +46,8 @@ type State struct {
|
|||
SharedKey []byte
|
||||
Client databroker.DataBrokerServiceClient
|
||||
Session *session.Session
|
||||
SessionState *sessions.State
|
||||
SessionStore sessions.SessionStore
|
||||
RelyingParty *webauthn.RelyingParty
|
||||
}
|
||||
|
||||
|
@ -177,21 +183,14 @@ func (h *Handler) handleAuthenticate(w http.ResponseWriter, r *http.Request, sta
|
|||
}
|
||||
}
|
||||
|
||||
// save the session
|
||||
// update the session
|
||||
state.Session.DeviceCredentials = append(state.Session.DeviceCredentials, &session.Session_DeviceCredential{
|
||||
TypeId: deviceType.GetId(),
|
||||
Credential: &session.Session_DeviceCredential_Id{
|
||||
Id: webauthnutil.GetDeviceCredentialID(serverCredential.ID),
|
||||
},
|
||||
})
|
||||
_, err = session.Put(ctx, state.Client, state.Session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// redirect
|
||||
httputil.Redirect(w, r, redirectURIParam, http.StatusFound)
|
||||
return nil
|
||||
return h.saveSessionAndRedirect(w, r, state, redirectURIParam)
|
||||
}
|
||||
|
||||
func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state *State) error {
|
||||
|
@ -286,21 +285,14 @@ func (h *Handler) handleRegister(w http.ResponseWriter, r *http.Request, state *
|
|||
return err
|
||||
}
|
||||
|
||||
// save the session
|
||||
// update the session
|
||||
state.Session.DeviceCredentials = append(state.Session.DeviceCredentials, &session.Session_DeviceCredential{
|
||||
TypeId: deviceType.GetId(),
|
||||
Credential: &session.Session_DeviceCredential_Id{
|
||||
Id: webauthnutil.GetDeviceCredentialID(serverCredential.ID),
|
||||
},
|
||||
})
|
||||
_, err = session.Put(ctx, state.Client, state.Session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// redirect
|
||||
httputil.Redirect(w, r, redirectURIParam, http.StatusFound)
|
||||
return nil
|
||||
return h.saveSessionAndRedirect(w, r, state, redirectURIParam)
|
||||
}
|
||||
|
||||
func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *State) error {
|
||||
|
@ -351,6 +343,52 @@ func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *Stat
|
|||
return err
|
||||
}
|
||||
|
||||
func (h *Handler) saveSessionAndRedirect(w http.ResponseWriter, r *http.Request, state *State, rawRedirectURI string) error {
|
||||
// save the session to the databroker
|
||||
res, err := session.Put(r.Context(), state.Client, state.Session)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add databroker versions to the session cookie and save
|
||||
state.SessionState.Version = sessions.Version(fmt.Sprint(res.GetServerVersion()))
|
||||
state.SessionState.DatabrokerServerVersion = res.GetServerVersion()
|
||||
state.SessionState.DatabrokerRecordVersion = res.GetRecord().GetVersion()
|
||||
err = state.SessionStore.SaveSession(w, r, state.SessionState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// sign+encrypt the session JWT
|
||||
encoder, err := jws.NewHS256Signer(state.SharedKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
signedJWT, err := encoder.Marshal(state.SessionState)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cipher, err := cryptutil.NewAEADCipher(state.SharedKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
encryptedJWT := cryptutil.Encrypt(cipher, signedJWT, nil)
|
||||
encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT)
|
||||
|
||||
// redirect to the proxy callback URL with the session
|
||||
callbackURL, err := urlutil.GetCallbackURL(r, encodedJWT)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
signedCallbackURL := urlutil.NewSignedURL(state.SharedKey, callbackURL)
|
||||
httputil.Redirect(w, r, signedCallbackURL.String(), http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
func getKnownDeviceCredentials(
|
||||
ctx context.Context,
|
||||
client databroker.DataBrokerServiceClient,
|
||||
|
|
|
@ -72,6 +72,12 @@ func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) (sessionO
|
|||
if ss == nil {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
// if the session state has databroker versions, wait for those to finish syncing
|
||||
if ss.DatabrokerServerVersion != 0 && ss.DatabrokerRecordVersion != 0 {
|
||||
a.forceSyncToVersion(ctx, ss.DatabrokerServerVersion, ss.DatabrokerRecordVersion)
|
||||
}
|
||||
|
||||
s := a.forceSyncSession(ctx, ss.ID)
|
||||
if s == nil {
|
||||
return nil, nil, errors.New("session not found")
|
||||
|
@ -80,6 +86,29 @@ func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) (sessionO
|
|||
return s, u, nil
|
||||
}
|
||||
|
||||
func (a *Authorize) forceSyncToVersion(ctx context.Context, serverVersion, recordVersion uint64) (ready bool) {
|
||||
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncToVersion")
|
||||
defer span.End()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(ctx, forceSyncRecordMaxWait)
|
||||
defer clearTimeout()
|
||||
|
||||
ticker := time.NewTicker(time.Millisecond * 50)
|
||||
for {
|
||||
currentServerVersion, currentRecordVersion := a.store.GetDataBrokerVersions()
|
||||
// check if the local record version is up to date with the expected record version
|
||||
if currentServerVersion == serverVersion && currentRecordVersion >= recordVersion {
|
||||
return true
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
case <-ticker.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *Authorize) forceSyncSession(ctx context.Context, sessionID string) sessionOrServiceAccount {
|
||||
ctx, span := trace.StartSpan(ctx, "authorize.forceSyncSession")
|
||||
defer span.End()
|
||||
|
|
|
@ -19,6 +19,43 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
func TestAuthorize_forceSyncToVersion(t *testing.T) {
|
||||
o := &config.Options{
|
||||
AuthenticateURLString: "https://authN.example.com",
|
||||
DataBrokerURLString: "https://databroker.example.com",
|
||||
SharedKey: "gXK6ggrlIW2HyKyUF9rUO4azrDgxhDPWqw9y+lJU7B8=",
|
||||
Policies: testPolicies(t),
|
||||
}
|
||||
a, err := New(&config.Config{Options: o})
|
||||
require.NoError(t, err)
|
||||
|
||||
a.store.UpdateRecord(1, &databroker.Record{
|
||||
Version: 1,
|
||||
})
|
||||
t.Run("ready", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
assert.True(t, a.forceSyncToVersion(ctx, 1, 1))
|
||||
})
|
||||
t.Run("not ready", func(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
assert.False(t, a.forceSyncToVersion(ctx, 1, 2))
|
||||
})
|
||||
t.Run("becomes ready", func(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
go func() {
|
||||
<-time.After(time.Millisecond * 100)
|
||||
a.store.UpdateRecord(1, &databroker.Record{
|
||||
Version: 2,
|
||||
})
|
||||
}()
|
||||
assert.True(t, a.forceSyncToVersion(ctx, 1, 2))
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthorize_waitForRecordSync(t *testing.T) {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30)
|
||||
defer clearTimeout()
|
||||
|
|
|
@ -63,6 +63,13 @@ type State struct {
|
|||
// Programmatic whether this state is used for machine-to-machine
|
||||
// programmatic access.
|
||||
Programmatic bool `json:"programmatic"`
|
||||
|
||||
// DatabrokerServerVersion tracks the last referenced databroker server version
|
||||
// for the saved session.
|
||||
DatabrokerServerVersion uint64 `json:"databroker_server_version,omitempty"`
|
||||
// DatabrokerRecordVersion tracks the last referenced databroker record version
|
||||
// for the saved session.
|
||||
DatabrokerRecordVersion uint64 `json:"databroker_record_version,omitempty"`
|
||||
}
|
||||
|
||||
// NewSession updates issuer, audience, and issuance timestamps but keeps
|
||||
|
|
51
internal/urlutil/proxy.go
Normal file
51
internal/urlutil/proxy.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package urlutil
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
// ErrMissingRedirectURI indicates the pomerium_redirect_uri was missing from the query string.
|
||||
var ErrMissingRedirectURI = errors.New("missing " + QueryRedirectURI)
|
||||
|
||||
// GetCallbackURL gets the proxy's callback URL from a request and a base64url encoded + encrypted session state JWT.
|
||||
func GetCallbackURL(r *http.Request, encodedSessionJWT string) (*url.URL, error) {
|
||||
rawRedirectURI := r.FormValue(QueryRedirectURI)
|
||||
if rawRedirectURI == "" {
|
||||
return nil, ErrMissingRedirectURI
|
||||
}
|
||||
|
||||
redirectURI, err := ParseAndValidateURL(rawRedirectURI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var callbackURI *url.URL
|
||||
if callbackStr := r.FormValue(QueryCallbackURI); callbackStr != "" {
|
||||
callbackURI, err = ParseAndValidateURL(callbackStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// otherwise, assume callback is the same host as redirect
|
||||
callbackURI, err = DeepCopy(redirectURI)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
callbackURI.Path = "/.pomerium/callback/"
|
||||
callbackURI.RawQuery = ""
|
||||
}
|
||||
|
||||
callbackParams := callbackURI.Query()
|
||||
|
||||
if r.FormValue(QueryIsProgrammatic) == "true" {
|
||||
callbackParams.Set(QueryIsProgrammatic, "true")
|
||||
}
|
||||
// add our encoded and encrypted route-session JWT to a query param
|
||||
callbackParams.Set(QuerySessionEncrypted, encodedSessionJWT)
|
||||
callbackParams.Set(QueryRedirectURI, redirectURI.String())
|
||||
callbackURI.RawQuery = callbackParams.Encode()
|
||||
|
||||
return callbackURI, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue