add test, fix bug

This commit is contained in:
Caleb Doxsey 2024-09-11 14:16:30 -06:00
parent edaf99b800
commit 71e306f76c
3 changed files with 154 additions and 18 deletions

View file

@ -217,7 +217,7 @@ func (c *controller) runUsageReporter(ctx context.Context, client databroker.Dat
return fmt.Errorf("error waiting for bootstrap: %w", err)
}
ur := usagereporter.New(c.api, c.bootstrapConfig.GetConfig().ZeroOrganizationID)
ur := usagereporter.New(c.api, c.bootstrapConfig.GetConfig().ZeroOrganizationID, time.Minute)
return retry.WithBackoff(ctx, "zero-usage-reporter", func(ctx context.Context) error {
// start the usage reporter
return ur.Run(ctx, client)

View file

@ -18,7 +18,6 @@ import (
"golang.org/x/sync/errgroup"
"github.com/pomerium/pomerium/internal/log"
sdk "github.com/pomerium/pomerium/internal/zero/api"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
@ -26,6 +25,11 @@ import (
"github.com/pomerium/pomerium/pkg/zero/cluster"
)
// API is the part of the Zero Cluster API used to report usage.
type API interface {
ReportUsage(ctx context.Context, req cluster.ReportUsageRequest) error
}
type usageReporterRecord struct {
userID string
userEmail string
@ -34,8 +38,9 @@ type usageReporterRecord struct {
// A UsageReporter reports usage to the zero api.
type UsageReporter struct {
api *sdk.API
api API
organizationID string
reportInterval time.Duration
mu sync.Mutex
byUserID map[string]usageReporterRecord
@ -43,12 +48,14 @@ type UsageReporter struct {
}
// New creates a new UsageReporter.
func New(api *sdk.API, organizationID string) *UsageReporter {
func New(api API, organizationID string, reportInterval time.Duration) *UsageReporter {
return &UsageReporter{
api: api,
organizationID: organizationID,
byUserID: make(map[string]usageReporterRecord),
updates: set.New[string](0),
reportInterval: reportInterval,
byUserID: make(map[string]usageReporterRecord),
updates: set.New[string](0),
}
}
@ -57,13 +64,13 @@ func (ur *UsageReporter) Run(ctx context.Context, client databroker.DataBrokerSe
ctx = log.Ctx(ctx).With().Str("organization-id", ur.organizationID).Logger().WithContext(ctx)
// first initialize the user collection
serverVersion, latestRecordVersion, err := ur.runInit(ctx, client)
serverVersion, latestSessionRecordVersion, latestUserRecordVersion, err := ur.runInit(ctx, client)
if err != nil {
return err
}
// run the continuous sync calls and periodically report usage
return ur.runSync(ctx, client, serverVersion, latestRecordVersion)
return ur.runSync(ctx, client, serverVersion, latestSessionRecordVersion, latestUserRecordVersion)
}
func (ur *UsageReporter) report(ctx context.Context, records []usageReporterRecord) error {
@ -80,27 +87,34 @@ func (ur *UsageReporter) report(ctx context.Context, records []usageReporterReco
}, backoff.WithContext(backoff.NewExponentialBackOff(), ctx))
}
func (ur *UsageReporter) runInit(ctx context.Context, client databroker.DataBrokerServiceClient) (serverVersion, latestRecordVersion uint64, err error) {
_, _, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateSession)
func (ur *UsageReporter) runInit(
ctx context.Context,
client databroker.DataBrokerServiceClient,
) (serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64, err error) {
_, latestSessionRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateSession)
if err != nil {
return 0, 0, err
return 0, 0, 0, err
}
serverVersion, latestRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateUser)
serverVersion, latestUserRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateUser)
if err != nil {
return 0, 0, err
return 0, 0, 0, err
}
return serverVersion, latestRecordVersion, nil
return serverVersion, latestSessionRecordVersion, latestUserRecordVersion, nil
}
func (ur *UsageReporter) runSync(ctx context.Context, client databroker.DataBrokerServiceClient, serverVersion, latestRecordVersion uint64) error {
func (ur *UsageReporter) runSync(
ctx context.Context,
client databroker.DataBrokerServiceClient,
serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64,
) error {
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return databroker.SyncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateSession)
return databroker.SyncRecords(ctx, client, serverVersion, latestSessionRecordVersion, ur.onUpdateSession)
})
eg.Go(func() error {
return databroker.SyncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateUser)
return databroker.SyncRecords(ctx, client, serverVersion, latestUserRecordVersion, ur.onUpdateUser)
})
eg.Go(func() error {
return ur.runReporter(ctx)
@ -110,7 +124,7 @@ func (ur *UsageReporter) runSync(ctx context.Context, client databroker.DataBrok
func (ur *UsageReporter) runReporter(ctx context.Context) error {
// every minute collect any updates and submit them to the API
timer := time.NewTicker(time.Minute)
timer := time.NewTicker(ur.reportInterval)
defer timer.Stop()
for {

View file

@ -1,14 +1,125 @@
package usagereporter
import (
"context"
"errors"
"testing"
"time"
"github.com/stretchr/testify/assert"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/testutil"
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/zero/cluster"
)
type mockAPI struct {
reportUsage func(ctx context.Context, req cluster.ReportUsageRequest) error
}
func (m mockAPI) ReportUsage(ctx context.Context, req cluster.ReportUsageRequest) error {
return m.reportUsage(ctx, req)
}
func TestUsageReporter(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
t.Cleanup(clearTimeout)
ctx, cancel := context.WithCancel(ctx)
t.Cleanup(cancel)
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New())
})
t.Cleanup(func() { cc.Close() })
tm1 := time.Date(2024, time.September, 11, 11, 56, 0, 0, time.UTC)
requests := make(chan cluster.ReportUsageRequest, 1)
client := databrokerpb.NewDataBrokerServiceClient(cc)
ur := New(mockAPI{
reportUsage: func(ctx context.Context, req cluster.ReportUsageRequest) error {
select {
case <-ctx.Done():
return ctx.Err()
case requests <- req:
}
return nil
},
}, "bQjwPpxcwJRbvsSMFgbZFkXmxFJ", time.Millisecond*100)
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return ur.Run(ctx, client)
})
eg.Go(func() error {
_, err := databrokerpb.Put(ctx, client,
&session.Session{
Id: "S1a",
UserId: "U1",
IssuedAt: timestamppb.New(tm1),
},
&session.Session{
Id: "S1b",
UserId: "U1",
IssuedAt: timestamppb.New(tm1),
})
if err != nil {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case req := <-requests:
assert.Equal(t, cluster.ReportUsageRequest{
Users: []cluster.ReportUsageUser{{
LastSignedInAt: tm1,
PseudonymousId: "095xqqsjEEgYf5Yf+TAjWjooMQyh6jSV5SCPGe9eqvg=",
}},
}, req, "should send a single usage record")
}
_, err = databrokerpb.Put(ctx, client,
&user.User{
Id: "U1",
Email: "u1@example.com",
})
if err != nil {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case req := <-requests:
assert.Equal(t, cluster.ReportUsageRequest{
Users: []cluster.ReportUsageUser{{
LastSignedInAt: tm1,
PseudonymousEmail: "iq8/fj+uZaKitkWY12JIQgKJ5KIP+E0Cmy/HpxpdBXY=",
PseudonymousId: "095xqqsjEEgYf5Yf+TAjWjooMQyh6jSV5SCPGe9eqvg=",
}},
}, req, "should send another usage record with the email set")
}
cancel()
return nil
})
err := eg.Wait()
if err != nil && !errors.Is(ctx.Err(), context.Canceled) {
assert.NoError(t, err)
}
}
func Test_coalesce(t *testing.T) {
t.Parallel()
@ -40,3 +151,14 @@ func Test_convertUsageReporterRecords(t *testing.T) {
lastSignedInAt: tm1,
}}), "should leave empty email")
}
func Test_latest(t *testing.T) {
t.Parallel()
tm1 := time.Date(2024, time.September, 11, 11, 56, 0, 0, time.UTC)
tm2 := time.Date(2024, time.September, 12, 11, 56, 0, 0, time.UTC)
assert.Equal(t, tm2, latest(tm1, tm2))
assert.Equal(t, tm2, latest(tm2, tm1), "should ignore ordering")
assert.Equal(t, tm1, latest(tm1, time.Time{}), "should handle zero time")
}