refactor to use sync

This commit is contained in:
Caleb Doxsey 2024-08-21 11:57:05 -06:00
parent 08e9e826ce
commit a9d8deed49

View file

@ -2,14 +2,22 @@ package controller
import ( import (
"context" "context"
"errors"
"fmt"
"io"
"sync" "sync"
"time" "time"
"github.com/cenkalti/backoff/v4"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
sdk "github.com/pomerium/pomerium/internal/zero/api" sdk "github.com/pomerium/pomerium/internal/zero/api"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/protoutil"
"github.com/pomerium/pomerium/pkg/zero/cluster" "github.com/pomerium/pomerium/pkg/zero/cluster"
) )
@ -25,12 +33,14 @@ type usageReporter struct {
mu sync.Mutex mu sync.Mutex
byUserID map[string]usageReporterRecord byUserID map[string]usageReporterRecord
updates map[string]struct{}
} }
func newUsageReporter(api *sdk.API) *usageReporter { func newUsageReporter(api *sdk.API) *usageReporter {
return &usageReporter{ return &usageReporter{
api: api, api: api,
byUserID: make(map[string]usageReporterRecord), byUserID: make(map[string]usageReporterRecord),
updates: make(map[string]struct{}),
} }
} }
@ -44,25 +54,76 @@ func (ur *usageReporter) report(ctx context.Context, records []usageReporterReco
Id: record.userID, Id: record.userID,
}) })
} }
log.Info(ctx).Int("users", len(req.Users)).Msg("reporting usage")
// if there were no updates there's nothing to do
if len(req.Users) == 0 {
return nil
}
return ur.api.ReportUsage(ctx, req) return ur.api.ReportUsage(ctx, req)
} }
func (ur *usageReporter) run(ctx context.Context, client databroker.DataBrokerServiceClient) error { func (ur *usageReporter) run(ctx context.Context, client databroker.DataBrokerServiceClient) error {
timer := time.NewTicker(time.Hour) // first initialize the user collection
serverVersion, latestRecordVersion, err := ur.runInit(ctx, client)
if err != nil {
return err
}
// run the continuous sync calls and periodically report usage
return ur.runSync(ctx, client, serverVersion, latestRecordVersion)
}
func (ur *usageReporter) runInit(ctx context.Context, client databroker.DataBrokerServiceClient) (serverVersion, latestRecordVersion uint64, err error) {
_, _, err = syncLatestRecords(ctx, client, ur.onUpdateSession)
if err != nil {
return 0, 0, err
}
serverVersion, latestRecordVersion, err = syncLatestRecords(ctx, client, ur.onUpdateUser)
if err != nil {
return 0, 0, err
}
return serverVersion, latestRecordVersion, nil
}
func (ur *usageReporter) runSync(ctx context.Context, client databroker.DataBrokerServiceClient, serverVersion, latestRecordVersion uint64) error {
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
return syncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateSession)
})
eg.Go(func() error {
return syncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateUser)
})
eg.Go(func() error {
return ur.runReporter(ctx)
})
return eg.Wait()
}
func (ur *usageReporter) runReporter(ctx context.Context) error {
// every minute collect any updates and submit them to the API
timer := time.NewTicker(time.Minute)
defer timer.Stop() defer timer.Stop()
for { for {
err := ur.runOnce(ctx, client) // collect the updated records since last run
if err != nil { ur.mu.Lock()
log.Error(ctx).Err(err).Msg("failed to report usage") records := make([]usageReporterRecord, 0, len(ur.updates))
for userID := range ur.updates {
records = append(records, ur.byUserID[userID])
}
clear(ur.updates)
ur.mu.Unlock()
// report the records with a backoff in case the API is temporarily unavailable
if len(records) > 0 {
log.Info(ctx).Int("updated-users", len(records)).Msg("reporting usage")
err := backoff.Retry(func() error {
err := ur.report(ctx, records)
if err != nil {
log.Error(ctx).Err(err).Msg("error reporting usage")
}
return err
}, backoff.WithContext(backoff.NewExponentialBackOff(), ctx))
if err != nil {
return err
}
} }
select { select {
@ -73,85 +134,45 @@ func (ur *usageReporter) run(ctx context.Context, client databroker.DataBrokerSe
} }
} }
func (ur *usageReporter) runOnce(ctx context.Context, client databroker.DataBrokerServiceClient) error { func (ur *usageReporter) onUpdateSession(s *session.Session) {
updated, err := ur.update(ctx, client) userID := s.GetUserId()
if err != nil { if userID == "" {
return err // ignore sessions without a user id
return
} }
err = ur.report(ctx, updated)
if err != nil {
return err
}
return nil
}
func (ur *usageReporter) update(ctx context.Context, client databroker.DataBrokerServiceClient) ([]usageReporterRecord, error) {
updatedUserIDs := map[string]struct{}{}
ur.mu.Lock() ur.mu.Lock()
defer ur.mu.Unlock() defer ur.mu.Unlock()
// delete old records r := ur.byUserID[userID]
now := time.Now() nr := r
for userID, r := range ur.byUserID { nr.accessedAt = latest(nr.accessedAt, s.GetIssuedAt().AsTime())
if r.accessedAt.Add(24 * time.Hour).Before(now) { nr.userID = userID
delete(ur.byUserID, userID) if nr != r {
} ur.byUserID[userID] = nr
ur.updates[userID] = struct{}{}
}
}
func (ur *usageReporter) onUpdateUser(u *user.User) {
userID := u.GetId()
if userID == "" {
// ignore users without a user id
return
} }
// create records for all the sessions ur.mu.Lock()
for s, err := range databroker.IterateAll[session.Session](ctx, client) { defer ur.mu.Unlock()
if err != nil {
return nil, err
}
userID := s.Object.GetUserId() r := ur.byUserID[userID]
if userID == "" { nr := r
continue nr.userID = userID
} nr.userDisplayName = u.GetName()
nr.userEmail = u.GetEmail()
r := ur.byUserID[userID] if nr != r {
nr := r ur.byUserID[userID] = nr
nr.accessedAt = latest(nr.accessedAt, s.Object.GetIssuedAt().AsTime()) ur.updates[userID] = struct{}{}
nr.userID = userID
if r != nr {
updatedUserIDs[userID] = struct{}{}
ur.byUserID[userID] = nr
}
} }
// fill in user names and emails
for u, err := range databroker.IterateAll[user.User](ctx, client) {
if err != nil {
return nil, err
}
userID := u.GetId()
if userID == "" {
continue
}
r, ok := ur.byUserID[userID]
if !ok {
// ignore sessionless users
continue
}
nr := r
nr.userDisplayName = u.Object.GetName()
nr.userEmail = u.Object.GetEmail()
if r != nr {
updatedUserIDs[userID] = struct{}{}
ur.byUserID[userID] = nr
}
}
var updated []usageReporterRecord
for key := range updatedUserIDs {
updated = append(updated, ur.byUserID[key])
}
return updated, nil
} }
func latest(t1, t2 time.Time) time.Time { func latest(t1, t2 time.Time) time.Time {
@ -160,3 +181,98 @@ func latest(t1, t2 time.Time) time.Time {
} }
return t1 return t1
} }
func syncRecords[T any, TMessage interface {
*T
proto.Message
}](
ctx context.Context,
client databroker.DataBrokerServiceClient,
serverVersion, latestRecordVersion uint64,
fn func(TMessage),
) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var msg TMessage = new(T)
stream, err := client.Sync(ctx, &databroker.SyncRequest{
Type: protoutil.GetTypeURL(msg),
ServerVersion: serverVersion,
RecordVersion: latestRecordVersion,
})
if err != nil {
return fmt.Errorf("error syncing %T: %w", msg, err)
}
for {
res, err := stream.Recv()
switch {
case errors.Is(err, io.EOF):
return nil
case err != nil:
return fmt.Errorf("error receiving record for %T: %w", msg, err)
}
msg = new(T)
err = res.GetRecord().GetData().UnmarshalTo(msg)
if err != nil {
log.Error(ctx).Err(err).
Str("record-type", res.Record.Type).
Str("record-id", res.Record.GetId()).
Msgf("unexpected data in %T stream", msg)
continue
}
fn(msg)
}
}
func syncLatestRecords[T any, TMessage interface {
*T
proto.Message
}](
ctx context.Context,
client databroker.DataBrokerServiceClient,
fn func(TMessage),
) (serverVersion, latestRecordVersion uint64, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var msg TMessage = new(T)
stream, err := client.SyncLatest(ctx, &databroker.SyncLatestRequest{
Type: protoutil.GetTypeURL(msg),
})
if err != nil {
return 0, 0, fmt.Errorf("error syncing latest %T: %w", msg, err)
}
for {
res, err := stream.Recv()
switch {
case errors.Is(err, io.EOF):
return serverVersion, latestRecordVersion, nil
case err != nil:
return 0, 0, fmt.Errorf("error receiving record for latest %T: %w", msg, err)
}
switch res := res.GetResponse().(type) {
case *databroker.SyncLatestResponse_Versions:
serverVersion = res.Versions.GetServerVersion()
latestRecordVersion = res.Versions.GetLatestRecordVersion()
case *databroker.SyncLatestResponse_Record:
msg = new(T)
err = res.Record.GetData().UnmarshalTo(msg)
if err != nil {
log.Error(ctx).Err(err).
Str("record-type", res.Record.Type).
Str("record-id", res.Record.GetId()).
Msgf("unexpected data in latest %T stream", msg)
continue
}
fn(msg)
default:
panic(fmt.Sprintf("unexpected response: %T", res))
}
}
}