mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
* wip * add response * handle empty email * use set, update log * add test * add coalesce, comments, test * add test, fix bug * use builtin cmp.Or * remove wait ready call * use api error
110 lines
2.5 KiB
Go
110 lines
2.5 KiB
Go
package databroker
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
|
|
"google.golang.org/protobuf/proto"
|
|
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
)
|
|
|
|
// SyncRecords calls fn for every record using Sync.
|
|
func SyncRecords[T any, TMessage interface {
|
|
*T
|
|
proto.Message
|
|
}](
|
|
ctx context.Context,
|
|
client DataBrokerServiceClient,
|
|
serverVersion, latestRecordVersion uint64,
|
|
fn func(TMessage),
|
|
) error {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
var msg TMessage = new(T)
|
|
stream, err := client.Sync(ctx, &SyncRequest{
|
|
Type: protoutil.GetTypeURL(msg),
|
|
ServerVersion: serverVersion,
|
|
RecordVersion: latestRecordVersion,
|
|
})
|
|
if err != nil {
|
|
return fmt.Errorf("error syncing %T: %w", msg, err)
|
|
}
|
|
|
|
for {
|
|
res, err := stream.Recv()
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
return nil
|
|
case err != nil:
|
|
return fmt.Errorf("error receiving record for %T: %w", msg, err)
|
|
}
|
|
|
|
msg = new(T)
|
|
err = res.GetRecord().GetData().UnmarshalTo(msg)
|
|
if err != nil {
|
|
log.Ctx(ctx).Error().Err(err).
|
|
Str("record-type", res.Record.Type).
|
|
Str("record-id", res.Record.GetId()).
|
|
Msgf("unexpected data in %T stream", msg)
|
|
continue
|
|
}
|
|
|
|
fn(msg)
|
|
}
|
|
}
|
|
|
|
// SyncLatestRecords calls fn for every record using SyncLatest.
|
|
func SyncLatestRecords[T any, TMessage interface {
|
|
*T
|
|
proto.Message
|
|
}](
|
|
ctx context.Context,
|
|
client DataBrokerServiceClient,
|
|
fn func(TMessage),
|
|
) (serverVersion, latestRecordVersion uint64, err error) {
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
defer cancel()
|
|
|
|
var msg TMessage = new(T)
|
|
stream, err := client.SyncLatest(ctx, &SyncLatestRequest{
|
|
Type: protoutil.GetTypeURL(msg),
|
|
})
|
|
if err != nil {
|
|
return 0, 0, fmt.Errorf("error syncing latest %T: %w", msg, err)
|
|
}
|
|
|
|
for {
|
|
res, err := stream.Recv()
|
|
switch {
|
|
case errors.Is(err, io.EOF):
|
|
return serverVersion, latestRecordVersion, nil
|
|
case err != nil:
|
|
return 0, 0, fmt.Errorf("error receiving record for latest %T: %w", msg, err)
|
|
}
|
|
|
|
switch res := res.GetResponse().(type) {
|
|
case *SyncLatestResponse_Versions:
|
|
serverVersion = res.Versions.GetServerVersion()
|
|
latestRecordVersion = res.Versions.GetLatestRecordVersion()
|
|
case *SyncLatestResponse_Record:
|
|
msg = new(T)
|
|
err = res.Record.GetData().UnmarshalTo(msg)
|
|
if err != nil {
|
|
log.Ctx(ctx).Error().Err(err).
|
|
Str("record-type", res.Record.Type).
|
|
Str("record-id", res.Record.GetId()).
|
|
Msgf("unexpected data in latest %T stream", msg)
|
|
continue
|
|
}
|
|
|
|
fn(msg)
|
|
default:
|
|
panic(fmt.Sprintf("unexpected response: %T", res))
|
|
}
|
|
}
|
|
}
|