mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-25 14:08:09 +02:00
add test, fix bug
This commit is contained in:
parent
edaf99b800
commit
71e306f76c
3 changed files with 154 additions and 18 deletions
|
@ -217,7 +217,7 @@ func (c *controller) runUsageReporter(ctx context.Context, client databroker.Dat
|
||||||
return fmt.Errorf("error waiting for bootstrap: %w", err)
|
return fmt.Errorf("error waiting for bootstrap: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ur := usagereporter.New(c.api, c.bootstrapConfig.GetConfig().ZeroOrganizationID)
|
ur := usagereporter.New(c.api, c.bootstrapConfig.GetConfig().ZeroOrganizationID, time.Minute)
|
||||||
return retry.WithBackoff(ctx, "zero-usage-reporter", func(ctx context.Context) error {
|
return retry.WithBackoff(ctx, "zero-usage-reporter", func(ctx context.Context) error {
|
||||||
// start the usage reporter
|
// start the usage reporter
|
||||||
return ur.Run(ctx, client)
|
return ur.Run(ctx, client)
|
||||||
|
|
|
@ -18,7 +18,6 @@ import (
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
sdk "github.com/pomerium/pomerium/internal/zero/api"
|
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
@ -26,6 +25,11 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/zero/cluster"
|
"github.com/pomerium/pomerium/pkg/zero/cluster"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// API is the part of the Zero Cluster API used to report usage.
|
||||||
|
type API interface {
|
||||||
|
ReportUsage(ctx context.Context, req cluster.ReportUsageRequest) error
|
||||||
|
}
|
||||||
|
|
||||||
type usageReporterRecord struct {
|
type usageReporterRecord struct {
|
||||||
userID string
|
userID string
|
||||||
userEmail string
|
userEmail string
|
||||||
|
@ -34,8 +38,9 @@ type usageReporterRecord struct {
|
||||||
|
|
||||||
// A UsageReporter reports usage to the zero api.
|
// A UsageReporter reports usage to the zero api.
|
||||||
type UsageReporter struct {
|
type UsageReporter struct {
|
||||||
api *sdk.API
|
api API
|
||||||
organizationID string
|
organizationID string
|
||||||
|
reportInterval time.Duration
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
byUserID map[string]usageReporterRecord
|
byUserID map[string]usageReporterRecord
|
||||||
|
@ -43,12 +48,14 @@ type UsageReporter struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new UsageReporter.
|
// New creates a new UsageReporter.
|
||||||
func New(api *sdk.API, organizationID string) *UsageReporter {
|
func New(api API, organizationID string, reportInterval time.Duration) *UsageReporter {
|
||||||
return &UsageReporter{
|
return &UsageReporter{
|
||||||
api: api,
|
api: api,
|
||||||
organizationID: organizationID,
|
organizationID: organizationID,
|
||||||
byUserID: make(map[string]usageReporterRecord),
|
reportInterval: reportInterval,
|
||||||
updates: set.New[string](0),
|
|
||||||
|
byUserID: make(map[string]usageReporterRecord),
|
||||||
|
updates: set.New[string](0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -57,13 +64,13 @@ func (ur *UsageReporter) Run(ctx context.Context, client databroker.DataBrokerSe
|
||||||
ctx = log.Ctx(ctx).With().Str("organization-id", ur.organizationID).Logger().WithContext(ctx)
|
ctx = log.Ctx(ctx).With().Str("organization-id", ur.organizationID).Logger().WithContext(ctx)
|
||||||
|
|
||||||
// first initialize the user collection
|
// first initialize the user collection
|
||||||
serverVersion, latestRecordVersion, err := ur.runInit(ctx, client)
|
serverVersion, latestSessionRecordVersion, latestUserRecordVersion, err := ur.runInit(ctx, client)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// run the continuous sync calls and periodically report usage
|
// run the continuous sync calls and periodically report usage
|
||||||
return ur.runSync(ctx, client, serverVersion, latestRecordVersion)
|
return ur.runSync(ctx, client, serverVersion, latestSessionRecordVersion, latestUserRecordVersion)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ur *UsageReporter) report(ctx context.Context, records []usageReporterRecord) error {
|
func (ur *UsageReporter) report(ctx context.Context, records []usageReporterRecord) error {
|
||||||
|
@ -80,27 +87,34 @@ func (ur *UsageReporter) report(ctx context.Context, records []usageReporterReco
|
||||||
}, backoff.WithContext(backoff.NewExponentialBackOff(), ctx))
|
}, backoff.WithContext(backoff.NewExponentialBackOff(), ctx))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ur *UsageReporter) runInit(ctx context.Context, client databroker.DataBrokerServiceClient) (serverVersion, latestRecordVersion uint64, err error) {
|
func (ur *UsageReporter) runInit(
|
||||||
_, _, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateSession)
|
ctx context.Context,
|
||||||
|
client databroker.DataBrokerServiceClient,
|
||||||
|
) (serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64, err error) {
|
||||||
|
_, latestSessionRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateSession)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, err
|
return 0, 0, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
serverVersion, latestRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateUser)
|
serverVersion, latestUserRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateUser)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, 0, err
|
return 0, 0, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return serverVersion, latestRecordVersion, nil
|
return serverVersion, latestSessionRecordVersion, latestUserRecordVersion, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ur *UsageReporter) runSync(ctx context.Context, client databroker.DataBrokerServiceClient, serverVersion, latestRecordVersion uint64) error {
|
func (ur *UsageReporter) runSync(
|
||||||
|
ctx context.Context,
|
||||||
|
client databroker.DataBrokerServiceClient,
|
||||||
|
serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64,
|
||||||
|
) error {
|
||||||
eg, ctx := errgroup.WithContext(ctx)
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
eg.Go(func() error {
|
eg.Go(func() error {
|
||||||
return databroker.SyncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateSession)
|
return databroker.SyncRecords(ctx, client, serverVersion, latestSessionRecordVersion, ur.onUpdateSession)
|
||||||
})
|
})
|
||||||
eg.Go(func() error {
|
eg.Go(func() error {
|
||||||
return databroker.SyncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateUser)
|
return databroker.SyncRecords(ctx, client, serverVersion, latestUserRecordVersion, ur.onUpdateUser)
|
||||||
})
|
})
|
||||||
eg.Go(func() error {
|
eg.Go(func() error {
|
||||||
return ur.runReporter(ctx)
|
return ur.runReporter(ctx)
|
||||||
|
@ -110,7 +124,7 @@ func (ur *UsageReporter) runSync(ctx context.Context, client databroker.DataBrok
|
||||||
|
|
||||||
func (ur *UsageReporter) runReporter(ctx context.Context) error {
|
func (ur *UsageReporter) runReporter(ctx context.Context) error {
|
||||||
// every minute collect any updates and submit them to the API
|
// every minute collect any updates and submit them to the API
|
||||||
timer := time.NewTicker(time.Minute)
|
timer := time.NewTicker(ur.reportInterval)
|
||||||
defer timer.Stop()
|
defer timer.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
|
|
|
@ -1,14 +1,125 @@
|
||||||
package usagereporter
|
package usagereporter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/databroker"
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
"github.com/pomerium/pomerium/pkg/zero/cluster"
|
"github.com/pomerium/pomerium/pkg/zero/cluster"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type mockAPI struct {
|
||||||
|
reportUsage func(ctx context.Context, req cluster.ReportUsageRequest) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockAPI) ReportUsage(ctx context.Context, req cluster.ReportUsageRequest) error {
|
||||||
|
return m.reportUsage(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageReporter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
t.Cleanup(clearTimeout)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
||||||
|
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New())
|
||||||
|
})
|
||||||
|
t.Cleanup(func() { cc.Close() })
|
||||||
|
|
||||||
|
tm1 := time.Date(2024, time.September, 11, 11, 56, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
requests := make(chan cluster.ReportUsageRequest, 1)
|
||||||
|
|
||||||
|
client := databrokerpb.NewDataBrokerServiceClient(cc)
|
||||||
|
ur := New(mockAPI{
|
||||||
|
reportUsage: func(ctx context.Context, req cluster.ReportUsageRequest) error {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case requests <- req:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, "bQjwPpxcwJRbvsSMFgbZFkXmxFJ", time.Millisecond*100)
|
||||||
|
|
||||||
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
|
eg.Go(func() error {
|
||||||
|
return ur.Run(ctx, client)
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
_, err := databrokerpb.Put(ctx, client,
|
||||||
|
&session.Session{
|
||||||
|
Id: "S1a",
|
||||||
|
UserId: "U1",
|
||||||
|
IssuedAt: timestamppb.New(tm1),
|
||||||
|
},
|
||||||
|
&session.Session{
|
||||||
|
Id: "S1b",
|
||||||
|
UserId: "U1",
|
||||||
|
IssuedAt: timestamppb.New(tm1),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case req := <-requests:
|
||||||
|
assert.Equal(t, cluster.ReportUsageRequest{
|
||||||
|
Users: []cluster.ReportUsageUser{{
|
||||||
|
LastSignedInAt: tm1,
|
||||||
|
PseudonymousId: "095xqqsjEEgYf5Yf+TAjWjooMQyh6jSV5SCPGe9eqvg=",
|
||||||
|
}},
|
||||||
|
}, req, "should send a single usage record")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = databrokerpb.Put(ctx, client,
|
||||||
|
&user.User{
|
||||||
|
Id: "U1",
|
||||||
|
Email: "u1@example.com",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case req := <-requests:
|
||||||
|
assert.Equal(t, cluster.ReportUsageRequest{
|
||||||
|
Users: []cluster.ReportUsageUser{{
|
||||||
|
LastSignedInAt: tm1,
|
||||||
|
PseudonymousEmail: "iq8/fj+uZaKitkWY12JIQgKJ5KIP+E0Cmy/HpxpdBXY=",
|
||||||
|
PseudonymousId: "095xqqsjEEgYf5Yf+TAjWjooMQyh6jSV5SCPGe9eqvg=",
|
||||||
|
}},
|
||||||
|
}, req, "should send another usage record with the email set")
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
err := eg.Wait()
|
||||||
|
if err != nil && !errors.Is(ctx.Err(), context.Canceled) {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func Test_coalesce(t *testing.T) {
|
func Test_coalesce(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
|
@ -40,3 +151,14 @@ func Test_convertUsageReporterRecords(t *testing.T) {
|
||||||
lastSignedInAt: tm1,
|
lastSignedInAt: tm1,
|
||||||
}}), "should leave empty email")
|
}}), "should leave empty email")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_latest(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tm1 := time.Date(2024, time.September, 11, 11, 56, 0, 0, time.UTC)
|
||||||
|
tm2 := time.Date(2024, time.September, 12, 11, 56, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
assert.Equal(t, tm2, latest(tm1, tm2))
|
||||||
|
assert.Equal(t, tm2, latest(tm2, tm1), "should ignore ordering")
|
||||||
|
assert.Equal(t, tm1, latest(tm1, time.Time{}), "should handle zero time")
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue