storage: add filter expressions, upgrade go to 1.18.1 (#3365)

* storage: add filter expressions

* upgrade go
This commit is contained in:
Caleb Doxsey 2022-05-17 02:09:50 +00:00 committed by GitHub
parent 51e716ef54
commit 70f5d8b173
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 517 additions and 63 deletions

143
pkg/storage/filter.go Normal file
View file

@ -0,0 +1,143 @@
package storage
import (
"fmt"
"sort"
"google.golang.org/protobuf/types/known/structpb"
)
// A FilterExpression describes an AST for record stream filters.
type FilterExpression interface {
isFilterExpression()
}
// FilterExpressionFromStruct creates a FilterExpression from a protobuf struct.
func FilterExpressionFromStruct(s *structpb.Struct) (FilterExpression, error) {
if s == nil {
return nil, nil
}
return filterExpressionFromStruct(nil, s)
}
func filterExpressionFromValue(path []string, v *structpb.Value) (FilterExpression, error) {
switch vv := v.GetKind().(type) {
case *structpb.Value_ListValue:
var or OrFilterExpression
for _, vvv := range vv.ListValue.Values {
e, err := filterExpressionFromValue(path, vvv)
if err != nil {
return nil, err
}
or = append(or, e)
}
return or, nil
case *structpb.Value_StructValue:
return filterExpressionFromStruct(path, vv.StructValue)
}
return filterExpressionFromEq(path, v)
}
func filterExpressionFromStruct(path []string, s *structpb.Struct) (FilterExpression, error) {
var and AndFilterExpression
var fs []string
for f := range s.GetFields() {
fs = append(fs, f)
}
sort.Strings(fs)
for _, f := range fs {
v := s.GetFields()[f]
switch f {
case "$and":
expr, err := filterExpressionFromValue(path, v)
if err != nil {
return nil, err
}
or, ok := expr.(OrFilterExpression)
if !ok {
return nil, fmt.Errorf("$and must be an array")
}
if len(or) == 1 {
and = append(and, or[0])
} else {
and = append(and, AndFilterExpression(or))
}
case "$or":
expr, err := filterExpressionFromValue(path, v)
if err != nil {
return nil, err
}
or, ok := expr.(OrFilterExpression)
if !ok {
return nil, fmt.Errorf("$or must be an array")
}
if len(or) == 1 {
and = append(and, or[0])
} else {
and = append(and, or)
}
case "$eq":
expr, err := filterExpressionFromEq(path, v)
if err != nil {
return nil, err
}
and = append(and, expr)
default:
expr, err := filterExpressionFromValue(append(path, f), v)
if err != nil {
return nil, err
}
and = append(and, expr)
}
}
if len(and) == 1 {
return and[0], nil
}
return and, nil
}
func filterExpressionFromEq(path []string, v *structpb.Value) (FilterExpression, error) {
switch vv := v.GetKind().(type) {
case *structpb.Value_BoolValue:
return EqualsFilterExpression{
Fields: path,
Value: fmt.Sprintf("%v", vv.BoolValue),
}, nil
case *structpb.Value_NullValue:
return EqualsFilterExpression{
Fields: path,
Value: fmt.Sprintf("%v", vv.NullValue),
}, nil
case *structpb.Value_NumberValue:
return EqualsFilterExpression{
Fields: path,
Value: fmt.Sprintf("%v", vv.NumberValue),
}, nil
case *structpb.Value_StringValue:
return EqualsFilterExpression{
Fields: path,
Value: vv.StringValue,
}, nil
}
return nil, fmt.Errorf("unsupported struct value type for eq: %T", v.GetKind())
}
// An OrFilterExpression represents a logical-or comparison operator.
type OrFilterExpression []FilterExpression
func (OrFilterExpression) isFilterExpression() {}
// An AndFilterExpression represents a logical-and comparison operator.
type AndFilterExpression []FilterExpression
func (AndFilterExpression) isFilterExpression() {}
// An EqualsFilterExpression represents a field comparison operator.
type EqualsFilterExpression struct {
Fields []string
Value string
}
func (EqualsFilterExpression) isFilterExpression() {}

View file

@ -0,0 +1,73 @@
package storage
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/structpb"
)
func TestFilterExpressionFromStruct(t *testing.T) {
type M = map[string]interface{}
type A = []interface{}
s, err := structpb.NewStruct(M{
"$and": A{
M{"a": M{"b": "1"}},
},
"c": M{
"d": M{
"e": M{
"$eq": "2",
},
},
},
"f": A{
"3", "4", "5",
},
"$or": A{
M{"g": "6"},
M{"h": "7"},
},
})
require.NoError(t, err)
expr, err := FilterExpressionFromStruct(s)
assert.NoError(t, err)
assert.Equal(t,
AndFilterExpression{
EqualsFilterExpression{
Fields: []string{"a", "b"},
Value: "1",
},
OrFilterExpression{
EqualsFilterExpression{
Fields: []string{"g"},
Value: "6",
},
EqualsFilterExpression{
Fields: []string{"h"},
Value: "7",
},
},
EqualsFilterExpression{
Fields: []string{"c", "d", "e"},
Value: "2",
},
OrFilterExpression{
EqualsFilterExpression{
Fields: []string{"f"},
Value: "3",
},
EqualsFilterExpression{
Fields: []string{"f"},
Value: "4",
},
EqualsFilterExpression{
Fields: []string{"f"},
Value: "5",
},
},
},
expr)
}

