rework session updates to use new patch method (#4705)

Update the AccessTracker, WebAuthn handlers, and identity manager
refresh loop to perform their session record updates using the
databroker Patch() method.

This should prevent any of these updates from conflicting.
This commit is contained in:
Kenneth Jenkins 2023-11-06 09:43:07 -08:00 committed by GitHub
parent 2771a5ae87
commit ab104a643a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 88 additions and 15 deletions

View file

@ -13,6 +13,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/authorize/evaluator"
@ -21,6 +23,7 @@ import (
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
const certPEM = `
@ -289,6 +292,37 @@ func (m mockDataBrokerServiceClient) Put(ctx context.Context, in *databroker.Put
return m.put(ctx, in, opts...)
}
// Patch emulates the patch operation using Get and Put. (This is not atomic.)
func (m mockDataBrokerServiceClient) Patch(ctx context.Context, in *databroker.PatchRequest, opts ...grpc.CallOption) (*databroker.PatchResponse, error) {
var records []*databroker.Record
for _, record := range in.GetRecords() {
getResponse, err := m.Get(ctx, &databroker.GetRequest{
Type: record.GetType(),
Id: record.GetId(),
}, opts...)
if storage.IsNotFound(err) {
continue
} else if err != nil {
return nil, err
}
existing := getResponse.GetRecord()
if err := storage.PatchRecord(existing, record, in.GetFieldMask()); err != nil {
return nil, status.Errorf(codes.Unknown, err.Error())
}
records = append(records, record)
}
putResponse, err := m.Put(ctx, &databroker.PutRequest{Records: records}, opts...)
if err != nil {
return nil, err
}
return &databroker.PatchResponse{
ServerVersion: putResponse.GetServerVersion(),
Records: putResponse.GetRecords(),
}, nil
}
func mustParseURL(rawURL string) url.URL {
u, err := url.Parse(rawURL)
if err != nil {