mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-01 18:33:19 +02:00
databroker: add support for field masks on Put (#3210)
* databroker: add support for field masks on Put * return errors * clean up go.mod
This commit is contained in:
parent
8fc5dbf4c5
commit
2dc778035d
15 changed files with 381 additions and 134 deletions
|
@ -10,7 +10,8 @@ import (
|
|||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -20,6 +21,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
|
@ -117,23 +119,7 @@ func (backend *Backend) Get(ctx context.Context, recordType, id string) (_ *data
|
|||
_, span := trace.StartSpan(ctx, "databroker.redis.Get")
|
||||
defer span.End()
|
||||
defer func(start time.Time) { recordOperation(ctx, start, "get", err) }(time.Now())
|
||||
|
||||
key, field := getHashKey(recordType, id)
|
||||
cmd := backend.client.HGet(ctx, key, field)
|
||||
raw, err := cmd.Result()
|
||||
if err == redis.Nil {
|
||||
return nil, storage.ErrNotFound
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var record databroker.Record
|
||||
err = proto.Unmarshal([]byte(raw), &record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &record, nil
|
||||
return backend.get(ctx, backend.client, recordType, id)
|
||||
}
|
||||
|
||||
// GetAll gets all the records from redis.
|
||||
|
@ -241,7 +227,11 @@ func (backend *Backend) Lease(ctx context.Context, leaseName, leaseID string, tt
|
|||
}
|
||||
|
||||
// Put puts a record into redis.
|
||||
func (backend *Backend) Put(ctx context.Context, record *databroker.Record) (serverVersion uint64, err error) {
|
||||
func (backend *Backend) Put(
|
||||
ctx context.Context,
|
||||
record *databroker.Record,
|
||||
mask *fieldmaskpb.FieldMask,
|
||||
) (serverVersion uint64, err error) {
|
||||
ctx, span := trace.StartSpan(ctx, "databroker.redis.Put")
|
||||
defer span.End()
|
||||
defer func(start time.Time) { recordOperation(ctx, start, "put", err) }(time.Now())
|
||||
|
@ -251,7 +241,7 @@ func (backend *Backend) Put(ctx context.Context, record *databroker.Record) (ser
|
|||
return serverVersion, err
|
||||
}
|
||||
|
||||
err = backend.put(ctx, record)
|
||||
err = backend.put(ctx, record, mask)
|
||||
if err != nil {
|
||||
return serverVersion, err
|
||||
}
|
||||
|
@ -294,19 +284,68 @@ func (backend *Backend) Sync(ctx context.Context, serverVersion, recordVersion u
|
|||
return newRecordStream(ctx, backend, serverVersion, recordVersion), nil
|
||||
}
|
||||
|
||||
func (backend *Backend) put(ctx context.Context, record *databroker.Record) error {
|
||||
func (backend *Backend) get(
|
||||
ctx context.Context,
|
||||
cmdable redis.Cmdable,
|
||||
recordType, recordID string,
|
||||
) (*databroker.Record, error) {
|
||||
key, field := getHashKey(recordType, recordID)
|
||||
cmd := cmdable.HGet(ctx, key, field)
|
||||
raw, err := cmd.Result()
|
||||
if err == redis.Nil {
|
||||
return nil, storage.ErrNotFound
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var record databroker.Record
|
||||
err = proto.Unmarshal([]byte(raw), &record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
func (backend *Backend) put(
|
||||
ctx context.Context,
|
||||
record *databroker.Record,
|
||||
mask *fieldmaskpb.FieldMask,
|
||||
) error {
|
||||
var oldRecord *databroker.Record
|
||||
return backend.incrementVersion(ctx,
|
||||
func(tx *redis.Tx, version uint64) error {
|
||||
if mask != nil {
|
||||
var err error
|
||||
oldRecord, err = backend.get(ctx, tx, record.GetType(), record.GetId())
|
||||
if errors.Is(err, storage.ErrNotFound) {
|
||||
// ignore
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
record.ModifiedAt = timestamppb.Now()
|
||||
record.Version = version
|
||||
return nil
|
||||
},
|
||||
func(p redis.Pipeliner, version uint64) error {
|
||||
if oldRecord != nil {
|
||||
var err error
|
||||
record.Data, err = protoutil.MergeAnyWithFieldMask(
|
||||
oldRecord.GetData(),
|
||||
record.GetData(),
|
||||
mask,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
bs, err := proto.Marshal(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key, field := getHashKey(record.GetType(), record.GetId())
|
||||
if record.DeletedAt != nil {
|
||||
p.HDel(ctx, key, field)
|
||||
|
@ -354,7 +393,7 @@ func (backend *Backend) enforceOptions(ctx context.Context, recordType string) e
|
|||
if err == nil {
|
||||
// mark the record as deleted and re-submit
|
||||
record.DeletedAt = timestamppb.Now()
|
||||
err = backend.put(ctx, record)
|
||||
err = backend.put(ctx, record, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue