storage: support ip address indexing for the in-memory store (#5568)

This commit is contained in:
Caleb Doxsey 2025-04-10 08:21:52 -06:00 committed by GitHub
parent c7ffb95483
commit cd731789be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 534 additions and 167 deletions

3
go.mod
View file

@ -1,6 +1,6 @@
module github.com/pomerium/pomerium
go 1.23.6
go 1.23.8
require (
cloud.google.com/go/storage v1.51.0
@ -20,6 +20,7 @@ require (
github.com/envoyproxy/protoc-gen-validate v1.2.1
github.com/exaring/otelpgx v0.9.0
github.com/fsnotify/fsnotify v1.8.0
github.com/gaissmai/bart v0.20.3
github.com/go-chi/chi/v5 v5.2.1
github.com/go-jose/go-jose/v3 v3.0.4
github.com/go-viper/mapstructure/v2 v2.2.1

2
go.sum
View file

@ -222,6 +222,8 @@ github.com/fsnotify/fsnotify v1.8.0 h1:dAwr6QBTBZIkG8roQaJjGof0pp0EeF+tNV7YBP3F/
github.com/fsnotify/fsnotify v1.8.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0=
github.com/fxamacker/cbor/v2 v2.6.0 h1:sU6J2usfADwWlYDAFhZBQ6TnLFBHxgesMrQfQgk1tWA=
github.com/fxamacker/cbor/v2 v2.6.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ=
github.com/gaissmai/bart v0.20.3 h1:hZxPDasx5f2rmNsKwvNu5JMV0+SUs/uovkUNKN/807U=
github.com/gaissmai/bart v0.20.3/go.mod h1:HRCXF6EPBV4dcRPUTZtjVx384e3RYVHJ5H22ApAqltA=
github.com/go-chi/chi/v5 v5.2.1 h1:KOIHODQj58PmL80G2Eak4WdvUzjSJSm0vG72crDCqb8=
github.com/go-chi/chi/v5 v5.2.1/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops=
github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU=

View file

@ -53,7 +53,7 @@ type Backend struct {
closed chan struct{}
mu sync.RWMutex
lookup map[string]*RecordCollection
lookup map[string]storage.RecordCollection
capacity map[string]*uint64
changes *btree.BTree
leases map[string]*lease
@ -67,7 +67,7 @@ func New(options ...Option) *Backend {
onChange: signal.New(),
serverVersion: cryptutil.NewRandomUInt64(),
closed: make(chan struct{}),
lookup: make(map[string]*RecordCollection),
lookup: make(map[string]storage.RecordCollection),
capacity: map[string]*uint64{},
changes: btree.New(cfg.degree),
leases: make(map[string]*lease),
@ -124,7 +124,7 @@ func (backend *Backend) Close() error {
backend.mu.Lock()
defer backend.mu.Unlock()
backend.lookup = map[string]*RecordCollection{}
backend.lookup = map[string]storage.RecordCollection{}
backend.capacity = map[string]*uint64{}
backend.changes = btree.New(backend.cfg.degree)
})
@ -148,8 +148,8 @@ func (backend *Backend) get(recordType, id string) *databroker.Record {
return nil
}
record := records.Get(id)
if record == nil {
record, ok := records.Get(id)
if !ok {
return nil
}
@ -244,15 +244,11 @@ func (backend *Backend) update(record *databroker.Record) {
c, ok := backend.lookup[record.GetType()]
if !ok {
c = NewRecordCollection()
c = storage.NewRecordCollection()
backend.lookup[record.GetType()] = c
}
if record.GetDeletedAt() != nil {
c.Delete(record.GetId())
} else {
c.Put(dup(record))
}
c.Put(record)
}
// Patch updates the specified fields of existing record(s).
@ -360,20 +356,14 @@ func (backend *Backend) enforceCapacity(recordType string) {
}
capacity := *ptr
if collection.Len() <= int(capacity) {
return
}
records := collection.List()
for len(records) > int(capacity) {
// delete the record
record := dup(records[0])
record.DeletedAt = timestamppb.Now()
backend.recordChange(record)
collection.Delete(record.GetId())
// move forward
records = records[1:]
for collection.Len() > int(capacity) {
r, ok := collection.Oldest()
if !ok {
break
}
r.DeletedAt = timestamppb.Now()
backend.recordChange(r)
collection.Put(r)
}
}

View file

@ -1,71 +0,0 @@
package inmemory
import (
"container/list"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type recordCollectionNode struct {
*databroker.Record
insertionOrderPtr *list.Element
}
// A RecordCollection is a collection of records which supports lookup by (record id) as well as enforcing capacity
// by insertion order. The collection is *not* thread safe.
type RecordCollection struct {
records map[string]recordCollectionNode
insertionOrder *list.List
}
// NewRecordCollection creates a new RecordCollection.
func NewRecordCollection() *RecordCollection {
return &RecordCollection{
records: map[string]recordCollectionNode{},
insertionOrder: list.New(),
}
}
// Delete deletes a record from the collection.
func (c *RecordCollection) Delete(recordID string) {
node, ok := c.records[recordID]
if !ok {
return
}
delete(c.records, recordID)
c.insertionOrder.Remove(node.insertionOrderPtr)
}
// Get gets a record from the collection.
func (c *RecordCollection) Get(recordID string) *databroker.Record {
node, ok := c.records[recordID]
if !ok {
return nil
}
return node.Record
}
// Len returns the length of the collection.
func (c *RecordCollection) Len() int {
return len(c.records)
}
// List lists all the records in the collection in insertion order.
func (c *RecordCollection) List() []*databroker.Record {
var all []*databroker.Record
for el := c.insertionOrder.Front(); el != nil; el = el.Next() {
all = append(all, c.records[el.Value.(string)].Record)
}
return all
}
// Put puts a record in the collection.
func (c *RecordCollection) Put(record *databroker.Record) {
c.Delete(record.GetId())
el := c.insertionOrder.PushBack(record.GetId())
c.records[record.GetId()] = recordCollectionNode{
Record: record,
insertionOrderPtr: el,
}
}

View file

@ -2,6 +2,8 @@ package inmemory
import (
"context"
"maps"
"slices"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
@ -13,42 +15,30 @@ func newSyncLatestRecordStream(
recordType string,
expr storage.FilterExpression,
) (storage.RecordStream, error) {
filter, err := storage.RecordStreamFilterFromFilterExpression(expr)
if err != nil {
return nil, err
}
if recordType != "" {
filter = filter.And(func(record *databroker.Record) (keep bool) {
return record.GetType() == recordType
})
backend.mu.RLock()
defer backend.mu.RUnlock()
var recordTypes []string
if recordType == "" {
recordTypes = slices.Sorted(maps.Keys(backend.lookup))
} else {
recordTypes = []string{recordType}
}
var ready []*databroker.Record
generator := func(_ context.Context, _ bool) (*databroker.Record, error) {
backend.mu.RLock()
for _, co := range backend.lookup {
for _, record := range co.List() {
if filter(record) {
ready = append(ready, record)
}
}
var records []*databroker.Record
for _, recordType := range recordTypes {
co, ok := backend.lookup[recordType]
if !ok {
continue
}
backend.mu.RUnlock()
return nil, storage.ErrStreamDone
rs, err := co.List(expr)
if err != nil {
return nil, err
}
records = append(records, rs...)
}
return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{
generator,
func(_ context.Context, _ bool) (*databroker.Record, error) {
if len(ready) == 0 {
return nil, storage.ErrStreamDone
}
record := ready[0]
ready = ready[1:]
return dup(record), nil
},
}, nil), nil
return storage.RecordListToStream(ctx, records), nil
}
func newSyncRecordStream(

View file

@ -51,18 +51,23 @@ func WithQuerier(ctx context.Context, querier Querier) context.Context {
}
type staticQuerier struct {
records []*databroker.Record
records map[string]RecordCollection
}
// NewStaticQuerier creates a Querier that returns statically defined protobuf records.
func NewStaticQuerier(msgs ...proto.Message) Querier {
getter := &staticQuerier{}
getter := &staticQuerier{records: make(map[string]RecordCollection)}
for _, msg := range msgs {
record, ok := msg.(*databroker.Record)
if !ok {
record = NewStaticRecord(protoutil.NewAny(msg).TypeUrl, msg)
}
getter.records = append(getter.records, record)
c, ok := getter.records[record.Type]
if !ok {
c = NewRecordCollection()
getter.records[record.Type] = c
}
c.Put(record)
}
return getter
}
@ -107,42 +112,8 @@ func NewStaticRecord(typeURL string, msg proto.Message) *databroker.Record {
func (q *staticQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {}
// Query queries for records.
func (q *staticQuerier) Query(_ context.Context, in *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
expr, err := FilterExpressionFromStruct(in.GetFilter())
if err != nil {
return nil, err
}
filter, err := RecordStreamFilterFromFilterExpression(expr)
if err != nil {
return nil, err
}
res := new(databroker.QueryResponse)
for _, record := range q.records {
if record.GetType() != in.GetType() {
continue
}
if !filter(record) {
continue
}
if in.GetQuery() != "" && !MatchAny(record.GetData(), in.GetQuery()) {
continue
}
res.Records = append(res.Records, record)
}
var total int
res.Records, total = databroker.ApplyOffsetAndLimit(
res.Records,
int(in.GetOffset()),
int(in.GetLimit()),
)
res.TotalCount = int64(total)
return res, nil
func (q *staticQuerier) Query(_ context.Context, req *databroker.QueryRequest, _ ...grpc.CallOption) (*databroker.QueryResponse, error) {
return QueryRecordCollections(q.records, req)
}
type clientQuerier struct {

View file

@ -0,0 +1,353 @@
package storage
import (
"container/list"
"fmt"
"maps"
"net/netip"
"slices"
"strings"
"github.com/gaissmai/bart"
set "github.com/hashicorp/go-set/v3"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// A RecordCollection stores records. It supports id and ip addr indexing and ordering of
// records in insertion order. It is not thread-safe.
type RecordCollection interface {
// All returns all of the databroker records as a slice. The slice is in insertion order.
All() []*databroker.Record
// Clear removes all the records from the collection.
Clear()
// Get returns a record based on the record id.
Get(recordID string) (*databroker.Record, bool)
// Len returns the number of records stored in the collection.
Len() int
// List returns all of the databroker records that match the given expression.
List(filter FilterExpression) ([]*databroker.Record, error)
// Newest returns the newest databroker record in the collection.
Newest() (*databroker.Record, bool)
// Oldest returns the oldest databroker record in the collection.
Oldest() (*databroker.Record, bool)
// Put puts a record into the collection. If the record's deleted at field is not nil, the record will
// be removed from the collection.
Put(record *databroker.Record)
}
type recordCollectionNode struct {
*databroker.Record
insertionOrderPtr *list.Element
}
type recordCollection struct {
cidrIndex bart.Table[[]string]
records map[string]recordCollectionNode
insertionOrder *list.List
}
// NewRecordCollection creates a new RecordCollection.
func NewRecordCollection() RecordCollection {
return &recordCollection{
records: make(map[string]recordCollectionNode),
insertionOrder: list.New(),
}
}
func (c *recordCollection) All() []*databroker.Record {
l := make([]*databroker.Record, 0, len(c.records))
for e := c.insertionOrder.Front(); e != nil; e = e.Next() {
r, ok := c.records[e.Value.(string)]
if ok {
l = append(l, dup(r.Record))
}
}
return l
}
func (c *recordCollection) Clear() {
c.cidrIndex = bart.Table[[]string]{}
clear(c.records)
c.insertionOrder = list.New()
}
func (c *recordCollection) Get(recordID string) (*databroker.Record, bool) {
node, ok := c.records[recordID]
if !ok {
return nil, false
}
return dup(node.Record), true
}
func (c *recordCollection) Len() int {
return len(c.records)
}
func (c *recordCollection) List(filter FilterExpression) ([]*databroker.Record, error) {
if filter == nil {
return c.All(), nil
}
switch expr := filter.(type) {
case AndFilterExpression:
var rss [][]*databroker.Record
for _, e := range expr {
rs, err := c.List(e)
if err != nil {
return nil, err
}
rss = append(rss, rs)
}
return intersection(rss), nil
case OrFilterExpression:
var rss [][]*databroker.Record
for _, e := range expr {
rs, err := c.List(e)
if err != nil {
return nil, err
}
rss = append(rss, rs)
}
return union(rss), nil
case EqualsFilterExpression:
switch strings.Join(expr.Fields, ".") {
case "id":
l := make([]*databroker.Record, 0, 1)
if node, ok := c.records[expr.Value]; ok {
l = append(l, node.Record)
}
return l, nil
case "$index":
l := make([]*databroker.Record, 0, 1)
if prefix, err := netip.ParsePrefix(expr.Value); err == nil {
l = append(l, c.lookupPrefix(prefix)...)
} else if addr, err := netip.ParseAddr(expr.Value); err == nil {
l = append(l, c.lookupAddr(addr)...)
}
return l, nil
default:
return nil, fmt.Errorf("unknown field: %s", strings.Join(expr.Fields, "."))
}
default:
return nil, fmt.Errorf("unknown expression type: %T", expr)
}
}
func (c *recordCollection) Put(record *databroker.Record) {
record = dup(record)
// first delete the record
c.delete(record.GetId())
if record.DeletedAt != nil {
return
}
// add it
el := c.insertionOrder.PushBack(record.GetId())
c.records[record.GetId()] = recordCollectionNode{
Record: record,
insertionOrderPtr: el,
}
if prefix := GetRecordIndexCIDR(record.GetData()); prefix != nil {
c.addIndex(*prefix, record.GetId())
}
}
func (c *recordCollection) Newest() (*databroker.Record, bool) {
e := c.insertionOrder.Back()
if e == nil {
return nil, false
}
node, ok := c.records[e.Value.(string)]
if !ok {
return nil, false
}
return node.Record, true
}
func (c *recordCollection) Oldest() (*databroker.Record, bool) {
e := c.insertionOrder.Front()
if e == nil {
return nil, false
}
node, ok := c.records[e.Value.(string)]
if !ok {
return nil, false
}
return node.Record, true
}
func (c *recordCollection) addIndex(prefix netip.Prefix, recordID string) {
c.cidrIndex.Update(prefix, func(ids []string, _ bool) []string {
// remove the id from the slice so it's not duplicated and gets moved to the end
ids = slices.DeleteFunc(ids, func(id string) bool { return id == recordID })
return append(ids, recordID)
})
}
func (c *recordCollection) delete(recordID string) {
node, ok := c.records[recordID]
if !ok {
return
}
// delete the record from the index if it's the current value stored there
if prefix := GetRecordIndexCIDR(node.GetData()); prefix != nil {
c.deleteIndex(*prefix, recordID)
}
delete(c.records, recordID)
c.insertionOrder.Remove(node.insertionOrderPtr)
}
func (c *recordCollection) deleteIndex(prefix netip.Prefix, recordID string) {
ids, ok := c.cidrIndex.Get(prefix)
if !ok {
return
}
if !slices.Contains(ids, recordID) {
return
}
ids = slices.DeleteFunc(ids, func(id string) bool { return id == recordID })
// last match, so delete the whole prefix
if len(ids) == 0 {
c.cidrIndex.Delete(prefix)
return
}
// update the prefix with the id removed
c.cidrIndex.Update(prefix, func(_ []string, _ bool) []string {
return ids
})
}
func (c *recordCollection) lookupPrefix(prefix netip.Prefix) []*databroker.Record {
recordIDs, ok := c.cidrIndex.LookupPrefix(prefix)
if !ok {
return nil
}
l := make([]*databroker.Record, 0, len(recordIDs))
for _, recordID := range slices.Backward(recordIDs) {
node, ok := c.records[recordID]
if ok {
l = append(l, dup(node.Record))
break
}
}
return l
}
func (c *recordCollection) lookupAddr(addr netip.Addr) []*databroker.Record {
recordIDs, ok := c.cidrIndex.Lookup(addr)
if !ok {
return nil
}
l := make([]*databroker.Record, 0, len(recordIDs))
for _, recordID := range slices.Backward(recordIDs) {
node, ok := c.records[recordID]
if ok {
l = append(l, dup(node.Record))
break
}
}
return l
}
func dup[T proto.Message](msg T) T {
return proto.Clone(msg).(T)
}
func intersection[T comparable](xs [][]T) []T {
var final []T
lookup := map[T]int{}
for _, x := range xs {
for _, e := range x {
lookup[e]++
}
}
seen := set.New[T](0)
for _, x := range xs {
for _, e := range x {
if lookup[e] == len(xs) {
if !seen.Contains(e) {
final = append(final, e)
seen.Insert(e)
}
}
}
}
return final
}
func union[T comparable](xs [][]T) []T {
var final []T
seen := set.New[T](0)
for _, x := range xs {
for _, e := range x {
if !seen.Contains(e) {
final = append(final, e)
seen.Insert(e)
}
}
}
return final
}
// QueryRecordCollections queries a map of record collections.
func QueryRecordCollections(
recordCollections map[string]RecordCollection,
req *databroker.QueryRequest,
) (*databroker.QueryResponse, error) {
filter, err := FilterExpressionFromStruct(req.GetFilter())
if err != nil {
return nil, err
}
var cs []RecordCollection
if req.Type == "" {
for _, recordType := range slices.Sorted(maps.Keys(recordCollections)) {
cs = append(cs, recordCollections[recordType])
}
} else {
c, ok := recordCollections[req.Type]
if ok {
cs = append(cs, c)
}
}
res := new(databroker.QueryResponse)
for _, c := range cs {
records, err := c.List(filter)
if err != nil {
return nil, err
}
for _, record := range records {
if req.GetQuery() != "" && !MatchAny(record.GetData(), req.GetQuery()) {
continue
}
res.Records = append(res.Records, record)
}
}
var total int
res.Records, total = databroker.ApplyOffsetAndLimit(
res.Records,
int(req.GetOffset()),
int(req.GetLimit()),
)
res.TotalCount = int64(total)
return res, nil
}

View file

@ -0,0 +1,131 @@
package storage_test
import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/testing/protocmp"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/storage"
)
func TestRecordCollection(t *testing.T) {
t.Parallel()
r1 := &databroker.Record{
Id: "r1",
Data: newStructAny(t, map[string]any{
"$index": map[string]any{
"cidr": "10.0.0.0/24",
},
}),
}
r2 := &databroker.Record{
Id: "r2",
Data: newStructAny(t, map[string]any{
"$index": map[string]any{
"cidr": "192.168.0.0/24",
},
}),
}
r3 := &databroker.Record{
Id: "r3",
Data: newStructAny(t, map[string]any{
"$index": map[string]any{
"cidr": "10.0.0.0/16",
},
}),
}
r4 := &databroker.Record{
Id: "r4",
Data: newStructAny(t, map[string]any{
"$index": map[string]any{
"cidr": "10.0.0.0/24",
},
}),
}
c := storage.NewRecordCollection()
c.Put(r4)
c.Put(r3)
c.Put(r2)
c.Put(r1)
assert.Equal(t, 4, c.Len())
r, ok := c.Get("r1")
assert.True(t, ok)
assert.Empty(t, cmp.Diff(r1, r, protocmp.Transform()),
"should return r1")
r, ok = c.Get("r2")
assert.True(t, ok)
assert.Empty(t, cmp.Diff(r2, r, protocmp.Transform()),
"should return r2")
r, ok = c.Get("r3")
assert.True(t, ok)
assert.Empty(t, cmp.Diff(r3, r, protocmp.Transform()),
"should return r3")
r, ok = c.Get("r4")
assert.True(t, ok)
assert.Empty(t, cmp.Diff(r4, r, protocmp.Transform()),
"should return r4")
r, ok = c.Oldest()
assert.True(t, ok)
assert.Empty(t, cmp.Diff(r4, r, protocmp.Transform()),
"should return the first added record")
r, ok = c.Newest()
assert.True(t, ok)
assert.Empty(t, cmp.Diff(r1, r, protocmp.Transform()),
"should return the last added record")
rs := c.All()
assert.Empty(t, cmp.Diff([]*databroker.Record{r4, r3, r2, r1}, rs, protocmp.Transform()),
"should return all records")
rs, err := c.List(nil)
assert.NoError(t, err)
assert.Empty(t, cmp.Diff([]*databroker.Record{r4, r3, r2, r1}, rs, protocmp.Transform()),
"should return all records for a nil filter")
rs, err = c.List(storage.OrFilterExpression{
storage.EqualsFilterExpression{Fields: []string{"id"}, Value: "r3"},
storage.EqualsFilterExpression{Fields: []string{"id"}, Value: "r1"},
})
assert.NoError(t, err)
assert.Empty(t, cmp.Diff([]*databroker.Record{r3, r1}, rs, protocmp.Transform()),
"should return two records for or")
rs, err = c.List(storage.EqualsFilterExpression{Fields: []string{"$index"}, Value: "10.0.0.3"})
assert.NoError(t, err)
assert.Empty(t, cmp.Diff([]*databroker.Record{r1}, rs, protocmp.Transform()))
r1.DeletedAt = timestamppb.Now()
c.Put(r1)
rs, err = c.List(storage.EqualsFilterExpression{Fields: []string{"$index"}, Value: "10.0.0.3"})
assert.NoError(t, err)
assert.Empty(t, cmp.Diff([]*databroker.Record{r4}, rs, protocmp.Transform()))
r4.DeletedAt = timestamppb.Now()
c.Put(r4)
rs, err = c.List(storage.EqualsFilterExpression{Fields: []string{"$index"}, Value: "10.0.0.3"})
assert.NoError(t, err)
assert.Empty(t, cmp.Diff([]*databroker.Record{r3}, rs, protocmp.Transform()))
}
func newStructAny(t *testing.T, m map[string]any) *anypb.Any {
t.Helper()
s, err := structpb.NewStruct(m)
require.NoError(t, err)
return protoutil.NewAny(s)
}