mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 02:09:15 +02:00
refactor to use sync
This commit is contained in:
parent
08e9e826ce
commit
a9d8deed49
1 changed files with 198 additions and 82 deletions
|
@ -2,14 +2,22 @@ package controller
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
sdk "github.com/pomerium/pomerium/internal/zero/api"
|
sdk "github.com/pomerium/pomerium/internal/zero/api"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
"github.com/pomerium/pomerium/pkg/zero/cluster"
|
"github.com/pomerium/pomerium/pkg/zero/cluster"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -25,12 +33,14 @@ type usageReporter struct {
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
byUserID map[string]usageReporterRecord
|
byUserID map[string]usageReporterRecord
|
||||||
|
updates map[string]struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUsageReporter(api *sdk.API) *usageReporter {
|
func newUsageReporter(api *sdk.API) *usageReporter {
|
||||||
return &usageReporter{
|
return &usageReporter{
|
||||||
api: api,
|
api: api,
|
||||||
byUserID: make(map[string]usageReporterRecord),
|
byUserID: make(map[string]usageReporterRecord),
|
||||||
|
updates: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -44,25 +54,76 @@ func (ur *usageReporter) report(ctx context.Context, records []usageReporterReco
|
||||||
Id: record.userID,
|
Id: record.userID,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info(ctx).Int("users", len(req.Users)).Msg("reporting usage")
|
|
||||||
|
|
||||||
// if there were no updates there's nothing to do
|
|
||||||
if len(req.Users) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return ur.api.ReportUsage(ctx, req)
|
return ur.api.ReportUsage(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ur *usageReporter) run(ctx context.Context, client databroker.DataBrokerServiceClient) error {
|
func (ur *usageReporter) run(ctx context.Context, client databroker.DataBrokerServiceClient) error {
|
||||||
timer := time.NewTicker(time.Hour)
|
// first initialize the user collection
|
||||||
|
serverVersion, latestRecordVersion, err := ur.runInit(ctx, client)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// run the continuous sync calls and periodically report usage
|
||||||
|
return ur.runSync(ctx, client, serverVersion, latestRecordVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ur *usageReporter) runInit(ctx context.Context, client databroker.DataBrokerServiceClient) (serverVersion, latestRecordVersion uint64, err error) {
|
||||||
|
_, _, err = syncLatestRecords(ctx, client, ur.onUpdateSession)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
serverVersion, latestRecordVersion, err = syncLatestRecords(ctx, client, ur.onUpdateUser)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return serverVersion, latestRecordVersion, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ur *usageReporter) runSync(ctx context.Context, client databroker.DataBrokerServiceClient, serverVersion, latestRecordVersion uint64) error {
|
||||||
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
|
eg.Go(func() error {
|
||||||
|
return syncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateSession)
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
return syncRecords(ctx, client, serverVersion, latestRecordVersion, ur.onUpdateUser)
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
return ur.runReporter(ctx)
|
||||||
|
})
|
||||||
|
return eg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ur *usageReporter) runReporter(ctx context.Context) error {
|
||||||
|
// every minute collect any updates and submit them to the API
|
||||||
|
timer := time.NewTicker(time.Minute)
|
||||||
defer timer.Stop()
|
defer timer.Stop()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
err := ur.runOnce(ctx, client)
|
// collect the updated records since last run
|
||||||
if err != nil {
|
ur.mu.Lock()
|
||||||
log.Error(ctx).Err(err).Msg("failed to report usage")
|
records := make([]usageReporterRecord, 0, len(ur.updates))
|
||||||
|
for userID := range ur.updates {
|
||||||
|
records = append(records, ur.byUserID[userID])
|
||||||
|
}
|
||||||
|
clear(ur.updates)
|
||||||
|
ur.mu.Unlock()
|
||||||
|
|
||||||
|
// report the records with a backoff in case the API is temporarily unavailable
|
||||||
|
if len(records) > 0 {
|
||||||
|
log.Info(ctx).Int("updated-users", len(records)).Msg("reporting usage")
|
||||||
|
err := backoff.Retry(func() error {
|
||||||
|
err := ur.report(ctx, records)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx).Err(err).Msg("error reporting usage")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}, backoff.WithContext(backoff.NewExponentialBackOff(), ctx))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
@ -73,85 +134,45 @@ func (ur *usageReporter) run(ctx context.Context, client databroker.DataBrokerSe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ur *usageReporter) runOnce(ctx context.Context, client databroker.DataBrokerServiceClient) error {
|
func (ur *usageReporter) onUpdateSession(s *session.Session) {
|
||||||
updated, err := ur.update(ctx, client)
|
userID := s.GetUserId()
|
||||||
if err != nil {
|
if userID == "" {
|
||||||
return err
|
// ignore sessions without a user id
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = ur.report(ctx, updated)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ur *usageReporter) update(ctx context.Context, client databroker.DataBrokerServiceClient) ([]usageReporterRecord, error) {
|
|
||||||
updatedUserIDs := map[string]struct{}{}
|
|
||||||
|
|
||||||
ur.mu.Lock()
|
ur.mu.Lock()
|
||||||
defer ur.mu.Unlock()
|
defer ur.mu.Unlock()
|
||||||
|
|
||||||
// delete old records
|
r := ur.byUserID[userID]
|
||||||
now := time.Now()
|
nr := r
|
||||||
for userID, r := range ur.byUserID {
|
nr.accessedAt = latest(nr.accessedAt, s.GetIssuedAt().AsTime())
|
||||||
if r.accessedAt.Add(24 * time.Hour).Before(now) {
|
nr.userID = userID
|
||||||
delete(ur.byUserID, userID)
|
if nr != r {
|
||||||
}
|
ur.byUserID[userID] = nr
|
||||||
|
ur.updates[userID] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ur *usageReporter) onUpdateUser(u *user.User) {
|
||||||
|
userID := u.GetId()
|
||||||
|
if userID == "" {
|
||||||
|
// ignore users without a user id
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// create records for all the sessions
|
ur.mu.Lock()
|
||||||
for s, err := range databroker.IterateAll[session.Session](ctx, client) {
|
defer ur.mu.Unlock()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
userID := s.Object.GetUserId()
|
r := ur.byUserID[userID]
|
||||||
if userID == "" {
|
nr := r
|
||||||
continue
|
nr.userID = userID
|
||||||
}
|
nr.userDisplayName = u.GetName()
|
||||||
|
nr.userEmail = u.GetEmail()
|
||||||
r := ur.byUserID[userID]
|
if nr != r {
|
||||||
nr := r
|
ur.byUserID[userID] = nr
|
||||||
nr.accessedAt = latest(nr.accessedAt, s.Object.GetIssuedAt().AsTime())
|
ur.updates[userID] = struct{}{}
|
||||||
nr.userID = userID
|
|
||||||
if r != nr {
|
|
||||||
updatedUserIDs[userID] = struct{}{}
|
|
||||||
ur.byUserID[userID] = nr
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// fill in user names and emails
|
|
||||||
for u, err := range databroker.IterateAll[user.User](ctx, client) {
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
userID := u.GetId()
|
|
||||||
if userID == "" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
r, ok := ur.byUserID[userID]
|
|
||||||
if !ok {
|
|
||||||
// ignore sessionless users
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
nr := r
|
|
||||||
nr.userDisplayName = u.Object.GetName()
|
|
||||||
nr.userEmail = u.Object.GetEmail()
|
|
||||||
if r != nr {
|
|
||||||
updatedUserIDs[userID] = struct{}{}
|
|
||||||
ur.byUserID[userID] = nr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var updated []usageReporterRecord
|
|
||||||
for key := range updatedUserIDs {
|
|
||||||
updated = append(updated, ur.byUserID[key])
|
|
||||||
}
|
|
||||||
return updated, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func latest(t1, t2 time.Time) time.Time {
|
func latest(t1, t2 time.Time) time.Time {
|
||||||
|
@ -160,3 +181,98 @@ func latest(t1, t2 time.Time) time.Time {
|
||||||
}
|
}
|
||||||
return t1
|
return t1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func syncRecords[T any, TMessage interface {
|
||||||
|
*T
|
||||||
|
proto.Message
|
||||||
|
}](
|
||||||
|
ctx context.Context,
|
||||||
|
client databroker.DataBrokerServiceClient,
|
||||||
|
serverVersion, latestRecordVersion uint64,
|
||||||
|
fn func(TMessage),
|
||||||
|
) error {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var msg TMessage = new(T)
|
||||||
|
stream, err := client.Sync(ctx, &databroker.SyncRequest{
|
||||||
|
Type: protoutil.GetTypeURL(msg),
|
||||||
|
ServerVersion: serverVersion,
|
||||||
|
RecordVersion: latestRecordVersion,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error syncing %T: %w", msg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
res, err := stream.Recv()
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
return nil
|
||||||
|
case err != nil:
|
||||||
|
return fmt.Errorf("error receiving record for %T: %w", msg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = new(T)
|
||||||
|
err = res.GetRecord().GetData().UnmarshalTo(msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx).Err(err).
|
||||||
|
Str("record-type", res.Record.Type).
|
||||||
|
Str("record-id", res.Record.GetId()).
|
||||||
|
Msgf("unexpected data in %T stream", msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fn(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func syncLatestRecords[T any, TMessage interface {
|
||||||
|
*T
|
||||||
|
proto.Message
|
||||||
|
}](
|
||||||
|
ctx context.Context,
|
||||||
|
client databroker.DataBrokerServiceClient,
|
||||||
|
fn func(TMessage),
|
||||||
|
) (serverVersion, latestRecordVersion uint64, err error) {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var msg TMessage = new(T)
|
||||||
|
stream, err := client.SyncLatest(ctx, &databroker.SyncLatestRequest{
|
||||||
|
Type: protoutil.GetTypeURL(msg),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("error syncing latest %T: %w", msg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
res, err := stream.Recv()
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
return serverVersion, latestRecordVersion, nil
|
||||||
|
case err != nil:
|
||||||
|
return 0, 0, fmt.Errorf("error receiving record for latest %T: %w", msg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch res := res.GetResponse().(type) {
|
||||||
|
case *databroker.SyncLatestResponse_Versions:
|
||||||
|
serverVersion = res.Versions.GetServerVersion()
|
||||||
|
latestRecordVersion = res.Versions.GetLatestRecordVersion()
|
||||||
|
case *databroker.SyncLatestResponse_Record:
|
||||||
|
msg = new(T)
|
||||||
|
err = res.Record.GetData().UnmarshalTo(msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Error(ctx).Err(err).
|
||||||
|
Str("record-type", res.Record.Type).
|
||||||
|
Str("record-id", res.Record.GetId()).
|
||||||
|
Msgf("unexpected data in latest %T stream", msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fn(msg)
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unexpected response: %T", res))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue