diff --git a/internal/databroker/config_source_test.go b/internal/databroker/config_source_test.go index 8f33c24bb..c1861041d 100644 --- a/internal/databroker/config_source_test.go +++ b/internal/databroker/config_source_test.go @@ -26,9 +26,9 @@ func TestConfigSource(t *testing.T) { } defer li.Close() - db := New() + dataBrokerServer := newTestServer() srv := grpc.NewServer() - databroker.RegisterDataBrokerServiceServer(srv, db) + databroker.RegisterDataBrokerServiceServer(srv, dataBrokerServer) go func() { _ = srv.Serve(li) }() cfgs := make(chan *config.Config, 10) @@ -52,7 +52,7 @@ func TestConfigSource(t *testing.T) { }, }, }) - _, _ = db.Set(ctx, &databroker.SetRequest{ + _, _ = dataBrokerServer.Set(ctx, &databroker.SetRequest{ Type: configTypeURL, Id: "1", Data: data, diff --git a/internal/databroker/helper_no_redis.go b/internal/databroker/helper_no_redis.go new file mode 100644 index 000000000..04d422b83 --- /dev/null +++ b/internal/databroker/helper_no_redis.go @@ -0,0 +1,7 @@ +// +build !redis + +package databroker + +func newTestServer() *Server { + return New() +} diff --git a/internal/databroker/helper_redis.go b/internal/databroker/helper_redis.go new file mode 100644 index 000000000..20536600b --- /dev/null +++ b/internal/databroker/helper_redis.go @@ -0,0 +1,17 @@ +// +build redis + +package databroker + +import ( + "os" + + "github.com/pomerium/pomerium/pkg/storage/redis" +) + +func newTestServer() *Server { + address := ":6379" + if redisURL := os.Getenv("REDIS_URL"); redisURL != "" { + address = redisURL + } + return New(WithStorageType(redis.Name), WithStorageConnectionString(address)) +} diff --git a/internal/databroker/server.go b/internal/databroker/server.go index 3e1356765..e243ac42b 100644 --- a/internal/databroker/server.go +++ b/internal/databroker/server.go @@ -18,6 +18,7 @@ import ( "google.golang.org/protobuf/types/known/emptypb" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/signal" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" @@ -36,9 +37,9 @@ type Server struct { cfg *serverConfig log zerolog.Logger - mu sync.RWMutex - byType map[string]storage.Backend - onchange *Signal + mu sync.RWMutex + byType map[string]storage.Backend + onTypechange *signal.Signal } // New creates a new server. @@ -49,8 +50,8 @@ func New(options ...ServerOption) *Server { cfg: cfg, log: log.With().Str("service", "databroker").Logger(), - byType: make(map[string]storage.Backend), - onchange: NewSignal(), + byType: make(map[string]storage.Backend), + onTypechange: signal.New(), } srv.initVersion() @@ -110,8 +111,6 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (* Str("id", req.GetId()). Msg("delete") - defer srv.onchange.Broadcast() - db, err := srv.getDB(req.GetType()) if err != nil { return nil, err @@ -182,8 +181,6 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr Str("id", req.GetId()). Msg("set") - defer srv.onchange.Broadcast() - db, err := srv.getDB(req.GetType()) if err != nil { return nil, err @@ -201,6 +198,24 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr }, nil } +func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage.Backend, stream databroker.DataBrokerService_SyncServer) error { + updated, err := db.List(ctx, *recordVersion) + if err != nil { + return err + } + if len(updated) == 0 { + return nil + } + sort.Slice(updated, func(i, j int) bool { + return updated[i].Version < updated[j].Version + }) + *recordVersion = updated[len(updated)-1].Version + return stream.Send(&databroker.SyncResponse{ + ServerVersion: srv.version, + Records: updated, + }) +} + // Sync streams updates for the given record type. func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error { _, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync") @@ -222,30 +237,20 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke return err } - ch := srv.onchange.Bind() - defer srv.onchange.Unbind(ch) - for { - updated, _ := db.List(context.Background(), recordVersion) - if len(updated) > 0 { - sort.Slice(updated, func(i, j int) bool { - return updated[i].Version < updated[j].Version - }) - recordVersion = updated[len(updated)-1].Version - err := stream.Send(&databroker.SyncResponse{ - ServerVersion: srv.version, - Records: updated, - }) - if err != nil { - return err - } - } + ctx := stream.Context() + ch := db.Watch(ctx) - select { - case <-stream.Context().Done(): - return stream.Context().Err() - case <-ch: + // Do first sync, so we won't missed anything. + if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil { + return err + } + + for range ch { + if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil { + return err } } + return nil } // GetTypes returns all the known record types. @@ -272,8 +277,8 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer srv.log.Info(). Msg("sync types") - ch := srv.onchange.Bind() - defer srv.onchange.Unbind(ch) + ch := srv.onTypechange.Bind() + defer srv.onTypechange.Unbind(ch) var prev []string for { diff --git a/internal/databroker/server_test.go b/internal/databroker/server_test.go index 8a013e2f7..e57effc08 100644 --- a/internal/databroker/server_test.go +++ b/internal/databroker/server_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/signal" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" ) @@ -20,8 +21,8 @@ func newServer(cfg *serverConfig) *Server { cfg: cfg, log: log.With().Str("service", "databroker").Logger(), - byType: make(map[string]storage.Backend), - onchange: NewSignal(), + byType: make(map[string]storage.Backend), + onTypechange: signal.New(), } } diff --git a/internal/databroker/signal.go b/internal/signal/signal.go similarity index 83% rename from internal/databroker/signal.go rename to internal/signal/signal.go index 14436871a..f460a3725 100644 --- a/internal/databroker/signal.go +++ b/internal/signal/signal.go @@ -1,6 +1,9 @@ -package databroker +// Package signal provides mechanism for notifying multiple listeners when something happened. +package signal -import "sync" +import ( + "sync" +) // A Signal is used to let multiple listeners know when something happened. type Signal struct { @@ -8,8 +11,8 @@ type Signal struct { chs map[chan struct{}]struct{} } -// NewSignal creates a new Signal. -func NewSignal() *Signal { +// New creates a new Signal. +func New() *Signal { return &Signal{ chs: make(map[chan struct{}]struct{}), } diff --git a/pkg/storage/inmemory/inmemory.go b/pkg/storage/inmemory/inmemory.go index 76010a354..f4a043ff4 100644 --- a/pkg/storage/inmemory/inmemory.go +++ b/pkg/storage/inmemory/inmemory.go @@ -14,6 +14,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "github.com/pomerium/pomerium/internal/signal" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" ) @@ -49,14 +50,17 @@ type DB struct { byID *btree.BTree byVersion *btree.BTree deletedIDs []string + onchange *signal.Signal } // NewDB creates a new in-memory database for the given record type. func NewDB(recordType string, btreeDegree int) *DB { + s := signal.New() return &DB{ recordType: recordType, byID: btree.New(btreeDegree), byVersion: btree.New(btreeDegree), + onchange: s, } } @@ -81,6 +85,7 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) { // Delete marks a record as deleted. func (db *DB) Delete(_ context.Context, id string) error { + defer db.onchange.Broadcast() db.replaceOrInsert(id, func(record *databroker.Record) { record.DeletedAt = ptypes.TimestampNow() db.deletedIDs = append(db.deletedIDs, id) @@ -122,12 +127,25 @@ func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record // Put replaces or inserts a record in the db. func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error { + defer db.onchange.Broadcast() db.replaceOrInsert(id, func(record *databroker.Record) { record.Data = data }) return nil } +// Watch returns the underlying signal.Signal binding channel to the caller. +// Then the caller can listen to the channel for detecting changes. +func (db *DB) Watch(ctx context.Context) chan struct{} { + ch := db.onchange.Bind() + go func() { + <-ctx.Done() + close(ch) + db.onchange.Unbind(ch) + }() + return ch +} + func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) { db.mu.Lock() defer db.mu.Unlock() diff --git a/pkg/storage/redis/redis.go b/pkg/storage/redis/redis.go index e266a58c7..bcb879938 100644 --- a/pkg/storage/redis/redis.go +++ b/pkg/storage/redis/redis.go @@ -4,20 +4,24 @@ package redis import ( "context" "fmt" + "net" "strconv" "time" + "github.com/cenkalti/backoff/v4" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/gomodule/redigo/redis" "google.golang.org/protobuf/types/known/anypb" + "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" ) // Name is the storage type name for redis backend. const Name = "redis" +const watchAction = "zadd" var _ storage.Backend = (*DB)(nil) @@ -58,8 +62,8 @@ func New(address, recordType string, deletePermanentAfter int64) (*DB, error) { }, deletePermanentlyAfter: deletePermanentAfter, recordType: recordType, - versionSet: "version_set", - deletedSet: "deleted_set", + versionSet: recordType + "_version_set", + deletedSet: recordType + "_deleted_set", lastVersionKey: recordType + "_last_version", } return db, nil @@ -208,6 +212,94 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) { } } +// doNotify receives event from redis and signal the channel that something happenned. +func doNotify(ctx context.Context, psc *redis.PubSubConn, ch chan struct{}) error { + switch v := psc.ReceiveWithTimeout(time.Second).(type) { + case redis.Message: + log.Debug().Str("action", string(v.Data)).Msg("got redis message") + if string(v.Data) != watchAction { + return nil + } + select { + case <-ctx.Done(): + log.Warn().Err(ctx.Err()).Msg("unable to notify channel") + return ctx.Err() + case ch <- struct{}{}: + } + case error: + log.Debug().Err(v).Msg("redis subscribe error") + return v + } + return nil +} + +// doNotifyLoop tries to run doNotify forever. +// +// Because redis.PubSubConn does not support context, so it will block until it receives event, we can not use +// context to signal it stops. We mitigate this case by using PubSubConn.ReceiveWithTimeout. In case of timeout +// occurred, we return a nil error, so the caller of doNotifyLoop will re-create new connection to start new loop. +func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}, psc *redis.PubSubConn, eb *backoff.ExponentialBackOff) error { + for { + err, ok := doNotify(ctx, psc, ch).(net.Error) + if !ok && err != nil { + log.Error().Err(ctx.Err()).Msg("failed to notify channel") + return err + } + if ok && err.Timeout() { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(eb.NextBackOff()): + } + return nil + } + } +} + +// watchLoop runs the doNotifyLoop forever. +// +// If doNotifyLoop returns a nil error, watchLoop re-create the PubSubConn and start new iteration. +func (db *DB) watchLoop(ctx context.Context, ch chan struct{}) { + var psConn redis.Conn + eb := backoff.NewExponentialBackOff() + for { + psConn = db.pool.Get() + psc := redis.PubSubConn{Conn: psConn} + if err := psc.PSubscribe("__keyspace*__:" + db.versionSet); err != nil { + log.Error().Err(err).Msg("failed to subscribe to version set channel") + psConn.Close() + return + } + if err := db.doNotifyLoop(ctx, ch, &psc, eb); err != nil { + psConn.Close() + return + } + } +} + +// Watch returns a channel to the caller, when there is a change to the version set, +// sending message to the channel to notify the caller. +func (db *DB) Watch(ctx context.Context) chan struct{} { + ch := make(chan struct{}) + go func() { + c := db.pool.Get() + defer func() { + close(ch) + }() + + // Setup notifications, we only care about changes to db.version_set. + if _, err := c.Do("CONFIG", "SET", "notify-keyspace-events", "Kz"); err != nil { + log.Error().Err(err).Msg("failed to setup redis notification") + c.Close() + return + } + c.Close() + db.watchLoop(ctx, ch) + }() + + return ch +} + func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) ([]*databroker.Record, error) { c := db.pool.Get() defer c.Close() diff --git a/pkg/storage/redis/redis_test.go b/pkg/storage/redis/redis_test.go index c70ce8bad..8c8e66b4c 100644 --- a/pkg/storage/redis/redis_test.go +++ b/pkg/storage/redis/redis_test.go @@ -25,7 +25,8 @@ func cleanup(c redis.Conn, db *DB, t *testing.T) { } func TestDB(t *testing.T) { - ctx := context.Background() + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() address := ":6379" if redisURL := os.Getenv("REDIS_URL"); redisURL != "" { address = redisURL @@ -41,6 +42,8 @@ func TestDB(t *testing.T) { _, err = c.Do("DEL", db.lastVersionKey) require.NoError(t, err) + ch := db.Watch(ctx) + t.Run("get missing record", func(t *testing.T) { record, err := db.Get(ctx, id) assert.Error(t, err) @@ -109,4 +112,13 @@ func TestDB(t *testing.T) { assert.NoError(t, err) assert.Len(t, records, 0) }) + + expectedNumEvents := 14 + actualNumEvents := 0 + for range ch { + actualNumEvents++ + if actualNumEvents == expectedNumEvents { + cancelFunc() + } + } } diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index 97e7b2d99..a05fff749 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -29,4 +29,9 @@ type Backend interface { // ClearDeleted is used clear marked delete records. ClearDeleted(ctx context.Context, cutoff time.Time) + + // Watch returns a channel to the caller. The channel is used to notify + // about changes that happen in storage. When ctx is finished, Watch will close + // the channel. + Watch(ctx context.Context) chan struct{} }