mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-02 00:10:45 +02:00
storage: add filter expressions, upgrade go to 1.18.1 (#3365)
* storage: add filter expressions * upgrade go
This commit is contained in:
parent
51e716ef54
commit
70f5d8b173
12 changed files with 517 additions and 63 deletions
143
pkg/storage/filter.go
Normal file
143
pkg/storage/filter.go
Normal file
|
@ -0,0 +1,143 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
// A FilterExpression describes an AST for record stream filters.
|
||||
type FilterExpression interface {
|
||||
isFilterExpression()
|
||||
}
|
||||
|
||||
// FilterExpressionFromStruct creates a FilterExpression from a protobuf struct.
|
||||
func FilterExpressionFromStruct(s *structpb.Struct) (FilterExpression, error) {
|
||||
if s == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return filterExpressionFromStruct(nil, s)
|
||||
}
|
||||
|
||||
func filterExpressionFromValue(path []string, v *structpb.Value) (FilterExpression, error) {
|
||||
switch vv := v.GetKind().(type) {
|
||||
case *structpb.Value_ListValue:
|
||||
var or OrFilterExpression
|
||||
for _, vvv := range vv.ListValue.Values {
|
||||
e, err := filterExpressionFromValue(path, vvv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
or = append(or, e)
|
||||
}
|
||||
return or, nil
|
||||
case *structpb.Value_StructValue:
|
||||
return filterExpressionFromStruct(path, vv.StructValue)
|
||||
}
|
||||
return filterExpressionFromEq(path, v)
|
||||
}
|
||||
|
||||
func filterExpressionFromStruct(path []string, s *structpb.Struct) (FilterExpression, error) {
|
||||
var and AndFilterExpression
|
||||
var fs []string
|
||||
for f := range s.GetFields() {
|
||||
fs = append(fs, f)
|
||||
}
|
||||
sort.Strings(fs)
|
||||
|
||||
for _, f := range fs {
|
||||
v := s.GetFields()[f]
|
||||
switch f {
|
||||
case "$and":
|
||||
expr, err := filterExpressionFromValue(path, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
or, ok := expr.(OrFilterExpression)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("$and must be an array")
|
||||
}
|
||||
if len(or) == 1 {
|
||||
and = append(and, or[0])
|
||||
} else {
|
||||
and = append(and, AndFilterExpression(or))
|
||||
}
|
||||
case "$or":
|
||||
expr, err := filterExpressionFromValue(path, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
or, ok := expr.(OrFilterExpression)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("$or must be an array")
|
||||
}
|
||||
if len(or) == 1 {
|
||||
and = append(and, or[0])
|
||||
} else {
|
||||
and = append(and, or)
|
||||
}
|
||||
case "$eq":
|
||||
expr, err := filterExpressionFromEq(path, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
and = append(and, expr)
|
||||
default:
|
||||
expr, err := filterExpressionFromValue(append(path, f), v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
and = append(and, expr)
|
||||
}
|
||||
}
|
||||
|
||||
if len(and) == 1 {
|
||||
return and[0], nil
|
||||
}
|
||||
return and, nil
|
||||
}
|
||||
|
||||
func filterExpressionFromEq(path []string, v *structpb.Value) (FilterExpression, error) {
|
||||
switch vv := v.GetKind().(type) {
|
||||
case *structpb.Value_BoolValue:
|
||||
return EqualsFilterExpression{
|
||||
Fields: path,
|
||||
Value: fmt.Sprintf("%v", vv.BoolValue),
|
||||
}, nil
|
||||
case *structpb.Value_NullValue:
|
||||
return EqualsFilterExpression{
|
||||
Fields: path,
|
||||
Value: fmt.Sprintf("%v", vv.NullValue),
|
||||
}, nil
|
||||
case *structpb.Value_NumberValue:
|
||||
return EqualsFilterExpression{
|
||||
Fields: path,
|
||||
Value: fmt.Sprintf("%v", vv.NumberValue),
|
||||
}, nil
|
||||
case *structpb.Value_StringValue:
|
||||
return EqualsFilterExpression{
|
||||
Fields: path,
|
||||
Value: vv.StringValue,
|
||||
}, nil
|
||||
}
|
||||
return nil, fmt.Errorf("unsupported struct value type for eq: %T", v.GetKind())
|
||||
}
|
||||
|
||||
// An OrFilterExpression represents a logical-or comparison operator.
|
||||
type OrFilterExpression []FilterExpression
|
||||
|
||||
func (OrFilterExpression) isFilterExpression() {}
|
||||
|
||||
// An AndFilterExpression represents a logical-and comparison operator.
|
||||
type AndFilterExpression []FilterExpression
|
||||
|
||||
func (AndFilterExpression) isFilterExpression() {}
|
||||
|
||||
// An EqualsFilterExpression represents a field comparison operator.
|
||||
type EqualsFilterExpression struct {
|
||||
Fields []string
|
||||
Value string
|
||||
}
|
||||
|
||||
func (EqualsFilterExpression) isFilterExpression() {}
|
73
pkg/storage/filter_test.go
Normal file
73
pkg/storage/filter_test.go
Normal file
|
@ -0,0 +1,73 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
func TestFilterExpressionFromStruct(t *testing.T) {
|
||||
type M = map[string]interface{}
|
||||
type A = []interface{}
|
||||
|
||||
s, err := structpb.NewStruct(M{
|
||||
"$and": A{
|
||||
M{"a": M{"b": "1"}},
|
||||
},
|
||||
"c": M{
|
||||
"d": M{
|
||||
"e": M{
|
||||
"$eq": "2",
|
||||
},
|
||||
},
|
||||
},
|
||||
"f": A{
|
||||
"3", "4", "5",
|
||||
},
|
||||
"$or": A{
|
||||
M{"g": "6"},
|
||||
M{"h": "7"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
expr, err := FilterExpressionFromStruct(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t,
|
||||
AndFilterExpression{
|
||||
EqualsFilterExpression{
|
||||
Fields: []string{"a", "b"},
|
||||
Value: "1",
|
||||
},
|
||||
OrFilterExpression{
|
||||
EqualsFilterExpression{
|
||||
Fields: []string{"g"},
|
||||
Value: "6",
|
||||
},
|
||||
EqualsFilterExpression{
|
||||
Fields: []string{"h"},
|
||||
Value: "7",
|
||||
},
|
||||
},
|
||||
EqualsFilterExpression{
|
||||
Fields: []string{"c", "d", "e"},
|
||||
Value: "2",
|
||||
},
|
||||
OrFilterExpression{
|
||||
EqualsFilterExpression{
|
||||
Fields: []string{"f"},
|
||||
Value: "3",
|
||||
},
|
||||
EqualsFilterExpression{
|
||||
Fields: []string{"f"},
|
||||
Value: "4",
|
||||
},
|
||||
EqualsFilterExpression{
|
||||
Fields: []string{"f"},
|
||||
Value: "5",
|
||||
},
|
||||
},
|
||||
},
|
||||
expr)
|
||||
}
|
65
pkg/storage/index.go
Normal file
65
pkg/storage/index.go
Normal file
|
@ -0,0 +1,65 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
)
|
||||
|
||||
const (
|
||||
indexField = "$index"
|
||||
cidrField = "cidr"
|
||||
)
|
||||
|
||||
// GetRecordIndex gets a record's index. If there is no index, nil is returned.
|
||||
func GetRecordIndex(msg proto.Message) *structpb.Struct {
|
||||
for {
|
||||
any, ok := msg.(*anypb.Any)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
msg, _ = any.UnmarshalNew()
|
||||
}
|
||||
|
||||
var s *structpb.Struct
|
||||
if sv, ok := msg.(*structpb.Value); ok {
|
||||
s = sv.GetStructValue()
|
||||
} else {
|
||||
s, _ = msg.(*structpb.Struct)
|
||||
}
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
f, ok := s.Fields[indexField]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return f.GetStructValue()
|
||||
}
|
||||
|
||||
// GetRecordIndexCIDR returns the $index.cidr for a record's data. If none is available nil is returned.
|
||||
func GetRecordIndexCIDR(msg proto.Message) *netip.Prefix {
|
||||
obj := GetRecordIndex(msg)
|
||||
if obj == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
cf, ok := obj.Fields[cidrField]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := cf.GetStringValue()
|
||||
if c == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
prefix, err := netip.ParsePrefix(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return &prefix
|
||||
}
|
62
pkg/storage/index_test.go
Normal file
62
pkg/storage/index_test.go
Normal file
|
@ -0,0 +1,62 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
func TestGetRecordIndex(t *testing.T) {
|
||||
type M = map[string]interface{}
|
||||
t.Run("missing", func(t *testing.T) {
|
||||
v, err := structpb.NewStruct(M{
|
||||
"notindex": "value",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, GetRecordIndex(v))
|
||||
})
|
||||
t.Run("struct", func(t *testing.T) {
|
||||
v, err := structpb.NewStruct(M{
|
||||
"$index": M{
|
||||
"cidr": "192.168.0.0/16",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &structpb.Struct{
|
||||
Fields: map[string]*structpb.Value{
|
||||
"cidr": structpb.NewStringValue("192.168.0.0/16"),
|
||||
},
|
||||
}, GetRecordIndex(v))
|
||||
})
|
||||
t.Run("value", func(t *testing.T) {
|
||||
v, err := structpb.NewValue(M{
|
||||
"$index": M{
|
||||
"cidr": "192.168.0.0/16",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, &structpb.Struct{
|
||||
Fields: map[string]*structpb.Value{
|
||||
"cidr": structpb.NewStringValue("192.168.0.0/16"),
|
||||
},
|
||||
}, GetRecordIndex(v))
|
||||
})
|
||||
t.Run("any", func(t *testing.T) {
|
||||
v, err := structpb.NewValue(M{
|
||||
"$index": M{
|
||||
"cidr": "192.168.0.0/16",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
any := protoutil.NewAny(v)
|
||||
assert.Equal(t, &structpb.Struct{
|
||||
Fields: map[string]*structpb.Value{
|
||||
"cidr": structpb.NewStringValue("192.168.0.0/16"),
|
||||
},
|
||||
}, GetRecordIndex(any))
|
||||
})
|
||||
}
|
121
pkg/storage/stream_filter.go
Normal file
121
pkg/storage/stream_filter.go
Normal file
|
@ -0,0 +1,121 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
// A RecordStreamFilter filters a RecordStream.
|
||||
type RecordStreamFilter func(record *databroker.Record) (keep bool)
|
||||
|
||||
// And creates a new RecordStreamFilter by applying both functions to a record.
|
||||
func (filter RecordStreamFilter) And(
|
||||
then RecordStreamFilter,
|
||||
) RecordStreamFilter {
|
||||
return func(record *databroker.Record) (keep bool) {
|
||||
return filter(record) && then(record)
|
||||
}
|
||||
}
|
||||
|
||||
// FilteredRecordStreamGenerator creates a RecordStreamGenerator that only returns records that pass the filter.
|
||||
func FilteredRecordStreamGenerator(
|
||||
generator RecordStreamGenerator,
|
||||
filter RecordStreamFilter,
|
||||
) RecordStreamGenerator {
|
||||
return func(ctx context.Context, block bool) (*databroker.Record, error) {
|
||||
for {
|
||||
record, err := generator(ctx, block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !filter(record) {
|
||||
continue
|
||||
}
|
||||
|
||||
return record, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// RecordStreamFilterFromFilterExpression returns a RecordStreamFilter from a FilterExpression.
|
||||
func RecordStreamFilterFromFilterExpression(
|
||||
expr FilterExpression,
|
||||
) (filter RecordStreamFilter, err error) {
|
||||
if expr == nil {
|
||||
return func(record *databroker.Record) (keep bool) { return true }, nil
|
||||
}
|
||||
|
||||
switch expr := expr.(type) {
|
||||
case AndFilterExpression:
|
||||
if len(expr) == 0 {
|
||||
return func(record *databroker.Record) (keep bool) { return true }, nil
|
||||
}
|
||||
|
||||
fs := make([]RecordStreamFilter, len(expr))
|
||||
for i, e := range expr {
|
||||
fs[i], err = RecordStreamFilterFromFilterExpression(e)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return func(record *databroker.Record) (keep bool) {
|
||||
for _, f := range fs {
|
||||
if !f(record) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}, nil
|
||||
case OrFilterExpression:
|
||||
if len(expr) == 0 {
|
||||
return func(record *databroker.Record) (keep bool) { return true }, nil
|
||||
}
|
||||
|
||||
fs := make([]RecordStreamFilter, len(expr))
|
||||
for i, e := range expr {
|
||||
fs[i], err = RecordStreamFilterFromFilterExpression(e)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return func(record *databroker.Record) (keep bool) {
|
||||
for _, f := range fs {
|
||||
if f(record) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, nil
|
||||
case EqualsFilterExpression:
|
||||
switch strings.Join(expr.Fields, ".") {
|
||||
case "id":
|
||||
id := expr.Value
|
||||
return func(record *databroker.Record) (keep bool) {
|
||||
return record.GetId() == id
|
||||
}, nil
|
||||
case "$index":
|
||||
ip, _ := netip.ParseAddr(expr.Value)
|
||||
return func(record *databroker.Record) (keep bool) {
|
||||
// indexed via CIDR
|
||||
if ip.IsValid() {
|
||||
msg, _ := record.GetData().UnmarshalNew()
|
||||
cidr := GetRecordIndexCIDR(msg)
|
||||
if cidr != nil && cidr.Contains(ip) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("only id or $index are supported for query filters")
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported filter expression type: %T", expr))
|
||||
}
|
||||
}
|
42
pkg/storage/stream_filter_test.go
Normal file
42
pkg/storage/stream_filter_test.go
Normal file
|
@ -0,0 +1,42 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
func TestRecordStreamFilterFromFilterExpression(t *testing.T) {
|
||||
type M = map[string]interface{}
|
||||
|
||||
s, err := structpb.NewStruct(M{
|
||||
"$index": M{
|
||||
"cidr": "192.168.0.0/16",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
f1, err := RecordStreamFilterFromFilterExpression(EqualsFilterExpression{
|
||||
Fields: []string{"$index"},
|
||||
Value: "192.168.0.1",
|
||||
})
|
||||
if assert.NoError(t, err) {
|
||||
assert.True(t, f1(&databroker.Record{
|
||||
Data: protoutil.NewAny(s),
|
||||
}))
|
||||
}
|
||||
|
||||
f2, err := RecordStreamFilterFromFilterExpression(EqualsFilterExpression{
|
||||
Fields: []string{"$index"},
|
||||
Value: "192.169.0.1",
|
||||
})
|
||||
if assert.NoError(t, err) {
|
||||
assert.False(t, f2(&databroker.Record{
|
||||
Data: protoutil.NewAny(s),
|
||||
}))
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue