From c4dd965f2d63a7cb4f9bf58757e51966be90dc9d Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Mon, 11 Dec 2023 13:37:01 -0500 Subject: [PATCH] zero/telemetry: calculate DAU and MAU (#4810) --- internal/sets/hash.go | 9 ++ internal/zero/analytics/activeusers.go | 80 ++++++++++++ internal/zero/analytics/activeusers_test.go | 37 ++++++ internal/zero/analytics/collector.go | 131 ++++++++++++++++++++ internal/zero/analytics/sessions.go | 53 ++++++++ internal/zero/analytics/storage.go | 126 +++++++++++++++++++ internal/zero/analytics/storage_test.go | 29 +++++ internal/zero/controller/controller.go | 13 ++ pkg/grpc/databroker/databroker.go | 7 ++ 9 files changed, 485 insertions(+) create mode 100644 internal/zero/analytics/activeusers.go create mode 100644 internal/zero/analytics/activeusers_test.go create mode 100644 internal/zero/analytics/collector.go create mode 100644 internal/zero/analytics/sessions.go create mode 100644 internal/zero/analytics/storage.go create mode 100644 internal/zero/analytics/storage_test.go diff --git a/internal/sets/hash.go b/internal/sets/hash.go index 1529a9c44..e18083852 100644 --- a/internal/sets/hash.go +++ b/internal/sets/hash.go @@ -31,3 +31,12 @@ func (s *Hash[T]) Has(element T) bool { func (s *Hash[T]) Size() int { return len(s.m) } + +// Items returns the set's elements as a slice. +func (s *Hash[T]) Items() []T { + items := make([]T, 0, len(s.m)) + for item := range s.m { + items = append(items, item) + } + return items +} diff --git a/internal/zero/analytics/activeusers.go b/internal/zero/analytics/activeusers.go new file mode 100644 index 000000000..b5dffeacf --- /dev/null +++ b/internal/zero/analytics/activeusers.go @@ -0,0 +1,80 @@ +package analytics + +import ( + "time" + + "github.com/pomerium/pomerium/pkg/counter" +) + +const ( + // activeUsersCap is the number of active users we would be able to track. + // for counter to work within the 1% error limit, actual number should be 80% of the cap. + activeUsersCap = 10_000 +) + +// IntervalResetFunc is a function that determines if a counter should be reset +type IntervalResetFunc func(lastReset time.Time, now time.Time) bool + +// ResetMonthlyUTC resets the counter on a monthly interval +func ResetMonthlyUTC(lastReset time.Time, now time.Time) bool { + lastResetUTC := lastReset.UTC() + nowUTC := now.UTC() + return lastResetUTC.Year() != nowUTC.Year() || + lastResetUTC.Month() != nowUTC.Month() +} + +// ResetDailyUTC resets the counter on a daily interval +func ResetDailyUTC(lastReset time.Time, now time.Time) bool { + lastResetUTC := lastReset.UTC() + nowUTC := now.UTC() + return lastResetUTC.Year() != nowUTC.Year() || + lastResetUTC.Month() != nowUTC.Month() || + lastResetUTC.Day() != nowUTC.Day() +} + +// ActiveUsersCounter is a counter that resets on a given interval +type ActiveUsersCounter struct { + *counter.Counter + lastReset time.Time + needsReset IntervalResetFunc +} + +// NewActiveUsersCounter creates a new active users counter +func NewActiveUsersCounter(needsReset IntervalResetFunc, now time.Time) *ActiveUsersCounter { + return &ActiveUsersCounter{ + Counter: counter.New(activeUsersCap), + lastReset: now, + needsReset: needsReset, + } +} + +// LoadActiveUsersCounter loads an active users counter from a binary state +func LoadActiveUsersCounter(state []byte, lastReset time.Time, resetFn IntervalResetFunc) (*ActiveUsersCounter, error) { + c, err := counter.FromBinary(state) + if err != nil { + return nil, err + } + return &ActiveUsersCounter{ + Counter: c, + lastReset: lastReset, + needsReset: resetFn, + }, nil +} + +// Update updates the counter with the current users +func (c *ActiveUsersCounter) Update(users []string, now time.Time) (current uint, wasReset bool) { + if c.needsReset(c.lastReset, now) { + c.Counter.Reset() + c.lastReset = now + wasReset = true + } + for _, user := range users { + c.Mark(user) + } + return c.Count(), wasReset +} + +// GetLastReset returns the last time the counter was reset +func (c *ActiveUsersCounter) GetLastReset() time.Time { + return c.lastReset +} diff --git a/internal/zero/analytics/activeusers_test.go b/internal/zero/analytics/activeusers_test.go new file mode 100644 index 000000000..226365159 --- /dev/null +++ b/internal/zero/analytics/activeusers_test.go @@ -0,0 +1,37 @@ +package analytics_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/zero/analytics" +) + +func TestActiveUsers(t *testing.T) { + t.Parallel() + + startTime := time.Now().UTC() + + // Create a new counter that resets on a daily interval + c := analytics.NewActiveUsersCounter(analytics.ResetDailyUTC, startTime) + + count, wasReset := c.Update([]string{"user1", "user2"}, startTime.Add(time.Minute)) + assert.False(t, wasReset) + assert.EqualValues(t, 2, count) + + count, wasReset = c.Update([]string{"user1", "user2", "user3"}, startTime.Add(time.Minute*2)) + assert.False(t, wasReset) + assert.EqualValues(t, 3, count) + + // Update the counter with a new user after lapse + count, wasReset = c.Update([]string{"user1", "user2", "user3", "user4"}, startTime.Add(time.Hour*25)) + assert.True(t, wasReset) + assert.EqualValues(t, 4, count) + + // Update the counter with a new user after lapse + count, wasReset = c.Update([]string{"user4"}, startTime.Add(time.Hour*25*2)) + assert.True(t, wasReset) + assert.EqualValues(t, 1, count) +} diff --git a/internal/zero/analytics/collector.go b/internal/zero/analytics/collector.go new file mode 100644 index 000000000..051985a3d --- /dev/null +++ b/internal/zero/analytics/collector.go @@ -0,0 +1,131 @@ +// Package analytics collects active user metrics and reports them to the cloud dashboard +package analytics + +import ( + "context" + "fmt" + "time" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +// Collect collects metrics and reports them to the cloud +func Collect( + ctx context.Context, + client databroker.DataBrokerServiceClient, + updateInterval time.Duration, +) error { + c := &collector{ + client: client, + counters: make(map[string]*ActiveUsersCounter), + updateInterval: updateInterval, + } + + leaser := databroker.NewLeaser("pomerium-zero-analytics", c.leaseTTL(), c) + return leaser.Run(ctx) +} + +type collector struct { + client databroker.DataBrokerServiceClient + counters map[string]*ActiveUsersCounter + updateInterval time.Duration +} + +func (c *collector) RunLeased(ctx context.Context) error { + err := c.loadCounters(ctx) + if err != nil { + return fmt.Errorf("failed to load counters: %w", err) + } + + err = c.runPeriodicUpdate(ctx) + if err != nil { + return fmt.Errorf("failed to run periodic update: %w", err) + } + + return nil +} + +func (c *collector) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient { + return c.client +} + +func (c *collector) loadCounters(ctx context.Context) error { + now := time.Now() + for key, resetFn := range map[string]IntervalResetFunc{ + "mau": ResetMonthlyUTC, + "dau": ResetDailyUTC, + } { + state, err := LoadMetricState(ctx, c.client, key) + if err != nil && !databroker.IsNotFound(err) { + return err + } + if state == nil { + c.counters[key] = NewActiveUsersCounter(resetFn, now) + continue + } + + counter, err := LoadActiveUsersCounter(state.Data, state.LastReset, resetFn) + if err != nil { + log.Ctx(ctx).Error().Err(err).Str("metric", key).Msg("failed to load metric state, resetting") + counter = NewActiveUsersCounter(resetFn, now) + } + c.counters[key] = counter + } + + return nil +} + +func (c *collector) runPeriodicUpdate(ctx context.Context) error { + ticker := time.NewTicker(c.updateInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + if err := c.update(ctx); err != nil { + return err + } + } + } +} + +func (c *collector) update(ctx context.Context) error { + users, err := CurrentUsers(ctx, c.client) + if err != nil { + return fmt.Errorf("failed to get current users: %w", err) + } + + now := time.Now() + for key, counter := range c.counters { + before := counter.Count() + after, _ := counter.Update(users, now) + if before == after { + log.Ctx(ctx).Debug().Msgf("metric %s not changed: %d", key, counter.Count()) + continue + } + log.Ctx(ctx).Debug().Msgf("metric %s updated: %d", key, counter.Count()) + + data, err := counter.ToBinary() + if err != nil { + return fmt.Errorf("failed to marshal metric %s: %w", key, err) + } + + err = SaveMetricState(ctx, c.client, key, data, after, counter.GetLastReset()) + if err != nil { + return fmt.Errorf("failed to save metric %s: %w", key, err) + } + } + + return nil +} + +func (c *collector) leaseTTL() time.Duration { + const defaultTTL = time.Minute * 5 + if defaultTTL < c.updateInterval { + return defaultTTL + } + return c.updateInterval +} diff --git a/internal/zero/analytics/sessions.go b/internal/zero/analytics/sessions.go new file mode 100644 index 000000000..d7bb7ff53 --- /dev/null +++ b/internal/zero/analytics/sessions.go @@ -0,0 +1,53 @@ +package analytics + +import ( + "context" + "fmt" + "time" + + "github.com/pomerium/pomerium/internal/sets" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +var ( + sessionTypeURL = protoutil.GetTypeURL(new(session.Session)) +) + +// CurrentUsers returns a list of users active within the current UTC day +func CurrentUsers( + ctx context.Context, + client databroker.DataBrokerServiceClient, +) ([]string, error) { + records, _, _, err := databroker.InitialSync(ctx, client, &databroker.SyncLatestRequest{ + Type: sessionTypeURL, + }) + if err != nil { + return nil, fmt.Errorf("fetching sessions: %w", err) + } + + users := sets.NewHash[string]() + utcNow := time.Now().UTC() + threshold := time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC) + + for _, record := range records { + var s session.Session + err := record.GetData().UnmarshalTo(&s) + if err != nil { + return nil, fmt.Errorf("unmarshaling session: %w", err) + } + if s.UserId == "" { // session creation is in progress + continue + } + if s.AccessedAt == nil { + continue + } + if s.AccessedAt.AsTime().Before(threshold) { + continue + } + users.Add(s.UserId) + } + + return users.Items(), nil +} diff --git a/internal/zero/analytics/storage.go b/internal/zero/analytics/storage.go new file mode 100644 index 000000000..617df6426 --- /dev/null +++ b/internal/zero/analytics/storage.go @@ -0,0 +1,126 @@ +package analytics + +import ( + "context" + "encoding/base64" + "fmt" + "time" + + "google.golang.org/protobuf/types/known/anypb" + "google.golang.org/protobuf/types/known/structpb" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +const ( + metricStateTypeURL = "pomerium.io/ActiveUsersMetricState" +) + +// SaveMetricState saves the state of a metric to the databroker +func SaveMetricState( + ctx context.Context, + client databroker.DataBrokerServiceClient, + id string, + data []byte, + value uint, + lastReset time.Time, +) error { + _, err := client.Put(ctx, &databroker.PutRequest{ + Records: []*databroker.Record{{ + Type: metricStateTypeURL, + Id: id, + Data: (&MetricState{ + Data: data, + LastReset: lastReset, + Count: value, + }).ToAny(), + }}, + }) + return err +} + +// LoadMetricState loads the state of a metric from the databroker +func LoadMetricState( + ctx context.Context, client databroker.DataBrokerServiceClient, id string, +) (*MetricState, error) { + resp, err := client.Get(ctx, &databroker.GetRequest{ + Type: metricStateTypeURL, + Id: id, + }) + if err != nil { + return nil, fmt.Errorf("load metric state: %w", err) + } + + var state MetricState + err = state.FromAny(resp.GetRecord().GetData()) + if err != nil { + return nil, fmt.Errorf("load metric state: %w", err) + } + + return &state, nil +} + +// MetricState is the persistent state of a metric +type MetricState struct { + Data []byte + LastReset time.Time + Count uint +} + +const ( + countKey = "count" + dataKey = "data" + lastResetKey = "last_reset" +) + +// ToAny marshals a MetricState into an anypb.Any +func (r *MetricState) ToAny() *anypb.Any { + return protoutil.NewAny(&structpb.Struct{ + Fields: map[string]*structpb.Value{ + countKey: structpb.NewNumberValue(float64(r.Count)), + dataKey: structpb.NewStringValue(base64.StdEncoding.EncodeToString(r.Data)), + lastResetKey: structpb.NewStringValue(r.LastReset.Format(time.RFC3339)), + }, + }) +} + +// FromAny unmarshals an anypb.Any into a MetricState +func (r *MetricState) FromAny(any *anypb.Any) error { + var s structpb.Struct + err := any.UnmarshalTo(&s) + if err != nil { + return fmt.Errorf("unmarshal struct: %w", err) + } + + vData, ok := s.GetFields()[dataKey] + if !ok { + return fmt.Errorf("missing %s field", dataKey) + } + data, err := base64.StdEncoding.DecodeString(vData.GetStringValue()) + if err != nil { + return fmt.Errorf("decode state: %w", err) + } + if len(data) == 0 { + return fmt.Errorf("empty data") + } + + vLastReset, ok := s.GetFields()[lastResetKey] + if !ok { + return fmt.Errorf("missing %s field", lastResetKey) + } + lastReset, err := time.Parse(time.RFC3339, vLastReset.GetStringValue()) + if err != nil { + return fmt.Errorf("parse last reset: %w", err) + } + vCount, ok := s.GetFields()[countKey] + if !ok { + return fmt.Errorf("missing %s field", countKey) + } + + r.Data = data + r.LastReset = lastReset + r.Count = uint(vCount.GetNumberValue()) + + return nil +} diff --git a/internal/zero/analytics/storage_test.go b/internal/zero/analytics/storage_test.go new file mode 100644 index 000000000..53868403f --- /dev/null +++ b/internal/zero/analytics/storage_test.go @@ -0,0 +1,29 @@ +package analytics_test + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/pomerium/pomerium/internal/zero/analytics" +) + +func TestStorage(t *testing.T) { + t.Parallel() + + now := time.Date(2020, 1, 2, 3, 4, 5, 6, time.UTC) + state := &analytics.MetricState{ + Data: []byte("data"), + LastReset: now, + } + + pbany := state.ToAny() + assert.NotNil(t, pbany) + + var newState analytics.MetricState + err := newState.FromAny(pbany) + assert.NoError(t, err) + assert.EqualValues(t, state.Data, newState.Data) + assert.EqualValues(t, state.LastReset.Truncate(time.Second), newState.LastReset.Truncate(time.Second)) +} diff --git a/internal/zero/controller/controller.go b/internal/zero/controller/controller.go index 8e6501f62..dc4a9fde1 100644 --- a/internal/zero/controller/controller.go +++ b/internal/zero/controller/controller.go @@ -5,11 +5,13 @@ import ( "context" "errors" "fmt" + "time" "github.com/rs/zerolog" "golang.org/x/sync/errgroup" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/zero/analytics" "github.com/pomerium/pomerium/internal/zero/bootstrap" "github.com/pomerium/pomerium/internal/zero/reconciler" "github.com/pomerium/pomerium/pkg/cmd/pomerium" @@ -43,6 +45,7 @@ func Run(ctx context.Context, opts ...Option) error { eg.Go(func() error { return run(ctx, "pomerium-core", c.runPomeriumCore, src.WaitReady) }) eg.Go(func() error { return run(ctx, "zero-reconciler", c.runReconciler, src.WaitReady) }) eg.Go(func() error { return run(ctx, "connect-log", c.RunConnectLog, nil) }) + eg.Go(func() error { return run(ctx, "zero-analytics", c.runAnalytics, src.WaitReady) }) return eg.Wait() } @@ -117,3 +120,13 @@ func (c *controller) runReconciler(ctx context.Context) error { reconciler.WithDataBrokerClient(c.GetDataBrokerServiceClient()), ) } + +func (c *controller) runAnalytics(ctx context.Context) error { + err := analytics.Collect(ctx, c.GetDataBrokerServiceClient(), time.Second*30) + if err != nil && ctx.Err() == nil { + log.Ctx(ctx).Error().Err(err).Msg("error collecting analytics, disabling") + return nil + } + + return err +} diff --git a/pkg/grpc/databroker/databroker.go b/pkg/grpc/databroker/databroker.go index 08d7a82d3..214164093 100644 --- a/pkg/grpc/databroker/databroker.go +++ b/pkg/grpc/databroker/databroker.go @@ -8,6 +8,8 @@ import ( "fmt" "io" + "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" structpb "google.golang.org/protobuf/types/known/structpb" @@ -33,6 +35,11 @@ func NewRecord(object recordObject) *Record { } } +// IsNotFound returns true if the error is a not found error. +func IsNotFound(err error) bool { + return status.Code(err) == codes.NotFound +} + // Get gets a record from the databroker and unmarshals it into the object. func Get(ctx context.Context, client DataBrokerServiceClient, object recordObject) error { res, err := client.Get(ctx, &GetRequest{