pomerium/internal/zero/analytics/storage.go
2023-12-11 13:37:01 -05:00

126 lines
2.9 KiB
Go

package analytics
import (
"context"
"encoding/base64"
"fmt"
"time"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
const (
metricStateTypeURL = "pomerium.io/ActiveUsersMetricState"
)
// SaveMetricState saves the state of a metric to the databroker
func SaveMetricState(
ctx context.Context,
client databroker.DataBrokerServiceClient,
id string,
data []byte,
value uint,
lastReset time.Time,
) error {
_, err := client.Put(ctx, &databroker.PutRequest{
Records: []*databroker.Record{{
Type: metricStateTypeURL,
Id: id,
Data: (&MetricState{
Data: data,
LastReset: lastReset,
Count: value,
}).ToAny(),
}},
})
return err
}
// LoadMetricState loads the state of a metric from the databroker
func LoadMetricState(
ctx context.Context, client databroker.DataBrokerServiceClient, id string,
) (*MetricState, error) {
resp, err := client.Get(ctx, &databroker.GetRequest{
Type: metricStateTypeURL,
Id: id,
})
if err != nil {
return nil, fmt.Errorf("load metric state: %w", err)
}
var state MetricState
err = state.FromAny(resp.GetRecord().GetData())
if err != nil {
return nil, fmt.Errorf("load metric state: %w", err)
}
return &state, nil
}
// MetricState is the persistent state of a metric
type MetricState struct {
Data []byte
LastReset time.Time
Count uint
}
const (
countKey = "count"
dataKey = "data"
lastResetKey = "last_reset"
)
// ToAny marshals a MetricState into an anypb.Any
func (r *MetricState) ToAny() *anypb.Any {
return protoutil.NewAny(&structpb.Struct{
Fields: map[string]*structpb.Value{
countKey: structpb.NewNumberValue(float64(r.Count)),
dataKey: structpb.NewStringValue(base64.StdEncoding.EncodeToString(r.Data)),
lastResetKey: structpb.NewStringValue(r.LastReset.Format(time.RFC3339)),
},
})
}
// FromAny unmarshals an anypb.Any into a MetricState
func (r *MetricState) FromAny(any *anypb.Any) error {
var s structpb.Struct
err := any.UnmarshalTo(&s)
if err != nil {
return fmt.Errorf("unmarshal struct: %w", err)
}
vData, ok := s.GetFields()[dataKey]
if !ok {
return fmt.Errorf("missing %s field", dataKey)
}
data, err := base64.StdEncoding.DecodeString(vData.GetStringValue())
if err != nil {
return fmt.Errorf("decode state: %w", err)
}
if len(data) == 0 {
return fmt.Errorf("empty data")
}
vLastReset, ok := s.GetFields()[lastResetKey]
if !ok {
return fmt.Errorf("missing %s field", lastResetKey)
}
lastReset, err := time.Parse(time.RFC3339, vLastReset.GetStringValue())
if err != nil {
return fmt.Errorf("parse last reset: %w", err)
}
vCount, ok := s.GetFields()[countKey]
if !ok {
return fmt.Errorf("missing %s field", countKey)
}
r.Data = data
r.LastReset = lastReset
r.Count = uint(vCount.GetNumberValue())
return nil
}