mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-18 11:37:08 +02:00
core/zero: add usage reporter (#5281)
* wip * add response * handle empty email * use set, update log * add test * add coalesce, comments, test * add test, fix bug * use builtin cmp.Or * remove wait ready call * use api error
This commit is contained in:
parent
82a9dbe42a
commit
146efc1b13
13 changed files with 697 additions and 2 deletions
1
go.mod
1
go.mod
|
@ -33,6 +33,7 @@ require (
|
||||||
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79
|
github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79
|
||||||
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0
|
github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0
|
||||||
github.com/hashicorp/go-multierror v1.1.1
|
github.com/hashicorp/go-multierror v1.1.1
|
||||||
|
github.com/hashicorp/go-set/v3 v3.0.0-alpha.1
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7
|
github.com/hashicorp/golang-lru/v2 v2.0.7
|
||||||
github.com/jackc/pgx/v5 v5.6.0
|
github.com/jackc/pgx/v5 v5.6.0
|
||||||
github.com/jxskiss/base62 v1.1.0
|
github.com/jxskiss/base62 v1.1.0
|
||||||
|
|
6
go.sum
6
go.sum
|
@ -375,6 +375,8 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY
|
||||||
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
|
||||||
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
|
||||||
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
|
||||||
|
github.com/hashicorp/go-set/v3 v3.0.0-alpha.1 h1:dPUtuqKJGgxtF7YO42oE+NdUONXi5nfLMKH2NpBffIM=
|
||||||
|
github.com/hashicorp/go-set/v3 v3.0.0-alpha.1/go.mod h1:7bJRgsF3EL3AtRTzcKXdjAFbYGSef+1gHXhglGGO52k=
|
||||||
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||||
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||||
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k=
|
||||||
|
@ -588,8 +590,8 @@ github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/i
|
||||||
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
|
github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk=
|
||||||
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM=
|
||||||
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ=
|
||||||
github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU=
|
github.com/shoenig/test v1.8.2 h1:WDlty8UBqJRdmgdJX8lMwvCq97tiN7Um/GZD2vBDuug=
|
||||||
github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k=
|
github.com/shoenig/test v1.8.2/go.mod h1:UxJ6u/x2v/TNs/LoLxBNJRV9DiwBBKYxXSyczsBHFoI=
|
||||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||||
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88=
|
||||||
|
|
44
internal/testutil/grpc.go
Normal file
44
internal/testutil/grpc.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
"google.golang.org/grpc/test/bufconn"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewGRPCServer starts a gRPC server and returns a client connection to it.
|
||||||
|
func NewGRPCServer(t testing.TB, register func(s *grpc.Server)) *grpc.ClientConn {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
li := bufconn.Listen(1024 * 1024)
|
||||||
|
s := grpc.NewServer()
|
||||||
|
register(s)
|
||||||
|
go func() {
|
||||||
|
err := s.Serve(li)
|
||||||
|
if errors.Is(err, grpc.ErrServerStopped) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
t.Cleanup(func() {
|
||||||
|
s.Stop()
|
||||||
|
})
|
||||||
|
|
||||||
|
cc, err := grpc.NewClient("passthrough://bufnet",
|
||||||
|
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
|
||||||
|
return li.Dial()
|
||||||
|
}),
|
||||||
|
grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
cc.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
return cc
|
||||||
|
}
|
|
@ -119,3 +119,10 @@ func (api *API) GetClusterResourceBundles(ctx context.Context) (*cluster_api.Get
|
||||||
func (api *API) GetTelemetryConn() *grpc.ClientConn {
|
func (api *API) GetTelemetryConn() *grpc.ClientConn {
|
||||||
return api.telemetryConn
|
return api.telemetryConn
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (api *API) ReportUsage(ctx context.Context, req cluster_api.ReportUsageRequest) error {
|
||||||
|
_, err := apierror.CheckResponse(
|
||||||
|
api.cluster.ReportUsageWithResponse(ctx, req),
|
||||||
|
)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
|
@ -17,6 +17,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/zero/bootstrap"
|
"github.com/pomerium/pomerium/internal/zero/bootstrap"
|
||||||
"github.com/pomerium/pomerium/internal/zero/bootstrap/writers"
|
"github.com/pomerium/pomerium/internal/zero/bootstrap/writers"
|
||||||
connect_mux "github.com/pomerium/pomerium/internal/zero/connect-mux"
|
connect_mux "github.com/pomerium/pomerium/internal/zero/connect-mux"
|
||||||
|
"github.com/pomerium/pomerium/internal/zero/controller/usagereporter"
|
||||||
"github.com/pomerium/pomerium/internal/zero/healthcheck"
|
"github.com/pomerium/pomerium/internal/zero/healthcheck"
|
||||||
"github.com/pomerium/pomerium/internal/zero/reconciler"
|
"github.com/pomerium/pomerium/internal/zero/reconciler"
|
||||||
"github.com/pomerium/pomerium/internal/zero/telemetry"
|
"github.com/pomerium/pomerium/internal/zero/telemetry"
|
||||||
|
@ -160,6 +161,7 @@ func (c *controller) runZeroControlLoop(ctx context.Context) error {
|
||||||
c.runSessionAnalyticsLeased,
|
c.runSessionAnalyticsLeased,
|
||||||
c.runHealthChecksLeased,
|
c.runHealthChecksLeased,
|
||||||
leaseStatus.MonitorLease,
|
leaseStatus.MonitorLease,
|
||||||
|
c.runUsageReporter,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
@ -208,6 +210,14 @@ func (c *controller) runHealthChecksLeased(ctx context.Context, client databroke
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *controller) runUsageReporter(ctx context.Context, client databroker.DataBrokerServiceClient) error {
|
||||||
|
ur := usagereporter.New(c.api, c.bootstrapConfig.GetConfig().ZeroOrganizationID, time.Minute)
|
||||||
|
return retry.WithBackoff(ctx, "zero-usage-reporter", func(ctx context.Context) error {
|
||||||
|
// start the usage reporter
|
||||||
|
return ur.Run(ctx, client)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (c *controller) getEnvoyScrapeURL() string {
|
func (c *controller) getEnvoyScrapeURL() string {
|
||||||
return (&url.URL{
|
return (&url.URL{
|
||||||
Scheme: "http",
|
Scheme: "http",
|
||||||
|
|
217
internal/zero/controller/usagereporter/usagereporter.go
Normal file
217
internal/zero/controller/usagereporter/usagereporter.go
Normal file
|
@ -0,0 +1,217 @@
|
||||||
|
// Package usagereporter reports usage for a cluster.
|
||||||
|
//
|
||||||
|
// Usage is determined from session and user records in the databroker. The usage reporter
|
||||||
|
// uses SyncLatest and Sync to retrieve this data, builds a collection of records and then
|
||||||
|
// sends them to the Zero Cluster API every minute.
|
||||||
|
//
|
||||||
|
// All usage users are reported on start but only the changed users are reported while running.
|
||||||
|
// The Zero Cluster API is tolerant of redundant data.
|
||||||
|
package usagereporter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"cmp"
|
||||||
|
"context"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
backoff "github.com/cenkalti/backoff/v4"
|
||||||
|
set "github.com/hashicorp/go-set/v3"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
"github.com/pomerium/pomerium/pkg/zero/cluster"
|
||||||
|
)
|
||||||
|
|
||||||
|
// API is the part of the Zero Cluster API used to report usage.
|
||||||
|
type API interface {
|
||||||
|
ReportUsage(ctx context.Context, req cluster.ReportUsageRequest) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type usageReporterRecord struct {
|
||||||
|
userID string
|
||||||
|
userEmail string
|
||||||
|
lastSignedInAt time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// A UsageReporter reports usage to the zero api.
|
||||||
|
type UsageReporter struct {
|
||||||
|
api API
|
||||||
|
organizationID string
|
||||||
|
reportInterval time.Duration
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
byUserID map[string]usageReporterRecord
|
||||||
|
updates *set.Set[string]
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new UsageReporter.
|
||||||
|
func New(api API, organizationID string, reportInterval time.Duration) *UsageReporter {
|
||||||
|
return &UsageReporter{
|
||||||
|
api: api,
|
||||||
|
organizationID: organizationID,
|
||||||
|
reportInterval: reportInterval,
|
||||||
|
|
||||||
|
byUserID: make(map[string]usageReporterRecord),
|
||||||
|
updates: set.New[string](0),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Run runs the usage reporter.
|
||||||
|
func (ur *UsageReporter) Run(ctx context.Context, client databroker.DataBrokerServiceClient) error {
|
||||||
|
ctx = log.Ctx(ctx).With().Str("organization-id", ur.organizationID).Logger().WithContext(ctx)
|
||||||
|
|
||||||
|
// first initialize the user collection
|
||||||
|
serverVersion, latestSessionRecordVersion, latestUserRecordVersion, err := ur.runInit(ctx, client)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// run the continuous sync calls and periodically report usage
|
||||||
|
return ur.runSync(ctx, client, serverVersion, latestSessionRecordVersion, latestUserRecordVersion)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ur *UsageReporter) report(ctx context.Context, records []usageReporterRecord) error {
|
||||||
|
req := cluster.ReportUsageRequest{
|
||||||
|
Users: convertUsageReporterRecords(ur.organizationID, records),
|
||||||
|
}
|
||||||
|
return backoff.Retry(func() error {
|
||||||
|
log.Debug(ctx).Int("updated-users", len(req.Users)).Msg("reporting usage")
|
||||||
|
err := ur.api.ReportUsage(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Warn(ctx).Err(err).Msg("error reporting usage")
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}, backoff.WithContext(backoff.NewExponentialBackOff(), ctx))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ur *UsageReporter) runInit(
|
||||||
|
ctx context.Context,
|
||||||
|
client databroker.DataBrokerServiceClient,
|
||||||
|
) (serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64, err error) {
|
||||||
|
_, latestSessionRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateSession)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
serverVersion, latestUserRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateUser)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return serverVersion, latestSessionRecordVersion, latestUserRecordVersion, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ur *UsageReporter) runSync(
|
||||||
|
ctx context.Context,
|
||||||
|
client databroker.DataBrokerServiceClient,
|
||||||
|
serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64,
|
||||||
|
) error {
|
||||||
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
|
eg.Go(func() error {
|
||||||
|
return databroker.SyncRecords(ctx, client, serverVersion, latestSessionRecordVersion, ur.onUpdateSession)
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
return databroker.SyncRecords(ctx, client, serverVersion, latestUserRecordVersion, ur.onUpdateUser)
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
return ur.runReporter(ctx)
|
||||||
|
})
|
||||||
|
return eg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ur *UsageReporter) runReporter(ctx context.Context) error {
|
||||||
|
// every minute collect any updates and submit them to the API
|
||||||
|
timer := time.NewTicker(ur.reportInterval)
|
||||||
|
defer timer.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
// collect the updated records since last run
|
||||||
|
ur.mu.Lock()
|
||||||
|
records := make([]usageReporterRecord, 0, ur.updates.Size())
|
||||||
|
for userID := range ur.updates.Items() {
|
||||||
|
records = append(records, ur.byUserID[userID])
|
||||||
|
}
|
||||||
|
ur.updates = set.New[string](0)
|
||||||
|
ur.mu.Unlock()
|
||||||
|
|
||||||
|
if len(records) > 0 {
|
||||||
|
err := ur.report(ctx, records)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-timer.C:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ur *UsageReporter) onUpdateSession(s *session.Session) {
|
||||||
|
userID := s.GetUserId()
|
||||||
|
if userID == "" {
|
||||||
|
// ignore sessions without a user id
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ur.mu.Lock()
|
||||||
|
defer ur.mu.Unlock()
|
||||||
|
|
||||||
|
r := ur.byUserID[userID]
|
||||||
|
nr := r
|
||||||
|
nr.lastSignedInAt = latest(nr.lastSignedInAt, s.GetIssuedAt().AsTime())
|
||||||
|
nr.userID = userID
|
||||||
|
if nr != r {
|
||||||
|
ur.byUserID[userID] = nr
|
||||||
|
ur.updates.Insert(userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ur *UsageReporter) onUpdateUser(u *user.User) {
|
||||||
|
userID := u.GetId()
|
||||||
|
if userID == "" {
|
||||||
|
// ignore users without a user id
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ur.mu.Lock()
|
||||||
|
defer ur.mu.Unlock()
|
||||||
|
|
||||||
|
r := ur.byUserID[userID]
|
||||||
|
nr := r
|
||||||
|
nr.userID = userID
|
||||||
|
nr.userEmail = cmp.Or(nr.userEmail, u.GetEmail())
|
||||||
|
if nr != r {
|
||||||
|
ur.byUserID[userID] = nr
|
||||||
|
ur.updates.Insert(userID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertUsageReporterRecords(organizationID string, records []usageReporterRecord) []cluster.ReportUsageUser {
|
||||||
|
var users []cluster.ReportUsageUser
|
||||||
|
for _, record := range records {
|
||||||
|
u := cluster.ReportUsageUser{
|
||||||
|
LastSignedInAt: record.lastSignedInAt,
|
||||||
|
PseudonymousId: cryptutil.Pseudonymize(organizationID, record.userID),
|
||||||
|
}
|
||||||
|
if record.userEmail != "" {
|
||||||
|
u.PseudonymousEmail = cryptutil.Pseudonymize(organizationID, record.userEmail)
|
||||||
|
}
|
||||||
|
users = append(users, u)
|
||||||
|
}
|
||||||
|
return users
|
||||||
|
}
|
||||||
|
|
||||||
|
// latest returns the latest time.
|
||||||
|
func latest(t1, t2 time.Time) time.Time {
|
||||||
|
if t2.After(t1) {
|
||||||
|
return t2
|
||||||
|
}
|
||||||
|
return t1
|
||||||
|
}
|
156
internal/zero/controller/usagereporter/usagereporter_test.go
Normal file
156
internal/zero/controller/usagereporter/usagereporter_test.go
Normal file
|
@ -0,0 +1,156 @@
|
||||||
|
package usagereporter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"golang.org/x/sync/errgroup"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/databroker"
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
"github.com/pomerium/pomerium/pkg/zero/cluster"
|
||||||
|
)
|
||||||
|
|
||||||
|
type mockAPI struct {
|
||||||
|
reportUsage func(ctx context.Context, req cluster.ReportUsageRequest) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m mockAPI) ReportUsage(ctx context.Context, req cluster.ReportUsageRequest) error {
|
||||||
|
return m.reportUsage(ctx, req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUsageReporter(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||||
|
t.Cleanup(clearTimeout)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
t.Cleanup(cancel)
|
||||||
|
|
||||||
|
cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) {
|
||||||
|
databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New())
|
||||||
|
})
|
||||||
|
t.Cleanup(func() { cc.Close() })
|
||||||
|
|
||||||
|
tm1 := time.Date(2024, time.September, 11, 11, 56, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
requests := make(chan cluster.ReportUsageRequest, 1)
|
||||||
|
|
||||||
|
client := databrokerpb.NewDataBrokerServiceClient(cc)
|
||||||
|
ur := New(mockAPI{
|
||||||
|
reportUsage: func(ctx context.Context, req cluster.ReportUsageRequest) error {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case requests <- req:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}, "bQjwPpxcwJRbvsSMFgbZFkXmxFJ", time.Millisecond*100)
|
||||||
|
|
||||||
|
eg, ctx := errgroup.WithContext(ctx)
|
||||||
|
eg.Go(func() error {
|
||||||
|
return ur.Run(ctx, client)
|
||||||
|
})
|
||||||
|
eg.Go(func() error {
|
||||||
|
_, err := databrokerpb.Put(ctx, client,
|
||||||
|
&session.Session{
|
||||||
|
Id: "S1a",
|
||||||
|
UserId: "U1",
|
||||||
|
IssuedAt: timestamppb.New(tm1),
|
||||||
|
},
|
||||||
|
&session.Session{
|
||||||
|
Id: "S1b",
|
||||||
|
UserId: "U1",
|
||||||
|
IssuedAt: timestamppb.New(tm1),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case req := <-requests:
|
||||||
|
assert.Equal(t, cluster.ReportUsageRequest{
|
||||||
|
Users: []cluster.ReportUsageUser{{
|
||||||
|
LastSignedInAt: tm1,
|
||||||
|
PseudonymousId: "095xqqsjEEgYf5Yf+TAjWjooMQyh6jSV5SCPGe9eqvg=",
|
||||||
|
}},
|
||||||
|
}, req, "should send a single usage record")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = databrokerpb.Put(ctx, client,
|
||||||
|
&user.User{
|
||||||
|
Id: "U1",
|
||||||
|
Email: "u1@example.com",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case req := <-requests:
|
||||||
|
assert.Equal(t, cluster.ReportUsageRequest{
|
||||||
|
Users: []cluster.ReportUsageUser{{
|
||||||
|
LastSignedInAt: tm1,
|
||||||
|
PseudonymousEmail: "iq8/fj+uZaKitkWY12JIQgKJ5KIP+E0Cmy/HpxpdBXY=",
|
||||||
|
PseudonymousId: "095xqqsjEEgYf5Yf+TAjWjooMQyh6jSV5SCPGe9eqvg=",
|
||||||
|
}},
|
||||||
|
}, req, "should send another usage record with the email set")
|
||||||
|
}
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
err := eg.Wait()
|
||||||
|
if err != nil && !errors.Is(ctx.Err(), context.Canceled) {
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_convertUsageReporterRecords(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tm1 := time.Date(2024, time.September, 11, 11, 56, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
assert.Empty(t, convertUsageReporterRecords("XXX", nil))
|
||||||
|
assert.Equal(t, []cluster.ReportUsageUser{{
|
||||||
|
LastSignedInAt: tm1,
|
||||||
|
PseudonymousId: "T9V1yL/UueF/LVuF6XjoSNde0INElXG10zKepmyPke8=",
|
||||||
|
PseudonymousEmail: "8w5rtnZyv0EGkpHmTlkmupgb1jCzn/IxGCfvpdGGnvI=",
|
||||||
|
}}, convertUsageReporterRecords("XXX", []usageReporterRecord{{
|
||||||
|
userID: "ID",
|
||||||
|
userEmail: "EMAIL@example.com",
|
||||||
|
lastSignedInAt: tm1,
|
||||||
|
}}))
|
||||||
|
assert.Equal(t, []cluster.ReportUsageUser{{
|
||||||
|
LastSignedInAt: tm1,
|
||||||
|
PseudonymousId: "T9V1yL/UueF/LVuF6XjoSNde0INElXG10zKepmyPke8=",
|
||||||
|
}}, convertUsageReporterRecords("XXX", []usageReporterRecord{{
|
||||||
|
userID: "ID",
|
||||||
|
lastSignedInAt: tm1,
|
||||||
|
}}), "should leave empty email")
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_latest(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
tm1 := time.Date(2024, time.September, 11, 11, 56, 0, 0, time.UTC)
|
||||||
|
tm2 := time.Date(2024, time.September, 12, 11, 56, 0, 0, time.UTC)
|
||||||
|
|
||||||
|
assert.Equal(t, tm2, latest(tm1, tm2))
|
||||||
|
assert.Equal(t, tm2, latest(tm2, tm1), "should ignore ordering")
|
||||||
|
assert.Equal(t, tm1, latest(tm1, time.Time{}), "should handle zero time")
|
||||||
|
}
|
16
pkg/cryptutil/pseudonymize.go
Normal file
16
pkg/cryptutil/pseudonymize.go
Normal file
|
@ -0,0 +1,16 @@
|
||||||
|
package cryptutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/hmac"
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/base64"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Pseudonymize pseudonymizes data by computing the HMAC-SHA256 of the data.
|
||||||
|
func Pseudonymize(organizationID string, data string) string {
|
||||||
|
h := hmac.New(sha256.New, []byte(organizationID))
|
||||||
|
_, _ = io.WriteString(h, data)
|
||||||
|
bs := h.Sum(nil)
|
||||||
|
return base64.StdEncoding.EncodeToString(bs)
|
||||||
|
}
|
110
pkg/grpc/databroker/sync.go
Normal file
110
pkg/grpc/databroker/sync.go
Normal file
|
@ -0,0 +1,110 @@
|
||||||
|
package databroker
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SyncRecords calls fn for every record using Sync.
|
||||||
|
func SyncRecords[T any, TMessage interface {
|
||||||
|
*T
|
||||||
|
proto.Message
|
||||||
|
}](
|
||||||
|
ctx context.Context,
|
||||||
|
client DataBrokerServiceClient,
|
||||||
|
serverVersion, latestRecordVersion uint64,
|
||||||
|
fn func(TMessage),
|
||||||
|
) error {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var msg TMessage = new(T)
|
||||||
|
stream, err := client.Sync(ctx, &SyncRequest{
|
||||||
|
Type: protoutil.GetTypeURL(msg),
|
||||||
|
ServerVersion: serverVersion,
|
||||||
|
RecordVersion: latestRecordVersion,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error syncing %T: %w", msg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
res, err := stream.Recv()
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
return nil
|
||||||
|
case err != nil:
|
||||||
|
return fmt.Errorf("error receiving record for %T: %w", msg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = new(T)
|
||||||
|
err = res.GetRecord().GetData().UnmarshalTo(msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Ctx(ctx).Error().Err(err).
|
||||||
|
Str("record-type", res.Record.Type).
|
||||||
|
Str("record-id", res.Record.GetId()).
|
||||||
|
Msgf("unexpected data in %T stream", msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fn(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SyncLatestRecords calls fn for every record using SyncLatest.
|
||||||
|
func SyncLatestRecords[T any, TMessage interface {
|
||||||
|
*T
|
||||||
|
proto.Message
|
||||||
|
}](
|
||||||
|
ctx context.Context,
|
||||||
|
client DataBrokerServiceClient,
|
||||||
|
fn func(TMessage),
|
||||||
|
) (serverVersion, latestRecordVersion uint64, err error) {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
var msg TMessage = new(T)
|
||||||
|
stream, err := client.SyncLatest(ctx, &SyncLatestRequest{
|
||||||
|
Type: protoutil.GetTypeURL(msg),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, fmt.Errorf("error syncing latest %T: %w", msg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
res, err := stream.Recv()
|
||||||
|
switch {
|
||||||
|
case errors.Is(err, io.EOF):
|
||||||
|
return serverVersion, latestRecordVersion, nil
|
||||||
|
case err != nil:
|
||||||
|
return 0, 0, fmt.Errorf("error receiving record for latest %T: %w", msg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch res := res.GetResponse().(type) {
|
||||||
|
case *SyncLatestResponse_Versions:
|
||||||
|
serverVersion = res.Versions.GetServerVersion()
|
||||||
|
latestRecordVersion = res.Versions.GetLatestRecordVersion()
|
||||||
|
case *SyncLatestResponse_Record:
|
||||||
|
msg = new(T)
|
||||||
|
err = res.Record.GetData().UnmarshalTo(msg)
|
||||||
|
if err != nil {
|
||||||
|
log.Ctx(ctx).Error().Err(err).
|
||||||
|
Str("record-type", res.Record.Type).
|
||||||
|
Str("record-id", res.Record.GetId()).
|
||||||
|
Msgf("unexpected data in latest %T stream", msg)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
fn(msg)
|
||||||
|
default:
|
||||||
|
panic(fmt.Sprintf("unexpected response: %T", res))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
67
pkg/grpc/databroker/sync_test.go
Normal file
67
pkg/grpc/databroker/sync_test.go
Normal file
|
@ -0,0 +1,67 @@
|
||||||
|
package databroker_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
grpc "google.golang.org/grpc"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/databroker"
|
||||||
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
|
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||||
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_SyncLatestRecords(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Minute)
|
||||||
|
defer clearTimeout()
|
||||||
|
|
||||||
|
cc := testutil.NewGRPCServer(t, func(s *grpc.Server) {
|
||||||
|
databrokerpb.RegisterDataBrokerServiceServer(s, databroker.New())
|
||||||
|
})
|
||||||
|
|
||||||
|
c := databrokerpb.NewDataBrokerServiceClient(cc)
|
||||||
|
|
||||||
|
expected := []*user.User{
|
||||||
|
{Id: "u1"},
|
||||||
|
{Id: "u2"},
|
||||||
|
{Id: "u3"},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, u := range expected {
|
||||||
|
_, err := c.Put(ctx, &databrokerpb.PutRequest{
|
||||||
|
Records: []*databrokerpb.Record{
|
||||||
|
databrokerpb.NewRecord(u),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// add a non-user record to make sure it gets ignored
|
||||||
|
_, err := c.Put(ctx, &databrokerpb.PutRequest{
|
||||||
|
Records: []*databrokerpb.Record{
|
||||||
|
{
|
||||||
|
Id: "u4",
|
||||||
|
Type: protoutil.GetTypeURL(new(user.User)),
|
||||||
|
Data: protoutil.NewAny(&session.Session{Id: "s1"}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var actual []*user.User
|
||||||
|
serverVersion, latestRecordVersion, err := databrokerpb.SyncLatestRecords(context.Background(), c, func(u *user.User) {
|
||||||
|
actual = append(actual, u)
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotZero(t, serverVersion)
|
||||||
|
assert.Equal(t, uint64(4), latestRecordVersion)
|
||||||
|
testutil.AssertProtoEqual(t, expected, actual)
|
||||||
|
}
|
|
@ -628,6 +628,8 @@ func (r ExchangeClusterIdentityTokenResp) StatusCode() int {
|
||||||
type ReportUsageResp struct {
|
type ReportUsageResp struct {
|
||||||
Body []byte
|
Body []byte
|
||||||
HTTPResponse *http.Response
|
HTTPResponse *http.Response
|
||||||
|
JSON400 *ErrorResponse
|
||||||
|
JSON500 *ErrorResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
// Status returns HTTPResponse.Status
|
// Status returns HTTPResponse.Status
|
||||||
|
@ -937,6 +939,23 @@ func ParseReportUsageResp(rsp *http.Response) (*ReportUsageResp, error) {
|
||||||
HTTPResponse: rsp,
|
HTTPResponse: rsp,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400:
|
||||||
|
var dest ErrorResponse
|
||||||
|
if err := json.Unmarshal(bodyBytes, &dest); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
response.JSON400 = &dest
|
||||||
|
|
||||||
|
case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500:
|
||||||
|
var dest ErrorResponse
|
||||||
|
if err := json.Unmarshal(bodyBytes, &dest); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
response.JSON500 = &dest
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1085,6 +1104,22 @@ func (r *ReportUsageResp) GetHTTPResponse() *http.Response {
|
||||||
return r.HTTPResponse
|
return r.HTTPResponse
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetBadRequestError implements apierror.APIResponse
|
||||||
|
func (r *ReportUsageResp) GetBadRequestError() (string, bool) {
|
||||||
|
if r.JSON400 == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return r.JSON400.Error, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetInternalServerError implements apierror.APIResponse
|
||||||
|
func (r *ReportUsageResp) GetInternalServerError() (string, bool) {
|
||||||
|
if r.JSON500 == nil {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
return r.JSON500.Error, true
|
||||||
|
}
|
||||||
|
|
||||||
// GetValue implements apierror.APIResponse
|
// GetValue implements apierror.APIResponse
|
||||||
func (r *ReportUsageResp) GetValue() *EmptyResponse {
|
func (r *ReportUsageResp) GetValue() *EmptyResponse {
|
||||||
if r.StatusCode()/100 != 2 {
|
if r.StatusCode()/100 != 2 {
|
||||||
|
|
|
@ -163,6 +163,18 @@ paths:
|
||||||
responses:
|
responses:
|
||||||
"204":
|
"204":
|
||||||
description: OK
|
description: OK
|
||||||
|
"400":
|
||||||
|
description: Bad Request
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/ErrorResponse"
|
||||||
|
"500":
|
||||||
|
description: Internal Server Error
|
||||||
|
content:
|
||||||
|
application/json:
|
||||||
|
schema:
|
||||||
|
$ref: "#/components/schemas/ErrorResponse"
|
||||||
|
|
||||||
components:
|
components:
|
||||||
parameters:
|
parameters:
|
||||||
|
|
|
@ -534,6 +534,24 @@ func (response ReportUsage204Response) VisitReportUsageResponse(w http.ResponseW
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ReportUsage400JSONResponse ErrorResponse
|
||||||
|
|
||||||
|
func (response ReportUsage400JSONResponse) VisitReportUsageResponse(w http.ResponseWriter) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(400)
|
||||||
|
|
||||||
|
return json.NewEncoder(w).Encode(response)
|
||||||
|
}
|
||||||
|
|
||||||
|
type ReportUsage500JSONResponse ErrorResponse
|
||||||
|
|
||||||
|
func (response ReportUsage500JSONResponse) VisitReportUsageResponse(w http.ResponseWriter) error {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(500)
|
||||||
|
|
||||||
|
return json.NewEncoder(w).Encode(response)
|
||||||
|
}
|
||||||
|
|
||||||
// StrictServerInterface represents all server handlers.
|
// StrictServerInterface represents all server handlers.
|
||||||
type StrictServerInterface interface {
|
type StrictServerInterface interface {
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue