mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-01 11:26:29 +02:00
Add a new Patch() method that updates specific fields of an existing record's data, based on a field mask. Extract some logic from the existing Get() and Put() methods so it can be shared with the new Patch() method.
414 lines
10 KiB
Go
414 lines
10 KiB
Go
// Package inmemory contains an in-memory implementation of the databroker backend.
|
|
package inmemory
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sort"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/google/btree"
|
|
"github.com/rs/zerolog"
|
|
"golang.org/x/exp/maps"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
|
"google.golang.org/protobuf/types/known/timestamppb"
|
|
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/internal/signal"
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/storage"
|
|
)
|
|
|
|
type lease struct {
|
|
id string
|
|
expiry time.Time
|
|
}
|
|
|
|
type recordChange struct {
|
|
record *databroker.Record
|
|
}
|
|
|
|
func (change recordChange) Less(item btree.Item) bool {
|
|
that, ok := item.(recordChange)
|
|
if !ok {
|
|
return false
|
|
}
|
|
|
|
return change.record.GetVersion() < that.record.GetVersion()
|
|
}
|
|
|
|
// A Backend stores data in-memory.
|
|
type Backend struct {
|
|
cfg *config
|
|
onChange *signal.Signal
|
|
serverVersion uint64
|
|
|
|
lastVersion uint64
|
|
closeOnce sync.Once
|
|
closed chan struct{}
|
|
|
|
mu sync.RWMutex
|
|
lookup map[string]*RecordCollection
|
|
capacity map[string]*uint64
|
|
changes *btree.BTree
|
|
leases map[string]*lease
|
|
}
|
|
|
|
// New creates a new in-memory backend storage.
|
|
func New(options ...Option) *Backend {
|
|
cfg := getConfig(options...)
|
|
backend := &Backend{
|
|
cfg: cfg,
|
|
onChange: signal.New(),
|
|
serverVersion: cryptutil.NewRandomUInt64(),
|
|
closed: make(chan struct{}),
|
|
lookup: make(map[string]*RecordCollection),
|
|
capacity: map[string]*uint64{},
|
|
changes: btree.New(cfg.degree),
|
|
leases: make(map[string]*lease),
|
|
}
|
|
if cfg.expiry != 0 {
|
|
go func() {
|
|
ticker := time.NewTicker(time.Second)
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-backend.closed:
|
|
return
|
|
case <-ticker.C:
|
|
}
|
|
|
|
backend.removeChangesBefore(time.Now().Add(-cfg.expiry))
|
|
}
|
|
}()
|
|
}
|
|
return backend
|
|
}
|
|
|
|
func (backend *Backend) removeChangesBefore(cutoff time.Time) {
|
|
backend.mu.Lock()
|
|
defer backend.mu.Unlock()
|
|
|
|
for {
|
|
item := backend.changes.Min()
|
|
if item == nil {
|
|
break
|
|
}
|
|
change, ok := item.(recordChange)
|
|
if !ok {
|
|
panic(fmt.Sprintf("invalid type in changes btree: %T", item))
|
|
}
|
|
if change.record.GetModifiedAt().AsTime().Before(cutoff) {
|
|
_ = backend.changes.DeleteMin()
|
|
continue
|
|
}
|
|
|
|
// nothing left to remove
|
|
break
|
|
}
|
|
}
|
|
|
|
// Close closes the in-memory store and erases any stored data.
|
|
func (backend *Backend) Close() error {
|
|
backend.closeOnce.Do(func() {
|
|
close(backend.closed)
|
|
|
|
backend.mu.Lock()
|
|
defer backend.mu.Unlock()
|
|
|
|
backend.lookup = map[string]*RecordCollection{}
|
|
backend.capacity = map[string]*uint64{}
|
|
backend.changes = btree.New(backend.cfg.degree)
|
|
})
|
|
return nil
|
|
}
|
|
|
|
// Get gets a record from the in-memory store.
|
|
func (backend *Backend) Get(_ context.Context, recordType, id string) (*databroker.Record, error) {
|
|
backend.mu.RLock()
|
|
defer backend.mu.RUnlock()
|
|
if record := backend.get(recordType, id); record != nil {
|
|
return record, nil
|
|
}
|
|
return nil, storage.ErrNotFound
|
|
}
|
|
|
|
// get gets a record from the in-memory store, assuming the RWMutex is held.
|
|
func (backend *Backend) get(recordType, id string) *databroker.Record {
|
|
records := backend.lookup[recordType]
|
|
if records == nil {
|
|
return nil
|
|
}
|
|
|
|
record := records.Get(id)
|
|
if record == nil {
|
|
return nil
|
|
}
|
|
|
|
return dup(record)
|
|
}
|
|
|
|
// GetOptions returns the options for a type in the in-memory store.
|
|
func (backend *Backend) GetOptions(_ context.Context, recordType string) (*databroker.Options, error) {
|
|
backend.mu.RLock()
|
|
defer backend.mu.RUnlock()
|
|
|
|
options := new(databroker.Options)
|
|
if capacity := backend.capacity[recordType]; capacity != nil {
|
|
options.Capacity = proto.Uint64(*capacity)
|
|
}
|
|
|
|
return options, nil
|
|
}
|
|
|
|
// Lease acquires or renews a lease.
|
|
func (backend *Backend) Lease(_ context.Context, leaseName, leaseID string, ttl time.Duration) (bool, error) {
|
|
backend.mu.Lock()
|
|
defer backend.mu.Unlock()
|
|
|
|
l, ok := backend.leases[leaseName]
|
|
// if there is no lease, or its expired, acquire a new one.
|
|
if !ok || l.expiry.Before(time.Now()) {
|
|
backend.leases[leaseName] = &lease{
|
|
id: leaseID,
|
|
expiry: time.Now().Add(ttl),
|
|
}
|
|
return true, nil
|
|
}
|
|
|
|
// if the lease doesn't match, we can't acquire it
|
|
if l.id != leaseID {
|
|
return false, nil
|
|
}
|
|
|
|
// release the lease
|
|
if ttl <= 0 {
|
|
delete(backend.leases, leaseName)
|
|
return false, nil
|
|
}
|
|
|
|
// update the expiry (renew the lease)
|
|
l.expiry = time.Now().Add(ttl)
|
|
return true, nil
|
|
}
|
|
|
|
// ListTypes lists the record types.
|
|
func (backend *Backend) ListTypes(_ context.Context) ([]string, error) {
|
|
backend.mu.Lock()
|
|
keys := maps.Keys(backend.lookup)
|
|
backend.mu.Unlock()
|
|
|
|
sort.Strings(keys)
|
|
return keys, nil
|
|
}
|
|
|
|
// Put puts a record into the in-memory store.
|
|
func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (serverVersion uint64, err error) {
|
|
backend.mu.Lock()
|
|
defer backend.mu.Unlock()
|
|
defer backend.onChange.Broadcast(ctx)
|
|
|
|
recordTypes := map[string]struct{}{}
|
|
for _, record := range records {
|
|
if record == nil {
|
|
return backend.serverVersion, fmt.Errorf("records cannot be nil")
|
|
}
|
|
|
|
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
|
|
return c.Str("db_op", "put").
|
|
Str("db_id", record.Id).
|
|
Str("db_type", record.Type)
|
|
})
|
|
|
|
backend.update(record)
|
|
|
|
recordTypes[record.GetType()] = struct{}{}
|
|
}
|
|
for recordType := range recordTypes {
|
|
backend.enforceCapacity(recordType)
|
|
}
|
|
|
|
return backend.serverVersion, nil
|
|
}
|
|
|
|
// update stores a record into the in-memory store, assuming the RWMutex is held.
|
|
func (backend *Backend) update(record *databroker.Record) {
|
|
backend.recordChange(record)
|
|
|
|
c, ok := backend.lookup[record.GetType()]
|
|
if !ok {
|
|
c = NewRecordCollection()
|
|
backend.lookup[record.GetType()] = c
|
|
}
|
|
|
|
if record.GetDeletedAt() != nil {
|
|
c.Delete(record.GetId())
|
|
} else {
|
|
c.Put(dup(record))
|
|
}
|
|
}
|
|
|
|
// Patch updates the specified fields of existing record(s).
|
|
func (backend *Backend) Patch(
|
|
ctx context.Context, records []*databroker.Record, fields *fieldmaskpb.FieldMask,
|
|
) (serverVersion uint64, patchedRecords []*databroker.Record, err error) {
|
|
backend.mu.Lock()
|
|
defer backend.mu.Unlock()
|
|
defer backend.onChange.Broadcast(ctx)
|
|
|
|
serverVersion = backend.serverVersion
|
|
patchedRecords = make([]*databroker.Record, 0, len(records))
|
|
|
|
for _, record := range records {
|
|
err = backend.patch(record, fields)
|
|
if storage.IsNotFound(err) {
|
|
// Skip any record that does not currently exist.
|
|
continue
|
|
} else if err != nil {
|
|
return
|
|
}
|
|
patchedRecords = append(patchedRecords, record)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// patch updates the specified fields of an existing record, assuming the RWMutex is held.
|
|
func (backend *Backend) patch(record *databroker.Record, fields *fieldmaskpb.FieldMask) error {
|
|
if record == nil {
|
|
return fmt.Errorf("cannot patch using a nil record")
|
|
}
|
|
|
|
existing := backend.get(record.GetType(), record.GetId())
|
|
if existing == nil {
|
|
return storage.ErrNotFound
|
|
}
|
|
|
|
if err := storage.PatchRecord(existing, record, fields); err != nil {
|
|
return err
|
|
}
|
|
|
|
backend.update(record)
|
|
|
|
return nil
|
|
}
|
|
|
|
// SetOptions sets the options for a type in the in-memory store.
|
|
func (backend *Backend) SetOptions(_ context.Context, recordType string, options *databroker.Options) error {
|
|
backend.mu.Lock()
|
|
defer backend.mu.Unlock()
|
|
|
|
if options.Capacity == nil {
|
|
delete(backend.capacity, recordType)
|
|
} else {
|
|
backend.capacity[recordType] = proto.Uint64(options.GetCapacity())
|
|
backend.enforceCapacity(recordType)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Sync returns a record stream for any changes after recordVersion.
|
|
func (backend *Backend) Sync(ctx context.Context, recordType string, serverVersion, recordVersion uint64) (storage.RecordStream, error) {
|
|
backend.mu.RLock()
|
|
currentServerVersion := backend.serverVersion
|
|
backend.mu.RUnlock()
|
|
|
|
if serverVersion != currentServerVersion {
|
|
return nil, storage.ErrInvalidServerVersion
|
|
}
|
|
return newSyncRecordStream(ctx, backend, recordType, recordVersion), nil
|
|
}
|
|
|
|
// SyncLatest returns a record stream for all the records.
|
|
func (backend *Backend) SyncLatest(
|
|
ctx context.Context,
|
|
recordType string,
|
|
expr storage.FilterExpression,
|
|
) (serverVersion, recordVersion uint64, stream storage.RecordStream, err error) {
|
|
backend.mu.RLock()
|
|
serverVersion = backend.serverVersion
|
|
recordVersion = backend.lastVersion
|
|
backend.mu.RUnlock()
|
|
|
|
stream, err = newSyncLatestRecordStream(ctx, backend, recordType, expr)
|
|
return serverVersion, recordVersion, stream, err
|
|
}
|
|
|
|
func (backend *Backend) recordChange(record *databroker.Record) {
|
|
record.ModifiedAt = timestamppb.Now()
|
|
record.Version = backend.nextVersion()
|
|
backend.changes.ReplaceOrInsert(recordChange{record: dup(record)})
|
|
}
|
|
|
|
func (backend *Backend) enforceCapacity(recordType string) {
|
|
collection, ok := backend.lookup[recordType]
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
ptr := backend.capacity[recordType]
|
|
if ptr == nil {
|
|
return
|
|
}
|
|
capacity := *ptr
|
|
|
|
if collection.Len() <= int(capacity) {
|
|
return
|
|
}
|
|
|
|
records := collection.List()
|
|
for len(records) > int(capacity) {
|
|
// delete the record
|
|
record := dup(records[0])
|
|
record.DeletedAt = timestamppb.Now()
|
|
backend.recordChange(record)
|
|
collection.Delete(record.GetId())
|
|
|
|
// move forward
|
|
records = records[1:]
|
|
}
|
|
}
|
|
|
|
func (backend *Backend) getSince(recordType string, version uint64) []*databroker.Record {
|
|
backend.mu.RLock()
|
|
defer backend.mu.RUnlock()
|
|
|
|
var records []*databroker.Record
|
|
pivot := recordChange{record: &databroker.Record{Version: version}}
|
|
backend.changes.AscendGreaterOrEqual(pivot, func(item btree.Item) bool {
|
|
change, ok := item.(recordChange)
|
|
if !ok {
|
|
panic(fmt.Sprintf("invalid type in changes btree: %T", item))
|
|
}
|
|
record := change.record
|
|
// skip the pivoting version as we only want records after it
|
|
if record.GetVersion() != version {
|
|
records = append(records, dup(record))
|
|
}
|
|
return true
|
|
})
|
|
|
|
if recordType != "" {
|
|
var filtered []*databroker.Record
|
|
for _, record := range records {
|
|
if record.GetType() == recordType {
|
|
filtered = append(filtered, record)
|
|
}
|
|
}
|
|
records = filtered
|
|
}
|
|
return records
|
|
}
|
|
|
|
func (backend *Backend) nextVersion() uint64 {
|
|
return atomic.AddUint64(&backend.lastVersion, 1)
|
|
}
|
|
|
|
func dup(record *databroker.Record) *databroker.Record {
|
|
return proto.Clone(record).(*databroker.Record)
|
|
}
|