diff --git a/authorize/access_tracker.go b/authorize/access_tracker.go index fb141450d..79e13bfd5 100644 --- a/authorize/access_tracker.go +++ b/authorize/access_tracker.go @@ -2,11 +2,13 @@ package authorize import ( "context" + "fmt" "sync/atomic" "time" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/log" @@ -158,13 +160,12 @@ func (tracker *AccessTracker) updateSession( ctx, clearTimeout := context.WithTimeout(ctx, accessTrackerUpdateTimeout) defer clearTimeout() - s, err := session.Get(ctx, client, sessionID) - if status.Code(err) == codes.NotFound { - return nil - } else if err != nil { - return err + s := &session.Session{Id: sessionID, AccessedAt: timestamppb.Now()} + m, err := fieldmaskpb.New(s, "accessed_at") + if err != nil { + return fmt.Errorf("internal error: %w", err) } - s.AccessedAt = timestamppb.Now() - _, err = session.Put(ctx, client, s) + + _, err = session.Patch(ctx, client, s, m) return err } diff --git a/authorize/grpc_test.go b/authorize/grpc_test.go index 925e4c52d..90b177dab 100644 --- a/authorize/grpc_test.go +++ b/authorize/grpc_test.go @@ -13,6 +13,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/structpb" "github.com/pomerium/pomerium/authorize/evaluator" @@ -21,6 +23,7 @@ import ( "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/storage" ) const certPEM = ` @@ -289,6 +292,37 @@ func (m mockDataBrokerServiceClient) Put(ctx context.Context, in *databroker.Put return m.put(ctx, in, opts...) } +// Patch emulates the patch operation using Get and Put. (This is not atomic.) +func (m mockDataBrokerServiceClient) Patch(ctx context.Context, in *databroker.PatchRequest, opts ...grpc.CallOption) (*databroker.PatchResponse, error) { + var records []*databroker.Record + for _, record := range in.GetRecords() { + getResponse, err := m.Get(ctx, &databroker.GetRequest{ + Type: record.GetType(), + Id: record.GetId(), + }, opts...) + if storage.IsNotFound(err) { + continue + } else if err != nil { + return nil, err + } + + existing := getResponse.GetRecord() + if err := storage.PatchRecord(existing, record, in.GetFieldMask()); err != nil { + return nil, status.Errorf(codes.Unknown, err.Error()) + } + + records = append(records, record) + } + putResponse, err := m.Put(ctx, &databroker.PutRequest{Records: records}, opts...) + if err != nil { + return nil, err + } + return &databroker.PatchResponse{ + ServerVersion: putResponse.GetServerVersion(), + Records: putResponse.GetRecords(), + }, nil +} + func mustParseURL(rawURL string) url.URL { u, err := url.Parse(rawURL) if err != nil { diff --git a/internal/handlers/webauthn/webauthn.go b/internal/handlers/webauthn/webauthn.go index ec9d35924..45e0f95a5 100644 --- a/internal/handlers/webauthn/webauthn.go +++ b/internal/handlers/webauthn/webauthn.go @@ -13,6 +13,7 @@ import ( "github.com/google/uuid" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/httputil" @@ -405,8 +406,13 @@ func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *Stat } func (h *Handler) saveSessionAndRedirect(w http.ResponseWriter, r *http.Request, state *State, rawRedirectURI string) error { + fm, err := fieldmaskpb.New(state.Session, "device_credentials") + if err != nil { + return fmt.Errorf("internal error: %w", err) + } + // save the session to the databroker - res, err := session.Put(r.Context(), state.Client, state.Session) + res, err := session.Patch(r.Context(), state.Client, state.Session, fm) if err != nil { return err } diff --git a/internal/identity/manager/manager.go b/internal/identity/manager/manager.go index 51056b70c..7c5837671 100644 --- a/internal/identity/manager/manager.go +++ b/internal/identity/manager/manager.go @@ -12,6 +12,7 @@ import ( "golang.org/x/sync/errgroup" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/atomicutil" @@ -288,7 +289,13 @@ func (mgr *Manager) refreshSessionInternal( return false } - if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil { + fm, err := fieldmaskpb.New(s.Session, "oauth_token", "id_token", "claims") + if err != nil { + log.Error(ctx).Err(err).Msg("internal error") + return false + } + + if _, err := session.Patch(ctx, mgr.cfg.Load().dataBrokerClient, s.Session, fm); err != nil { log.Error(ctx).Err(err). Str("user_id", s.GetUserId()). Str("session_id", s.GetId()). diff --git a/internal/identity/manager/manager_test.go b/internal/identity/manager/manager_test.go index 29e49280b..3fb374e4a 100644 --- a/internal/identity/manager/manager_test.go +++ b/internal/identity/manager/manager_test.go @@ -13,6 +13,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/events" @@ -241,12 +242,17 @@ func TestManager_refreshSession(t *testing.T) { ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)), RefreshToken: "new-refresh-token", } - client.EXPECT().Put(gomock.Any(), - objectsAreEqualMatcher{&databroker.PutRequest{Records: []*databroker.Record{{ - Type: "type.googleapis.com/session.Session", - Id: "session-id", - Data: protoutil.NewAny(expectedSession), - }}}}). + client.EXPECT().Patch(gomock.Any(), objectsAreEqualMatcher{ + &databroker.PatchRequest{ + Records: []*databroker.Record{{ + Type: "type.googleapis.com/session.Session", + Id: "session-id", + Data: protoutil.NewAny(expectedSession), + }}, + FieldMask: &fieldmaskpb.FieldMask{ + Paths: []string{"oauth_token", "id_token", "claims"}, + }, + }}). Return(nil /* this result is currently unused */, nil) mgr.refreshSession(context.Background(), "user-id", "session-id") diff --git a/pkg/grpc/session/session.go b/pkg/grpc/session/session.go index edea20f14..3f839b121 100644 --- a/pkg/grpc/session/session.go +++ b/pkg/grpc/session/session.go @@ -7,6 +7,7 @@ import ( "time" "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" @@ -63,6 +64,24 @@ func Put(ctx context.Context, client databroker.DataBrokerServiceClient, s *Sess return res, err } +// Patch updates specific fields of an existing session in the databroker. +func Patch( + ctx context.Context, client databroker.DataBrokerServiceClient, + s *Session, fields *fieldmaskpb.FieldMask, +) (*databroker.PatchResponse, error) { + s = proto.Clone(s).(*Session) + data := protoutil.NewAny(s) + res, err := client.Patch(ctx, &databroker.PatchRequest{ + Records: []*databroker.Record{{ + Type: data.GetTypeUrl(), + Id: s.Id, + Data: data, + }}, + FieldMask: fields, + }) + return res, err +} + // AddClaims adds the flattened claims to the session. func (x *Session) AddClaims(claims identity.FlattenedClaims) { if x.Claims == nil {