pomerium/pkg/grpc/databroker/sync.go
Caleb Doxsey 146efc1b13
core/zero: add usage reporter (#5281)
* wip

* add response

* handle empty email

* use set, update log

* add test

* add coalesce, comments, test

* add test, fix bug

* use builtin cmp.Or

* remove wait ready call

* use api error
2024-09-12 15:45:54 -06:00

110 lines
2.5 KiB
Go

package databroker
import (
"context"
"errors"
"fmt"
"io"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/protoutil"
)
// SyncRecords calls fn for every record using Sync.
func SyncRecords[T any, TMessage interface {
*T
proto.Message
}](
ctx context.Context,
client DataBrokerServiceClient,
serverVersion, latestRecordVersion uint64,
fn func(TMessage),
) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var msg TMessage = new(T)
stream, err := client.Sync(ctx, &SyncRequest{
Type: protoutil.GetTypeURL(msg),
ServerVersion: serverVersion,
RecordVersion: latestRecordVersion,
})
if err != nil {
return fmt.Errorf("error syncing %T: %w", msg, err)
}
for {
res, err := stream.Recv()
switch {
case errors.Is(err, io.EOF):
return nil
case err != nil:
return fmt.Errorf("error receiving record for %T: %w", msg, err)
}
msg = new(T)
err = res.GetRecord().GetData().UnmarshalTo(msg)
if err != nil {
log.Ctx(ctx).Error().Err(err).
Str("record-type", res.Record.Type).
Str("record-id", res.Record.GetId()).
Msgf("unexpected data in %T stream", msg)
continue
}
fn(msg)
}
}
// SyncLatestRecords calls fn for every record using SyncLatest.
func SyncLatestRecords[T any, TMessage interface {
*T
proto.Message
}](
ctx context.Context,
client DataBrokerServiceClient,
fn func(TMessage),
) (serverVersion, latestRecordVersion uint64, err error) {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
var msg TMessage = new(T)
stream, err := client.SyncLatest(ctx, &SyncLatestRequest{
Type: protoutil.GetTypeURL(msg),
})
if err != nil {
return 0, 0, fmt.Errorf("error syncing latest %T: %w", msg, err)
}
for {
res, err := stream.Recv()
switch {
case errors.Is(err, io.EOF):
return serverVersion, latestRecordVersion, nil
case err != nil:
return 0, 0, fmt.Errorf("error receiving record for latest %T: %w", msg, err)
}
switch res := res.GetResponse().(type) {
case *SyncLatestResponse_Versions:
serverVersion = res.Versions.GetServerVersion()
latestRecordVersion = res.Versions.GetLatestRecordVersion()
case *SyncLatestResponse_Record:
msg = new(T)
err = res.Record.GetData().UnmarshalTo(msg)
if err != nil {
log.Ctx(ctx).Error().Err(err).
Str("record-type", res.Record.Type).
Str("record-id", res.Record.GetId()).
Msgf("unexpected data in latest %T stream", msg)
continue
}
fn(msg)
default:
panic(fmt.Sprintf("unexpected response: %T", res))
}
}
}