rework session updates to use new patch method (#4705)

Update the AccessTracker, WebAuthn handlers, and identity manager
refresh loop to perform their session record updates using the
databroker Patch() method.

This should prevent any of these updates from conflicting.
This commit is contained in:
Kenneth Jenkins 2023-11-06 09:43:07 -08:00 committed by GitHub
parent 2771a5ae87
commit ab104a643a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 88 additions and 15 deletions

View file

@ -2,11 +2,13 @@ package authorize
import ( import (
"context" "context"
"fmt"
"sync/atomic" "sync/atomic"
"time" "time"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
@ -158,13 +160,12 @@ func (tracker *AccessTracker) updateSession(
ctx, clearTimeout := context.WithTimeout(ctx, accessTrackerUpdateTimeout) ctx, clearTimeout := context.WithTimeout(ctx, accessTrackerUpdateTimeout)
defer clearTimeout() defer clearTimeout()
s, err := session.Get(ctx, client, sessionID) s := &session.Session{Id: sessionID, AccessedAt: timestamppb.Now()}
if status.Code(err) == codes.NotFound { m, err := fieldmaskpb.New(s, "accessed_at")
return nil if err != nil {
} else if err != nil { return fmt.Errorf("internal error: %w", err)
return err
} }
s.AccessedAt = timestamppb.Now()
_, err = session.Put(ctx, client, s) _, err = session.Patch(ctx, client, s, m)
return err return err
} }

View file

@ -13,6 +13,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
@ -21,6 +23,7 @@ import (
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/testutil" "github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
) )
const certPEM = ` const certPEM = `
@ -289,6 +292,37 @@ func (m mockDataBrokerServiceClient) Put(ctx context.Context, in *databroker.Put
return m.put(ctx, in, opts...) return m.put(ctx, in, opts...)
} }
// Patch emulates the patch operation using Get and Put. (This is not atomic.)
func (m mockDataBrokerServiceClient) Patch(ctx context.Context, in *databroker.PatchRequest, opts ...grpc.CallOption) (*databroker.PatchResponse, error) {
var records []*databroker.Record
for _, record := range in.GetRecords() {
getResponse, err := m.Get(ctx, &databroker.GetRequest{
Type: record.GetType(),
Id: record.GetId(),
}, opts...)
if storage.IsNotFound(err) {
continue
} else if err != nil {
return nil, err
}
existing := getResponse.GetRecord()
if err := storage.PatchRecord(existing, record, in.GetFieldMask()); err != nil {
return nil, status.Errorf(codes.Unknown, err.Error())
}
records = append(records, record)
}
putResponse, err := m.Put(ctx, &databroker.PutRequest{Records: records}, opts...)
if err != nil {
return nil, err
}
return &databroker.PatchResponse{
ServerVersion: putResponse.GetServerVersion(),
Records: putResponse.GetRecords(),
}, nil
}
func mustParseURL(rawURL string) url.URL { func mustParseURL(rawURL string) url.URL {
u, err := url.Parse(rawURL) u, err := url.Parse(rawURL)
if err != nil { if err != nil {

View file

@ -13,6 +13,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
@ -405,8 +406,13 @@ func (h *Handler) handleView(w http.ResponseWriter, r *http.Request, state *Stat
} }
func (h *Handler) saveSessionAndRedirect(w http.ResponseWriter, r *http.Request, state *State, rawRedirectURI string) error { func (h *Handler) saveSessionAndRedirect(w http.ResponseWriter, r *http.Request, state *State, rawRedirectURI string) error {
fm, err := fieldmaskpb.New(state.Session, "device_credentials")
if err != nil {
return fmt.Errorf("internal error: %w", err)
}
// save the session to the databroker // save the session to the databroker
res, err := session.Put(r.Context(), state.Client, state.Session) res, err := session.Patch(r.Context(), state.Client, state.Session, fm)
if err != nil { if err != nil {
return err return err
} }

View file

@ -12,6 +12,7 @@ import (
"golang.org/x/sync/errgroup" "golang.org/x/sync/errgroup"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/atomicutil" "github.com/pomerium/pomerium/internal/atomicutil"
@ -288,7 +289,13 @@ func (mgr *Manager) refreshSessionInternal(
return false return false
} }
if _, err := session.Put(ctx, mgr.cfg.Load().dataBrokerClient, s.Session); err != nil { fm, err := fieldmaskpb.New(s.Session, "oauth_token", "id_token", "claims")
if err != nil {
log.Error(ctx).Err(err).Msg("internal error")
return false
}
if _, err := session.Patch(ctx, mgr.cfg.Load().dataBrokerClient, s.Session, fm); err != nil {
log.Error(ctx).Err(err). log.Error(ctx).Err(err).
Str("user_id", s.GetUserId()). Str("user_id", s.GetUserId()).
Str("session_id", s.GetId()). Str("session_id", s.GetId()).

View file

@ -13,6 +13,7 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/events" "github.com/pomerium/pomerium/internal/events"
@ -241,12 +242,17 @@ func TestManager_refreshSession(t *testing.T) {
ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)), ExpiresAt: timestamppb.New(now.Add(5 * time.Minute)),
RefreshToken: "new-refresh-token", RefreshToken: "new-refresh-token",
} }
client.EXPECT().Put(gomock.Any(), client.EXPECT().Patch(gomock.Any(), objectsAreEqualMatcher{
objectsAreEqualMatcher{&databroker.PutRequest{Records: []*databroker.Record{{ &databroker.PatchRequest{
Type: "type.googleapis.com/session.Session", Records: []*databroker.Record{{
Id: "session-id", Type: "type.googleapis.com/session.Session",
Data: protoutil.NewAny(expectedSession), Id: "session-id",
}}}}). Data: protoutil.NewAny(expectedSession),
}},
FieldMask: &fieldmaskpb.FieldMask{
Paths: []string{"oauth_token", "id_token", "claims"},
},
}}).
Return(nil /* this result is currently unused */, nil) Return(nil /* this result is currently unused */, nil)
mgr.refreshSession(context.Background(), "user-id", "session-id") mgr.refreshSession(context.Background(), "user-id", "session-id")

View file

@ -7,6 +7,7 @@ import (
"time" "time"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb" "google.golang.org/protobuf/types/known/timestamppb"
@ -63,6 +64,24 @@ func Put(ctx context.Context, client databroker.DataBrokerServiceClient, s *Sess
return res, err return res, err
} }
// Patch updates specific fields of an existing session in the databroker.
func Patch(
ctx context.Context, client databroker.DataBrokerServiceClient,
s *Session, fields *fieldmaskpb.FieldMask,
) (*databroker.PatchResponse, error) {
s = proto.Clone(s).(*Session)
data := protoutil.NewAny(s)
res, err := client.Patch(ctx, &databroker.PatchRequest{
Records: []*databroker.Record{{
Type: data.GetTypeUrl(),
Id: s.Id,
Data: data,
}},
FieldMask: fields,
})
return res, err
}
// AddClaims adds the flattened claims to the session. // AddClaims adds the flattened claims to the session.
func (x *Session) AddClaims(claims identity.FlattenedClaims) { func (x *Session) AddClaims(claims identity.FlattenedClaims) {
if x.Claims == nil { if x.Claims == nil {