zero: refactor telemetry and controller (#5135)

* zero: refactor controller

* refactor zero telemetry and controller

* wire with connect handler

* cr
This commit is contained in:
Denis Mishin 2024-06-12 21:59:25 -04:00 committed by GitHub
parent cc636be707
commit 114f730dba
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 612 additions and 342 deletions

View file

@ -0,0 +1,31 @@
package reporter
import (
"go.opentelemetry.io/otel/sdk/metric"
)
type config struct {
producers map[string]*metricsProducer
}
type Option func(*config)
// WithProducer adds a metric producer to the reporter
func WithProducer(name string, p metric.Producer) Option {
return func(c *config) {
if _, ok := c.producers[name]; ok {
panic("duplicate producer name " + name)
}
c.producers[name] = newProducer(name, p)
}
}
func getConfig(opts ...Option) config {
c := config{
producers: make(map[string]*metricsProducer),
}
for _, opt := range opts {
opt(&c)
}
return c
}

View file

@ -0,0 +1,92 @@
package reporter
import (
"context"
"errors"
"fmt"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/exporters/otlp/otlptrace"
export_grpc "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracegrpc"
"go.opentelemetry.io/otel/sdk/resource"
trace_sdk "go.opentelemetry.io/otel/sdk/trace"
"go.opentelemetry.io/otel/trace"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/health"
)
type healthCheckReporter struct {
resource *resource.Resource
exporter *otlptrace.Exporter
provider *trace_sdk.TracerProvider
tracer trace.Tracer
}
// NewhealthCheckReporter creates a new unstarted health check healthCheckReporter
func newHealthCheckReporter(
conn *grpc.ClientConn,
resource *resource.Resource,
) *healthCheckReporter {
exporter := export_grpc.NewUnstarted(export_grpc.WithGRPCConn(conn))
processor := trace_sdk.NewBatchSpanProcessor(exporter)
provider := trace_sdk.NewTracerProvider(
trace_sdk.WithResource(resource),
trace_sdk.WithSampler(trace_sdk.AlwaysSample()),
trace_sdk.WithSpanProcessor(processor),
)
tracer := provider.Tracer(serviceName)
return &healthCheckReporter{
resource: resource,
exporter: exporter,
tracer: tracer,
provider: provider,
}
}
func (r *healthCheckReporter) Run(ctx context.Context) error {
err := r.exporter.Start(ctx)
if err != nil {
// this should not happen for the gRPC exporter as its non-blocking
return fmt.Errorf("error starting health check exporter: %w", err)
}
<-ctx.Done()
return nil
}
func (r *healthCheckReporter) Shutdown(ctx context.Context) error {
return errors.Join(
r.provider.Shutdown(ctx),
r.exporter.Shutdown(ctx),
)
}
// ReportOK implements health.Provider interface
func (r *healthCheckReporter) ReportOK(check health.Check, attr ...health.Attr) {
ctx := context.Background()
log.Ctx(ctx).Debug().Str("check", string(check)).Msg("health check ok")
_, span := r.tracer.Start(ctx, string(check))
span.SetStatus(codes.Ok, "")
setAttributes(span, attr...)
span.End()
}
// ReportError implements health.Provider interface
func (r *healthCheckReporter) ReportError(check health.Check, err error, attr ...health.Attr) {
ctx := context.Background()
log.Ctx(ctx).Warn().Str("check", string(check)).Err(err).Msg("health check error")
_, span := r.tracer.Start(ctx, string(check))
span.SetStatus(codes.Error, err.Error())
setAttributes(span, attr...)
span.End()
}
func setAttributes(span trace.Span, attr ...health.Attr) {
for _, a := range attr {
span.SetAttributes(attribute.String(a.Key, a.Value))
}
}

View file

@ -0,0 +1,47 @@
package reporter
import (
"context"
"sync/atomic"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
"github.com/pomerium/pomerium/internal/log"
)
type metricsProducer struct {
enabled atomic.Bool
name string
metric.Producer
}
func newProducer(name string, p metric.Producer) *metricsProducer {
return &metricsProducer{
name: name,
Producer: p,
}
}
var _ metric.Producer = (*metricsProducer)(nil)
// Produce wraps the underlying producer's Produce method and logs any errors,
// to prevent the error from blocking the export of other metrics.
// also checks if the producer is enabled before producing metrics
func (p *metricsProducer) Produce(ctx context.Context) ([]metricdata.ScopeMetrics, error) {
if enabled := p.enabled.Load(); !enabled {
return nil, nil
}
data, err := p.Producer.Produce(ctx)
if err != nil {
log.Error(ctx).Err(err).Str("producer", p.name).Msg("failed to produce metrics")
return nil, err
}
return data, nil
}
// SetEnabled sets the enabled state of the producer
func (p *metricsProducer) SetEnabled(v bool) {
p.enabled.Store(v)
}

View file

@ -0,0 +1,99 @@
package reporter
import (
"context"
"errors"
"fmt"
export_grpc "go.opentelemetry.io/otel/exporters/otlp/otlpmetric/otlpmetricgrpc"
metric_sdk "go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
"go.opentelemetry.io/otel/sdk/resource"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/health"
)
type metricsReporter struct {
exporter *export_grpc.Exporter
resource *resource.Resource
reader *metric_sdk.ManualReader
producers map[string]*metricsProducer
singleTask
}
func newMetricsReporter(
ctx context.Context,
conn *grpc.ClientConn,
resource *resource.Resource,
producers map[string]*metricsProducer,
) (*metricsReporter, error) {
exporter, err := export_grpc.New(ctx, export_grpc.WithGRPCConn(conn))
if err != nil {
return nil, fmt.Errorf("create exporter: %w", err)
}
readerOpts := make([]metric_sdk.ManualReaderOption, 0, len(producers))
for _, p := range producers {
readerOpts = append(readerOpts, metric_sdk.WithProducer(p))
}
reader := metric_sdk.NewManualReader(readerOpts...)
return &metricsReporter{
exporter: exporter,
resource: resource,
reader: reader,
producers: producers,
}, nil
}
func (r *metricsReporter) Run(ctx context.Context) error {
<-ctx.Done()
return nil
}
func (r *metricsReporter) Shutdown(ctx context.Context) error {
return errors.Join(
r.reader.Shutdown(ctx),
r.exporter.Shutdown(ctx),
)
}
func (r *metricsReporter) SetMetricProducerEnabled(name string, enabled bool) error {
p, ok := r.producers[name]
if !ok {
return fmt.Errorf("producer %q not found", name)
}
p.SetEnabled(enabled)
return nil
}
func (r *metricsReporter) CollectAndExportMetrics(ctx context.Context) {
r.singleTask.Run(ctx, func(ctx context.Context) {
err := r.collectAndExport(ctx)
if errors.Is(err, ErrAnotherExecutionRequested) {
log.Warn(ctx).Msg("telemetry metrics were not sent, due to another execution requested")
return
}
if err != nil {
health.ReportError(health.CollectAndSendTelemetry, err)
} else {
health.ReportOK(health.CollectAndSendTelemetry)
}
})
}
func (r *metricsReporter) collectAndExport(ctx context.Context) error {
rm := &metricdata.ResourceMetrics{
Resource: r.resource,
}
err := withBackoff(ctx, "collect metrics", func(ctx context.Context) error { return r.reader.Collect(ctx, rm) })
if err != nil {
return fmt.Errorf("collect metrics: %w", err)
}
err = withBackoff(ctx, "export metrics", func(ctx context.Context) error { return r.exporter.Export(ctx, rm) })
if err != nil {
return fmt.Errorf("export metrics: %w", err)
}
return nil
}

View file

@ -0,0 +1,99 @@
// Package reporter periodically submits metrics back to the cloud.
package reporter
import (
"context"
"fmt"
"os"
"time"
"github.com/cenkalti/backoff/v4"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/sdk/resource"
semconv "go.opentelemetry.io/otel/semconv/v1.4.0"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/version"
)
type Reporter struct {
*metricsReporter
*healthCheckReporter
}
const (
serviceName = "pomerium-managed-core"
)
// New creates a new unstarted zero telemetry reporter
func New(
ctx context.Context,
conn *grpc.ClientConn,
opts ...Option,
) (*Reporter, error) {
cfg := getConfig(opts...)
resource := getResource()
metrics, err := newMetricsReporter(ctx, conn, resource, cfg.producers)
if err != nil {
return nil, fmt.Errorf("failed to create metrics reporter: %w", err)
}
healthChecks := newHealthCheckReporter(conn, resource)
return &Reporter{
metricsReporter: metrics,
healthCheckReporter: healthChecks,
}, nil
}
func (r *Reporter) Run(ctx context.Context) error {
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error { return withBackoff(ctx, "metrics reporter", r.metricsReporter.Run) })
eg.Go(func() error { return withBackoff(ctx, "health check reporter", r.healthCheckReporter.Run) })
return eg.Wait()
}
// Shutdown should be called after Run to cleanly shutdown the reporter
func (r *Reporter) Shutdown(ctx context.Context) error {
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error { return r.metricsReporter.Shutdown(ctx) })
eg.Go(func() error { return r.healthCheckReporter.Shutdown(ctx) })
return eg.Wait()
}
func getResource() *resource.Resource {
attr := []attribute.KeyValue{
semconv.ServiceNameKey.String(serviceName),
semconv.ServiceVersionKey.String(version.FullVersion()),
}
hostname, err := os.Hostname()
if err == nil {
attr = append(attr, semconv.HostNameKey.String(hostname))
}
return resource.NewSchemaless(attr...)
}
func withBackoff(ctx context.Context, name string, f func(context.Context) error) error {
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
return backoff.RetryNotify(
func() error { return f(ctx) },
backoff.WithContext(bo, ctx),
func(err error, d time.Duration) {
log.Warn(ctx).
Str("name", name).
Err(err).
Dur("backoff", d).
Msg("retrying")
},
)
}

View file

@ -0,0 +1,27 @@
package reporter
import (
"context"
"errors"
"sync"
)
type singleTask struct {
lock sync.Mutex
cancel context.CancelCauseFunc
}
var ErrAnotherExecutionRequested = errors.New("another execution requested")
func (s *singleTask) Run(ctx context.Context, f func(context.Context)) {
s.lock.Lock()
defer s.lock.Unlock()
if s.cancel != nil {
s.cancel(ErrAnotherExecutionRequested)
}
ctx, cancel := context.WithCancelCause(ctx)
s.cancel = cancel
go f(ctx)
}

View file

@ -0,0 +1,80 @@
package sessions
import (
"time"
"github.com/pomerium/pomerium/pkg/counter"
)
const (
// activeUsersCap is the number of active users we would be able to track.
// for counter to work within the 1% error limit, actual number should be 80% of the cap.
activeUsersCap = 10_000
)
// IntervalResetFunc is a function that determines if a counter should be reset
type IntervalResetFunc func(lastReset time.Time, now time.Time) bool
// ResetMonthlyUTC resets the counter on a monthly interval
func ResetMonthlyUTC(lastReset time.Time, now time.Time) bool {
lastResetUTC := lastReset.UTC()
nowUTC := now.UTC()
return lastResetUTC.Year() != nowUTC.Year() ||
lastResetUTC.Month() != nowUTC.Month()
}
// ResetDailyUTC resets the counter on a daily interval
func ResetDailyUTC(lastReset time.Time, now time.Time) bool {
lastResetUTC := lastReset.UTC()
nowUTC := now.UTC()
return lastResetUTC.Year() != nowUTC.Year() ||
lastResetUTC.Month() != nowUTC.Month() ||
lastResetUTC.Day() != nowUTC.Day()
}
// ActiveUsersCounter is a counter that resets on a given interval
type ActiveUsersCounter struct {
*counter.Counter
lastReset time.Time
needsReset IntervalResetFunc
}
// NewActiveUsersCounter creates a new active users counter
func NewActiveUsersCounter(needsReset IntervalResetFunc, now time.Time) *ActiveUsersCounter {
return &ActiveUsersCounter{
Counter: counter.New(activeUsersCap),
lastReset: now,
needsReset: needsReset,
}
}
// LoadActiveUsersCounter loads an active users counter from a binary state
func LoadActiveUsersCounter(state []byte, lastReset time.Time, resetFn IntervalResetFunc) (*ActiveUsersCounter, error) {
c, err := counter.FromBinary(state)
if err != nil {
return nil, err
}
return &ActiveUsersCounter{
Counter: c,
lastReset: lastReset,
needsReset: resetFn,
}, nil
}
// Update updates the counter with the current users
func (c *ActiveUsersCounter) Update(users []string, now time.Time) (current uint, wasReset bool) {
if c.needsReset(c.lastReset, now) {
c.Counter.Reset()
c.lastReset = now
wasReset = true
}
for _, user := range users {
c.Mark(user)
}
return c.Count(), wasReset
}
// GetLastReset returns the last time the counter was reset
func (c *ActiveUsersCounter) GetLastReset() time.Time {
return c.lastReset
}

