mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-06 21:04:39 +02:00
Revert "core/grpc: add IterateAll method" (#5234)
Revert "core/grpc: add IterateAll method (#5227)"
This reverts commit 3961098681
.
This commit is contained in:
parent
99d7a73cef
commit
98cea10421
3 changed files with 18 additions and 198 deletions
|
@ -8,32 +8,43 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/sets"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
var sessionTypeURL = protoutil.GetTypeURL(new(session.Session))
|
||||
|
||||
// CurrentUsers returns a list of users active within the current UTC day
|
||||
func CurrentUsers(
|
||||
ctx context.Context,
|
||||
client databroker.DataBrokerServiceClient,
|
||||
) ([]string, error) {
|
||||
records, _, _, err := databroker.InitialSync(ctx, client, &databroker.SyncLatestRequest{
|
||||
Type: sessionTypeURL,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("fetching sessions: %w", err)
|
||||
}
|
||||
|
||||
users := sets.NewHash[string]()
|
||||
utcNow := time.Now().UTC()
|
||||
threshold := time.Date(utcNow.Year(), utcNow.Month(), utcNow.Day(), 0, 0, 0, 0, time.UTC)
|
||||
|
||||
for s, err := range databroker.IterateAll[session.Session](ctx, client) {
|
||||
for _, record := range records {
|
||||
var s session.Session
|
||||
err := record.GetData().UnmarshalTo(&s)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error fetching sessions: %w", err)
|
||||
return nil, fmt.Errorf("unmarshaling session: %w", err)
|
||||
}
|
||||
|
||||
if s.Object.GetUserId() == "" { // session creation is in progress
|
||||
if s.UserId == "" { // session creation is in progress
|
||||
continue
|
||||
}
|
||||
if s.Object.GetAccessedAt() == nil {
|
||||
if s.AccessedAt == nil {
|
||||
continue
|
||||
}
|
||||
if s.Object.GetAccessedAt().AsTime().Before(threshold) {
|
||||
if s.AccessedAt.AsTime().Before(threshold) {
|
||||
continue
|
||||
}
|
||||
users.Add(s.Object.GetUserId())
|
||||
users.Add(s.UserId)
|
||||
}
|
||||
|
||||
return users.Items(), nil
|
||||
|
|
|
@ -1,81 +0,0 @@
|
|||
package databroker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"iter"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
// A GenericRecord is a record with its associated unmarshaled object.
|
||||
type GenericRecord[T proto.Message] struct {
|
||||
*Record
|
||||
Object T
|
||||
}
|
||||
|
||||
// IterateAll iterates through all the records using a SyncLatest call.
|
||||
func IterateAll[T any, TMessage interface {
|
||||
*T
|
||||
proto.Message
|
||||
}](
|
||||
ctx context.Context,
|
||||
client DataBrokerServiceClient,
|
||||
) iter.Seq2[GenericRecord[TMessage], error] {
|
||||
var zero GenericRecord[TMessage]
|
||||
return func(yield func(GenericRecord[TMessage], error) bool) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
var msg any = new(T)
|
||||
stream, err := client.SyncLatest(ctx, &SyncLatestRequest{
|
||||
Type: protoutil.GetTypeURL(msg.(TMessage)),
|
||||
})
|
||||
if err != nil {
|
||||
_ = yield(zero, err)
|
||||
return
|
||||
}
|
||||
|
||||
for {
|
||||
res, err := stream.Recv()
|
||||
switch {
|
||||
case errors.Is(err, io.EOF):
|
||||
// all done
|
||||
return
|
||||
case err != nil:
|
||||
_ = yield(zero, err)
|
||||
return
|
||||
}
|
||||
|
||||
switch res := res.GetResponse().(type) {
|
||||
case *SyncLatestResponse_Versions:
|
||||
// ignore versions
|
||||
continue
|
||||
case *SyncLatestResponse_Record:
|
||||
// ignore records with no data
|
||||
if res.Record.GetData() == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
gr := GenericRecord[TMessage]{
|
||||
Record: res.Record,
|
||||
}
|
||||
var msg any = new(T)
|
||||
gr.Object = msg.(TMessage)
|
||||
err = res.Record.GetData().UnmarshalTo(gr.Object)
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Str("record-type", res.Record.GetType()).Str("record-id", res.Record.GetId()).Msg("databroker: unexpected object found in databroker record")
|
||||
} else if !yield(gr, nil) {
|
||||
return
|
||||
}
|
||||
default:
|
||||
panic(fmt.Sprintf("unexpected response: %T", res))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
|
@ -1,110 +0,0 @@
|
|||
package databroker_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
grpc "google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/user"
|
||||
)
|
||||
|
||||
func TestIterateAll(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
li, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
defer li.Close()
|
||||
|
||||
r1 := databroker.NewRecord(&session.Session{
|
||||
Id: "s1",
|
||||
})
|
||||
r2 := databroker.NewRecord(&user.User{
|
||||
Id: "u1",
|
||||
})
|
||||
r3 := databroker.NewRecord(&session.Session{
|
||||
Id: "s2",
|
||||
})
|
||||
r4 := &databroker.Record{
|
||||
Id: "unknown1",
|
||||
Type: "type.googleapis.com/session.Session",
|
||||
}
|
||||
|
||||
m := &mockServer{
|
||||
syncLatest: func(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error {
|
||||
assert.Equal(t, "type.googleapis.com/session.Session", req.GetType())
|
||||
|
||||
require.NoError(t, stream.Send(&databroker.SyncLatestResponse{
|
||||
Response: &databroker.SyncLatestResponse_Versions{
|
||||
Versions: &databroker.Versions{
|
||||
ServerVersion: 123,
|
||||
LatestRecordVersion: 1,
|
||||
},
|
||||
},
|
||||
}))
|
||||
|
||||
require.NoError(t, stream.Send(&databroker.SyncLatestResponse{
|
||||
Response: &databroker.SyncLatestResponse_Record{
|
||||
Record: r1,
|
||||
},
|
||||
}))
|
||||
require.NoError(t, stream.Send(&databroker.SyncLatestResponse{
|
||||
Response: &databroker.SyncLatestResponse_Record{
|
||||
Record: r2,
|
||||
},
|
||||
}))
|
||||
require.NoError(t, stream.Send(&databroker.SyncLatestResponse{
|
||||
Response: &databroker.SyncLatestResponse_Record{
|
||||
Record: r3,
|
||||
},
|
||||
}))
|
||||
require.NoError(t, stream.Send(&databroker.SyncLatestResponse{
|
||||
Response: &databroker.SyncLatestResponse_Record{
|
||||
Record: r4,
|
||||
},
|
||||
}))
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
srv := grpc.NewServer()
|
||||
databroker.RegisterDataBrokerServiceServer(srv, m)
|
||||
go srv.Serve(li)
|
||||
|
||||
cc, err := grpc.NewClient(li.Addr().String(), grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer cc.Close()
|
||||
|
||||
var records []*databroker.Record
|
||||
c := databroker.NewDataBrokerServiceClient(cc)
|
||||
for record, err := range databroker.IterateAll[session.Session](ctx, c) {
|
||||
require.NoError(t, err)
|
||||
records = append(records, record.Record)
|
||||
}
|
||||
|
||||
testutil.AssertProtoEqual(t, []*databroker.Record{r1, r3}, records)
|
||||
}
|
||||
|
||||
type mockServer struct {
|
||||
databroker.DataBrokerServiceServer
|
||||
|
||||
syncLatest func(*databroker.SyncLatestRequest, databroker.DataBrokerService_SyncLatestServer) error
|
||||
}
|
||||
|
||||
func (m *mockServer) SyncLatest(req *databroker.SyncLatestRequest, stream databroker.DataBrokerService_SyncLatestServer) error {
|
||||
return m.syncLatest(req, stream)
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue