package databroker import ( "context" "errors" "fmt" "io" "google.golang.org/protobuf/proto" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/protoutil" ) // SyncRecords calls fn for every record using Sync. func SyncRecords[T any, TMessage interface { *T proto.Message }]( ctx context.Context, client DataBrokerServiceClient, serverVersion, latestRecordVersion uint64, fn func(TMessage), ) error { ctx, cancel := context.WithCancel(ctx) defer cancel() var msg TMessage = new(T) stream, err := client.Sync(ctx, &SyncRequest{ Type: protoutil.GetTypeURL(msg), ServerVersion: serverVersion, RecordVersion: latestRecordVersion, }) if err != nil { return fmt.Errorf("error syncing %T: %w", msg, err) } for { res, err := stream.Recv() switch { case errors.Is(err, io.EOF): return nil case err != nil: return fmt.Errorf("error receiving record for %T: %w", msg, err) } msg = new(T) err = res.GetRecord().GetData().UnmarshalTo(msg) if err != nil { log.Ctx(ctx).Error().Err(err). Str("record-type", res.Record.Type). Str("record-id", res.Record.GetId()). Msgf("unexpected data in %T stream", msg) continue } fn(msg) } } // SyncLatestRecords calls fn for every record using SyncLatest. func SyncLatestRecords[T any, TMessage interface { *T proto.Message }]( ctx context.Context, client DataBrokerServiceClient, fn func(TMessage), ) (serverVersion, latestRecordVersion uint64, err error) { ctx, cancel := context.WithCancel(ctx) defer cancel() var msg TMessage = new(T) stream, err := client.SyncLatest(ctx, &SyncLatestRequest{ Type: protoutil.GetTypeURL(msg), }) if err != nil { return 0, 0, fmt.Errorf("error syncing latest %T: %w", msg, err) } for { res, err := stream.Recv() switch { case errors.Is(err, io.EOF): return serverVersion, latestRecordVersion, nil case err != nil: return 0, 0, fmt.Errorf("error receiving record for latest %T: %w", msg, err) } switch res := res.GetResponse().(type) { case *SyncLatestResponse_Versions: serverVersion = res.Versions.GetServerVersion() latestRecordVersion = res.Versions.GetLatestRecordVersion() case *SyncLatestResponse_Record: msg = new(T) err = res.Record.GetData().UnmarshalTo(msg) if err != nil { log.Ctx(ctx).Error().Err(err). Str("record-type", res.Record.Type). Str("record-id", res.Record.GetId()). Msgf("unexpected data in latest %T stream", msg) continue } fn(msg) default: panic(fmt.Sprintf("unexpected response: %T", res)) } } }