From d45a7e999661f10361a2e2c7bf5a38f274c6230f Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Fri, 16 Jul 2021 10:26:29 -0600 Subject: [PATCH] databroker: tests (#2367) * databroker: tests * fix lint --- internal/databroker/config_source_test.go | 16 +- internal/databroker/server_test.go | 277 ++++++++++++++++++++++ 2 files changed, 291 insertions(+), 2 deletions(-) diff --git a/internal/databroker/config_source_test.go b/internal/databroker/config_source_test.go index f0520f938..d2bac2d00 100644 --- a/internal/databroker/config_source_test.go +++ b/internal/databroker/config_source_test.go @@ -3,6 +3,7 @@ package databroker import ( "context" "net" + "net/url" "testing" "time" @@ -32,14 +33,21 @@ func TestConfigSource(t *testing.T) { cfgs := make(chan *config.Config, 10) + u, _ := url.Parse("https://to.example.com") base := config.NewDefaultOptions() base.DataBrokerURLString = "http://" + li.Addr().String() base.InsecureServer = true base.GRPCInsecure = true + base.Policies = append(base.Policies, config.Policy{ + From: "https://pomerium.io", To: config.WeightedURLs{ + {URL: *u}, + }, AllowedUsers: []string{"foo@bar.com"}, + }) - src := NewConfigSource(ctx, config.NewStaticSource(&config.Config{ + baseSource := config.NewStaticSource(&config.Config{ Options: base, - }), func(_ context.Context, cfg *config.Config) { + }) + src := NewConfigSource(ctx, baseSource, func(_ context.Context, cfg *config.Config) { cfgs <- cfg }) cfgs <- src.GetConfig() @@ -76,4 +84,8 @@ func TestConfigSource(t *testing.T) { case cfg := <-cfgs: assert.Len(t, cfg.Options.AdditionalPolicies, 1) } + + baseSource.SetConfig(ctx, &config.Config{ + Options: base, + }) } diff --git a/internal/databroker/server_test.go b/internal/databroker/server_test.go index 6f59fdcdb..287a8883a 100644 --- a/internal/databroker/server_test.go +++ b/internal/databroker/server_test.go @@ -2,18 +2,49 @@ package databroker import ( "context" + "errors" + "io" + "net" "testing" + "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/pomerium/pomerium/pkg/cryptutil" + + "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" ) +type testSyncerHandler struct { + getDataBrokerServiceClient func() databroker.DataBrokerServiceClient + clearRecords func(ctx context.Context) + updateRecords func(ctx context.Context, serverVersion uint64, records []*databroker.Record) +} + +func (h testSyncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { + return h.getDataBrokerServiceClient() +} + +func (h testSyncerHandler) ClearRecords(ctx context.Context) { + h.clearRecords(ctx) +} + +func (h testSyncerHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*databroker.Record) { + h.updateRecords(ctx, serverVersion, records) +} + func newServer(cfg *serverConfig) *Server { return &Server{ cfg: cfg, @@ -54,3 +85,249 @@ func TestServer_Get(t *testing.T) { assert.Equal(t, codes.NotFound, status.Code(err)) }) } + +func TestServer_Options(t *testing.T) { + cfg := newServerConfig() + srv := newServer(cfg) + + s := new(session.Session) + s.Id = "1" + any, err := anypb.New(s) + assert.NoError(t, err) + + _, err = srv.Put(context.Background(), &databroker.PutRequest{ + Record: &databroker.Record{ + Type: any.TypeUrl, + Id: s.Id, + Data: any, + }, + }) + assert.NoError(t, err) + _, err = srv.SetOptions(context.Background(), &databroker.SetOptionsRequest{ + Type: any.TypeUrl, + Options: &databroker.Options{ + Capacity: proto.Uint64(1), + }, + }) + assert.NoError(t, err) +} + +func TestServer_Lease(t *testing.T) { + cfg := newServerConfig() + srv := newServer(cfg) + + res, err := srv.AcquireLease(context.Background(), &databroker.AcquireLeaseRequest{ + Name: "TEST", + Duration: durationpb.New(time.Second * 10), + }) + assert.NoError(t, err) + assert.NotEmpty(t, res.GetId()) + + _, err = srv.RenewLease(context.Background(), &databroker.RenewLeaseRequest{ + Name: "TEST", + Id: res.GetId(), + Duration: durationpb.New(time.Second * 10), + }) + assert.NoError(t, err) + + _, err = srv.ReleaseLease(context.Background(), &databroker.ReleaseLeaseRequest{ + Name: "TEST", + Id: res.GetId(), + }) + assert.NoError(t, err) +} + +func TestServer_Query(t *testing.T) { + cfg := newServerConfig() + srv := newServer(cfg) + + s := new(session.Session) + s.Id = "1" + any, err := anypb.New(s) + assert.NoError(t, err) + + _, err = srv.Put(context.Background(), &databroker.PutRequest{ + Record: &databroker.Record{ + Type: any.TypeUrl, + Id: s.Id, + Data: any, + }, + }) + assert.NoError(t, err) + _, err = srv.Query(context.Background(), &databroker.QueryRequest{ + Type: any.TypeUrl, + }) + assert.NoError(t, err) +} + +func TestServer_Sync(t *testing.T) { + cfg := newServerConfig() + srv := newServer(cfg) + + s := new(session.Session) + s.Id = "1" + any, err := anypb.New(s) + assert.NoError(t, err) + + _, err = srv.Put(context.Background(), &databroker.PutRequest{ + Record: &databroker.Record{ + Type: any.TypeUrl, + Id: s.Id, + Data: any, + }, + }) + assert.NoError(t, err) + + gs := grpc.NewServer() + databroker.RegisterDataBrokerServiceServer(gs, srv) + li, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer li.Close() + + eg, ctx := errgroup.WithContext(context.Background()) + eg.Go(func() error { + return gs.Serve(li) + }) + eg.Go(func() error { + defer gs.Stop() + + cc, err := grpc.DialContext(ctx, li.Addr().String(), grpc.WithInsecure()) + if err != nil { + return err + } + defer cc.Close() + + clearRecords := make(chan struct{}, 10) + updateRecords := make(chan uint64, 10) + + client := databroker.NewDataBrokerServiceClient(cc) + syncer := databroker.NewSyncer("TEST", testSyncerHandler{ + getDataBrokerServiceClient: func() databroker.DataBrokerServiceClient { + return client + }, + clearRecords: func(_ context.Context) { + clearRecords <- struct{}{} + }, + updateRecords: func(_ context.Context, recordVersion uint64, _ []*databroker.Record) { + updateRecords <- recordVersion + }, + }) + go syncer.Run(ctx) + select { + case <-clearRecords: + case <-ctx.Done(): + return ctx.Err() + } + select { + case <-updateRecords: + case <-ctx.Done(): + return ctx.Err() + + } + + _, err = srv.Put(context.Background(), &databroker.PutRequest{ + Record: &databroker.Record{ + Type: any.TypeUrl, + Id: s.Id, + Data: any, + }, + }) + assert.NoError(t, err) + + select { + case <-updateRecords: + case <-ctx.Done(): + return ctx.Err() + + } + return nil + }) + assert.NoError(t, eg.Wait()) +} + +func TestServerInvalidStorage(t *testing.T) { + srv := newServer(&serverConfig{ + storageType: "", + }) + + s := new(session.Session) + s.Id = "1" + any, err := anypb.New(s) + assert.NoError(t, err) + + _, err = srv.Put(context.Background(), &databroker.PutRequest{ + Record: &databroker.Record{ + Type: any.TypeUrl, + Id: s.Id, + Data: any, + }, + }) + _ = assert.Error(t, err) && assert.Contains(t, err.Error(), "unsupported storage type") +} + +func TestServerRedis(t *testing.T) { + testutil.WithTestRedis(false, func(rawURL string) error { + srv := newServer(&serverConfig{ + storageType: "redis", + storageConnectionString: rawURL, + secret: cryptutil.NewKey(), + }) + + s := new(session.Session) + s.Id = "1" + any, err := anypb.New(s) + assert.NoError(t, err) + + _, err = srv.Put(context.Background(), &databroker.PutRequest{ + Record: &databroker.Record{ + Type: any.TypeUrl, + Id: s.Id, + Data: any, + }, + }) + assert.NoError(t, err) + + gs := grpc.NewServer() + databroker.RegisterDataBrokerServiceServer(gs, srv) + li, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + defer li.Close() + + eg, ctx := errgroup.WithContext(context.Background()) + eg.Go(func() error { + return gs.Serve(li) + }) + eg.Go(func() error { + defer gs.Stop() + + cc, err := grpc.DialContext(ctx, li.Addr().String(), grpc.WithInsecure()) + if err != nil { + return err + } + defer cc.Close() + + client := databroker.NewDataBrokerServiceClient(cc) + stream, err := client.SyncLatest(ctx, &databroker.SyncLatestRequest{ + Type: any.TypeUrl, + }) + if err != nil { + return err + } + + for { + res, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } else if err != nil { + return err + } + + assert.NotNil(t, res) + } + + return nil + }) + assert.NoError(t, eg.Wait()) + return nil + }) +}