databroker: add support for query filtering (#3369)

* wip

* storage: add filtering to SyncLatest

* don't increment the record version, so intermediate changes are requested

* databroker: add support for query filtering

* fill server and record version

* add test checks

* add explanation to query filter error
This commit is contained in:
Caleb Doxsey 2022-05-19 15:07:32 +00:00 committed by GitHub
parent 1669b601ea
commit 994faba0c8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 285 additions and 206 deletions

View file

@ -136,6 +136,7 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
Str("query", req.GetQuery()).
Int64("offset", req.GetOffset()).
Int64("limit", req.GetLimit()).
Interface("filter", req.GetFilter()).
Msg("query")
query := strings.ToLower(req.GetQuery())
@ -145,7 +146,12 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
return nil, err
}
_, _, stream, err := db.SyncLatest(ctx, req.GetType(), nil)
expr, err := storage.FilterExpressionFromStruct(req.GetFilter())
if err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid query filter: %v", err)
}
serverVersion, recordVersion, stream, err := db.SyncLatest(ctx, req.GetType(), expr)
if err != nil {
return nil, err
}
@ -155,10 +161,6 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
for stream.Next(false) {
record := stream.Record()
if record.GetType() != req.GetType() {
continue
}
if query != "" && !storage.MatchAny(record.GetData(), query) {
continue
}
@ -171,8 +173,10 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
records, totalCount := databroker.ApplyOffsetAndLimit(filtered, int(req.GetOffset()), int(req.GetLimit()))
return &databroker.QueryResponse{
Records: records,
TotalCount: int64(totalCount),
Records: records,
TotalCount: int64(totalCount),
ServerVersion: serverVersion,
RecordVersion: recordVersion,
}, nil
}

View file

@ -3,8 +3,10 @@ package databroker
import (
"context"
"errors"
"fmt"
"io"
"net"
"sort"
"testing"
"time"
@ -16,6 +18,7 @@ import (
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/testutil"
@ -135,21 +138,52 @@ func TestServer_Query(t *testing.T) {
cfg := newServerConfig()
srv := newServer(cfg)
s := new(session.Session)
s.Id = "1"
any := protoutil.NewAny(s)
_, err := srv.Put(context.Background(), &databroker.PutRequest{
Records: []*databroker.Record{{
Type: any.TypeUrl,
Id: s.Id,
Data: any,
}},
})
assert.NoError(t, err)
_, err = srv.Query(context.Background(), &databroker.QueryRequest{
Type: any.TypeUrl,
for i := 0; i < 10; i++ {
s := new(session.Session)
s.Id = fmt.Sprint(i)
any := protoutil.NewAny(s)
_, err := srv.Put(context.Background(), &databroker.PutRequest{
Records: []*databroker.Record{{
Type: any.TypeUrl,
Id: s.Id,
Data: any,
}},
})
assert.NoError(t, err)
}
res, err := srv.Query(context.Background(), &databroker.QueryRequest{
Type: protoutil.GetTypeURL(new(session.Session)),
Filter: &structpb.Struct{
Fields: map[string]*structpb.Value{
"$or": structpb.NewListValue(&structpb.ListValue{Values: []*structpb.Value{
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"id": structpb.NewStringValue("1"),
}}),
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"id": structpb.NewStringValue("3"),
}}),
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"id": structpb.NewStringValue("5"),
}}),
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"id": structpb.NewStringValue("7"),
}}),
}}),
},
},
Limit: 10,
})
assert.NoError(t, err)
if assert.Len(t, res.Records, 4) {
sort.Slice(res.Records, func(i, j int) bool {
return res.Records[i].GetId() < res.Records[j].GetId()
})
assert.Equal(t, "1", res.Records[0].GetId())
assert.Equal(t, "3", res.Records[1].GetId())
assert.Equal(t, "5", res.Records[2].GetId())
assert.Equal(t, "7", res.Records[3].GetId())
}
}
func TestServer_Sync(t *testing.T) {