mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
storage: support ip address indexing for the in-memory store (#5568)
This commit is contained in:
parent
c7ffb95483
commit
cd731789be
8 changed files with 534 additions and 167 deletions
3
go.mod
3
go.mod
|
@ -1,6 +1,6 @@
|
|||
module github.com/pomerium/pomerium
|
||||
|
||||
go 1.23.6
|
||||
go 1.23.8
|
||||
|
||||
require (
|
||||
cloud.google.com/go/storage v1.51.0
|
||||
|
@ -20,6 +20,7 @@ require (
|
|||
github.com/envoyproxy/protoc-gen-validate v1.2.1
|
||||
github.com/exaring/otelpgx v0.9.0
|
||||
github.com/fsnotify/fsnotify v1.8.0
|
||||
github.com/gaissmai/bart v0.20.3
|
||||
github.com/go-chi/chi/v5 v5.2.1
|
||||
github.com/go-jose/go-jose/v3 v3.0.4
|
||||
github.com/go-viper/mapstructure/v2 v2.2.1
|
||||
|
|
2
go.sum
2
go.sum
|
@ -222,6 +222,8 @@ github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/
|
|||
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
|
||||
github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA=
|
||||
github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
|
||||
github.com/gaissmai/bart v0.20.3 h1:hZxPDasx5f2rmNsKwvNu5JMV0+SUs/uovkUNKN/807U=
|
||||
github.com/gaissmai/bart v0.20.3/go.mod h1:HRCXF6EPBV4dcRPUTZtjVx384e3RYVHJ5H22ApAqltA=
|
||||
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
|
||||
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
|
||||
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=
|
||||
|
|
|
@ -53,7 +53,7 @@ type Backend struct {
|
|||
closed chan struct{}
|
||||
|
||||
mu sync.RWMutex
|
||||
lookup map[string]*RecordCollection
|
||||
lookup map[string]storage.RecordCollection
|
||||
capacity map[string]*uint64
|
||||
changes *btree.BTree
|
||||
leases map[string]*lease
|
||||
|
@ -67,7 +67,7 @@ func New(options ...Option) *Backend {
|
|||
onChange: signal.New(),
|
||||
serverVersion: cryptutil.NewRandomUInt64(),
|
||||
closed: make(chan struct{}),
|
||||
lookup: make(map[string]*RecordCollection),
|
||||
lookup: make(map[string]storage.RecordCollection),
|
||||
capacity: map[string]*uint64{},
|
||||
changes: btree.New(cfg.degree),
|
||||
leases: make(map[string]*lease),
|
||||
|
@ -124,7 +124,7 @@ func (backend *Backend) Close() error {
|
|||
backend.mu.Lock()
|
||||
defer backend.mu.Unlock()
|
||||
|
||||
backend.lookup = map[string]*RecordCollection{}
|
||||
backend.lookup = map[string]storage.RecordCollection{}
|
||||
backend.capacity = map[string]*uint64{}
|
||||
backend.changes = btree.New(backend.cfg.degree)
|
||||
})
|
||||
|
@ -148,8 +148,8 @@ func (backend *Backend) get(recordType, id string) *databroker.Record {
|
|||
return nil
|
||||
}
|
||||
|
||||
record := records.Get(id)
|
||||
if record == nil {
|
||||
record, ok := records.Get(id)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -244,15 +244,11 @@ func (backend *Backend) update(record *databroker.Record) {
|
|||
|
||||
c, ok := backend.lookup[record.GetType()]
|
||||
if !ok {
|
||||
c = NewRecordCollection()
|
||||
c = storage.NewRecordCollection()
|
||||
backend.lookup[record.GetType()] = c
|
||||
}
|
||||
|
||||
if record.GetDeletedAt() != nil {
|
||||
c.Delete(record.GetId())
|
||||
} else {
|
||||
c.Put(dup(record))
|
||||
}
|
||||
c.Put(record)
|
||||
}
|
||||
|
||||
// Patch updates the specified fields of existing record(s).
|
||||
|
@ -360,20 +356,14 @@ func (backend *Backend) enforceCapacity(recordType string) {
|
|||
}
|
||||
capacity := *ptr
|
||||
|
||||
if collection.Len() <= int(capacity) {
|
||||
return
|
||||
}
|
||||
|
||||
records := collection.List()
|
||||
for len(records) > int(capacity) {
|
||||
// delete the record
|
||||
record := dup(records[0])
|
||||
record.DeletedAt = timestamppb.Now()
|
||||
backend.recordChange(record)
|
||||
collection.Delete(record.GetId())
|
||||
|
||||
// move forward
|
||||
records = records[1:]
|
||||
for collection.Len() > int(capacity) {
|
||||
r, ok := collection.Oldest()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
r.DeletedAt = timestamppb.Now()
|
||||
backend.recordChange(r)
|
||||
collection.Put(r)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,71 +0,0 @@
|
|||
package inmemory
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
type recordCollectionNode struct {
|
||||
*databroker.Record
|
||||
insertionOrderPtr *list.Element
|
||||
}
|
||||
|
||||
// A RecordCollection is a collection of records which supports lookup by (record id) as well as enforcing capacity
|
||||
// by insertion order. The collection is *not* thread safe.
|
||||
type RecordCollection struct {
|
||||
records map[string]recordCollectionNode
|
||||
insertionOrder *list.List
|
||||
}
|
||||
|
||||
// NewRecordCollection creates a new RecordCollection.
|
||||
func NewRecordCollection() *RecordCollection {
|
||||
return &RecordCollection{
|
||||
records: map[string]recordCollectionNode{},
|
||||
insertionOrder: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// Delete deletes a record from the collection.
|
||||
func (c *RecordCollection) Delete(recordID string) {
|
||||
node, ok := c.records[recordID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
delete(c.records, recordID)
|
||||
c.insertionOrder.Remove(node.insertionOrderPtr)
|
||||
}
|
||||
|
||||
// Get gets a record from the collection.
|
||||
func (c *RecordCollection) Get(recordID string) *databroker.Record {
|
||||
node, ok := c.records[recordID]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return node.Record
|
||||
}
|
||||
|
||||
// Len returns the length of the collection.
|
||||
func (c *RecordCollection) Len() int {
|
||||
return len(c.records)
|
||||
}
|
||||
|
||||
// List lists all the records in the collection in insertion order.
|
||||
func (c *RecordCollection) List() []*databroker.Record {
|
||||
var all []*databroker.Record
|
||||
for el := c.insertionOrder.Front(); el != nil; el = el.Next() {
|
||||
all = append(all, c.records[el.Value.(string)].Record)
|
||||
}
|
||||
return all
|
||||
}
|
||||
|
||||
// Put puts a record in the collection.
|
||||
func (c *RecordCollection) Put(record *databroker.Record) {
|
||||
c.Delete(record.GetId())
|
||||
|
||||
el := c.insertionOrder.PushBack(record.GetId())
|
||||
c.records[record.GetId()] = recordCollectionNode{
|
||||
Record: record,
|
||||
insertionOrderPtr: el,
|
||||
}
|
||||
}
|
|
@ -2,6 +2,8 @@ package inmemory
|
|||
|
||||
import (
|
||||
"context"
|
||||
"maps"
|
||||
"slices"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
|
@ -13,42 +15,30 @@ func newSyncLatestRecordStream(
|
|||
recordType string,
|
||||
expr storage.FilterExpression,
|
||||
) (storage.RecordStream, error) {
|
||||
filter, err := storage.RecordStreamFilterFromFilterExpression(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if recordType != "" {
|
||||
filter = filter.And(func(record *databroker.Record) (keep bool) {
|
||||
return record.GetType() == recordType
|
||||
})
|
||||
backend.mu.RLock()
|
||||
defer backend.mu.RUnlock()
|
||||
|
||||
var recordTypes []string
|
||||
if recordType == "" {
|
||||
recordTypes = slices.Sorted(maps.Keys(backend.lookup))
|
||||
} else {
|
||||
recordTypes = []string{recordType}
|
||||
}
|
||||
|
||||
var ready []*databroker.Record
|
||||
generator := func(_ context.Context, _ bool) (*databroker.Record, error) {
|
||||
backend.mu.RLock()
|
||||
for _, co := range backend.lookup {
|
||||
for _, record := range co.List() {
|
||||
if filter(record) {
|
||||
ready = append(ready, record)
|
||||
}
|
||||
}
|
||||
var records []*databroker.Record
|
||||
for _, recordType := range recordTypes {
|
||||
co, ok := backend.lookup[recordType]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
backend.mu.RUnlock()
|
||||
return nil, storage.ErrStreamDone
|
||||
rs, err := co.List(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
records = append(records, rs...)
|
||||
}
|
||||
|
||||
return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{
|
||||
generator,
|
||||
func(_ context.Context, _ bool) (*databroker.Record, error) {
|
||||
if len(ready) == 0 {
|
||||
return nil, storage.ErrStreamDone
|
||||
}
|
||||
|
||||
record := ready[0]
|
||||
ready = ready[1:]
|
||||
return dup(record), nil
|
||||
},
|
||||
}, nil), nil
|
||||
return storage.RecordListToStream(ctx, records), nil
|
||||
}
|
||||
|
||||
func newSyncRecordStream(
|
||||
|
|
|
@ -51,18 +51,23 @@ func WithQuerier(ctx context.Context, querier Querier) context.Context {
|
|||
}
|
||||
|
||||
type staticQuerier struct {
|
||||
records []*databroker.Record
|
||||
records map[string]RecordCollection
|
||||
}
|
||||
|
||||
// NewStaticQuerier creates a Querier that returns statically defined protobuf records.
|
||||
func NewStaticQuerier(msgs ...proto.Message) Querier {
|
||||
getter := &staticQuerier{}
|
||||
getter := &staticQuerier{records: make(map[string]RecordCollection)}
|
||||
for _, msg := range msgs {
|
||||
record, ok := msg.(*databroker.Record)
|
||||
if !ok {
|
||||
record = NewStaticRecord(protoutil.NewAny(msg).TypeUrl, msg)
|
||||
}
|
||||
getter.records = append(getter.records, record)
|
||||
c, ok := getter.records[record.Type]
|
||||
if !ok {
|
||||
c = NewRecordCollection()
|
||||
getter.records[record.Type] = c
|
||||
}
|
||||
c.Put(record)
|
||||
}
|
||||
return getter
|
||||
}
|
||||
|
@ -107,42 +112,8 @@ func NewStaticRecord(typeURL string, msg proto.Message) *databroker.Record {
|
|||
func (q *staticQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {}
|
||||
|
||||
// Query queries for records.
|
||||
func (q *staticQuerier) Query(_ context.Context, in *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||
expr, err := FilterExpressionFromStruct(in.GetFilter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filter, err := RecordStreamFilterFromFilterExpression(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
res := new(databroker.QueryResponse)
|
||||
for _, record := range q.records {
|
||||
if record.GetType() != in.GetType() {
|
||||
continue
|
||||
}
|
||||
|
||||
if !filter(record) {
|
||||
continue
|
||||
}
|
||||
|
||||
if in.GetQuery() != "" && !MatchAny(record.GetData(), in.GetQuery()) {
|
||||
continue
|
||||
}
|
||||
|
||||
res.Records = append(res.Records, record)
|
||||
}
|
||||
|
||||
var total int
|
||||
res.Records, total = databroker.ApplyOffsetAndLimit(
|
||||
res.Records,
|
||||
int(in.GetOffset()),
|
||||
int(in.GetLimit()),
|
||||
)
|
||||
res.TotalCount = int64(total)
|
||||
return res, nil
|
||||
func (q *staticQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
|
||||
return QueryRecordCollections(q.records, req)
|
||||
}
|
||||
|
||||
type clientQuerier struct {
|
||||
|
|
353
pkg/storage/record_collection.go
Normal file
353
pkg/storage/record_collection.go
Normal file
|
@ -0,0 +1,353 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"fmt"
|
||||
"maps"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/gaissmai/bart"
|
||||
set "github.com/hashicorp/go-set/v3"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
// A RecordCollection stores records. It supports id and ip addr indexing and ordering of
|
||||
// records in insertion order. It is not thread-safe.
|
||||
type RecordCollection interface {
|
||||
// All returns all of the databroker records as a slice. The slice is in insertion order.
|
||||
All() []*databroker.Record
|
||||
// Clear removes all the records from the collection.
|
||||
Clear()
|
||||
// Get returns a record based on the record id.
|
||||
Get(recordID string) (*databroker.Record, bool)
|
||||
// Len returns the number of records stored in the collection.
|
||||
Len() int
|
||||
// List returns all of the databroker records that match the given expression.
|
||||
List(filter FilterExpression) ([]*databroker.Record, error)
|
||||
// Newest returns the newest databroker record in the collection.
|
||||
Newest() (*databroker.Record, bool)
|
||||
// Oldest returns the oldest databroker record in the collection.
|
||||
Oldest() (*databroker.Record, bool)
|
||||
// Put puts a record into the collection. If the record's deleted at field is not nil, the record will
|
||||
// be removed from the collection.
|
||||
Put(record *databroker.Record)
|
||||
}
|
||||
|
||||
type recordCollectionNode struct {
|
||||
*databroker.Record
|
||||
insertionOrderPtr *list.Element
|
||||
}
|
||||
|
||||
type recordCollection struct {
|
||||
cidrIndex bart.Table[[]string]
|
||||
records map[string]recordCollectionNode
|
||||
insertionOrder *list.List
|
||||
}
|
||||
|
||||
// NewRecordCollection creates a new RecordCollection.
|
||||
func NewRecordCollection() RecordCollection {
|
||||
return &recordCollection{
|
||||
records: make(map[string]recordCollectionNode),
|
||||
insertionOrder: list.New(),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *recordCollection) All() []*databroker.Record {
|
||||
l := make([]*databroker.Record, 0, len(c.records))
|
||||
for e := c.insertionOrder.Front(); e != nil; e = e.Next() {
|
||||
r, ok := c.records[e.Value.(string)]
|
||||
if ok {
|
||||
l = append(l, dup(r.Record))
|
||||
}
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func (c *recordCollection) Clear() {
|
||||
c.cidrIndex = bart.Table[[]string]{}
|
||||
clear(c.records)
|
||||
c.insertionOrder = list.New()
|
||||
}
|
||||
|
||||
func (c *recordCollection) Get(recordID string) (*databroker.Record, bool) {
|
||||
node, ok := c.records[recordID]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return dup(node.Record), true
|
||||
}
|
||||
|
||||
func (c *recordCollection) Len() int {
|
||||
return len(c.records)
|
||||
}
|
||||
|
||||
func (c *recordCollection) List(filter FilterExpression) ([]*databroker.Record, error) {
|
||||
if filter == nil {
|
||||
return c.All(), nil
|
||||
}
|
||||
|
||||
switch expr := filter.(type) {
|
||||
case AndFilterExpression:
|
||||
var rss [][]*databroker.Record
|
||||
for _, e := range expr {
|
||||
rs, err := c.List(e)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rss = append(rss, rs)
|
||||
}
|
||||
return intersection(rss), nil
|
||||
case OrFilterExpression:
|
||||
var rss [][]*databroker.Record
|
||||
for _, e := range expr {
|
||||
rs, err := c.List(e)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
rss = append(rss, rs)
|
||||
}
|
||||
return union(rss), nil
|
||||
case EqualsFilterExpression:
|
||||
switch strings.Join(expr.Fields, ".") {
|
||||
case "id":
|
||||
l := make([]*databroker.Record, 0, 1)
|
||||
if node, ok := c.records[expr.Value]; ok {
|
||||
l = append(l, node.Record)
|
||||
}
|
||||
return l, nil
|
||||
case "$index":
|
||||
l := make([]*databroker.Record, 0, 1)
|
||||
if prefix, err := netip.ParsePrefix(expr.Value); err == nil {
|
||||
l = append(l, c.lookupPrefix(prefix)...)
|
||||
} else if addr, err := netip.ParseAddr(expr.Value); err == nil {
|
||||
l = append(l, c.lookupAddr(addr)...)
|
||||
}
|
||||
return l, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown field: %s", strings.Join(expr.Fields, "."))
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown expression type: %T", expr)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *recordCollection) Put(record *databroker.Record) {
|
||||
record = dup(record)
|
||||
|
||||
// first delete the record
|
||||
c.delete(record.GetId())
|
||||
if record.DeletedAt != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// add it
|
||||
el := c.insertionOrder.PushBack(record.GetId())
|
||||
c.records[record.GetId()] = recordCollectionNode{
|
||||
Record: record,
|
||||
insertionOrderPtr: el,
|
||||
}
|
||||
if prefix := GetRecordIndexCIDR(record.GetData()); prefix != nil {
|
||||
c.addIndex(*prefix, record.GetId())
|
||||
}
|
||||
}
|
||||
|
||||
func (c *recordCollection) Newest() (*databroker.Record, bool) {
|
||||
e := c.insertionOrder.Back()
|
||||
if e == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
node, ok := c.records[e.Value.(string)]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return node.Record, true
|
||||
}
|
||||
|
||||
func (c *recordCollection) Oldest() (*databroker.Record, bool) {
|
||||
e := c.insertionOrder.Front()
|
||||
if e == nil {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
node, ok := c.records[e.Value.(string)]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
return node.Record, true
|
||||
}
|
||||
|
||||
func (c *recordCollection) addIndex(prefix netip.Prefix, recordID string) {
|
||||
c.cidrIndex.Update(prefix, func(ids []string, _ bool) []string {
|
||||
// remove the id from the slice so it's not duplicated and gets moved to the end
|
||||
ids = slices.DeleteFunc(ids, func(id string) bool { return id == recordID })
|
||||
return append(ids, recordID)
|
||||
})
|
||||
}
|
||||
|
||||
func (c *recordCollection) delete(recordID string) {
|
||||
node, ok := c.records[recordID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
// delete the record from the index if it's the current value stored there
|
||||
if prefix := GetRecordIndexCIDR(node.GetData()); prefix != nil {
|
||||
c.deleteIndex(*prefix, recordID)
|
||||
}
|
||||
|
||||
delete(c.records, recordID)
|
||||
c.insertionOrder.Remove(node.insertionOrderPtr)
|
||||
}
|
||||
|
||||
func (c *recordCollection) deleteIndex(prefix netip.Prefix, recordID string) {
|
||||
ids, ok := c.cidrIndex.Get(prefix)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
if !slices.Contains(ids, recordID) {
|
||||
return
|
||||
}
|
||||
|
||||
ids = slices.DeleteFunc(ids, func(id string) bool { return id == recordID })
|
||||
|
||||
// last match, so delete the whole prefix
|
||||
if len(ids) == 0 {
|
||||
c.cidrIndex.Delete(prefix)
|
||||
return
|
||||
}
|
||||
|
||||
// update the prefix with the id removed
|
||||
c.cidrIndex.Update(prefix, func(_ []string, _ bool) []string {
|
||||
return ids
|
||||
})
|
||||
}
|
||||
|
||||
func (c *recordCollection) lookupPrefix(prefix netip.Prefix) []*databroker.Record {
|
||||
recordIDs, ok := c.cidrIndex.LookupPrefix(prefix)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
l := make([]*databroker.Record, 0, len(recordIDs))
|
||||
for _, recordID := range slices.Backward(recordIDs) {
|
||||
node, ok := c.records[recordID]
|
||||
if ok {
|
||||
l = append(l, dup(node.Record))
|
||||
break
|
||||
}
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func (c *recordCollection) lookupAddr(addr netip.Addr) []*databroker.Record {
|
||||
recordIDs, ok := c.cidrIndex.Lookup(addr)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
l := make([]*databroker.Record, 0, len(recordIDs))
|
||||
for _, recordID := range slices.Backward(recordIDs) {
|
||||
node, ok := c.records[recordID]
|
||||
if ok {
|
||||
l = append(l, dup(node.Record))
|
||||
break
|
||||
}
|
||||
}
|
||||
return l
|
||||
}
|
||||
|
||||
func dup[T proto.Message](msg T) T {
|
||||
return proto.Clone(msg).(T)
|
||||
}
|
||||
|
||||
func intersection[T comparable](xs [][]T) []T {
|
||||
var final []T
|
||||
lookup := map[T]int{}
|
||||
for _, x := range xs {
|
||||
for _, e := range x {
|
||||
lookup[e]++
|
||||
}
|
||||
}
|
||||
seen := set.New[T](0)
|
||||
for _, x := range xs {
|
||||
for _, e := range x {
|
||||
if lookup[e] == len(xs) {
|
||||
if !seen.Contains(e) {
|
||||
final = append(final, e)
|
||||
seen.Insert(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return final
|
||||
}
|
||||
|
||||
func union[T comparable](xs [][]T) []T {
|
||||
var final []T
|
||||
seen := set.New[T](0)
|
||||
for _, x := range xs {
|
||||
for _, e := range x {
|
||||
if !seen.Contains(e) {
|
||||
final = append(final, e)
|
||||
seen.Insert(e)
|
||||
}
|
||||
}
|
||||
}
|
||||
return final
|
||||
}
|
||||
|
||||
// QueryRecordCollections queries a map of record collections.
|
||||
func QueryRecordCollections(
|
||||
recordCollections map[string]RecordCollection,
|
||||
req *databroker.QueryRequest,
|
||||
) (*databroker.QueryResponse, error) {
|
||||
filter, err := FilterExpressionFromStruct(req.GetFilter())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var cs []RecordCollection
|
||||
if req.Type == "" {
|
||||
for _, recordType := range slices.Sorted(maps.Keys(recordCollections)) {
|
||||
cs = append(cs, recordCollections[recordType])
|
||||
}
|
||||
} else {
|
||||
c, ok := recordCollections[req.Type]
|
||||
if ok {
|
||||
cs = append(cs, c)
|
||||
}
|
||||
}
|
||||
|
||||
res := new(databroker.QueryResponse)
|
||||
for _, c := range cs {
|
||||
records, err := c.List(filter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, record := range records {
|
||||
if req.GetQuery() != "" && !MatchAny(record.GetData(), req.GetQuery()) {
|
||||
continue
|
||||
}
|
||||
|
||||
res.Records = append(res.Records, record)
|
||||
}
|
||||
}
|
||||
|
||||
var total int
|
||||
res.Records, total = databroker.ApplyOffsetAndLimit(
|
||||
res.Records,
|
||||
int(req.GetOffset()),
|
||||
int(req.GetLimit()),
|
||||
)
|
||||
res.TotalCount = int64(total)
|
||||
return res, nil
|
||||
}
|
131
pkg/storage/record_collection_test.go
Normal file
131
pkg/storage/record_collection_test.go
Normal file
|
@ -0,0 +1,131 @@
|
|||
package storage_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/testing/protocmp"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func TestRecordCollection(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
r1 := &databroker.Record{
|
||||
Id: "r1",
|
||||
Data: newStructAny(t, map[string]any{
|
||||
"$index": map[string]any{
|
||||
"cidr": "10.0.0.0/24",
|
||||
},
|
||||
}),
|
||||
}
|
||||
r2 := &databroker.Record{
|
||||
Id: "r2",
|
||||
Data: newStructAny(t, map[string]any{
|
||||
"$index": map[string]any{
|
||||
"cidr": "192.168.0.0/24",
|
||||
},
|
||||
}),
|
||||
}
|
||||
r3 := &databroker.Record{
|
||||
Id: "r3",
|
||||
Data: newStructAny(t, map[string]any{
|
||||
"$index": map[string]any{
|
||||
"cidr": "10.0.0.0/16",
|
||||
},
|
||||
}),
|
||||
}
|
||||
r4 := &databroker.Record{
|
||||
Id: "r4",
|
||||
Data: newStructAny(t, map[string]any{
|
||||
"$index": map[string]any{
|
||||
"cidr": "10.0.0.0/24",
|
||||
},
|
||||
}),
|
||||
}
|
||||
|
||||
c := storage.NewRecordCollection()
|
||||
c.Put(r4)
|
||||
c.Put(r3)
|
||||
c.Put(r2)
|
||||
c.Put(r1)
|
||||
|
||||
assert.Equal(t, 4, c.Len())
|
||||
|
||||
r, ok := c.Get("r1")
|
||||
assert.True(t, ok)
|
||||
assert.Empty(t, cmp.Diff(r1, r, protocmp.Transform()),
|
||||
"should return r1")
|
||||
r, ok = c.Get("r2")
|
||||
assert.True(t, ok)
|
||||
assert.Empty(t, cmp.Diff(r2, r, protocmp.Transform()),
|
||||
"should return r2")
|
||||
r, ok = c.Get("r3")
|
||||
assert.True(t, ok)
|
||||
assert.Empty(t, cmp.Diff(r3, r, protocmp.Transform()),
|
||||
"should return r3")
|
||||
r, ok = c.Get("r4")
|
||||
assert.True(t, ok)
|
||||
assert.Empty(t, cmp.Diff(r4, r, protocmp.Transform()),
|
||||
"should return r4")
|
||||
|
||||
r, ok = c.Oldest()
|
||||
assert.True(t, ok)
|
||||
assert.Empty(t, cmp.Diff(r4, r, protocmp.Transform()),
|
||||
"should return the first added record")
|
||||
|
||||
r, ok = c.Newest()
|
||||
assert.True(t, ok)
|
||||
assert.Empty(t, cmp.Diff(r1, r, protocmp.Transform()),
|
||||
"should return the last added record")
|
||||
|
||||
rs := c.All()
|
||||
assert.Empty(t, cmp.Diff([]*databroker.Record{r4, r3, r2, r1}, rs, protocmp.Transform()),
|
||||
"should return all records")
|
||||
|
||||
rs, err := c.List(nil)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff([]*databroker.Record{r4, r3, r2, r1}, rs, protocmp.Transform()),
|
||||
"should return all records for a nil filter")
|
||||
|
||||
rs, err = c.List(storage.OrFilterExpression{
|
||||
storage.EqualsFilterExpression{Fields: []string{"id"}, Value: "r3"},
|
||||
storage.EqualsFilterExpression{Fields: []string{"id"}, Value: "r1"},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff([]*databroker.Record{r3, r1}, rs, protocmp.Transform()),
|
||||
"should return two records for or")
|
||||
|
||||
rs, err = c.List(storage.EqualsFilterExpression{Fields: []string{"$index"}, Value: "10.0.0.3"})
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff([]*databroker.Record{r1}, rs, protocmp.Transform()))
|
||||
|
||||
r1.DeletedAt = timestamppb.Now()
|
||||
c.Put(r1)
|
||||
|
||||
rs, err = c.List(storage.EqualsFilterExpression{Fields: []string{"$index"}, Value: "10.0.0.3"})
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff([]*databroker.Record{r4}, rs, protocmp.Transform()))
|
||||
|
||||
r4.DeletedAt = timestamppb.Now()
|
||||
c.Put(r4)
|
||||
|
||||
rs, err = c.List(storage.EqualsFilterExpression{Fields: []string{"$index"}, Value: "10.0.0.3"})
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, cmp.Diff([]*databroker.Record{r3}, rs, protocmp.Transform()))
|
||||
}
|
||||
|
||||
func newStructAny(t *testing.T, m map[string]any) *anypb.Any {
|
||||
t.Helper()
|
||||
s, err := structpb.NewStruct(m)
|
||||
require.NoError(t, err)
|
||||
return protoutil.NewAny(s)
|
||||
}
|
Loading…
Add table
Reference in a new issue