View file

@ -0,0 +1,37 @@
package sessions_test
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/zero/telemetry/sessions"
)
func TestActiveUsers(t *testing.T) {
t.Parallel()
startTime := time.Now().UTC()
// Create a new counter that resets on a daily interval
c := sessions.NewActiveUsersCounter(sessions.ResetDailyUTC, startTime)
count, wasReset := c.Update([]string{"user1", "user2"}, startTime.Add(time.Minute))
assert.False(t, wasReset)
assert.EqualValues(t, 2, count)
count, wasReset = c.Update([]string{"user1", "user2", "user3"}, startTime.Add(time.Minute*2))
assert.False(t, wasReset)
assert.EqualValues(t, 3, count)
// Update the counter with a new user after lapse
count, wasReset = c.Update([]string{"user1", "user2", "user3", "user4"}, startTime.Add(time.Hour*25))
assert.True(t, wasReset)
assert.EqualValues(t, 4, count)
// Update the counter with a new user after lapse
count, wasReset = c.Update([]string{"user4"}, startTime.Add(time.Hour*25*2))
assert.True(t, wasReset)
assert.EqualValues(t, 1, count)
}

View file

@ -0,0 +1,118 @@
// Package analytics collects active user metrics and reports them to the cloud dashboard
package sessions
import (
"context"
"fmt"
"time"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// Collect collects metrics and stores them in the databroker
func Collect(
ctx context.Context,
client databroker.DataBrokerServiceClient,
updateInterval time.Duration,
) error {
c := &collector{
client: client,
counters: make(map[string]*ActiveUsersCounter),
updateInterval: updateInterval,
}
return c.run(ctx)
}
type collector struct {
client databroker.DataBrokerServiceClient
counters map[string]*ActiveUsersCounter
updateInterval time.Duration
}
func (c *collector) run(ctx context.Context) error {
err := c.loadCounters(ctx)
if err != nil {
return fmt.Errorf("failed to load counters: %w", err)
}
err = c.runPeriodicUpdate(ctx)
if err != nil {
return fmt.Errorf("failed to run periodic update: %w", err)
}
return nil
}
func (c *collector) loadCounters(ctx context.Context) error {
now := time.Now()
for key, resetFn := range map[string]IntervalResetFunc{
"mau": ResetMonthlyUTC,
"dau": ResetDailyUTC,
} {
state, err := LoadMetricState(ctx, c.client, key)
if err != nil && !databroker.IsNotFound(err) {
return err
}
if state == nil {
c.counters[key] = NewActiveUsersCounter(resetFn, now)
continue
}
counter, err := LoadActiveUsersCounter(state.Data, state.LastReset, resetFn)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("metric", key).Msg("failed to load metric state, resetting")
counter = NewActiveUsersCounter(resetFn, now)
}
c.counters[key] = counter
}
return nil
}
func (c *collector) runPeriodicUpdate(ctx context.Context) error {
ticker := time.NewTicker(c.updateInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return nil
case <-ticker.C:
if err := c.update(ctx); err != nil {
return err
}
}
}
}
func (c *collector) update(ctx context.Context) error {
users, err := CurrentUsers(ctx, c.client)
if err != nil {
return fmt.Errorf("failed to get current users: %w", err)
}
now := time.Now()
for key, counter := range c.counters {
before := counter.Count()
after, _ := counter.Update(users, now)
if before == after {
log.Ctx(ctx).Debug().Msgf("metric %s not changed: %d", key, counter.Count())
continue
}
log.Ctx(ctx).Debug().Msgf("metric %s updated: %d", key, counter.Count())
data, err := counter.ToBinary()
if err != nil {
return fmt.Errorf("failed to marshal metric %s: %w", key, err)
}
err = SaveMetricState(ctx, c.client, key, data, after, counter.GetLastReset())
if err != nil {
return fmt.Errorf("failed to save metric %s: %w", key, err)
}
}
return nil
}

