mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 18:36:30 +02:00
107 lines
2.1 KiB
Go
107 lines
2.1 KiB
Go
package databroker
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/pkg/slices"
|
|
)
|
|
|
|
// fastForwardHandler will skip
|
|
type fastForwardHandler struct {
|
|
handler SyncerHandler
|
|
pending chan ffCmd
|
|
|
|
mu sync.Mutex
|
|
}
|
|
|
|
type ffCmd struct {
|
|
clearRecords bool
|
|
serverVersion uint64
|
|
records []*Record
|
|
}
|
|
|
|
func newFastForwardHandler(ctx context.Context, handler SyncerHandler) SyncerHandler {
|
|
ff := &fastForwardHandler{
|
|
handler: handler,
|
|
pending: make(chan ffCmd, 1),
|
|
}
|
|
go ff.run(ctx)
|
|
return ff
|
|
}
|
|
|
|
func (ff *fastForwardHandler) run(ctx context.Context) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case cmd := <-ff.pending:
|
|
if cmd.clearRecords {
|
|
ff.handler.ClearRecords(ctx)
|
|
} else {
|
|
ff.handler.UpdateRecords(ctx, cmd.serverVersion, cmd.records)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (ff *fastForwardHandler) GetDataBrokerServiceClient() DataBrokerServiceClient {
|
|
return ff.handler.GetDataBrokerServiceClient()
|
|
}
|
|
|
|
func (ff *fastForwardHandler) ClearRecords(ctx context.Context) {
|
|
ff.mu.Lock()
|
|
defer ff.mu.Unlock()
|
|
|
|
var cmd ffCmd
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case cmd = <-ff.pending:
|
|
default:
|
|
}
|
|
cmd.clearRecords = true
|
|
cmd.records = nil
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
case ff.pending <- cmd:
|
|
}
|
|
}
|
|
|
|
func (ff *fastForwardHandler) UpdateRecords(ctx context.Context, serverVersion uint64, records []*Record) {
|
|
ff.mu.Lock()
|
|
defer ff.mu.Unlock()
|
|
|
|
var cmd ffCmd
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case cmd = <-ff.pending:
|
|
default:
|
|
}
|
|
|
|
records = append(cmd.records, records...)
|
|
// reverse, so that when we get the unique records, the newest take precedence
|
|
slices.Reverse(records)
|
|
cnt := len(records)
|
|
records = slices.UniqueBy(records, func(record *Record) [2]string {
|
|
return [2]string{record.GetType(), record.GetId()}
|
|
})
|
|
dropped := cnt - len(records)
|
|
if dropped > 0 {
|
|
log.Info(ctx).Msgf("databroker: fast-forwarded %d records", dropped)
|
|
}
|
|
// reverse back so they appear in the order they were delivered
|
|
slices.Reverse(records)
|
|
|
|
cmd.clearRecords = false
|
|
cmd.serverVersion = serverVersion
|
|
cmd.records = records
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
case ff.pending <- cmd:
|
|
}
|
|
}
|