pomerium/internal/databroker/server_test.go
Kenneth Jenkins d5da872157
databroker: add patch method (#4704)
Add a Patch() method to the databroker gRPC service.

Update the storage.Backend interface to include the Patch() method now
that all the storage.Backend implementations include it.

Add a test to exercise the patch method under concurrent usage.
2023-11-02 15:07:37 -07:00

404 lines
9.8 KiB
Go

package databroker
import (
"context"
"errors"
"fmt"
"io"
"net"
"sort"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"google.golang.org/protobuf/types/known/structpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/protoutil"
)
type testSyncerHandler struct {
getDataBrokerServiceClient func() databroker.DataBrokerServiceClient
clearRecords func(ctx context.Context)
updateRecords func(ctx context.Context, serverVersion uint64, records []*databroker.Record)
}
func (h testSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return h.getDataBrokerServiceClient()
}
func (h testSyncerHandler) ClearRecords(ctx context.Context) {
h.clearRecords(ctx)
}
func (h testSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) {
h.updateRecords(ctx, serverVersion, records)
}
func newServer(cfg *serverConfig) *Server {
return &Server{
cfg: cfg,
}
}
func TestServer_Get(t *testing.T) {
cfg := newServerConfig()
t.Run("ignore deleted", func(t *testing.T) {
srv := newServer(cfg)
s := new(session.Session)
s.Id = "1"
data := protoutil.NewAny(s)
_, err := srv.Put(context.Background(), &databroker.PutRequest{
Records: []*databroker.Record{{
Type: data.TypeUrl,
Id: s.Id,
Data: data,
}},
})
assert.NoError(t, err)
_, err = srv.Put(context.Background(), &databroker.PutRequest{
Records: []*databroker.Record{{
Type: data.TypeUrl,
Id: s.Id,
DeletedAt: timestamppb.Now(),
}},
})
assert.NoError(t, err)
_, err = srv.Get(context.Background(), &databroker.GetRequest{
Type: data.TypeUrl,
Id: s.Id,
})
assert.Error(t, err)
assert.Equal(t, codes.NotFound, status.Code(err))
})
}
func TestServer_Patch(t *testing.T) {
cfg := newServerConfig()
srv := newServer(cfg)
s := &session.Session{
Id: "1",
OauthToken: &session.OAuthToken{AccessToken: "access-token"},
}
data := protoutil.NewAny(s)
_, err := srv.Put(context.Background(), &databroker.PutRequest{
Records: []*databroker.Record{{
Type: data.TypeUrl,
Id: s.Id,
Data: data,
}},
})
require.NoError(t, err)
fm, err := fieldmaskpb.New(s, "accessed_at")
require.NoError(t, err)
now := timestamppb.Now()
s.AccessedAt = now
s.OauthToken.AccessToken = "access-token-field-ignored"
data = protoutil.NewAny(s)
patchResponse, err := srv.Patch(context.Background(), &databroker.PatchRequest{
Records: []*databroker.Record{{
Type: data.TypeUrl,
Id: s.Id,
Data: data,
}},
FieldMask: fm,
})
require.NoError(t, err)
testutil.AssertProtoEqual(t, protoutil.NewAny(&session.Session{
Id: "1",
AccessedAt: now,
OauthToken: &session.OAuthToken{AccessToken: "access-token"},
}), patchResponse.GetRecord().GetData())
getResponse, err := srv.Get(context.Background(), &databroker.GetRequest{
Type: data.TypeUrl,
Id: s.Id,
})
require.NoError(t, err)
testutil.AssertProtoEqual(t, protoutil.NewAny(&session.Session{
Id: "1",
AccessedAt: now,
OauthToken: &session.OAuthToken{AccessToken: "access-token"},
}), getResponse.GetRecord().GetData())
}
func TestServer_Options(t *testing.T) {
cfg := newServerConfig()
srv := newServer(cfg)
s := new(session.Session)
s.Id = "1"
data := protoutil.NewAny(s)
_, err := srv.Put(context.Background(), &databroker.PutRequest{
Records: []*databroker.Record{{
Type: data.TypeUrl,
Id: s.Id,
Data: data,
}},
})
assert.NoError(t, err)
_, err = srv.SetOptions(context.Background(), &databroker.SetOptionsRequest{
Type: data.TypeUrl,
Options: &databroker.Options{
Capacity: proto.Uint64(1),
},
})
assert.NoError(t, err)
}
func TestServer_Lease(t *testing.T) {
cfg := newServerConfig()
srv := newServer(cfg)
res, err := srv.AcquireLease(context.Background(), &databroker.AcquireLeaseRequest{
Name: "TEST",
Duration: durationpb.New(time.Second * 10),
})
assert.NoError(t, err)
assert.NotEmpty(t, res.GetId())
_, err = srv.RenewLease(context.Background(), &databroker.RenewLeaseRequest{
Name: "TEST",
Id: res.GetId(),
Duration: durationpb.New(time.Second * 10),
})
assert.NoError(t, err)
_, err = srv.ReleaseLease(context.Background(), &databroker.ReleaseLeaseRequest{
Name: "TEST",
Id: res.GetId(),
})
assert.NoError(t, err)
}
func TestServer_Query(t *testing.T) {
cfg := newServerConfig()
srv := newServer(cfg)
for i := 0; i < 10; i++ {
s := new(session.Session)
s.Id = fmt.Sprint(i)
data := protoutil.NewAny(s)
_, err := srv.Put(context.Background(), &databroker.PutRequest{
Records: []*databroker.Record{{
Type: data.TypeUrl,
Id: s.Id,
Data: data,
}},
})
assert.NoError(t, err)
}
res, err := srv.Query(context.Background(), &databroker.QueryRequest{
Type: protoutil.GetTypeURL(new(session.Session)),
Filter: &structpb.Struct{
Fields: map[string]*structpb.Value{
"$or": structpb.NewListValue(&structpb.ListValue{Values: []*structpb.Value{
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"id": structpb.NewStringValue("1"),
}}),
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"id": structpb.NewStringValue("3"),
}}),
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"id": structpb.NewStringValue("5"),
}}),
structpb.NewStructValue(&structpb.Struct{Fields: map[string]*structpb.Value{
"id": structpb.NewStringValue("7"),
}}),
}}),
},
},
Limit: 10,
})
assert.NoError(t, err)
if assert.Len(t, res.Records, 4) {
sort.Slice(res.Records, func(i, j int) bool {
return res.Records[i].GetId() < res.Records[j].GetId()
})
assert.Equal(t, "1", res.Records[0].GetId())
assert.Equal(t, "3", res.Records[1].GetId())
assert.Equal(t, "5", res.Records[2].GetId())
assert.Equal(t, "7", res.Records[3].GetId())
}
}
func TestServer_Sync(t *testing.T) {
cfg := newServerConfig()
srv := newServer(cfg)
s := new(session.Session)
s.Id = "1"
data := protoutil.NewAny(s)
_, err := srv.Put(context.Background(), &databroker.PutRequest{
Records: []*databroker.Record{{
Type: data.TypeUrl,
Id: s.Id,
Data: data,
}},
})
assert.NoError(t, err)
gs := grpc.NewServer()
databroker.RegisterDataBrokerServiceServer(gs, srv)
li, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer li.Close()
eg, ctx := errgroup.WithContext(context.Background())
eg.Go(func() error {
return gs.Serve(li)
})
eg.Go(func() error {
defer gs.Stop()
cc, err := grpc.DialContext(ctx, li.Addr().String(), grpc.WithInsecure())
if err != nil {
return err
}
defer cc.Close()
clearRecords := make(chan struct{}, 10)
updateRecords := make(chan uint64, 10)
client := databroker.NewDataBrokerServiceClient(cc)
syncer := databroker.NewSyncer("TEST", testSyncerHandler{
getDataBrokerServiceClient: func() databroker.DataBrokerServiceClient {
return client
},
clearRecords: func(_ context.Context) {
clearRecords <- struct{}{}
},
updateRecords: func(_ context.Context, recordVersion uint64, _ []*databroker.Record) {
updateRecords <- recordVersion
},
})
go syncer.Run(ctx)
select {
case <-clearRecords:
case <-ctx.Done():
return ctx.Err()
}
select {
case <-updateRecords:
case <-ctx.Done():
return ctx.Err()
}
_, err = srv.Put(context.Background(), &databroker.PutRequest{
Records: []*databroker.Record{{
Type: data.TypeUrl,
Id: s.Id,
Data: data,
}},
})
assert.NoError(t, err)
select {
case <-updateRecords:
case <-ctx.Done():
return ctx.Err()
}
return nil
})
assert.NoError(t, eg.Wait())
}
func TestServerInvalidStorage(t *testing.T) {
srv := newServer(&serverConfig{
storageType: "<INVALID>",
})
s := new(session.Session)
s.Id = "1"
data := protoutil.NewAny(s)
_, err := srv.Put(context.Background(), &databroker.PutRequest{
Records: []*databroker.Record{{
Type: data.TypeUrl,
Id: s.Id,
Data: data,
}},
})
_ = assert.Error(t, err) && assert.Contains(t, err.Error(), "unsupported storage type")
}
func TestServerPostgres(t *testing.T) {
testutil.WithTestPostgres(func(dsn string) error {
srv := newServer(&serverConfig{
storageType: "postgres",
storageConnectionString: dsn,
})
s := new(session.Session)
s.Id = "1"
data := protoutil.NewAny(s)
_, err := srv.Put(context.Background(), &databroker.PutRequest{
Records: []*databroker.Record{{
Type: data.TypeUrl,
Id: s.Id,
Data: data,
}},
})
assert.NoError(t, err)
gs := grpc.NewServer()
databroker.RegisterDataBrokerServiceServer(gs, srv)
li, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer li.Close()
eg, ctx := errgroup.WithContext(context.Background())
eg.Go(func() error {
return gs.Serve(li)
})
eg.Go(func() error {
defer gs.Stop()
cc, err := grpc.DialContext(ctx, li.Addr().String(), grpc.WithInsecure())
if err != nil {
return err
}
defer cc.Close()
client := databroker.NewDataBrokerServiceClient(cc)
stream, err := client.SyncLatest(ctx, &databroker.SyncLatestRequest{
Type: data.TypeUrl,
})
if err != nil {
return err
}
for {
res, err := stream.Recv()
if errors.Is(err, io.EOF) {
break
} else if err != nil {
return err
}
assert.NotNil(t, res)
}
return nil
})
assert.NoError(t, eg.Wait())
return nil
})
}