mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 00:40:25 +02:00
databroker: add support for query filtering (#3369)
* wip * storage: add filtering to SyncLatest * don't increment the record version, so intermediate changes are requested * databroker: add support for query filtering * fill server and record version * add test checks * add explanation to query filter error
This commit is contained in:
parent
1669b601ea
commit
994faba0c8
4 changed files with 285 additions and 206 deletions
|
@ -136,6 +136,7 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
|
|||
Str("query", req.GetQuery()).
|
||||
Int64("offset", req.GetOffset()).
|
||||
Int64("limit", req.GetLimit()).
|
||||
Interface("filter", req.GetFilter()).
|
||||
Msg("query")
|
||||
|
||||
query := strings.ToLower(req.GetQuery())
|
||||
|
@ -145,7 +146,12 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
|
|||
return nil, err
|
||||
}
|
||||
|
||||
_, _, stream, err := db.SyncLatest(ctx, req.GetType(), nil)
|
||||
expr, err := storage.FilterExpressionFromStruct(req.GetFilter())
|
||||
if err != nil {
|
||||
return nil, status.Errorf(codes.InvalidArgument, "invalid query filter: %v", err)
|
||||
}
|
||||
|
||||
serverVersion, recordVersion, stream, err := db.SyncLatest(ctx, req.GetType(), expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -155,10 +161,6 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
|
|||
for stream.Next(false) {
|
||||
record := stream.Record()
|
||||
|
||||
if record.GetType() != req.GetType() {
|
||||
continue
|
||||
}
|
||||
|
||||
if query != "" && !storage.MatchAny(record.GetData(), query) {
|
||||
continue
|
||||
}
|
||||
|
@ -171,8 +173,10 @@ func (srv *Server) Query(ctx context.Context, req *databroker.QueryRequest) (*da
|
|||
|
||||
records, totalCount := databroker.ApplyOffsetAndLimit(filtered, int(req.GetOffset()), int(req.GetLimit()))
|
||||
return &databroker.QueryResponse{
|
||||
Records: records,
|
||||
TotalCount: int64(totalCount),
|
||||
Records: records,
|
||||
TotalCount: int64(totalCount),
|
||||
ServerVersion: serverVersion,
|
||||
RecordVersion: recordVersion,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -3,8 +3,10 @@ package databroker
|
|||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"sort"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
|
@ -16,6 +18,7 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
|
@ -135,21 +138,52 @@ func TestServer_Query(t *testing.T) {
|
|||
cfg := newServerConfig()
|
||||
srv := newServer(cfg)
|
||||
|
||||
s := new(session.Session)
|
||||
s.Id = "1"
|
||||
any := protoutil.NewAny(s)
|
||||
_, err := srv.Put(context.Background(), &databroker.PutRequest{
|
||||
Records: []*databroker.Record{{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
Data: any,
|
||||
}},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
_, err = srv.Query(context.Background(), &databroker.QueryRequest{
|
||||
Type: any.TypeUrl,
|
||||
for i := 0; i < 10; i++ {
|
||||
s := new(session.Session)
|
||||
s.Id = fmt.Sprint(i)
|
||||
any := protoutil.NewAny(s)
|
||||
_, err := srv.Put(context.Background(), &databroker.PutRequest{
|
||||
Records: []*databroker.Record{{
|
||||
Type: any.TypeUrl,
|
||||
Id: s.Id,
|
||||
Data: any,
|
||||
}},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
res, err := srv.Query(context.Background(), &databroker.QueryRequest{
|
||||
Type: protoutil.GetTypeURL(new(session.Session)),
|
||||
Filter: &structpb.Struct{
|
||||
Fields: map[string]*structpb.Value{
|
||||
"$or": structpb.NewListValue(&structpb.ListValue{Values: []*structpb.Value{
|
||||
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"id": structpb.NewStringValue("1"),
|
||||
}}),
|
||||
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"id": structpb.NewStringValue("3"),
|
||||
}}),
|
||||
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"id": structpb.NewStringValue("5"),
|
||||
}}),
|
||||
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
"id": structpb.NewStringValue("7"),
|
||||
}}),
|
||||
}}),
|
||||
},
|
||||
},
|
||||
Limit: 10,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
|
||||
if assert.Len(t, res.Records, 4) {
|
||||
sort.Slice(res.Records, func(i, j int) bool {
|
||||
return res.Records[i].GetId() < res.Records[j].GetId()
|
||||
})
|
||||
assert.Equal(t, "1", res.Records[0].GetId())
|
||||
assert.Equal(t, "3", res.Records[1].GetId())
|
||||
assert.Equal(t, "5", res.Records[2].GetId())
|
||||
assert.Equal(t, "7", res.Records[3].GetId())
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_Sync(t *testing.T) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue