Revert "core/grpc: add IterateAll method" (#5234)

Revert "core/grpc: add IterateAll method (#5227)"

This reverts commit 3961098681.
This commit is contained in:
Caleb Doxsey 2024-08-23 10:35:46 -06:00 committed by GitHub
parent 99d7a73cef
commit 98cea10421
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 18 additions and 198 deletions

View file

@ -8,32 +8,43 @@ import (
"github.com/pomerium/pomerium/internal/sets" "github.com/pomerium/pomerium/internal/sets"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/protoutil"
) )
var sessionTypeURL = protoutil.GetTypeURL(new(session.Session))
// CurrentUsers returns a list of users active within the current UTC day // CurrentUsers returns a list of users active within the current UTC day
func CurrentUsers( func CurrentUsers(
ctx context.Context, ctx context.Context,
client databroker.DataBrokerServiceClient, client databroker.DataBrokerServiceClient,
) ([]string, error) { ) ([]string, error) {
records, _, _, err := databroker.InitialSync(ctx, client, &databroker.SyncLatestRequest{
Type: sessionTypeURL,
})
if err != nil {
return nil, fmt.Errorf("fetching sessions: %w", err)
}
users := sets.NewHash[string]() users := sets.NewHash[string]()
utcNow := time.Now().UTC() utcNow := time.Now().UTC()
threshold := time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC) threshold := time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
for s, err := range databroker.IterateAll[session.Session](ctx, client) { for _, record := range records {
var s session.Session
err := record.GetData().UnmarshalTo(&s)
if err != nil { if err != nil {
return nil, fmt.Errorf("error fetching sessions: %w", err) return nil, fmt.Errorf("unmarshaling session: %w", err)
} }
if s.UserId == "" { // session creation is in progress
if s.Object.GetUserId() == "" { // session creation is in progress
continue continue
} }
if s.Object.GetAccessedAt() == nil { if s.AccessedAt == nil {
continue continue
} }
if s.Object.GetAccessedAt().AsTime().Before(threshold) { if s.AccessedAt.AsTime().Before(threshold) {
continue continue
} }
users.Add(s.Object.GetUserId()) users.Add(s.UserId)
} }
return users.Items(), nil return users.Items(), nil

View file

@ -1,81 +0,0 @@
package databroker
import (
"context"
"errors"
"fmt"
"io"
"iter"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/protoutil"
)
// A GenericRecord is a record with its associated unmarshaled object.
type GenericRecord[T proto.Message] struct {
*Record
Object T
}
// IterateAll iterates through all the records using a SyncLatest call.
func IterateAll[T any, TMessage interface {
*T
proto.Message
}](
ctx context.Context,
client DataBrokerServiceClient,
) iter.Seq2[GenericRecord[TMessage], error] {
var zero GenericRecord[TMessage]
return func(yield func(GenericRecord[TMessage], error) bool) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var msg any = new(T)
stream, err := client.SyncLatest(ctx, &SyncLatestRequest{
Type: protoutil.GetTypeURL(msg.(TMessage)),
})
if err != nil {
_ = yield(zero, err)
return
}
for {
res, err := stream.Recv()
switch {
case errors.Is(err, io.EOF):
// all done
return
case err != nil:
_ = yield(zero, err)
return
}
switch res := res.GetResponse().(type) {
case *SyncLatestResponse_Versions:
// ignore versions
continue
case *SyncLatestResponse_Record:
// ignore records with no data
if res.Record.GetData() == nil {
continue
}
gr := GenericRecord[TMessage]{
Record: res.Record,
}
var msg any = new(T)
gr.Object = msg.(TMessage)
err = res.Record.GetData().UnmarshalTo(gr.Object)
if err != nil {
log.Error(ctx).Err(err).Str("record-type", res.Record.GetType()).Str("record-id", res.Record.GetId()).Msg("databroker: unexpected object found in databroker record")
} else if !yield(gr, nil) {
return
}
default:
panic(fmt.Sprintf("unexpected response: %T", res))
}
}
}
}

View file

@ -1,110 +0,0 @@
package databroker_test
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
)
func TestIterateAll(t *testing.T) {
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
li, err := net.Listen("tcp", "127.0.0.1:0")
if !assert.NoError(t, err) {
return
}
defer li.Close()
r1 := databroker.NewRecord(&session.Session{
Id: "s1",
})
r2 := databroker.NewRecord(&user.User{
Id: "u1",
})
r3 := databroker.NewRecord(&session.Session{
Id: "s2",
})
r4 := &databroker.Record{
Id: "unknown1",
Type: "type.googleapis.com/session.Session",
}
m := &mockServer{
syncLatest: func(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error {
assert.Equal(t, "type.googleapis.com/session.Session", req.GetType())
require.NoError(t, stream.Send(&databroker.SyncLatestResponse{
Response: &databroker.SyncLatestResponse_Versions{
Versions: &databroker.Versions{
ServerVersion: 123,
LatestRecordVersion: 1,
},
},
}))
require.NoError(t, stream.Send(&databroker.SyncLatestResponse{
Response: &databroker.SyncLatestResponse_Record{
Record: r1,
},
}))
require.NoError(t, stream.Send(&databroker.SyncLatestResponse{
Response: &databroker.SyncLatestResponse_Record{
Record: r2,
},
}))
require.NoError(t, stream.Send(&databroker.SyncLatestResponse{
Response: &databroker.SyncLatestResponse_Record{
Record: r3,
},
}))
require.NoError(t, stream.Send(&databroker.SyncLatestResponse{
Response: &databroker.SyncLatestResponse_Record{
Record: r4,
},
}))
return nil
},
}
srv := grpc.NewServer()
databroker.RegisterDataBrokerServiceServer(srv, m)
go srv.Serve(li)
cc, err := grpc.NewClient(li.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer cc.Close()
var records []*databroker.Record
c := databroker.NewDataBrokerServiceClient(cc)
for record, err := range databroker.IterateAll[session.Session](ctx, c) {
require.NoError(t, err)
records = append(records, record.Record)
}
testutil.AssertProtoEqual(t, []*databroker.Record{r1, r3}, records)
}
type mockServer struct {
databroker.DataBrokerServiceServer
syncLatest func(*databroker.SyncLatestRequest, databroker.DataBrokerService_SyncLatestServer) error
}
func (m *mockServer) SyncLatest(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error {
return m.syncLatest(req, stream)
}