mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-03 00:40:25 +02:00
add test, fix bug
This commit is contained in:
parent
edaf99b800
commit
71e306f76c
3 changed files with 154 additions and 18 deletions
|
@ -18,7 +18,6 @@ import (
|
|||
"golang.org/x/sync/errgroup"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
sdk "github.com/pomerium/pomerium/internal/zero/api"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
|
@ -26,6 +25,11 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/zero/cluster"
|
||||
)
|
||||
|
||||
// API is the part of the Zero Cluster API used to report usage.
|
||||
type API interface {
|
||||
ReportUsage(ctx context.Context, req cluster.ReportUsageRequest) error
|
||||
}
|
||||
|
||||
type usageReporterRecord struct {
|
||||
userID string
|
||||
userEmail string
|
||||
|
@ -34,8 +38,9 @@ type usageReporterRecord struct {
|
|||
|
||||
// A UsageReporter reports usage to the zero api.
|
||||
type UsageReporter struct {
|
||||
api *sdk.API
|
||||
api API
|
||||
organizationID string
|
||||
reportInterval time.Duration
|
||||
|
||||
mu sync.Mutex
|
||||
byUserID map[string]usageReporterRecord
|
||||
|
@ -43,12 +48,14 @@ type UsageReporter struct {
|
|||
}
|
||||
|
||||
// New creates a new UsageReporter.
|
||||
func New(api *sdk.API, organizationID string) *UsageReporter {
|
||||
func New(api API, organizationID string, reportInterval time.Duration) *UsageReporter {
|
||||
return &UsageReporter{
|
||||
api: api,
|
||||
organizationID: organizationID,
|
||||
byUserID: make(map[string]usageReporterRecord),
|
||||
updates: set.New[string](0),
|
||||
reportInterval: reportInterval,
|
||||
|
||||
byUserID: make(map[string]usageReporterRecord),
|
||||
updates: set.New[string](0),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -57,13 +64,13 @@ func (ur *UsageReporter) Run(ctx context.Context, client databroker.DataBrokerSe
|
|||
ctx = log.Ctx(ctx).With().Str("organization-id", ur.organizationID).Logger().WithContext(ctx)
|
||||
|
||||
// first initialize the user collection
|
||||
serverVersion, latestRecordVersion, err := ur.runInit(ctx, client)
|
||||
serverVersion, latestSessionRecordVersion, latestUserRecordVersion, err := ur.runInit(ctx, client)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// run the continuous sync calls and periodically report usage
|
||||
return ur.runSync(ctx, client, serverVersion, latestRecordVersion)
|
||||
return ur.runSync(ctx, client, serverVersion, latestSessionRecordVersion, latestUserRecordVersion)
|
||||
}
|
||||
|
||||
func (ur *UsageReporter) report(ctx context.Context, records []usageReporterRecord) error {
|
||||
|
@ -80,27 +87,34 @@ func (ur *UsageReporter) report(ctx context.Context, records []usageReporterReco
|
|||
}, backoff.WithContext(backoff.NewExponentialBackOff(), ctx))
|
||||
}
|
||||
|
||||
func (ur *UsageReporter) runInit(ctx context.Context, client databroker.DataBrokerServiceClient) (serverVersion, latestRecordVersion uint64, err error) {
|
||||
_, _, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateSession)
|
||||
func (ur *UsageReporter) runInit(
|
||||
ctx context.Context,
|
||||
client databroker.DataBrokerServiceClient,
|
||||
) (serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64, err error) {
|
||||
_, latestSessionRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateSession)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
|
||||
serverVersion, latestRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateUser)
|
||||
serverVersion, latestUserRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateUser)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
return 0, 0, 0, err
|
||||
}
|
||||
|
||||
return serverVersion, latestRecordVersion, nil
|
||||
return serverVersion, latestSessionRecordVersion, latestUserRecordVersion, nil
|
||||
}
|
||||
|
||||
func (ur *UsageReporter) runSync(ctx context.Context, client databroker.DataBrokerServiceClient, serverVersion, latestRecordVersion uint64) error {
|
||||
func (ur *UsageReporter) runSync(
|
||||
ctx context.Context,
|
||||
client databroker.DataBrokerServiceClient,
|
||||
serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64,
|
||||
) error {
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
return databroker.SyncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateSession)
|
||||
return databroker.SyncRecords(ctx, client, serverVersion, latestSessionRecordVersion, ur.onUpdateSession)
|
||||
})
|
||||
eg.Go(func() error {
|
||||
return databroker.SyncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateUser)
|
||||
return databroker.SyncRecords(ctx, client, serverVersion, latestUserRecordVersion, ur.onUpdateUser)
|
||||
})
|
||||
eg.Go(func() error {
|
||||
return ur.runReporter(ctx)
|
||||
|
@ -110,7 +124,7 @@ func (ur *UsageReporter) runSync(ctx context.Context, client databroker.DataBrok
|
|||
|
||||
func (ur *UsageReporter) runReporter(ctx context.Context) error {
|
||||
// every minute collect any updates and submit them to the API
|
||||
timer := time.NewTicker(time.Minute)
|
||||
timer := time.NewTicker(ur.reportInterval)
|
||||
defer timer.Stop()
|
||||
|
||||
for {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue