From 146efc1b1397296794aadf2ec5124d414be9aec5 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Thu, 12 Sep 2024 15:45:54 -0600 Subject: [PATCH] core/zero: add usage reporter (#5281) * wip * add response * handle empty email * use set, update log * add test * add coalesce, comments, test * add test, fix bug * use builtin cmp.Or * remove wait ready call * use api error --- go.mod | 1 + go.sum | 6 +- internal/testutil/grpc.go | 44 ++++ internal/zero/api/api.go | 7 + internal/zero/controller/controller.go | 10 + .../controller/usagereporter/usagereporter.go | 217 ++++++++++++++++++ .../usagereporter/usagereporter_test.go | 156 +++++++++++++ pkg/cryptutil/pseudonymize.go | 16 ++ pkg/grpc/databroker/sync.go | 110 +++++++++ pkg/grpc/databroker/sync_test.go | 67 ++++++ pkg/zero/cluster/client.gen.go | 35 +++ pkg/zero/cluster/openapi.yaml | 12 + pkg/zero/cluster/server.gen.go | 18 ++ 13 files changed, 697 insertions(+), 2 deletions(-) create mode 100644 internal/testutil/grpc.go create mode 100644 internal/zero/controller/usagereporter/usagereporter.go create mode 100644 internal/zero/controller/usagereporter/usagereporter_test.go create mode 100644 pkg/cryptutil/pseudonymize.go create mode 100644 pkg/grpc/databroker/sync.go create mode 100644 pkg/grpc/databroker/sync_test.go diff --git a/go.mod b/go.mod index 46dea1804..c6299a894 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( github.com/gregjones/httpcache v0.0.0-20190611155906-901d90724c79 github.com/grpc-ecosystem/go-grpc-middleware/v2 v2.1.0 github.com/hashicorp/go-multierror v1.1.1 + github.com/hashicorp/go-set/v3 v3.0.0-alpha.1 github.com/hashicorp/golang-lru/v2 v2.0.7 github.com/jackc/pgx/v5 v5.6.0 github.com/jxskiss/base62 v1.1.0 diff --git a/go.sum b/go.sum index a8a0203cb..4021f4fcd 100644 --- a/go.sum +++ b/go.sum @@ -375,6 +375,8 @@ github.com/hashicorp/errwrap v1.1.0 h1:OxrOeh75EUXMY8TBjag2fzXGZ40LB6IKw45YeGUDY github.com/hashicorp/errwrap v1.1.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4= github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo= github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM= +github.com/hashicorp/go-set/v3 v3.0.0-alpha.1 h1:dPUtuqKJGgxtF7YO42oE+NdUONXi5nfLMKH2NpBffIM= +github.com/hashicorp/go-set/v3 v3.0.0-alpha.1/go.mod h1:7bJRgsF3EL3AtRTzcKXdjAFbYGSef+1gHXhglGGO52k= github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8= github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= @@ -588,8 +590,8 @@ github.com/shirou/gopsutil/v3 v3.24.5 h1:i0t8kL+kQTvpAYToeuiVk3TgDeKOFioZO3Ztz/i github.com/shirou/gopsutil/v3 v3.24.5/go.mod h1:bsoOS1aStSs9ErQ1WWfxllSeS1K5D+U30r2NfcubMVk= github.com/shoenig/go-m1cpu v0.1.6 h1:nxdKQNcEB6vzgA2E2bvzKIYRuNj7XNJ4S/aRSwKzFtM= github.com/shoenig/go-m1cpu v0.1.6/go.mod h1:1JJMcUBvfNwpq05QDQVAnx3gUHr9IYF7GNg9SUEw2VQ= -github.com/shoenig/test v0.6.4 h1:kVTaSd7WLz5WZ2IaoM0RSzRsUD+m8wRR+5qvntpn4LU= -github.com/shoenig/test v0.6.4/go.mod h1:byHiCGXqrVaflBLAMq/srcZIHynQPQgeyvkvXnjqq0k= +github.com/shoenig/test v1.8.2 h1:WDlty8UBqJRdmgdJX8lMwvCq97tiN7Um/GZD2vBDuug= +github.com/shoenig/test v1.8.2/go.mod h1:UxJ6u/x2v/TNs/LoLxBNJRV9DiwBBKYxXSyczsBHFoI= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= github.com/sirupsen/logrus v1.6.0/go.mod h1:7uNnSEd1DgxDLC74fIahvMZmmYsHGZGEOFrfsX/uA88= diff --git a/internal/testutil/grpc.go b/internal/testutil/grpc.go new file mode 100644 index 000000000..3b34f4575 --- /dev/null +++ b/internal/testutil/grpc.go @@ -0,0 +1,44 @@ +package testutil + +import ( + "context" + "errors" + "net" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" +) + +// NewGRPCServer starts a gRPC server and returns a client connection to it. +func NewGRPCServer(t testing.TB, register func(s *grpc.Server)) *grpc.ClientConn { + t.Helper() + + li := bufconn.Listen(1024 * 1024) + s := grpc.NewServer() + register(s) + go func() { + err := s.Serve(li) + if errors.Is(err, grpc.ErrServerStopped) { + err = nil + } + require.NoError(t, err) + }() + t.Cleanup(func() { + s.Stop() + }) + + cc, err := grpc.NewClient("passthrough://bufnet", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + return li.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + t.Cleanup(func() { + cc.Close() + }) + + return cc +} diff --git a/internal/zero/api/api.go b/internal/zero/api/api.go index b83d9b1a7..cfb437faa 100644 --- a/internal/zero/api/api.go +++ b/internal/zero/api/api.go @@ -119,3 +119,10 @@ func (api *API) GetClusterResourceBundles(ctx context.Context) (*cluster_api.Get func (api *API) GetTelemetryConn() *grpc.ClientConn { return api.telemetryConn } + +func (api *API) ReportUsage(ctx context.Context, req cluster_api.ReportUsageRequest) error { + _, err := apierror.CheckResponse( + api.cluster.ReportUsageWithResponse(ctx, req), + ) + return err +} diff --git a/internal/zero/controller/controller.go b/internal/zero/controller/controller.go index f53590611..591111a76 100644 --- a/internal/zero/controller/controller.go +++ b/internal/zero/controller/controller.go @@ -17,6 +17,7 @@ import ( "github.com/pomerium/pomerium/internal/zero/bootstrap" "github.com/pomerium/pomerium/internal/zero/bootstrap/writers" connect_mux "github.com/pomerium/pomerium/internal/zero/connect-mux" + "github.com/pomerium/pomerium/internal/zero/controller/usagereporter" "github.com/pomerium/pomerium/internal/zero/healthcheck" "github.com/pomerium/pomerium/internal/zero/reconciler" "github.com/pomerium/pomerium/internal/zero/telemetry" @@ -160,6 +161,7 @@ func (c *controller) runZeroControlLoop(ctx context.Context) error { c.runSessionAnalyticsLeased, c.runHealthChecksLeased, leaseStatus.MonitorLease, + c.runUsageReporter, ), ) }) @@ -208,6 +210,14 @@ func (c *controller) runHealthChecksLeased(ctx context.Context, client databroke }) } +func (c *controller) runUsageReporter(ctx context.Context, client databroker.DataBrokerServiceClient) error { + ur := usagereporter.New(c.api, c.bootstrapConfig.GetConfig().ZeroOrganizationID, time.Minute) + return retry.WithBackoff(ctx, "zero-usage-reporter", func(ctx context.Context) error { + // start the usage reporter + return ur.Run(ctx, client) + }) +} + func (c *controller) getEnvoyScrapeURL() string { return (&url.URL{ Scheme: "http", diff --git a/internal/zero/controller/usagereporter/usagereporter.go b/internal/zero/controller/usagereporter/usagereporter.go new file mode 100644 index 000000000..5ad6a039f --- /dev/null +++ b/internal/zero/controller/usagereporter/usagereporter.go @@ -0,0 +1,217 @@ +// Package usagereporter reports usage for a cluster. +// +// Usage is determined from session and user records in the databroker. The usage reporter +// uses SyncLatest and Sync to retrieve this data, builds a collection of records and then +// sends them to the Zero Cluster API every minute. +// +// All usage users are reported on start but only the changed users are reported while running. +// The Zero Cluster API is tolerant of redundant data. +package usagereporter + +import ( + "cmp" + "context" + "sync" + "time" + + backoff "github.com/cenkalti/backoff/v4" + set "github.com/hashicorp/go-set/v3" + "golang.org/x/sync/errgroup" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/zero/cluster" +) + +// API is the part of the Zero Cluster API used to report usage. +type API interface { + ReportUsage(ctx context.Context, req cluster.ReportUsageRequest) error +} + +type usageReporterRecord struct { + userID string + userEmail string + lastSignedInAt time.Time +} + +// A UsageReporter reports usage to the zero api. +type UsageReporter struct { + api API + organizationID string + reportInterval time.Duration + + mu sync.Mutex + byUserID map[string]usageReporterRecord + updates *set.Set[string] +} + +// New creates a new UsageReporter. +func New(api API, organizationID string, reportInterval time.Duration) *UsageReporter { + return &UsageReporter{ + api: api, + organizationID: organizationID, + reportInterval: reportInterval, + + byUserID: make(map[string]usageReporterRecord), + updates: set.New[string](0), + } +} + +// Run runs the usage reporter. +func (ur *UsageReporter) Run(ctx context.Context, client databroker.DataBrokerServiceClient) error { + ctx = log.Ctx(ctx).With().Str("organization-id", ur.organizationID).Logger().WithContext(ctx) + + // first initialize the user collection + serverVersion, latestSessionRecordVersion, latestUserRecordVersion, err := ur.runInit(ctx, client) + if err != nil { + return err + } + + // run the continuous sync calls and periodically report usage + return ur.runSync(ctx, client, serverVersion, latestSessionRecordVersion, latestUserRecordVersion) +} + +func (ur *UsageReporter) report(ctx context.Context, records []usageReporterRecord) error { + req := cluster.ReportUsageRequest{ + Users: convertUsageReporterRecords(ur.organizationID, records), + } + return backoff.Retry(func() error { + log.Debug(ctx).Int("updated-users", len(req.Users)).Msg("reporting usage") + err := ur.api.ReportUsage(ctx, req) + if err != nil { + log.Warn(ctx).Err(err).Msg("error reporting usage") + } + return err + }, backoff.WithContext(backoff.NewExponentialBackOff(), ctx)) +} + +func (ur *UsageReporter) runInit( + ctx context.Context, + client databroker.DataBrokerServiceClient, +) (serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64, err error) { + _, latestSessionRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateSession) + if err != nil { + return 0, 0, 0, err + } + + serverVersion, latestUserRecordVersion, err = databroker.SyncLatestRecords(ctx, client, ur.onUpdateUser) + if err != nil { + return 0, 0, 0, err + } + + return serverVersion, latestSessionRecordVersion, latestUserRecordVersion, nil +} + +func (ur *UsageReporter) runSync( + ctx context.Context, + client databroker.DataBrokerServiceClient, + serverVersion, latestSessionRecordVersion, latestUserRecordVersion uint64, +) error { + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + return databroker.SyncRecords(ctx, client, serverVersion, latestSessionRecordVersion, ur.onUpdateSession) + }) + eg.Go(func() error { + return databroker.SyncRecords(ctx, client, serverVersion, latestUserRecordVersion, ur.onUpdateUser) + }) + eg.Go(func() error { + return ur.runReporter(ctx) + }) + return eg.Wait() +} + +func (ur *UsageReporter) runReporter(ctx context.Context) error { + // every minute collect any updates and submit them to the API + timer := time.NewTicker(ur.reportInterval) + defer timer.Stop() + + for { + // collect the updated records since last run + ur.mu.Lock() + records := make([]usageReporterRecord, 0, ur.updates.Size()) + for userID := range ur.updates.Items() { + records = append(records, ur.byUserID[userID]) + } + ur.updates = set.New[string](0) + ur.mu.Unlock() + + if len(records) > 0 { + err := ur.report(ctx, records) + if err != nil { + return err + } + } + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + } + } +} + +func (ur *UsageReporter) onUpdateSession(s *session.Session) { + userID := s.GetUserId() + if userID == "" { + // ignore sessions without a user id + return + } + + ur.mu.Lock() + defer ur.mu.Unlock() + + r := ur.byUserID[userID] + nr := r + nr.lastSignedInAt = latest(nr.lastSignedInAt, s.GetIssuedAt().AsTime()) + nr.userID = userID + if nr != r { + ur.byUserID[userID] = nr + ur.updates.Insert(userID) + } +} + +func (ur *UsageReporter) onUpdateUser(u *user.User) { + userID := u.GetId() + if userID == "" { + // ignore users without a user id + return + } + + ur.mu.Lock() + defer ur.mu.Unlock() + + r := ur.byUserID[userID] + nr := r + nr.userID = userID + nr.userEmail = cmp.Or(nr.userEmail, u.GetEmail()) + if nr != r { + ur.byUserID[userID] = nr + ur.updates.Insert(userID) + } +} + +func convertUsageReporterRecords(organizationID string, records []usageReporterRecord) []cluster.ReportUsageUser { + var users []cluster.ReportUsageUser + for _, record := range records { + u := cluster.ReportUsageUser{ + LastSignedInAt: record.lastSignedInAt, + PseudonymousId: cryptutil.Pseudonymize(organizationID, record.userID), + } + if record.userEmail != "" { + u.PseudonymousEmail = cryptutil.Pseudonymize(organizationID, record.userEmail) + } + users = append(users, u) + } + return users +} + +// latest returns the latest time. +func latest(t1, t2 time.Time) time.Time { + if t2.After(t1) { + return t2 + } + return t1 +} diff --git a/internal/zero/controller/usagereporter/usagereporter_test.go b/internal/zero/controller/usagereporter/usagereporter_test.go new file mode 100644 index 000000000..6f031fd95 --- /dev/null +++ b/internal/zero/controller/usagereporter/usagereporter_test.go @@ -0,0 +1,156 @@ +package usagereporter + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/internal/testutil" + databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/zero/cluster" +) + +type mockAPI struct { + reportUsage func(ctx context.Context, req cluster.ReportUsageRequest) error +} + +func (m mockAPI) ReportUsage(ctx context.Context, req cluster.ReportUsageRequest) error { + return m.reportUsage(ctx, req) +} + +func TestUsageReporter(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + t.Cleanup(clearTimeout) + + ctx, cancel := context.WithCancel(ctx) + t.Cleanup(cancel) + + cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { + databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New()) + }) + t.Cleanup(func() { cc.Close() }) + + tm1 := time.Date(2024, time.September, 11, 11, 56, 0, 0, time.UTC) + + requests := make(chan cluster.ReportUsageRequest, 1) + + client := databrokerpb.NewDataBrokerServiceClient(cc) + ur := New(mockAPI{ + reportUsage: func(ctx context.Context, req cluster.ReportUsageRequest) error { + select { + case <-ctx.Done(): + return ctx.Err() + case requests <- req: + } + return nil + }, + }, "bQjwPpxcwJRbvsSMFgbZFkXmxFJ", time.Millisecond*100) + + eg, ctx := errgroup.WithContext(ctx) + eg.Go(func() error { + return ur.Run(ctx, client) + }) + eg.Go(func() error { + _, err := databrokerpb.Put(ctx, client, + &session.Session{ + Id: "S1a", + UserId: "U1", + IssuedAt: timestamppb.New(tm1), + }, + &session.Session{ + Id: "S1b", + UserId: "U1", + IssuedAt: timestamppb.New(tm1), + }) + if err != nil { + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-requests: + assert.Equal(t, cluster.ReportUsageRequest{ + Users: []cluster.ReportUsageUser{{ + LastSignedInAt: tm1, + PseudonymousId: "095xqqsjEEgYf5Yf+TAjWjooMQyh6jSV5SCPGe9eqvg=", + }}, + }, req, "should send a single usage record") + } + + _, err = databrokerpb.Put(ctx, client, + &user.User{ + Id: "U1", + Email: "u1@example.com", + }) + if err != nil { + return err + } + + select { + case <-ctx.Done(): + return ctx.Err() + case req := <-requests: + assert.Equal(t, cluster.ReportUsageRequest{ + Users: []cluster.ReportUsageUser{{ + LastSignedInAt: tm1, + PseudonymousEmail: "iq8/fj+uZaKitkWY12JIQgKJ5KIP+E0Cmy/HpxpdBXY=", + PseudonymousId: "095xqqsjEEgYf5Yf+TAjWjooMQyh6jSV5SCPGe9eqvg=", + }}, + }, req, "should send another usage record with the email set") + } + + cancel() + return nil + }) + err := eg.Wait() + if err != nil && !errors.Is(ctx.Err(), context.Canceled) { + assert.NoError(t, err) + } +} + +func Test_convertUsageReporterRecords(t *testing.T) { + t.Parallel() + + tm1 := time.Date(2024, time.September, 11, 11, 56, 0, 0, time.UTC) + + assert.Empty(t, convertUsageReporterRecords("XXX", nil)) + assert.Equal(t, []cluster.ReportUsageUser{{ + LastSignedInAt: tm1, + PseudonymousId: "T9V1yL/UueF/LVuF6XjoSNde0INElXG10zKepmyPke8=", + PseudonymousEmail: "8w5rtnZyv0EGkpHmTlkmupgb1jCzn/IxGCfvpdGGnvI=", + }}, convertUsageReporterRecords("XXX", []usageReporterRecord{{ + userID: "ID", + userEmail: "EMAIL@example.com", + lastSignedInAt: tm1, + }})) + assert.Equal(t, []cluster.ReportUsageUser{{ + LastSignedInAt: tm1, + PseudonymousId: "T9V1yL/UueF/LVuF6XjoSNde0INElXG10zKepmyPke8=", + }}, convertUsageReporterRecords("XXX", []usageReporterRecord{{ + userID: "ID", + lastSignedInAt: tm1, + }}), "should leave empty email") +} + +func Test_latest(t *testing.T) { + t.Parallel() + + tm1 := time.Date(2024, time.September, 11, 11, 56, 0, 0, time.UTC) + tm2 := time.Date(2024, time.September, 12, 11, 56, 0, 0, time.UTC) + + assert.Equal(t, tm2, latest(tm1, tm2)) + assert.Equal(t, tm2, latest(tm2, tm1), "should ignore ordering") + assert.Equal(t, tm1, latest(tm1, time.Time{}), "should handle zero time") +} diff --git a/pkg/cryptutil/pseudonymize.go b/pkg/cryptutil/pseudonymize.go new file mode 100644 index 000000000..c791a5696 --- /dev/null +++ b/pkg/cryptutil/pseudonymize.go @@ -0,0 +1,16 @@ +package cryptutil + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/base64" + "io" +) + +// Pseudonymize pseudonymizes data by computing the HMAC-SHA256 of the data. +func Pseudonymize(organizationID string, data string) string { + h := hmac.New(sha256.New, []byte(organizationID)) + _, _ = io.WriteString(h, data) + bs := h.Sum(nil) + return base64.StdEncoding.EncodeToString(bs) +} diff --git a/pkg/grpc/databroker/sync.go b/pkg/grpc/databroker/sync.go new file mode 100644 index 000000000..afbc66f84 --- /dev/null +++ b/pkg/grpc/databroker/sync.go @@ -0,0 +1,110 @@ +package databroker + +import ( + "context" + "errors" + "fmt" + "io" + + "google.golang.org/protobuf/proto" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +// SyncRecords calls fn for every record using Sync. +func SyncRecords[T any, TMessage interface { + *T + proto.Message +}]( + ctx context.Context, + client DataBrokerServiceClient, + serverVersion, latestRecordVersion uint64, + fn func(TMessage), +) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var msg TMessage = new(T) + stream, err := client.Sync(ctx, &SyncRequest{ + Type: protoutil.GetTypeURL(msg), + ServerVersion: serverVersion, + RecordVersion: latestRecordVersion, + }) + if err != nil { + return fmt.Errorf("error syncing %T: %w", msg, err) + } + + for { + res, err := stream.Recv() + switch { + case errors.Is(err, io.EOF): + return nil + case err != nil: + return fmt.Errorf("error receiving record for %T: %w", msg, err) + } + + msg = new(T) + err = res.GetRecord().GetData().UnmarshalTo(msg) + if err != nil { + log.Ctx(ctx).Error().Err(err). + Str("record-type", res.Record.Type). + Str("record-id", res.Record.GetId()). + Msgf("unexpected data in %T stream", msg) + continue + } + + fn(msg) + } +} + +// SyncLatestRecords calls fn for every record using SyncLatest. +func SyncLatestRecords[T any, TMessage interface { + *T + proto.Message +}]( + ctx context.Context, + client DataBrokerServiceClient, + fn func(TMessage), +) (serverVersion, latestRecordVersion uint64, err error) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var msg TMessage = new(T) + stream, err := client.SyncLatest(ctx, &SyncLatestRequest{ + Type: protoutil.GetTypeURL(msg), + }) + if err != nil { + return 0, 0, fmt.Errorf("error syncing latest %T: %w", msg, err) + } + + for { + res, err := stream.Recv() + switch { + case errors.Is(err, io.EOF): + return serverVersion, latestRecordVersion, nil + case err != nil: + return 0, 0, fmt.Errorf("error receiving record for latest %T: %w", msg, err) + } + + switch res := res.GetResponse().(type) { + case *SyncLatestResponse_Versions: + serverVersion = res.Versions.GetServerVersion() + latestRecordVersion = res.Versions.GetLatestRecordVersion() + case *SyncLatestResponse_Record: + msg = new(T) + err = res.Record.GetData().UnmarshalTo(msg) + if err != nil { + log.Ctx(ctx).Error().Err(err). + Str("record-type", res.Record.Type). + Str("record-id", res.Record.GetId()). + Msgf("unexpected data in latest %T stream", msg) + continue + } + + fn(msg) + default: + panic(fmt.Sprintf("unexpected response: %T", res)) + } + } +} diff --git a/pkg/grpc/databroker/sync_test.go b/pkg/grpc/databroker/sync_test.go new file mode 100644 index 000000000..85a454876 --- /dev/null +++ b/pkg/grpc/databroker/sync_test.go @@ -0,0 +1,67 @@ +package databroker_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + grpc "google.golang.org/grpc" + + "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/internal/testutil" + databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +func Test_SyncLatestRecords(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Minute) + defer clearTimeout() + + cc := testutil.NewGRPCServer(t, func(s *grpc.Server) { + databrokerpb.RegisterDataBrokerServiceServer(s, databroker.New()) + }) + + c := databrokerpb.NewDataBrokerServiceClient(cc) + + expected := []*user.User{ + {Id: "u1"}, + {Id: "u2"}, + {Id: "u3"}, + } + + for _, u := range expected { + _, err := c.Put(ctx, &databrokerpb.PutRequest{ + Records: []*databrokerpb.Record{ + databrokerpb.NewRecord(u), + }, + }) + require.NoError(t, err) + } + + // add a non-user record to make sure it gets ignored + _, err := c.Put(ctx, &databrokerpb.PutRequest{ + Records: []*databrokerpb.Record{ + { + Id: "u4", + Type: protoutil.GetTypeURL(new(user.User)), + Data: protoutil.NewAny(&session.Session{Id: "s1"}), + }, + }, + }) + require.NoError(t, err) + + var actual []*user.User + serverVersion, latestRecordVersion, err := databrokerpb.SyncLatestRecords(context.Background(), c, func(u *user.User) { + actual = append(actual, u) + }) + assert.NoError(t, err) + assert.NotZero(t, serverVersion) + assert.Equal(t, uint64(4), latestRecordVersion) + testutil.AssertProtoEqual(t, expected, actual) +} diff --git a/pkg/zero/cluster/client.gen.go b/pkg/zero/cluster/client.gen.go index 4ec9ec035..591a435a1 100644 --- a/pkg/zero/cluster/client.gen.go +++ b/pkg/zero/cluster/client.gen.go @@ -628,6 +628,8 @@ func (r ExchangeClusterIdentityTokenResp) StatusCode() int { type ReportUsageResp struct { Body []byte HTTPResponse *http.Response + JSON400 *ErrorResponse + JSON500 *ErrorResponse } // Status returns HTTPResponse.Status @@ -937,6 +939,23 @@ func ParseReportUsageResp(rsp *http.Response) (*ReportUsageResp, error) { HTTPResponse: rsp, } + switch { + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400: + var dest ErrorResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON400 = &dest + + case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500: + var dest ErrorResponse + if err := json.Unmarshal(bodyBytes, &dest); err != nil { + return nil, err + } + response.JSON500 = &dest + + } + return response, nil } @@ -1085,6 +1104,22 @@ func (r *ReportUsageResp) GetHTTPResponse() *http.Response { return r.HTTPResponse } +// GetBadRequestError implements apierror.APIResponse +func (r *ReportUsageResp) GetBadRequestError() (string, bool) { + if r.JSON400 == nil { + return "", false + } + return r.JSON400.Error, true +} + +// GetInternalServerError implements apierror.APIResponse +func (r *ReportUsageResp) GetInternalServerError() (string, bool) { + if r.JSON500 == nil { + return "", false + } + return r.JSON500.Error, true +} + // GetValue implements apierror.APIResponse func (r *ReportUsageResp) GetValue() *EmptyResponse { if r.StatusCode()/100 != 2 { diff --git a/pkg/zero/cluster/openapi.yaml b/pkg/zero/cluster/openapi.yaml index 10d20b0fb..451da1fdd 100644 --- a/pkg/zero/cluster/openapi.yaml +++ b/pkg/zero/cluster/openapi.yaml @@ -163,6 +163,18 @@ paths: responses: "204": description: OK + "400": + description: Bad Request + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + "500": + description: Internal Server Error + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" components: parameters: diff --git a/pkg/zero/cluster/server.gen.go b/pkg/zero/cluster/server.gen.go index d8cd60eec..b97bbaa15 100644 --- a/pkg/zero/cluster/server.gen.go +++ b/pkg/zero/cluster/server.gen.go @@ -534,6 +534,24 @@ func (response ReportUsage204Response) VisitReportUsageResponse(w http.ResponseW return nil } +type ReportUsage400JSONResponse ErrorResponse + +func (response ReportUsage400JSONResponse) VisitReportUsageResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(400) + + return json.NewEncoder(w).Encode(response) +} + +type ReportUsage500JSONResponse ErrorResponse + +func (response ReportUsage500JSONResponse) VisitReportUsageResponse(w http.ResponseWriter) error { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(500) + + return json.NewEncoder(w).Encode(response) +} + // StrictServerInterface represents all server handlers. type StrictServerInterface interface {