package authorize import ( "context" "io" "time" "github.com/pomerium/pomerium/internal/telemetry/trace" backoff "github.com/cenkalti/backoff/v4" "golang.org/x/sync/errgroup" "google.golang.org/protobuf/types/known/emptypb" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/grpc/databroker" ) // Run runs the authorize server. func (a *Authorize) Run(ctx context.Context) error { eg, ctx := errgroup.WithContext(ctx) updateTypes := make(chan []string) eg.Go(func() error { return a.runTypesSyncer(ctx, updateTypes) }) eg.Go(func() error { return a.runDataSyncer(ctx, updateTypes) }) return eg.Wait() } func (a *Authorize) runTypesSyncer(ctx context.Context, updateTypes chan<- []string) error { log.Info().Msg("starting type sync") return tryForever(ctx, func(backoff interface{ Reset() }) error { ctx, span := trace.StartSpan(ctx, "authorize.dataBrokerClient.Sync") defer span.End() stream, err := a.dataBrokerClient.SyncTypes(ctx, new(emptypb.Empty)) if err != nil { return err } for { res, err := stream.Recv() if err == io.EOF { return nil } else if err != nil { return err } backoff.Reset() select { case <-stream.Context().Done(): return stream.Context().Err() case updateTypes <- res.GetTypes(): } } }) } func (a *Authorize) runDataSyncer(ctx context.Context, updateTypes <-chan []string) error { eg, ctx := errgroup.WithContext(ctx) eg.Go(func() error { seen := map[string]struct{}{} for { select { case <-ctx.Done(): return ctx.Err() case types := <-updateTypes: for _, dataType := range types { dataType := dataType if _, ok := seen[dataType]; !ok { eg.Go(func() error { return a.runDataTypeSyncer(ctx, dataType) }) seen[dataType] = struct{}{} } } } } }) return eg.Wait() } func (a *Authorize) runDataTypeSyncer(ctx context.Context, typeURL string) error { var serverVersion, recordVersion string log.Info().Str("type_url", typeURL).Msg("starting data initial load") ctx, span := trace.StartSpan(ctx, "authorize.dataBrokerClient.GetAll") backoff := backoff.NewExponentialBackOff() for { res, err := a.dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{ Type: typeURL, }) if err != nil { log.Warn().Err(err).Str("type_url", typeURL).Msg("error getting data") select { case <-ctx.Done(): return ctx.Err() case <-time.After(backoff.NextBackOff()): } continue } serverVersion = res.GetServerVersion() recordVersion = res.GetRecordVersion() for _, record := range res.GetRecords() { a.updateRecord(record) } break } span.End() log.Info().Str("type_url", typeURL).Msg("starting data syncer") return tryForever(ctx, func(backoff interface{ Reset() }) error { ctx, span := trace.StartSpan(ctx, "authorize.dataBrokerClient.Sync") defer span.End() stream, err := a.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{ ServerVersion: serverVersion, RecordVersion: recordVersion, Type: typeURL, }) if err != nil { return err } for { res, err := stream.Recv() if err == io.EOF { return nil } else if err != nil { return err } backoff.Reset() if res.GetServerVersion() != serverVersion { log.Info(). Str("old_version", serverVersion). Str("new_version", res.GetServerVersion()). Str("type_url", typeURL). Msg("detected new server version, clearing data") serverVersion = res.GetServerVersion() recordVersion = "" a.clearRecords(typeURL) } for _, record := range res.GetRecords() { if record.GetVersion() > recordVersion { recordVersion = record.GetVersion() } } for _, record := range res.GetRecords() { a.updateRecord(record) } } }) } func (a *Authorize) clearRecords(typeURL string) { a.store.ClearRecords(typeURL) a.dataBrokerDataLock.Lock() a.dataBrokerData.Clear(typeURL) a.dataBrokerDataLock.Unlock() } func (a *Authorize) updateRecord(record *databroker.Record) { a.store.UpdateRecord(record) a.dataBrokerDataLock.Lock() a.dataBrokerData.Update(record) a.dataBrokerDataLock.Unlock() } func tryForever(ctx context.Context, callback func(onSuccess interface{ Reset() }) error) error { backoff := backoff.NewExponentialBackOff() for { err := callback(backoff) if err != nil { log.Warn().Err(err).Msg("sync error") } select { case <-ctx.Done(): return ctx.Err() case <-time.After(backoff.NextBackOff()): } } }