From ae4d266fa4ff925572e6810e79409a5a4b58c687 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Mon, 19 Aug 2024 16:51:23 -0600 Subject: [PATCH] core/grpc: add IterateAll method --- internal/zero/telemetry/sessions/sessions.go | 25 ++--- pkg/grpc/databroker/generic.go | 81 ++++++++++++++ pkg/grpc/databroker/generic_test.go | 110 +++++++++++++++++++ 3 files changed, 198 insertions(+), 18 deletions(-) create mode 100644 pkg/grpc/databroker/generic.go create mode 100644 pkg/grpc/databroker/generic_test.go diff --git a/internal/zero/telemetry/sessions/sessions.go b/internal/zero/telemetry/sessions/sessions.go index bae7deb3c..a9dffdcc1 100644 --- a/internal/zero/telemetry/sessions/sessions.go +++ b/internal/zero/telemetry/sessions/sessions.go @@ -8,43 +8,32 @@ import ( "github.com/pomerium/pomerium/internal/sets" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" - "github.com/pomerium/pomerium/pkg/protoutil" ) -var sessionTypeURL = protoutil.GetTypeURL(new(session.Session)) - // CurrentUsers returns a list of users active within the current UTC day func CurrentUsers( ctx context.Context, client databroker.DataBrokerServiceClient, ) ([]string, error) { - records, _, _, err := databroker.InitialSync(ctx, client, &databroker.SyncLatestRequest{ - Type: sessionTypeURL, - }) - if err != nil { - return nil, fmt.Errorf("fetching sessions: %w", err) - } - users := sets.NewHash[string]() utcNow := time.Now().UTC() threshold := time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC) - for _, record := range records { - var s session.Session - err := record.GetData().UnmarshalTo(&s) + for s, err := range databroker.IterateAll[session.Session](ctx, client) { if err != nil { - return nil, fmt.Errorf("unmarshaling session: %w", err) + return nil, fmt.Errorf("error fetching sessions: %w", err) } - if s.UserId == "" { // session creation is in progress + + if s.Object.GetUserId() == "" { // session creation is in progress continue } - if s.AccessedAt == nil { + if s.Object.GetAccessedAt() == nil { continue } - if s.AccessedAt.AsTime().Before(threshold) { + if s.Object.GetAccessedAt().AsTime().Before(threshold) { continue } - users.Add(s.UserId) + users.Add(s.Object.GetUserId()) } return users.Items(), nil diff --git a/pkg/grpc/databroker/generic.go b/pkg/grpc/databroker/generic.go new file mode 100644 index 000000000..8e47bfc43 --- /dev/null +++ b/pkg/grpc/databroker/generic.go @@ -0,0 +1,81 @@ +package databroker + +import ( + "context" + "errors" + "fmt" + "io" + "iter" + + "google.golang.org/protobuf/proto" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +// A GenericRecord is a record with its associated unmarshaled object. +type GenericRecord[T proto.Message] struct { + *Record + Object T +} + +// IterateAll iterates through all the records using a SyncLatest call. +func IterateAll[T any, TMessage interface { + *T + proto.Message +}]( + ctx context.Context, + client DataBrokerServiceClient, +) iter.Seq2[GenericRecord[TMessage], error] { + var zero GenericRecord[TMessage] + return func(yield func(GenericRecord[TMessage], error) bool) { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + var msg any = new(T) + stream, err := client.SyncLatest(ctx, &SyncLatestRequest{ + Type: protoutil.GetTypeURL(msg.(TMessage)), + }) + if err != nil { + _ = yield(zero, err) + return + } + + for { + res, err := stream.Recv() + switch { + case errors.Is(err, io.EOF): + // all done + return + case err != nil: + _ = yield(zero, err) + return + } + + switch res := res.GetResponse().(type) { + case *SyncLatestResponse_Versions: + // ignore versions + continue + case *SyncLatestResponse_Record: + // ignore records with no data + if res.Record.GetData() == nil { + continue + } + + gr := GenericRecord[TMessage]{ + Record: res.Record, + } + var msg any = new(T) + gr.Object = msg.(TMessage) + err = res.Record.GetData().UnmarshalTo(gr.Object) + if err != nil { + log.Error(ctx).Err(err).Msg("databroker: unexpected object found in databroker record") + } else if !yield(gr, nil) { + return + } + default: + panic(fmt.Sprintf("unexpected response: %T", res)) + } + } + } +} diff --git a/pkg/grpc/databroker/generic_test.go b/pkg/grpc/databroker/generic_test.go new file mode 100644 index 000000000..651be3899 --- /dev/null +++ b/pkg/grpc/databroker/generic_test.go @@ -0,0 +1,110 @@ +package databroker_test + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + grpc "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + + "github.com/pomerium/pomerium/internal/testutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" +) + +func TestIterateAll(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + li, err := net.Listen("tcp", "127.0.0.1:0") + if !assert.NoError(t, err) { + return + } + defer li.Close() + + r1 := databroker.NewRecord(&session.Session{ + Id: "s1", + }) + r2 := databroker.NewRecord(&user.User{ + Id: "u1", + }) + r3 := databroker.NewRecord(&session.Session{ + Id: "s2", + }) + r4 := &databroker.Record{ + Id: "unknown1", + Type: "type.googleapis.com/session.Session", + } + + m := &mockServer{ + syncLatest: func(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error { + assert.Equal(t, "type.googleapis.com/session.Session", req.GetType()) + + require.NoError(t, stream.Send(&databroker.SyncLatestResponse{ + Response: &databroker.SyncLatestResponse_Versions{ + Versions: &databroker.Versions{ + ServerVersion: 123, + LatestRecordVersion: 1, + }, + }, + })) + + require.NoError(t, stream.Send(&databroker.SyncLatestResponse{ + Response: &databroker.SyncLatestResponse_Record{ + Record: r1, + }, + })) + require.NoError(t, stream.Send(&databroker.SyncLatestResponse{ + Response: &databroker.SyncLatestResponse_Record{ + Record: r2, + }, + })) + require.NoError(t, stream.Send(&databroker.SyncLatestResponse{ + Response: &databroker.SyncLatestResponse_Record{ + Record: r3, + }, + })) + require.NoError(t, stream.Send(&databroker.SyncLatestResponse{ + Response: &databroker.SyncLatestResponse_Record{ + Record: r4, + }, + })) + + return nil + }, + } + + srv := grpc.NewServer() + databroker.RegisterDataBrokerServiceServer(srv, m) + go srv.Serve(li) + + cc, err := grpc.NewClient(li.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + defer cc.Close() + + var records []*databroker.Record + c := databroker.NewDataBrokerServiceClient(cc) + for record, err := range databroker.IterateAll[session.Session](ctx, c) { + require.NoError(t, err) + records = append(records, record.Record) + } + + testutil.AssertProtoEqual(t, []*databroker.Record{r1, r3}, records) +} + +type mockServer struct { + databroker.DataBrokerServiceServer + + syncLatest func(*databroker.SyncLatestRequest, databroker.DataBrokerService_SyncLatestServer) error +} + +func (m *mockServer) SyncLatest(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error { + return m.syncLatest(req, stream) +}