View file

@ -0,0 +1,47 @@
package sessions
import (
"context"
"go.opentelemetry.io/otel/metric"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// Metrics returns a list of metrics to be exported
func Metrics(
clientProvider func() databroker.DataBrokerServiceClient,
) []func(m metric.Meter) error {
return []func(m metric.Meter) error{
registerMetric("dau", clientProvider),
registerMetric("mau", clientProvider),
}
}
func registerMetric(
id string,
clientProvider func() databroker.DataBrokerServiceClient,
) func(m metric.Meter) error {
return func(m metric.Meter) error {
_, err := m.Int64ObservableGauge(id,
metric.WithInt64Callback(metricCallback(id, clientProvider)),
)
return err
}
}
func metricCallback(
id string,
clientProvider func() databroker.DataBrokerServiceClient,
) metric.Int64Callback {
return func(ctx context.Context, result metric.Int64Observer) error {
state, err := LoadMetricState(ctx, clientProvider(), id)
if err != nil {
log.Ctx(ctx).Error().Err(err).Str("metric", id).Msg("error loading metric state")
return nil // returning an error would block export of other metrics according to SDK design
}
result.Observe(int64(state.Count))
return nil
}
}

View file

@ -0,0 +1,75 @@
package sessions
import (
"context"
"fmt"
"time"
"go.opentelemetry.io/otel/sdk/instrumentation"
"go.opentelemetry.io/otel/sdk/metric"
"go.opentelemetry.io/otel/sdk/metric/metricdata"
"golang.org/x/sync/errgroup"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type producer struct {
scope instrumentation.Scope
clientProvider func() (databroker.DataBrokerServiceClient, error)
}
func NewProducer(
scope instrumentation.Scope,
clientProvider func() (databroker.DataBrokerServiceClient, error),
) metric.Producer {
return &producer{
clientProvider: clientProvider,
scope: scope,
}
}
func (p *producer) Produce(ctx context.Context) ([]metricdata.ScopeMetrics, error) {
client, err := p.clientProvider()
if err != nil {
return nil, fmt.Errorf("error getting client: %w", err)
}
now := time.Now()
ids := []string{"dau", "mau"}
metrics := make([]metricdata.Metrics, len(ids))
eg, ctx := errgroup.WithContext(ctx)
for i := 0; i < len(ids); i++ {
i := i
eg.Go(func() error {
state, err := LoadMetricState(ctx, client, ids[i])
if err != nil {
return err
}
metrics[i] = metricdata.Metrics{
Name: ids[i],
Unit: "unique users",
Data: metricdata.Gauge[int64]{
DataPoints: []metricdata.DataPoint[int64]{
{
Time: now,
Value: int64(state.Count),
},
},
},
}
return nil
})
}
err = eg.Wait()
if err != nil {
return nil, err
}
return []metricdata.ScopeMetrics{
{
Scope: p.scope,
Metrics: metrics,
},
}, nil
}

View file

@ -0,0 +1,51 @@
package sessions
import (
"context"
"fmt"
"time"
"github.com/pomerium/pomerium/internal/sets"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/protoutil"
)
var sessionTypeURL = protoutil.GetTypeURL(new(session.Session))
// CurrentUsers returns a list of users active within the current UTC day
func CurrentUsers(
ctx context.Context,
client databroker.DataBrokerServiceClient,
) ([]string, error) {
records, _, _, err := databroker.InitialSync(ctx, client, &databroker.SyncLatestRequest{
Type: sessionTypeURL,
})
if err != nil {
return nil, fmt.Errorf("fetching sessions: %w", err)
}
users := sets.NewHash[string]()
utcNow := time.Now().UTC()
threshold := time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
for _, record := range records {
var s session.Session
err := record.GetData().UnmarshalTo(&s)
if err != nil {
return nil, fmt.Errorf("unmarshaling session: %w", err)
}
if s.UserId == "" { // session creation is in progress
continue
}
if s.AccessedAt == nil {
continue
}
if s.AccessedAt.AsTime().Before(threshold) {
continue
}
users.Add(s.UserId)
}
return users.Items(), nil
}

View file

@ -0,0 +1,126 @@
package sessions
import (
"context"
"encoding/base64"
"fmt"
"time"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
const (
metricStateTypeURL = "pomerium.io/ActiveUsersMetricState"
)
// SaveMetricState saves the state of a metric to the databroker
func SaveMetricState(
ctx context.Context,
client databroker.DataBrokerServiceClient,
id string,
data []byte,
value uint,
lastReset time.Time,
) error {
_, err := client.Put(ctx, &databroker.PutRequest{
Records: []*databroker.Record{{
Type: metricStateTypeURL,
Id: id,
Data: (&MetricState{
Data: data,
LastReset: lastReset,
Count: value,
}).ToAny(),
}},
})
return err
}
// LoadMetricState loads the state of a metric from the databroker
func LoadMetricState(
ctx context.Context, client databroker.DataBrokerServiceClient, id string,
) (*MetricState, error) {
resp, err := client.Get(ctx, &databroker.GetRequest{
Type: metricStateTypeURL,
Id: id,
})
if err != nil {
return nil, fmt.Errorf("load metric state: %w", err)
}
var state MetricState
err = state.FromAny(resp.GetRecord().GetData())
if err != nil {
return nil, fmt.Errorf("load metric state: %w", err)
}
return &state, nil
}
// MetricState is the persistent state of a metric
type MetricState struct {
Data []byte
LastReset time.Time
Count uint
}
const (
countKey = "count"
dataKey = "data"
lastResetKey = "last_reset"
)
// ToAny marshals a MetricState into an anypb.Any
func (r *MetricState) ToAny() *anypb.Any {
return protoutil.NewAny(&structpb.Struct{
Fields: map[string]*structpb.Value{
countKey: structpb.NewNumberValue(float64(r.Count)),
dataKey: structpb.NewStringValue(base64.StdEncoding.EncodeToString(r.Data)),
lastResetKey: structpb.NewStringValue(r.LastReset.Format(time.RFC3339)),
},
})
}
// FromAny unmarshals an anypb.Any into a MetricState
func (r *MetricState) FromAny(any *anypb.Any) error {
var s structpb.Struct
err := any.UnmarshalTo(&s)
if err != nil {
return fmt.Errorf("unmarshal struct: %w", err)
}
vData, ok := s.GetFields()[dataKey]
if !ok {
return fmt.Errorf("missing %s field", dataKey)
}
data, err := base64.StdEncoding.DecodeString(vData.GetStringValue())
if err != nil {
return fmt.Errorf("decode state: %w", err)
}
if len(data) == 0 {
return fmt.Errorf("empty data")
}
vLastReset, ok := s.GetFields()[lastResetKey]
if !ok {
return fmt.Errorf("missing %s field", lastResetKey)
}
lastReset, err := time.Parse(time.RFC3339, vLastReset.GetStringValue())
if err != nil {
return fmt.Errorf("parse last reset: %w", err)
}
vCount, ok := s.GetFields()[countKey]
if !ok {
return fmt.Errorf("missing %s field", countKey)
}
r.Data = data
r.LastReset = lastReset
r.Count = uint(vCount.GetNumberValue())
return nil
}

View file

@ -0,0 +1,29 @@
package sessions_test
import (
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/pomerium/pomerium/internal/zero/telemetry/sessions"
)
func TestStorage(t *testing.T) {
t.Parallel()
now := time.Date(2020, 1, 2, 3, 4, 5, 6, time.UTC)
state := &sessions.MetricState{
Data: []byte("data"),
LastReset: now,
}
pbany := state.ToAny()
assert.NotNil(t, pbany)
var newState sessions.MetricState
err := newState.FromAny(pbany)
assert.NoError(t, err)
assert.EqualValues(t, state.Data, newState.Data)
assert.EqualValues(t, state.LastReset.Truncate(time.Second), newState.LastReset.Truncate(time.Second))
}