diff --git a/cache/databroker_test.go b/cache/databroker_test.go new file mode 100644 index 000000000..6fef9ad23 --- /dev/null +++ b/cache/databroker_test.go @@ -0,0 +1,124 @@ +package cache + +import ( + "context" + "net" + "strconv" + "testing" + + "github.com/golang/protobuf/ptypes" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/test/bufconn" + + internal_databroker "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" +) + +const bufSize = 1024 * 1024 + +var lis *bufconn.Listener + +func init() { + lis = bufconn.Listen(bufSize) + s := grpc.NewServer() + internalSrv := internal_databroker.New() + srv := &DataBrokerServer{DataBrokerServiceServer: internalSrv} + databroker.RegisterDataBrokerServiceServer(s, srv) + + go func() { + if err := s.Serve(lis); err != nil { + log.Fatal().Err(err).Msg("Server exited with error") + } + }() +} + +func bufDialer(context.Context, string) (net.Conn, error) { + return lis.Dial() +} + +func TestServerSync(t *testing.T) { + ctx := context.Background() + conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) + require.NoError(t, err) + defer conn.Close() + c := databroker.NewDataBrokerServiceClient(conn) + any, _ := ptypes.MarshalAny(new(user.User)) + numRecords := 200 + + for i := 0; i < numRecords; i++ { + c.Set(ctx, &databroker.SetRequest{Type: any.TypeUrl, Id: strconv.Itoa(i), Data: any}) + } + + t.Run("Sync ok", func(t *testing.T) { + client, _ := c.Sync(ctx, &databroker.SyncRequest{Type: any.GetTypeUrl()}) + count := 0 + for { + res, err := client.Recv() + if err != nil { + break + } + count += len(res.Records) + if count == numRecords { + break + } + } + }) + t.Run("Error occurred while syncing", func(t *testing.T) { + ctx, cancelFunc := context.WithCancel(ctx) + defer cancelFunc() + + client, _ := c.Sync(ctx, &databroker.SyncRequest{Type: any.GetTypeUrl()}) + count := 0 + numRecordsWanted := 100 + cancelFuncCalled := false + for { + res, err := client.Recv() + if err != nil { + assert.True(t, cancelFuncCalled) + break + } + count += len(res.Records) + if count == numRecordsWanted { + cancelFunc() + cancelFuncCalled = true + } + } + }) +} + +func BenchmarkSync(b *testing.B) { + ctx := context.Background() + conn, err := grpc.DialContext(ctx, "bufnet", grpc.WithContextDialer(bufDialer), grpc.WithInsecure()) + if err != nil { + b.Fatalf("Failed to dial bufnet: %v", err) + } + defer conn.Close() + c := databroker.NewDataBrokerServiceClient(conn) + any, _ := ptypes.MarshalAny(new(session.Session)) + numRecords := 10000 + + for i := 0; i < numRecords; i++ { + c.Set(ctx, &databroker.SetRequest{Type: any.TypeUrl, Id: strconv.Itoa(i), Data: any}) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + client, _ := c.Sync(ctx, &databroker.SyncRequest{Type: any.GetTypeUrl()}) + count := 0 + for { + res, err := client.Recv() + if err != nil { + break + } + count += len(res.Records) + if count == numRecords { + break + } + } + } +} diff --git a/internal/databroker/server.go b/internal/databroker/server.go index 8fee509f5..8e41d653e 100644 --- a/internal/databroker/server.go +++ b/internal/databroker/server.go @@ -30,6 +30,7 @@ import ( const ( recordTypeServerVersion = "server_version" serverVersionKey = "version" + syncBatchSize = 100 ) // Server implements the databroker service using an in memory database. @@ -225,10 +226,19 @@ func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage return updated[i].Version < updated[j].Version }) *recordVersion = updated[len(updated)-1].Version - return stream.Send(&databroker.SyncResponse{ - ServerVersion: srv.version, - Records: updated, - }) + for i := 0; i < len(updated); i += syncBatchSize { + j := i + syncBatchSize + if j > len(updated) { + j = len(updated) + } + if err := stream.Send(&databroker.SyncResponse{ + ServerVersion: srv.version, + Records: updated[i:j], + }); err != nil { + return err + } + } + return nil } // Sync streams updates for the given record type.