databroker: add support for field masks on Put (#3210)

* databroker: add support for field masks on Put

* return errors

* clean up go.mod
This commit is contained in:
Caleb Doxsey 2022-03-29 16:36:40 -06:00 committed by GitHub
parent 8fc5dbf4c5
commit 2dc778035d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
15 changed files with 381 additions and 134 deletions

View file

@ -10,7 +10,8 @@ import (
"github.com/cenkalti/backoff/v4"
"github.com/go-redis/redis/v8"
"github.com/golang/protobuf/proto"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log"
@ -20,6 +21,7 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
)
@ -117,23 +119,7 @@ func (backend *Backend) Get(ctx context.Context, recordType, id string) (_ *data
_, span := trace.StartSpan(ctx, "databroker.redis.Get")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "get", err) }(time.Now())
key, field := getHashKey(recordType, id)
cmd := backend.client.HGet(ctx, key, field)
raw, err := cmd.Result()
if err == redis.Nil {
return nil, storage.ErrNotFound
} else if err != nil {
return nil, err
}
var record databroker.Record
err = proto.Unmarshal([]byte(raw), &record)
if err != nil {
return nil, err
}
return &record, nil
return backend.get(ctx, backend.client, recordType, id)
}
// GetAll gets all the records from redis.
@ -241,7 +227,11 @@ func (backend *Backend) Lease(ctx context.Context, leaseName, leaseID string, tt
}
// Put puts a record into redis.
func (backend *Backend) Put(ctx context.Context, record *databroker.Record) (serverVersion uint64, err error) {
func (backend *Backend) Put(
ctx context.Context,
record *databroker.Record,
mask *fieldmaskpb.FieldMask,
) (serverVersion uint64, err error) {
ctx, span := trace.StartSpan(ctx, "databroker.redis.Put")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "put", err) }(time.Now())
@ -251,7 +241,7 @@ func (backend *Backend) Put(ctx context.Context, record *databroker.Record) (ser
return serverVersion, err
}
err = backend.put(ctx, record)
err = backend.put(ctx, record, mask)
if err != nil {
return serverVersion, err
}
@ -294,19 +284,68 @@ func (backend *Backend) Sync(ctx context.Context, serverVersion, recordVersion u
return newRecordStream(ctx, backend, serverVersion, recordVersion), nil
}
func (backend *Backend) put(ctx context.Context, record *databroker.Record) error {
func (backend *Backend) get(
ctx context.Context,
cmdable redis.Cmdable,
recordType, recordID string,
) (*databroker.Record, error) {
key, field := getHashKey(recordType, recordID)
cmd := cmdable.HGet(ctx, key, field)
raw, err := cmd.Result()
if err == redis.Nil {
return nil, storage.ErrNotFound
} else if err != nil {
return nil, err
}
var record databroker.Record
err = proto.Unmarshal([]byte(raw), &record)
if err != nil {
return nil, err
}
return &record, nil
}
func (backend *Backend) put(
ctx context.Context,
record *databroker.Record,
mask *fieldmaskpb.FieldMask,
) error {
var oldRecord *databroker.Record
return backend.incrementVersion(ctx,
func(tx *redis.Tx, version uint64) error {
if mask != nil {
var err error
oldRecord, err = backend.get(ctx, tx, record.GetType(), record.GetId())
if errors.Is(err, storage.ErrNotFound) {
// ignore
} else if err != nil {
return err
}
}
record.ModifiedAt = timestamppb.Now()
record.Version = version
return nil
},
func(p redis.Pipeliner, version uint64) error {
if oldRecord != nil {
var err error
record.Data, err = protoutil.MergeAnyWithFieldMask(
oldRecord.GetData(),
record.GetData(),
mask,
)
if err != nil {
return err
}
}
bs, err := proto.Marshal(record)
if err != nil {
return err
}
key, field := getHashKey(record.GetType(), record.GetId())
if record.DeletedAt != nil {
p.HDel(ctx, key, field)
@ -354,7 +393,7 @@ func (backend *Backend) enforceOptions(ctx context.Context, recordType string) e
if err == nil {
// mark the record as deleted and re-submit
record.DeletedAt = timestamppb.Now()
err = backend.put(ctx, record)
err = backend.put(ctx, record, nil)
if err != nil {
return err
}