65
pkg/storage/index.go Normal file
View file

@ -0,0 +1,65 @@
package storage
import (
"net/netip"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"
)
const (
indexField = "$index"
cidrField = "cidr"
)
// GetRecordIndex gets a record's index. If there is no index, nil is returned.
func GetRecordIndex(msg proto.Message) *structpb.Struct {
for {
any, ok := msg.(*anypb.Any)
if !ok {
break
}
msg, _ = any.UnmarshalNew()
}
var s *structpb.Struct
if sv, ok := msg.(*structpb.Value); ok {
s = sv.GetStructValue()
} else {
s, _ = msg.(*structpb.Struct)
}
if s == nil {
return nil
}
f, ok := s.Fields[indexField]
if !ok {
return nil
}
return f.GetStructValue()
}
// GetRecordIndexCIDR returns the $index.cidr for a record's data. If none is available nil is returned.
func GetRecordIndexCIDR(msg proto.Message) *netip.Prefix {
obj := GetRecordIndex(msg)
if obj == nil {
return nil
}
cf, ok := obj.Fields[cidrField]
if !ok {
return nil
}
c := cf.GetStringValue()
if c == "" {
return nil
}
prefix, err := netip.ParsePrefix(c)
if err != nil {
return nil
}
return &prefix
}

62
pkg/storage/index_test.go Normal file
View file

@ -0,0 +1,62 @@
package storage
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestGetRecordIndex(t *testing.T) {
type M = map[string]interface{}
t.Run("missing", func(t *testing.T) {
v, err := structpb.NewStruct(M{
"notindex": "value",
})
require.NoError(t, err)
assert.Nil(t, GetRecordIndex(v))
})
t.Run("struct", func(t *testing.T) {
v, err := structpb.NewStruct(M{
"$index": M{
"cidr": "192.168.0.0/16",
},
})
require.NoError(t, err)
assert.Equal(t, &structpb.Struct{
Fields: map[string]*structpb.Value{
"cidr": structpb.NewStringValue("192.168.0.0/16"),
},
}, GetRecordIndex(v))
})
t.Run("value", func(t *testing.T) {
v, err := structpb.NewValue(M{
"$index": M{
"cidr": "192.168.0.0/16",
},
})
require.NoError(t, err)
assert.Equal(t, &structpb.Struct{
Fields: map[string]*structpb.Value{
"cidr": structpb.NewStringValue("192.168.0.0/16"),
},
}, GetRecordIndex(v))
})
t.Run("any", func(t *testing.T) {
v, err := structpb.NewValue(M{
"$index": M{
"cidr": "192.168.0.0/16",
},
})
require.NoError(t, err)
any := protoutil.NewAny(v)
assert.Equal(t, &structpb.Struct{
Fields: map[string]*structpb.Value{
"cidr": structpb.NewStringValue("192.168.0.0/16"),
},
}, GetRecordIndex(any))
})
}

