diff --git a/internal/databroker/config_source.go b/internal/databroker/config_source.go index 3f6ec2570..68185e282 100644 --- a/internal/databroker/config_source.go +++ b/internal/databroker/config_source.go @@ -81,7 +81,7 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) { // start the updater src.runUpdater(cfg) - seen := map[uint64]struct{}{} + seen := map[uint64]string{} for _, policy := range cfg.Options.GetAllPolicies() { id, err := policy.RouteID() if err != nil { @@ -90,7 +90,7 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) { Msg("databroker: invalid policy config, ignoring") return } - seen[id] = struct{}{} + seen[id] = "" } var additionalPolicies []config.Policy @@ -145,11 +145,12 @@ func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) { errCount++ log.Warn(ctx).Err(err). Str("db_config_id", id). + Str("seen-in", seen[routeID]). Str("policy", policy.String()). Msg("databroker: duplicate policy detected, ignoring") continue } - seen[routeID] = struct{}{} + seen[routeID] = id additionalPolicies = append(additionalPolicies, *policy) } diff --git a/pkg/grpc/databroker/fast_forward.go b/pkg/grpc/databroker/fast_forward.go index 633c14e38..205074925 100644 --- a/pkg/grpc/databroker/fast_forward.go +++ b/pkg/grpc/databroker/fast_forward.go @@ -2,15 +2,18 @@ package databroker import ( "context" + "sync" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/slices" ) // fastForwardHandler will skip type fastForwardHandler struct { handler SyncerHandler - in chan *ffCmd - exec chan *ffCmd + pending chan ffCmd + + mu sync.Mutex } type ffCmd struct { @@ -22,52 +25,23 @@ type ffCmd struct { func newFastForwardHandler(ctx context.Context, handler SyncerHandler) SyncerHandler { ff := &fastForwardHandler{ handler: handler, - in: make(chan *ffCmd, 20), - exec: make(chan *ffCmd), + pending: make(chan ffCmd, 1), } - go ff.runSelect(ctx) - go ff.runExec(ctx) - + go ff.run(ctx) return ff } -func (ff *fastForwardHandler) update(ctx context.Context, c *ffCmd) { - ff.handler.UpdateRecords(ctx, c.serverVersion, c.records) -} - -func (ff *fastForwardHandler) runSelect(ctx context.Context) { - var update *ffCmd - - for { - if update == nil { - select { - case <-ctx.Done(): - return - case update = <-ff.in: - } - } else { - select { - case <-ctx.Done(): - return - case update = <-ff.in: - case ff.exec <- update: - update = nil - } - } - } -} - -func (ff *fastForwardHandler) runExec(ctx context.Context) { +func (ff *fastForwardHandler) run(ctx context.Context) { for { select { case <-ctx.Done(): return - case update := <-ff.exec: - if update.clearRecords { + case cmd := <-ff.pending: + if cmd.clearRecords { ff.handler.ClearRecords(ctx) - continue + } else { + ff.handler.UpdateRecords(ctx, cmd.serverVersion, cmd.records) } - ff.update(ctx, update) } } } @@ -77,19 +51,57 @@ func (ff *fastForwardHandler) GetDataBrokerServiceClient() DataBrokerServiceClie } func (ff *fastForwardHandler) ClearRecords(ctx context.Context) { + ff.mu.Lock() + defer ff.mu.Unlock() + + var cmd ffCmd select { case <-ctx.Done(): - log.Error(ctx). - Msg("ff_handler: ClearRecords: context canceled") - case ff.exec <- &ffCmd{clearRecords: true}: + return + case cmd = <-ff.pending: + default: + } + cmd.clearRecords = true + cmd.records = nil + + select { + case <-ctx.Done(): + case ff.pending <- cmd: } } func (ff *fastForwardHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*Record) { + ff.mu.Lock() + defer ff.mu.Unlock() + + var cmd ffCmd select { case <-ctx.Done(): - log.Error(ctx). - Msg("ff_handler: UpdateRecords: context canceled") - case ff.in <- &ffCmd{serverVersion: serverVersion, records: records}: + return + case cmd = <-ff.pending: + default: + } + + records = append(cmd.records, records...) + // reverse, so that when we get the unique records, the newest take precedence + slices.Reverse(records) + cnt := len(records) + records = slices.UniqueBy(records, func(record *Record) [2]string { + return [2]string{record.GetType(), record.GetId()} + }) + dropped := cnt - len(records) + if dropped > 0 { + log.Info(ctx).Msgf("databroker: fast-forwarded %d records", dropped) + } + // reverse back so they appear in the order they were delivered + slices.Reverse(records) + + cmd.clearRecords = false + cmd.serverVersion = serverVersion + cmd.records = records + + select { + case <-ctx.Done(): + case ff.pending <- cmd: } } diff --git a/pkg/slices/slices.go b/pkg/slices/slices.go index 3a3317505..6508caa74 100644 --- a/pkg/slices/slices.go +++ b/pkg/slices/slices.go @@ -33,6 +33,13 @@ func Remove[S ~[]E, E comparable](s S, e E) S { return ns } +// Reverse reverses a slice's order. +func Reverse[S ~[]E, E comparable](s S) { + for i := 0; i < len(s)/2; i++ { + s[i], s[len(s)-1-i] = s[len(s)-1-i], s[i] + } +} + // Unique returns the unique elements of s. func Unique[S ~[]E, E comparable](s S) S { var ns S @@ -45,3 +52,17 @@ func Unique[S ~[]E, E comparable](s S) S { } return ns } + +// UniqueBy returns the unique elements of s using a function to map elements. +func UniqueBy[S ~[]E, E any, V comparable](s S, by func(E) V) S { + var ns S + h := map[V]struct{}{} + for _, el := range s { + v := by(el) + if _, ok := h[v]; !ok { + h[v] = struct{}{} + ns = append(ns, el) + } + } + return ns +} diff --git a/pkg/slices/slices_test.go b/pkg/slices/slices_test.go new file mode 100644 index 000000000..f3b79bd68 --- /dev/null +++ b/pkg/slices/slices_test.go @@ -0,0 +1,32 @@ +package slices + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestReverse(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + in []int + expect []int + }{ + {in: []int{1, 2, 3}, expect: []int{3, 2, 1}}, + {in: []int{1, 2}, expect: []int{2, 1}}, + {in: []int{1}, expect: []int{1}}, + } { + s := make([]int, len(tc.in)) + copy(s, tc.in) + Reverse(s) + assert.Equal(t, tc.expect, s) + } +} + +func TestUniqueBy(t *testing.T) { + t.Parallel() + + s := UniqueBy([]int{1, 2, 3, 4, 3, 1, 1, 4, 2}, func(i int) int { return i % 3 }) + assert.Equal(t, []int{1, 2, 3}, s) +}