pomerium/pkg/storage/inmemory/backend.go
Joe Kralicky 332932b7a8
Replace usages of x/exp/maps + bump golang.org/x/exp (#5221)
Bump golang.org/x/exp; replace usages of x/exp/maps with stdlib equivalents
2024-08-15 17:49:24 -04:00

417 lines
10 KiB
Go

// Package inmemory contains an in-memory implementation of the databroker backend.
package inmemory
import (
"context"
"fmt"
"maps"
"slices"
"sync"
"sync/atomic"
"time"
"github.com/google/btree"
"github.com/rs/zerolog"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/health"
"github.com/pomerium/pomerium/pkg/storage"
)
type lease struct {
id string
expiry time.Time
}
type recordChange struct {
record *databroker.Record
}
func (change recordChange) Less(item btree.Item) bool {
that, ok := item.(recordChange)
if !ok {
return false
}
return change.record.GetVersion() < that.record.GetVersion()
}
// A Backend stores data in-memory.
type Backend struct {
cfg *config
onChange *signal.Signal
serverVersion uint64
lastVersion uint64
closeOnce sync.Once
closed chan struct{}
mu sync.RWMutex
lookup map[string]*RecordCollection
capacity map[string]*uint64
changes *btree.BTree
leases map[string]*lease
}
// New creates a new in-memory backend storage.
func New(options ...Option) *Backend {
cfg := getConfig(options...)
backend := &Backend{
cfg: cfg,
onChange: signal.New(),
serverVersion: cryptutil.NewRandomUInt64(),
closed: make(chan struct{}),
lookup: make(map[string]*RecordCollection),
capacity: map[string]*uint64{},
changes: btree.New(cfg.degree),
leases: make(map[string]*lease),
}
if cfg.expiry != 0 {
go func() {
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
for {
select {
case <-backend.closed:
return
case <-ticker.C:
}
backend.removeChangesBefore(time.Now().Add(-cfg.expiry))
}
}()
}
health.ReportOK(health.StorageBackend, health.StrAttr("backend", "in-memory"))
return backend
}
func (backend *Backend) removeChangesBefore(cutoff time.Time) {
backend.mu.Lock()
defer backend.mu.Unlock()
for {
item := backend.changes.Min()
if item == nil {
break
}
change, ok := item.(recordChange)
if !ok {
panic(fmt.Sprintf("invalid type in changes btree: %T", item))
}
if change.record.GetModifiedAt().AsTime().Before(cutoff) {
_ = backend.changes.DeleteMin()
continue
}
// nothing left to remove
break
}
}
// Close closes the in-memory store and erases any stored data.
func (backend *Backend) Close() error {
backend.closeOnce.Do(func() {
close(backend.closed)
backend.mu.Lock()
defer backend.mu.Unlock()
backend.lookup = map[string]*RecordCollection{}
backend.capacity = map[string]*uint64{}
backend.changes = btree.New(backend.cfg.degree)
})
return nil
}
// Get gets a record from the in-memory store.
func (backend *Backend) Get(_ context.Context, recordType, id string) (*databroker.Record, error) {
backend.mu.RLock()
defer backend.mu.RUnlock()
if record := backend.get(recordType, id); record != nil {
return record, nil
}
return nil, storage.ErrNotFound
}
// get gets a record from the in-memory store, assuming the RWMutex is held.
func (backend *Backend) get(recordType, id string) *databroker.Record {
records := backend.lookup[recordType]
if records == nil {
return nil
}
record := records.Get(id)
if record == nil {
return nil
}
return dup(record)
}
// GetOptions returns the options for a type in the in-memory store.
func (backend *Backend) GetOptions(_ context.Context, recordType string) (*databroker.Options, error) {
backend.mu.RLock()
defer backend.mu.RUnlock()
options := new(databroker.Options)
if capacity := backend.capacity[recordType]; capacity != nil {
options.Capacity = proto.Uint64(*capacity)
}
return options, nil
}
// Lease acquires or renews a lease.
func (backend *Backend) Lease(_ context.Context, leaseName, leaseID string, ttl time.Duration) (bool, error) {
backend.mu.Lock()
defer backend.mu.Unlock()
l, ok := backend.leases[leaseName]
// if there is no lease, or its expired, acquire a new one.
if !ok || l.expiry.Before(time.Now()) {
backend.leases[leaseName] = &lease{
id: leaseID,
expiry: time.Now().Add(ttl),
}
return true, nil
}
// if the lease doesn't match, we can't acquire it
if l.id != leaseID {
return false, nil
}
// release the lease
if ttl <= 0 {
delete(backend.leases, leaseName)
return false, nil
}
// update the expiry (renew the lease)
l.expiry = time.Now().Add(ttl)
return true, nil
}
// ListTypes lists the record types.
func (backend *Backend) ListTypes(_ context.Context) ([]string, error) {
backend.mu.Lock()
keys := maps.Keys(backend.lookup)
backend.mu.Unlock()
return slices.Sorted(keys), nil
}
// Put puts a record into the in-memory store.
func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (serverVersion uint64, err error) {
backend.mu.Lock()
defer backend.mu.Unlock()
defer backend.onChange.Broadcast(ctx)
recordTypes := map[string]struct{}{}
for _, record := range records {
if record == nil {
return backend.serverVersion, fmt.Errorf("records cannot be nil")
}
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("db_op", "put").
Str("db_id", record.Id).
Str("db_type", record.Type)
})
backend.update(record)
recordTypes[record.GetType()] = struct{}{}
}
for recordType := range recordTypes {
backend.enforceCapacity(recordType)
}
return backend.serverVersion, nil
}
// update stores a record into the in-memory store, assuming the RWMutex is held.
func (backend *Backend) update(record *databroker.Record) {
backend.recordChange(record)
c, ok := backend.lookup[record.GetType()]
if !ok {
c = NewRecordCollection()
backend.lookup[record.GetType()] = c
}
if record.GetDeletedAt() != nil {
c.Delete(record.GetId())
} else {
c.Put(dup(record))
}
}
// Patch updates the specified fields of existing record(s).
func (backend *Backend) Patch(
ctx context.Context, records []*databroker.Record, fields *fieldmaskpb.FieldMask,
) (serverVersion uint64, patchedRecords []*databroker.Record, err error) {
backend.mu.Lock()
defer backend.mu.Unlock()
defer backend.onChange.Broadcast(ctx)
serverVersion = backend.serverVersion
patchedRecords = make([]*databroker.Record, 0, len(records))
for _, record := range records {
err = backend.patch(record, fields)
if storage.IsNotFound(err) {
// Skip any record that does not currently exist.
continue
} else if err != nil {
return serverVersion, patchedRecords, err
}
patchedRecords = append(patchedRecords, record)
}
return serverVersion, patchedRecords, nil
}
// patch updates the specified fields of an existing record, assuming the RWMutex is held.
func (backend *Backend) patch(record *databroker.Record, fields *fieldmaskpb.FieldMask) error {
if record == nil {
return fmt.Errorf("cannot patch using a nil record")
}
existing := backend.get(record.GetType(), record.GetId())
if existing == nil {
return storage.ErrNotFound
}
if err := storage.PatchRecord(existing, record, fields); err != nil {
return err
}
backend.update(record)
return nil
}
// SetOptions sets the options for a type in the in-memory store.
func (backend *Backend) SetOptions(_ context.Context, recordType string, options *databroker.Options) error {
backend.mu.Lock()
defer backend.mu.Unlock()
if options.Capacity == nil {
delete(backend.capacity, recordType)
} else {
backend.capacity[recordType] = proto.Uint64(options.GetCapacity())
backend.enforceCapacity(recordType)
}
return nil
}
// Sync returns a record stream for any changes after recordVersion.
func (backend *Backend) Sync(ctx context.Context, recordType string, serverVersion, recordVersion uint64) (storage.RecordStream, error) {
backend.mu.RLock()
currentServerVersion := backend.serverVersion
backend.mu.RUnlock()
if serverVersion != currentServerVersion {
return nil, storage.ErrInvalidServerVersion
}
return newSyncRecordStream(ctx, backend, recordType, recordVersion), nil
}
// SyncLatest returns a record stream for all the records.
func (backend *Backend) SyncLatest(
ctx context.Context,
recordType string,
expr storage.FilterExpression,
) (serverVersion, recordVersion uint64, stream storage.RecordStream, err error) {
backend.mu.RLock()
serverVersion = backend.serverVersion
recordVersion = backend.lastVersion
backend.mu.RUnlock()
stream, err = newSyncLatestRecordStream(ctx, backend, recordType, expr)
return serverVersion, recordVersion, stream, err
}
func (backend *Backend) recordChange(record *databroker.Record) {
record.ModifiedAt = timestamppb.Now()
record.Version = backend.nextVersion()
backend.changes.ReplaceOrInsert(recordChange{record: dup(record)})
}
func (backend *Backend) enforceCapacity(recordType string) {
collection, ok := backend.lookup[recordType]
if !ok {
return
}
ptr := backend.capacity[recordType]
if ptr == nil {
return
}
capacity := *ptr
if collection.Len() <= int(capacity) {
return
}
records := collection.List()
for len(records) > int(capacity) {
// delete the record
record := dup(records[0])
record.DeletedAt = timestamppb.Now()
backend.recordChange(record)
collection.Delete(record.GetId())
// move forward
records = records[1:]
}
}
func (backend *Backend) getSince(recordType string, version uint64) []*databroker.Record {
backend.mu.RLock()
defer backend.mu.RUnlock()
var records []*databroker.Record
pivot := recordChange{record: &databroker.Record{Version: version}}
backend.changes.AscendGreaterOrEqual(pivot, func(item btree.Item) bool {
change, ok := item.(recordChange)
if !ok {
panic(fmt.Sprintf("invalid type in changes btree: %T", item))
}
record := change.record
// skip the pivoting version as we only want records after it
if record.GetVersion() != version {
records = append(records, dup(record))
}
return true
})
if recordType != "" {
var filtered []*databroker.Record
for _, record := range records {
if record.GetType() == recordType {
filtered = append(filtered, record)
}
}
records = filtered
}
return records
}
func (backend *Backend) nextVersion() uint64 {
return atomic.AddUint64(&backend.lastVersion, 1)
}
func dup(record *databroker.Record) *databroker.Record {
return proto.Clone(record).(*databroker.Record)
}