add test, fix bug

This commit is contained in:
Caleb Doxsey 2024-09-11 14:16:30 -06:00
parent edaf99b800
commit 71e306f76c
3 changed files with 154 additions and 18 deletions

View file

@ -18,7 +18,6 @@ import (
"golang.org/x/sync/errgroup"
"github.com/pomerium/pomerium/internal/log"
sdk "github.com/pomerium/pomerium/internal/zero/api"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
@ -26,6 +25,11 @@ import (
"github.com/pomerium/pomerium/pkg/zero/cluster"
)
// API is the part of the Zero Cluster API used to report usage.
type API interface {
ReportUsage(ctx context.Context, req cluster.ReportUsageRequest) error
}
type usageReporterRecord struct {
userID string
userEmail string
@ -34,8 +38,9 @@ type usageReporterRecord struct {
// A UsageReporter reports usage to the zero api.
type UsageReporter struct {
api *sdk.API
api API
organizationID string
reportInterval time.Duration
mu sync.Mutex
byUserID map[string]usageReporterRecord
@ -43,12 +48,14 @@ type UsageReporter struct {
}
// New creates a new UsageReporter.
func New(api *sdk.API, organizationID string) *UsageReporter {
func New(api API, organizationID string, reportInterval time.Duration) *UsageReporter {
return &UsageReporter{
api: api,
organizationID: organizationID,
byUserID: make(map[string]usageReporterRecord),
updates: set.New[string](0),
reportInterval: reportInterval,
byUserID: make(map[string]usageReporterRecord),
updates: set.New[string](0),
}
}
@ -57,13 +64,13 @@ func (ur *UsageReporter) Run(ctx context.Context, client databroker.DataBrokerSe
ctx = log.Ctx(ctx).With().Str("organization-id", ur.organizationID).Logger().WithContext(ctx)
// first initialize the user collection
serverVersion, latestRecordVersion, err := ur.runInit(ctx, client)
serverVersion, latestSessionRecordVersion, latestUserRecordVersion, err := ur.runInit(ctx, client)
if err != nil {
return err
}
// run the continuous sync calls and periodically report usage
return ur.runSync(ctx, client, serverVersion, latestRecordVersion)
return ur.runSync(ctx, client, serverVersion, latestSessionRecordVersion, latestUserRecordVersion)
}
func (ur *UsageReporter) report(ctx context.Context, records []usageReporterRecord) error {
@ -80,27 +87,34 @@ func (ur *UsageReporter) report(ctx context.Context, records []usageReporterReco
}, backoff.WithContext(backoff.NewExponentialBackOff(), ctx))
}
func (ur *UsageReporter) runInit(ctx context.Context, client databroker.DataBrokerServiceClient) (serverVersion, latestRecordVersion uint64, err error) {
_, _, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateSession)
func (ur *UsageReporter) runInit(
ctx context.Context,
client databroker.DataBrokerServiceClient,
) (serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64, err error) {
_, latestSessionRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateSession)
if err != nil {
return 0, 0, err
return 0, 0, 0, err
}
serverVersion, latestRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateUser)
serverVersion, latestUserRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateUser)
if err != nil {
return 0, 0, err
return 0, 0, 0, err
}
return serverVersion, latestRecordVersion, nil
return serverVersion, latestSessionRecordVersion, latestUserRecordVersion, nil
}
func (ur *UsageReporter) runSync(ctx context.Context, client databroker.DataBrokerServiceClient, serverVersion, latestRecordVersion uint64) error {
func (ur *UsageReporter) runSync(
ctx context.Context,
client databroker.DataBrokerServiceClient,
serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64,
) error {
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return databroker.SyncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateSession)
return databroker.SyncRecords(ctx, client, serverVersion, latestSessionRecordVersion, ur.onUpdateSession)
})
eg.Go(func() error {
return databroker.SyncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateUser)
return databroker.SyncRecords(ctx, client, serverVersion, latestUserRecordVersion, ur.onUpdateUser)
})
eg.Go(func() error {
return ur.runReporter(ctx)
@ -110,7 +124,7 @@ func (ur *UsageReporter) runSync(ctx context.Context, client databroker.DataBrok
func (ur *UsageReporter) runReporter(ctx context.Context) error {
// every minute collect any updates and submit them to the API
timer := time.NewTicker(time.Minute)
timer := time.NewTicker(ur.reportInterval)
defer timer.Stop()
for {