View file

@ -0,0 +1,121 @@
package storage
import (
"context"
"fmt"
"net/netip"
"strings"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// A RecordStreamFilter filters a RecordStream.
type RecordStreamFilter func(record *databroker.Record) (keep bool)
// And creates a new RecordStreamFilter by applying both functions to a record.
func (filter RecordStreamFilter) And(
then RecordStreamFilter,
) RecordStreamFilter {
return func(record *databroker.Record) (keep bool) {
return filter(record) && then(record)
}
}
// FilteredRecordStreamGenerator creates a RecordStreamGenerator that only returns records that pass the filter.
func FilteredRecordStreamGenerator(
generator RecordStreamGenerator,
filter RecordStreamFilter,
) RecordStreamGenerator {
return func(ctx context.Context, block bool) (*databroker.Record, error) {
for {
record, err := generator(ctx, block)
if err != nil {
return nil, err
}
if !filter(record) {
continue
}
return record, nil
}
}
}
// RecordStreamFilterFromFilterExpression returns a RecordStreamFilter from a FilterExpression.
func RecordStreamFilterFromFilterExpression(
expr FilterExpression,
) (filter RecordStreamFilter, err error) {
if expr == nil {
return func(record *databroker.Record) (keep bool) { return true }, nil
}
switch expr := expr.(type) {
case AndFilterExpression:
if len(expr) == 0 {
return func(record *databroker.Record) (keep bool) { return true }, nil
}
fs := make([]RecordStreamFilter, len(expr))
for i, e := range expr {
fs[i], err = RecordStreamFilterFromFilterExpression(e)
if err != nil {
return nil, err
}
}
return func(record *databroker.Record) (keep bool) {
for _, f := range fs {
if !f(record) {
return false
}
}
return true
}, nil
case OrFilterExpression:
if len(expr) == 0 {
return func(record *databroker.Record) (keep bool) { return true }, nil
}
fs := make([]RecordStreamFilter, len(expr))
for i, e := range expr {
fs[i], err = RecordStreamFilterFromFilterExpression(e)
if err != nil {
return nil, err
}
}
return func(record *databroker.Record) (keep bool) {
for _, f := range fs {
if f(record) {
return true
}
}
return false
}, nil
case EqualsFilterExpression:
switch strings.Join(expr.Fields, ".") {
case "id":
id := expr.Value
return func(record *databroker.Record) (keep bool) {
return record.GetId() == id
}, nil
case "$index":
ip, _ := netip.ParseAddr(expr.Value)
return func(record *databroker.Record) (keep bool) {
// indexed via CIDR
if ip.IsValid() {
msg, _ := record.GetData().UnmarshalNew()
cidr := GetRecordIndexCIDR(msg)
if cidr != nil && cidr.Contains(ip) {
return true
}
}
return false
}, nil
default:
return nil, fmt.Errorf("only id or $index are supported for query filters")
}
default:
panic(fmt.Sprintf("unsupported filter expression type: %T", expr))
}
}

View file

@ -0,0 +1,42 @@
package storage
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestRecordStreamFilterFromFilterExpression(t *testing.T) {
type M = map[string]interface{}
s, err := structpb.NewStruct(M{
"$index": M{
"cidr": "192.168.0.0/16",
},
})
require.NoError(t, err)
f1, err := RecordStreamFilterFromFilterExpression(EqualsFilterExpression{
Fields: []string{"$index"},
Value: "192.168.0.1",
})
if assert.NoError(t, err) {
assert.True(t, f1(&databroker.Record{
Data: protoutil.NewAny(s),
}))
}
f2, err := RecordStreamFilterFromFilterExpression(EqualsFilterExpression{
Fields: []string{"$index"},
Value: "192.169.0.1",
})
if assert.NoError(t, err) {
assert.False(t, f2(&databroker.Record{
Data: protoutil.NewAny(s),
}))
}
}