From 71e306f76cdbd3afabfd2b8fda735f44ce12e2c8 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Wed, 11 Sep 2024 14:16:30 -0600 Subject: [PATCH] add test, fix bug --- internal/zero/controller/controller.go | 2 +- .../controller/usagereporter/usagereporter.go | 48 ++++--- .../usagereporter/usagereporter_test.go | 122 ++++++++++++++++++ 3 files changed, 154 insertions(+), 18 deletions(-) diff --git a/internal/zero/controller/controller.go b/internal/zero/controller/controller.go index 2c2364409..75cd228f0 100644 --- a/internal/zero/controller/controller.go +++ b/internal/zero/controller/controller.go @@ -217,7 +217,7 @@ func (c *controller) runUsageReporter(ctx context.Context, client databroker.Dat return fmt.Errorf("error waiting for bootstrap: %w", err) } - ur := usagereporter.New(c.api, c.bootstrapConfig.GetConfig().ZeroOrganizationID) + ur := usagereporter.New(c.api, c.bootstrapConfig.GetConfig().ZeroOrganizationID, time.Minute) return retry.WithBackoff(ctx, "zero-usage-reporter", func(ctx context.Context) error { // start the usage reporter return ur.Run(ctx, client) diff --git a/internal/zero/controller/usagereporter/usagereporter.go b/internal/zero/controller/usagereporter/usagereporter.go index 09729a5f2..288a25bd2 100644 --- a/internal/zero/controller/usagereporter/usagereporter.go +++ b/internal/zero/controller/usagereporter/usagereporter.go @@ -18,7 +18,6 @@ import ( "golang.org/x/sync/errgroup" "github.com/pomerium/pomerium/internal/log" - sdk "github.com/pomerium/pomerium/internal/zero/api" "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" @@ -26,6 +25,11 @@ import ( "github.com/pomerium/pomerium/pkg/zero/cluster" ) +// API is the part of the Zero Cluster API used to report usage. +type API interface { + ReportUsage(ctx context.Context, req cluster.ReportUsageRequest) error +} + type usageReporterRecord struct { userID string userEmail string @@ -34,8 +38,9 @@ type usageReporterRecord struct { // A UsageReporter reports usage to the zero api. type UsageReporter struct { - api *sdk.API + api API organizationID string + reportInterval time.Duration mu sync.Mutex byUserID map[string]usageReporterRecord @@ -43,12 +48,14 @@ type UsageReporter struct { } // New creates a new UsageReporter. -func New(api *sdk.API, organizationID string) *UsageReporter { +func New(api API, organizationID string, reportInterval time.Duration) *UsageReporter { return &UsageReporter{ api: api, organizationID: organizationID, - byUserID: make(map[string]usageReporterRecord), - updates: set.New[string](0), + reportInterval: reportInterval, + + byUserID: make(map[string]usageReporterRecord), + updates: set.New[string](0), } } @@ -57,13 +64,13 @@ func (ur *UsageReporter) Run(ctx context.Context, client databroker.DataBrokerSe ctx = log.Ctx(ctx).With().Str("organization-id", ur.organizationID).Logger().WithContext(ctx) // first initialize the user collection - serverVersion, latestRecordVersion, err := ur.runInit(ctx, client) + serverVersion, latestSessionRecordVersion, latestUserRecordVersion, err := ur.runInit(ctx, client) if err != nil { return err } // run the continuous sync calls and periodically report usage - return ur.runSync(ctx, client, serverVersion, latestRecordVersion) + return ur.runSync(ctx, client, serverVersion, latestSessionRecordVersion, latestUserRecordVersion) } func (ur *UsageReporter) report(ctx context.Context, records []usageReporterRecord) error { @@ -80,27 +87,34 @@ func (ur *UsageReporter) report(ctx context.Context, records []usageReporterReco }, backoff.WithContext(backoff.NewExponentialBackOff(), ctx)) } -func (ur *UsageReporter) runInit(ctx context.Context, client databroker.DataBrokerServiceClient) (serverVersion, latestRecordVersion uint64, err error) { - _, _, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateSession) +func (ur *UsageReporter) runInit( + ctx context.Context, + client databroker.DataBrokerServiceClient, +) (serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64, err error) { + _, latestSessionRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateSession) if err != nil { - return 0, 0, err + return 0, 0, 0, err } - serverVersion, latestRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateUser) + serverVersion, latestUserRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateUser) if err != nil { - return 0, 0, err + return 0, 0, 0, err } - return serverVersion, latestRecordVersion, nil + return serverVersion, latestSessionRecordVersion, latestUserRecordVersion, nil } -func (ur *UsageReporter) runSync(ctx context.Context, client databroker.DataBrokerServiceClient, serverVersion, latestRecordVersion uint64) error { +func (ur *UsageReporter) runSync( + ctx context.Context, + client databroker.DataBrokerServiceClient, + serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64, +) error { eg, ctx := errgroup.WithContext(ctx) eg.Go(func() error { - return databroker.SyncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateSession) + return databroker.SyncRecords(ctx, client, serverVersion, latestSessionRecordVersion, ur.onUpdateSession) }) eg.Go(func() error { - return databroker.SyncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateUser) + return databroker.SyncRecords(ctx, client, serverVersion, latestUserRecordVersion, ur.onUpdateUser) }) eg.Go(func() error { return ur.runReporter(ctx) @@ -110,7 +124,7 @@ func (ur *UsageReporter) runSync(ctx context.Context, client databroker.DataBrok func (ur *UsageReporter) runReporter(ctx context.Context) error { // every minute collect any updates and submit them to the API - timer := time.NewTicker(time.Minute) + timer := time.NewTicker(ur.reportInterval) defer timer.Stop() for { diff --git a/internal/zero/controller/usagereporter/usagereporter_test.go b/internal/zero/controller/usagereporter/usagereporter_test.go index 49398052a..93140d268 100644 --- a/internal/zero/controller/usagereporter/usagereporter_test.go +++ b/internal/zero/controller/usagereporter/usagereporter_test.go @@ -1,14 +1,125 @@ package usagereporter import ( + "context" + "errors" "testing" "time" "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/timestamppb" + "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/internal/testutil" + databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/zero/cluster" ) +type mockAPI struct { + reportUsage func(ctx context.Context, req cluster.ReportUsageRequest) error +} + +func (m mockAPI) ReportUsage(ctx context.Context, req cluster.ReportUsageRequest) error { + return m.reportUsage(ctx, req) +} + +func TestUsageReporter(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + t.Cleanup(clearTimeout) + + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + + cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { + databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New()) + }) + t.Cleanup(func() { cc.Close() }) + + tm1 := time.Date(2024, time.September, 11, 11, 56, 0, 0, time.UTC) + + requests := make(chan cluster.ReportUsageRequest, 1) + + client := databrokerpb.NewDataBrokerServiceClient(cc) + ur := New(mockAPI{ + reportUsage: func(ctx context.Context, req cluster.ReportUsageRequest) error { + select { + case <-ctx.Done(): + return ctx.Err() + case requests <- req: + } + return nil + }, + }, "bQjwPpxcwJRbvsSMFgbZFkXmxFJ", time.Millisecond*100) + + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + return ur.Run(ctx, client) + }) + eg.Go(func() error { + _, err := databrokerpb.Put(ctx, client, + &session.Session{ + Id: "S1a", + UserId: "U1", + IssuedAt: timestamppb.New(tm1), + }, + &session.Session{ + Id: "S1b", + UserId: "U1", + IssuedAt: timestamppb.New(tm1), + }) + if err != nil { + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-requests: + assert.Equal(t, cluster.ReportUsageRequest{ + Users: []cluster.ReportUsageUser{{ + LastSignedInAt: tm1, + PseudonymousId: "095xqqsjEEgYf5Yf+TAjWjooMQyh6jSV5SCPGe9eqvg=", + }}, + }, req, "should send a single usage record") + } + + _, err = databrokerpb.Put(ctx, client, + &user.User{ + Id: "U1", + Email: "u1@example.com", + }) + if err != nil { + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-requests: + assert.Equal(t, cluster.ReportUsageRequest{ + Users: []cluster.ReportUsageUser{{ + LastSignedInAt: tm1, + PseudonymousEmail: "iq8/fj+uZaKitkWY12JIQgKJ5KIP+E0Cmy/HpxpdBXY=", + PseudonymousId: "095xqqsjEEgYf5Yf+TAjWjooMQyh6jSV5SCPGe9eqvg=", + }}, + }, req, "should send another usage record with the email set") + } + + cancel() + return nil + }) + err := eg.Wait() + if err != nil && !errors.Is(ctx.Err(), context.Canceled) { + assert.NoError(t, err) + } +} + func Test_coalesce(t *testing.T) { t.Parallel() @@ -40,3 +151,14 @@ func Test_convertUsageReporterRecords(t *testing.T) { lastSignedInAt: tm1, }}), "should leave empty email") } + +func Test_latest(t *testing.T) { + t.Parallel() + + tm1 := time.Date(2024, time.September, 11, 11, 56, 0, 0, time.UTC) + tm2 := time.Date(2024, time.September, 12, 11, 56, 0, 0, time.UTC) + + assert.Equal(t, tm2, latest(tm1, tm2)) + assert.Equal(t, tm2, latest(tm2, tm1), "should ignore ordering") + assert.Equal(t, tm1, latest(tm1, time.Time{}), "should handle zero time") +}