mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-01 11:26:29 +02:00
storage: support ip address indexing for the in-memory store (#5568)
This commit is contained in:
parent
c7ffb95483
commit
cd731789be
8 changed files with 534 additions and 167 deletions
3
go.mod
3
go.mod
|
@ -1,6 +1,6 @@
|
||||||
module github.com/pomerium/pomerium
|
module github.com/pomerium/pomerium
|
||||||
|
|
||||||
go 1.23.6
|
go 1.23.8
|
||||||
|
|
||||||
require (
|
require (
|
||||||
cloud.google.com/go/storage v1.51.0
|
cloud.google.com/go/storage v1.51.0
|
||||||
|
@ -20,6 +20,7 @@ require (
|
||||||
github.com/envoyproxy/protoc-gen-validate v1.2.1
|
github.com/envoyproxy/protoc-gen-validate v1.2.1
|
||||||
github.com/exaring/otelpgx v0.9.0
|
github.com/exaring/otelpgx v0.9.0
|
||||||
github.com/fsnotify/fsnotify v1.8.0
|
github.com/fsnotify/fsnotify v1.8.0
|
||||||
|
github.com/gaissmai/bart v0.20.3
|
||||||
github.com/go-chi/chi/v5 v5.2.1
|
github.com/go-chi/chi/v5 v5.2.1
|
||||||
github.com/go-jose/go-jose/v3 v3.0.4
|
github.com/go-jose/go-jose/v3 v3.0.4
|
||||||
github.com/go-viper/mapstructure/v2 v2.2.1
|
github.com/go-viper/mapstructure/v2 v2.2.1
|
||||||
|
|
2
go.sum
2
go.sum
|
@ -222,6 +222,8 @@ github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/
|
||||||
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||||
github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA=
|
github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA=
|
||||||
github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
|
github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
|
||||||
|
github.com/gaissmai/bart v0.20.3 h1:hZxPDasx5f2rmNsKwvNu5JMV0+SUs/uovkUNKN/807U=
|
||||||
|
github.com/gaissmai/bart v0.20.3/go.mod h1:HRCXF6EPBV4dcRPUTZtjVx384e3RYVHJ5H22ApAqltA=
|
||||||
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
|
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
|
||||||
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||||
|
|
|
@ -53,7 +53,7 @@ type Backend struct {
|
||||||
closed chan struct{}
|
closed chan struct{}
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
lookup map[string]*RecordCollection
|
lookup map[string]storage.RecordCollection
|
||||||
capacity map[string]*uint64
|
capacity map[string]*uint64
|
||||||
changes *btree.BTree
|
changes *btree.BTree
|
||||||
leases map[string]*lease
|
leases map[string]*lease
|
||||||
|
@ -67,7 +67,7 @@ func New(options ...Option) *Backend {
|
||||||
onChange: signal.New(),
|
onChange: signal.New(),
|
||||||
serverVersion: cryptutil.NewRandomUInt64(),
|
serverVersion: cryptutil.NewRandomUInt64(),
|
||||||
closed: make(chan struct{}),
|
closed: make(chan struct{}),
|
||||||
lookup: make(map[string]*RecordCollection),
|
lookup: make(map[string]storage.RecordCollection),
|
||||||
capacity: map[string]*uint64{},
|
capacity: map[string]*uint64{},
|
||||||
changes: btree.New(cfg.degree),
|
changes: btree.New(cfg.degree),
|
||||||
leases: make(map[string]*lease),
|
leases: make(map[string]*lease),
|
||||||
|
@ -124,7 +124,7 @@ func (backend *Backend) Close() error {
|
||||||
backend.mu.Lock()
|
backend.mu.Lock()
|
||||||
defer backend.mu.Unlock()
|
defer backend.mu.Unlock()
|
||||||
|
|
||||||
backend.lookup = map[string]*RecordCollection{}
|
backend.lookup = map[string]storage.RecordCollection{}
|
||||||
backend.capacity = map[string]*uint64{}
|
backend.capacity = map[string]*uint64{}
|
||||||
backend.changes = btree.New(backend.cfg.degree)
|
backend.changes = btree.New(backend.cfg.degree)
|
||||||
})
|
})
|
||||||
|
@ -148,8 +148,8 @@ func (backend *Backend) get(recordType, id string) *databroker.Record {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
record := records.Get(id)
|
record, ok := records.Get(id)
|
||||||
if record == nil {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -244,15 +244,11 @@ func (backend *Backend) update(record *databroker.Record) {
|
||||||
|
|
||||||
c, ok := backend.lookup[record.GetType()]
|
c, ok := backend.lookup[record.GetType()]
|
||||||
if !ok {
|
if !ok {
|
||||||
c = NewRecordCollection()
|
c = storage.NewRecordCollection()
|
||||||
backend.lookup[record.GetType()] = c
|
backend.lookup[record.GetType()] = c
|
||||||
}
|
}
|
||||||
|
|
||||||
if record.GetDeletedAt() != nil {
|
c.Put(record)
|
||||||
c.Delete(record.GetId())
|
|
||||||
} else {
|
|
||||||
c.Put(dup(record))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Patch updates the specified fields of existing record(s).
|
// Patch updates the specified fields of existing record(s).
|
||||||
|
@ -360,20 +356,14 @@ func (backend *Backend) enforceCapacity(recordType string) {
|
||||||
}
|
}
|
||||||
capacity := *ptr
|
capacity := *ptr
|
||||||
|
|
||||||
if collection.Len() <= int(capacity) {
|
for collection.Len() > int(capacity) {
|
||||||
return
|
r, ok := collection.Oldest()
|
||||||
}
|
if !ok {
|
||||||
|
break
|
||||||
records := collection.List()
|
}
|
||||||
for len(records) > int(capacity) {
|
r.DeletedAt = timestamppb.Now()
|
||||||
// delete the record
|
backend.recordChange(r)
|
||||||
record := dup(records[0])
|
collection.Put(r)
|
||||||
record.DeletedAt = timestamppb.Now()
|
|
||||||
backend.recordChange(record)
|
|
||||||
collection.Delete(record.GetId())
|
|
||||||
|
|
||||||
// move forward
|
|
||||||
records = records[1:]
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,71 +0,0 @@
|
||||||
package inmemory
|
|
||||||
|
|
||||||
import (
|
|
||||||
"container/list"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
)
|
|
||||||
|
|
||||||
type recordCollectionNode struct {
|
|
||||||
*databroker.Record
|
|
||||||
insertionOrderPtr *list.Element
|
|
||||||
}
|
|
||||||
|
|
||||||
// A RecordCollection is a collection of records which supports lookup by (record id) as well as enforcing capacity
|
|
||||||
// by insertion order. The collection is *not* thread safe.
|
|
||||||
type RecordCollection struct {
|
|
||||||
records map[string]recordCollectionNode
|
|
||||||
insertionOrder *list.List
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewRecordCollection creates a new RecordCollection.
|
|
||||||
func NewRecordCollection() *RecordCollection {
|
|
||||||
return &RecordCollection{
|
|
||||||
records: map[string]recordCollectionNode{},
|
|
||||||
insertionOrder: list.New(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete deletes a record from the collection.
|
|
||||||
func (c *RecordCollection) Delete(recordID string) {
|
|
||||||
node, ok := c.records[recordID]
|
|
||||||
if !ok {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
delete(c.records, recordID)
|
|
||||||
c.insertionOrder.Remove(node.insertionOrderPtr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get gets a record from the collection.
|
|
||||||
func (c *RecordCollection) Get(recordID string) *databroker.Record {
|
|
||||||
node, ok := c.records[recordID]
|
|
||||||
if !ok {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return node.Record
|
|
||||||
}
|
|
||||||
|
|
||||||
// Len returns the length of the collection.
|
|
||||||
func (c *RecordCollection) Len() int {
|
|
||||||
return len(c.records)
|
|
||||||
}
|
|
||||||
|
|
||||||
// List lists all the records in the collection in insertion order.
|
|
||||||
func (c *RecordCollection) List() []*databroker.Record {
|
|
||||||
var all []*databroker.Record
|
|
||||||
for el := c.insertionOrder.Front(); el != nil; el = el.Next() {
|
|
||||||
all = append(all, c.records[el.Value.(string)].Record)
|
|
||||||
}
|
|
||||||
return all
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put puts a record in the collection.
|
|
||||||
func (c *RecordCollection) Put(record *databroker.Record) {
|
|
||||||
c.Delete(record.GetId())
|
|
||||||
|
|
||||||
el := c.insertionOrder.PushBack(record.GetId())
|
|
||||||
c.records[record.GetId()] = recordCollectionNode{
|
|
||||||
Record: record,
|
|
||||||
insertionOrderPtr: el,
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -2,6 +2,8 @@ package inmemory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"maps"
|
||||||
|
"slices"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
@ -13,42 +15,30 @@ func newSyncLatestRecordStream(
|
||||||
recordType string,
|
recordType string,
|
||||||
expr storage.FilterExpression,
|
expr storage.FilterExpression,
|
||||||
) (storage.RecordStream, error) {
|
) (storage.RecordStream, error) {
|
||||||
filter, err := storage.RecordStreamFilterFromFilterExpression(expr)
|
backend.mu.RLock()
|
||||||
if err != nil {
|
defer backend.mu.RUnlock()
|
||||||
return nil, err
|
|
||||||
}
|
var recordTypes []string
|
||||||
if recordType != "" {
|
if recordType == "" {
|
||||||
filter = filter.And(func(record *databroker.Record) (keep bool) {
|
recordTypes = slices.Sorted(maps.Keys(backend.lookup))
|
||||||
return record.GetType() == recordType
|
} else {
|
||||||
})
|
recordTypes = []string{recordType}
|
||||||
}
|
}
|
||||||
|
|
||||||
var ready []*databroker.Record
|
var records []*databroker.Record
|
||||||
generator := func(_ context.Context, _ bool) (*databroker.Record, error) {
|
for _, recordType := range recordTypes {
|
||||||
backend.mu.RLock()
|
co, ok := backend.lookup[recordType]
|
||||||
for _, co := range backend.lookup {
|
if !ok {
|
||||||
for _, record := range co.List() {
|
continue
|
||||||
if filter(record) {
|
|
||||||
ready = append(ready, record)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
backend.mu.RUnlock()
|
rs, err := co.List(expr)
|
||||||
return nil, storage.ErrStreamDone
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
records = append(records, rs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{
|
return storage.RecordListToStream(ctx, records), nil
|
||||||
generator,
|
|
||||||
func(_ context.Context, _ bool) (*databroker.Record, error) {
|
|
||||||
if len(ready) == 0 {
|
|
||||||
return nil, storage.ErrStreamDone
|
|
||||||
}
|
|
||||||
|
|
||||||
record := ready[0]
|
|
||||||
ready = ready[1:]
|
|
||||||
return dup(record), nil
|
|
||||||
},
|
|
||||||
}, nil), nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSyncRecordStream(
|
func newSyncRecordStream(
|
||||||
|
|
|
@ -51,18 +51,23 @@ func WithQuerier(ctx context.Context, querier Querier) context.Context {
|
||||||
}
|
}
|
||||||
|
|
||||||
type staticQuerier struct {
|
type staticQuerier struct {
|
||||||
records []*databroker.Record
|
records map[string]RecordCollection
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewStaticQuerier creates a Querier that returns statically defined protobuf records.
|
// NewStaticQuerier creates a Querier that returns statically defined protobuf records.
|
||||||
func NewStaticQuerier(msgs ...proto.Message) Querier {
|
func NewStaticQuerier(msgs ...proto.Message) Querier {
|
||||||
getter := &staticQuerier{}
|
getter := &staticQuerier{records: make(map[string]RecordCollection)}
|
||||||
for _, msg := range msgs {
|
for _, msg := range msgs {
|
||||||
record, ok := msg.(*databroker.Record)
|
record, ok := msg.(*databroker.Record)
|
||||||
if !ok {
|
if !ok {
|
||||||
record = NewStaticRecord(protoutil.NewAny(msg).TypeUrl, msg)
|
record = NewStaticRecord(protoutil.NewAny(msg).TypeUrl, msg)
|
||||||
}
|
}
|
||||||
getter.records = append(getter.records, record)
|
c, ok := getter.records[record.Type]
|
||||||
|
if !ok {
|
||||||
|
c = NewRecordCollection()
|
||||||
|
getter.records[record.Type] = c
|
||||||
|
}
|
||||||
|
c.Put(record)
|
||||||
}
|
}
|
||||||
return getter
|
return getter
|
||||||
}
|
}
|
||||||
|
@ -107,42 +112,8 @@ func NewStaticRecord(typeURL string, msg proto.Message) *databroker.Record {
|
||||||
func (q *staticQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {}
|
func (q *staticQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {}
|
||||||
|
|
||||||
// Query queries for records.
|
// Query queries for records.
|
||||||
func (q *staticQuerier) Query(_ context.Context, in *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
func (q *staticQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||||
expr, err := FilterExpressionFromStruct(in.GetFilter())
|
return QueryRecordCollections(q.records, req)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
filter, err := RecordStreamFilterFromFilterExpression(expr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
res := new(databroker.QueryResponse)
|
|
||||||
for _, record := range q.records {
|
|
||||||
if record.GetType() != in.GetType() {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if !filter(record) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
if in.GetQuery() != "" && !MatchAny(record.GetData(), in.GetQuery()) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
res.Records = append(res.Records, record)
|
|
||||||
}
|
|
||||||
|
|
||||||
var total int
|
|
||||||
res.Records, total = databroker.ApplyOffsetAndLimit(
|
|
||||||
res.Records,
|
|
||||||
int(in.GetOffset()),
|
|
||||||
int(in.GetLimit()),
|
|
||||||
)
|
|
||||||
res.TotalCount = int64(total)
|
|
||||||
return res, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type clientQuerier struct {
|
type clientQuerier struct {
|
||||||
|
|
353
pkg/storage/record_collection.go
Normal file
353
pkg/storage/record_collection.go
Normal file
|
@ -0,0 +1,353 @@
|
||||||
|
package storage
|
||||||
|
|
||||||
|
import (
|
||||||
|
"container/list"
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/gaissmai/bart"
|
||||||
|
set "github.com/hashicorp/go-set/v3"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
)
|
||||||
|
|
||||||
|
// A RecordCollection stores records. It supports id and ip addr indexing and ordering of
|
||||||
|
// records in insertion order. It is not thread-safe.
|
||||||
|
type RecordCollection interface {
|
||||||
|
// All returns all of the databroker records as a slice. The slice is in insertion order.
|
||||||
|
All() []*databroker.Record
|
||||||
|
// Clear removes all the records from the collection.
|
||||||
|
Clear()
|
||||||
|
// Get returns a record based on the record id.
|
||||||
|
Get(recordID string) (*databroker.Record, bool)
|
||||||
|
// Len returns the number of records stored in the collection.
|
||||||
|
Len() int
|
||||||
|
// List returns all of the databroker records that match the given expression.
|
||||||
|
List(filter FilterExpression) ([]*databroker.Record, error)
|
||||||
|
// Newest returns the newest databroker record in the collection.
|
||||||
|
Newest() (*databroker.Record, bool)
|
||||||
|
// Oldest returns the oldest databroker record in the collection.
|
||||||
|
Oldest() (*databroker.Record, bool)
|
||||||
|
// Put puts a record into the collection. If the record's deleted at field is not nil, the record will
|
||||||
|
// be removed from the collection.
|
||||||
|
Put(record *databroker.Record)
|
||||||
|
}
|
||||||
|
|
||||||
|
type recordCollectionNode struct {
|
||||||
|
*databroker.Record
|
||||||
|
insertionOrderPtr *list.Element
|
||||||
|
}
|
||||||
|
|
||||||
|
type recordCollection struct {
|
||||||
|
cidrIndex bart.Table[[]string]
|
||||||
|
records map[string]recordCollectionNode
|
||||||
|
insertionOrder *list.List
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRecordCollection creates a new RecordCollection.
|
||||||
|
func NewRecordCollection() RecordCollection {
|
||||||
|
return &recordCollection{
|
||||||
|
records: make(map[string]recordCollectionNode),
|
||||||
|
insertionOrder: list.New(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) All() []*databroker.Record {
|
||||||
|
l := make([]*databroker.Record, 0, len(c.records))
|
||||||
|
for e := c.insertionOrder.Front(); e != nil; e = e.Next() {
|
||||||
|
r, ok := c.records[e.Value.(string)]
|
||||||
|
if ok {
|
||||||
|
l = append(l, dup(r.Record))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) Clear() {
|
||||||
|
c.cidrIndex = bart.Table[[]string]{}
|
||||||
|
clear(c.records)
|
||||||
|
c.insertionOrder = list.New()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) Get(recordID string) (*databroker.Record, bool) {
|
||||||
|
node, ok := c.records[recordID]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return dup(node.Record), true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) Len() int {
|
||||||
|
return len(c.records)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) List(filter FilterExpression) ([]*databroker.Record, error) {
|
||||||
|
if filter == nil {
|
||||||
|
return c.All(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch expr := filter.(type) {
|
||||||
|
case AndFilterExpression:
|
||||||
|
var rss [][]*databroker.Record
|
||||||
|
for _, e := range expr {
|
||||||
|
rs, err := c.List(e)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rss = append(rss, rs)
|
||||||
|
}
|
||||||
|
return intersection(rss), nil
|
||||||
|
case OrFilterExpression:
|
||||||
|
var rss [][]*databroker.Record
|
||||||
|
for _, e := range expr {
|
||||||
|
rs, err := c.List(e)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rss = append(rss, rs)
|
||||||
|
}
|
||||||
|
return union(rss), nil
|
||||||
|
case EqualsFilterExpression:
|
||||||
|
switch strings.Join(expr.Fields, ".") {
|
||||||
|
case "id":
|
||||||
|
l := make([]*databroker.Record, 0, 1)
|
||||||
|
if node, ok := c.records[expr.Value]; ok {
|
||||||
|
l = append(l, node.Record)
|
||||||
|
}
|
||||||
|
return l, nil
|
||||||
|
case "$index":
|
||||||
|
l := make([]*databroker.Record, 0, 1)
|
||||||
|
if prefix, err := netip.ParsePrefix(expr.Value); err == nil {
|
||||||
|
l = append(l, c.lookupPrefix(prefix)...)
|
||||||
|
} else if addr, err := netip.ParseAddr(expr.Value); err == nil {
|
||||||
|
l = append(l, c.lookupAddr(addr)...)
|
||||||
|
}
|
||||||
|
return l, nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown field: %s", strings.Join(expr.Fields, "."))
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unknown expression type: %T", expr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) Put(record *databroker.Record) {
|
||||||
|
record = dup(record)
|
||||||
|
|
||||||
|
// first delete the record
|
||||||
|
c.delete(record.GetId())
|
||||||
|
if record.DeletedAt != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// add it
|
||||||
|
el := c.insertionOrder.PushBack(record.GetId())
|
||||||
|
c.records[record.GetId()] = recordCollectionNode{
|
||||||
|
Record: record,
|
||||||
|
insertionOrderPtr: el,
|
||||||
|
}
|
||||||
|
if prefix := GetRecordIndexCIDR(record.GetData()); prefix != nil {
|
||||||
|
c.addIndex(*prefix, record.GetId())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) Newest() (*databroker.Record, bool) {
|
||||||
|
e := c.insertionOrder.Back()
|
||||||
|
if e == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
node, ok := c.records[e.Value.(string)]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return node.Record, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) Oldest() (*databroker.Record, bool) {
|
||||||
|
e := c.insertionOrder.Front()
|
||||||
|
if e == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
node, ok := c.records[e.Value.(string)]
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
return node.Record, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) addIndex(prefix netip.Prefix, recordID string) {
|
||||||
|
c.cidrIndex.Update(prefix, func(ids []string, _ bool) []string {
|
||||||
|
// remove the id from the slice so it's not duplicated and gets moved to the end
|
||||||
|
ids = slices.DeleteFunc(ids, func(id string) bool { return id == recordID })
|
||||||
|
return append(ids, recordID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) delete(recordID string) {
|
||||||
|
node, ok := c.records[recordID]
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// delete the record from the index if it's the current value stored there
|
||||||
|
if prefix := GetRecordIndexCIDR(node.GetData()); prefix != nil {
|
||||||
|
c.deleteIndex(*prefix, recordID)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(c.records, recordID)
|
||||||
|
c.insertionOrder.Remove(node.insertionOrderPtr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) deleteIndex(prefix netip.Prefix, recordID string) {
|
||||||
|
ids, ok := c.cidrIndex.Get(prefix)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !slices.Contains(ids, recordID) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ids = slices.DeleteFunc(ids, func(id string) bool { return id == recordID })
|
||||||
|
|
||||||
|
// last match, so delete the whole prefix
|
||||||
|
if len(ids) == 0 {
|
||||||
|
c.cidrIndex.Delete(prefix)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// update the prefix with the id removed
|
||||||
|
c.cidrIndex.Update(prefix, func(_ []string, _ bool) []string {
|
||||||
|
return ids
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) lookupPrefix(prefix netip.Prefix) []*databroker.Record {
|
||||||
|
recordIDs, ok := c.cidrIndex.LookupPrefix(prefix)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
l := make([]*databroker.Record, 0, len(recordIDs))
|
||||||
|
for _, recordID := range slices.Backward(recordIDs) {
|
||||||
|
node, ok := c.records[recordID]
|
||||||
|
if ok {
|
||||||
|
l = append(l, dup(node.Record))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *recordCollection) lookupAddr(addr netip.Addr) []*databroker.Record {
|
||||||
|
recordIDs, ok := c.cidrIndex.Lookup(addr)
|
||||||
|
if !ok {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
l := make([]*databroker.Record, 0, len(recordIDs))
|
||||||
|
for _, recordID := range slices.Backward(recordIDs) {
|
||||||
|
node, ok := c.records[recordID]
|
||||||
|
if ok {
|
||||||
|
l = append(l, dup(node.Record))
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
func dup[T proto.Message](msg T) T {
|
||||||
|
return proto.Clone(msg).(T)
|
||||||
|
}
|
||||||
|
|
||||||
|
func intersection[T comparable](xs [][]T) []T {
|
||||||
|
var final []T
|
||||||
|
lookup := map[T]int{}
|
||||||
|
for _, x := range xs {
|
||||||
|
for _, e := range x {
|
||||||
|
lookup[e]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
seen := set.New[T](0)
|
||||||
|
for _, x := range xs {
|
||||||
|
for _, e := range x {
|
||||||
|
if lookup[e] == len(xs) {
|
||||||
|
if !seen.Contains(e) {
|
||||||
|
final = append(final, e)
|
||||||
|
seen.Insert(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return final
|
||||||
|
}
|
||||||
|
|
||||||
|
func union[T comparable](xs [][]T) []T {
|
||||||
|
var final []T
|
||||||
|
seen := set.New[T](0)
|
||||||
|
for _, x := range xs {
|
||||||
|
for _, e := range x {
|
||||||
|
if !seen.Contains(e) {
|
||||||
|
final = append(final, e)
|
||||||
|
seen.Insert(e)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return final
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRecordCollections queries a map of record collections.
|
||||||
|
func QueryRecordCollections(
|
||||||
|
recordCollections map[string]RecordCollection,
|
||||||
|
req *databroker.QueryRequest,
|
||||||
|
) (*databroker.QueryResponse, error) {
|
||||||
|
filter, err := FilterExpressionFromStruct(req.GetFilter())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
var cs []RecordCollection
|
||||||
|
if req.Type == "" {
|
||||||
|
for _, recordType := range slices.Sorted(maps.Keys(recordCollections)) {
|
||||||
|
cs = append(cs, recordCollections[recordType])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
c, ok := recordCollections[req.Type]
|
||||||
|
if ok {
|
||||||
|
cs = append(cs, c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res := new(databroker.QueryResponse)
|
||||||
|
for _, c := range cs {
|
||||||
|
records, err := c.List(filter)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, record := range records {
|
||||||
|
if req.GetQuery() != "" && !MatchAny(record.GetData(), req.GetQuery()) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
res.Records = append(res.Records, record)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var total int
|
||||||
|
res.Records, total = databroker.ApplyOffsetAndLimit(
|
||||||
|
res.Records,
|
||||||
|
int(req.GetOffset()),
|
||||||
|
int(req.GetLimit()),
|
||||||
|
)
|
||||||
|
res.TotalCount = int64(total)
|
||||||
|
return res, nil
|
||||||
|
}
|
131
pkg/storage/record_collection_test.go
Normal file
131
pkg/storage/record_collection_test.go
Normal file
|
@ -0,0 +1,131 @@
|
||||||
|
package storage_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/protobuf/testing/protocmp"
|
||||||
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
"google.golang.org/protobuf/types/known/structpb"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRecordCollection(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
r1 := &databroker.Record{
|
||||||
|
Id: "r1",
|
||||||
|
Data: newStructAny(t, map[string]any{
|
||||||
|
"$index": map[string]any{
|
||||||
|
"cidr": "10.0.0.0/24",
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
r2 := &databroker.Record{
|
||||||
|
Id: "r2",
|
||||||
|
Data: newStructAny(t, map[string]any{
|
||||||
|
"$index": map[string]any{
|
||||||
|
"cidr": "192.168.0.0/24",
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
r3 := &databroker.Record{
|
||||||
|
Id: "r3",
|
||||||
|
Data: newStructAny(t, map[string]any{
|
||||||
|
"$index": map[string]any{
|
||||||
|
"cidr": "10.0.0.0/16",
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
r4 := &databroker.Record{
|
||||||
|
Id: "r4",
|
||||||
|
Data: newStructAny(t, map[string]any{
|
||||||
|
"$index": map[string]any{
|
||||||
|
"cidr": "10.0.0.0/24",
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
c := storage.NewRecordCollection()
|
||||||
|
c.Put(r4)
|
||||||
|
c.Put(r3)
|
||||||
|
c.Put(r2)
|
||||||
|
c.Put(r1)
|
||||||
|
|
||||||
|
assert.Equal(t, 4, c.Len())
|
||||||
|
|
||||||
|
r, ok := c.Get("r1")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Empty(t, cmp.Diff(r1, r, protocmp.Transform()),
|
||||||
|
"should return r1")
|
||||||
|
r, ok = c.Get("r2")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Empty(t, cmp.Diff(r2, r, protocmp.Transform()),
|
||||||
|
"should return r2")
|
||||||
|
r, ok = c.Get("r3")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Empty(t, cmp.Diff(r3, r, protocmp.Transform()),
|
||||||
|
"should return r3")
|
||||||
|
r, ok = c.Get("r4")
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Empty(t, cmp.Diff(r4, r, protocmp.Transform()),
|
||||||
|
"should return r4")
|
||||||
|
|
||||||
|
r, ok = c.Oldest()
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Empty(t, cmp.Diff(r4, r, protocmp.Transform()),
|
||||||
|
"should return the first added record")
|
||||||
|
|
||||||
|
r, ok = c.Newest()
|
||||||
|
assert.True(t, ok)
|
||||||
|
assert.Empty(t, cmp.Diff(r1, r, protocmp.Transform()),
|
||||||
|
"should return the last added record")
|
||||||
|
|
||||||
|
rs := c.All()
|
||||||
|
assert.Empty(t, cmp.Diff([]*databroker.Record{r4, r3, r2, r1}, rs, protocmp.Transform()),
|
||||||
|
"should return all records")
|
||||||
|
|
||||||
|
rs, err := c.List(nil)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, cmp.Diff([]*databroker.Record{r4, r3, r2, r1}, rs, protocmp.Transform()),
|
||||||
|
"should return all records for a nil filter")
|
||||||
|
|
||||||
|
rs, err = c.List(storage.OrFilterExpression{
|
||||||
|
storage.EqualsFilterExpression{Fields: []string{"id"}, Value: "r3"},
|
||||||
|
storage.EqualsFilterExpression{Fields: []string{"id"}, Value: "r1"},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, cmp.Diff([]*databroker.Record{r3, r1}, rs, protocmp.Transform()),
|
||||||
|
"should return two records for or")
|
||||||
|
|
||||||
|
rs, err = c.List(storage.EqualsFilterExpression{Fields: []string{"$index"}, Value: "10.0.0.3"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, cmp.Diff([]*databroker.Record{r1}, rs, protocmp.Transform()))
|
||||||
|
|
||||||
|
r1.DeletedAt = timestamppb.Now()
|
||||||
|
c.Put(r1)
|
||||||
|
|
||||||
|
rs, err = c.List(storage.EqualsFilterExpression{Fields: []string{"$index"}, Value: "10.0.0.3"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, cmp.Diff([]*databroker.Record{r4}, rs, protocmp.Transform()))
|
||||||
|
|
||||||
|
r4.DeletedAt = timestamppb.Now()
|
||||||
|
c.Put(r4)
|
||||||
|
|
||||||
|
rs, err = c.List(storage.EqualsFilterExpression{Fields: []string{"$index"}, Value: "10.0.0.3"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, cmp.Diff([]*databroker.Record{r3}, rs, protocmp.Transform()))
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStructAny(t *testing.T, m map[string]any) *anypb.Any {
|
||||||
|
t.Helper()
|
||||||
|
s, err := structpb.NewStruct(m)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return protoutil.NewAny(s)
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue