From a7bd2caae9f9414cbb3331f95630ad981583c964 Mon Sep 17 00:00:00 2001 From: Cuong Manh Le Date: Mon, 27 Jul 2020 21:10:47 +0700 Subject: [PATCH] pkg/storage: introduce storage.Backend Watch method (#1135) Currently, we're doing "sync" in databroker server. If we're going to support multiple databroker servers instance, this mechanism won't work. This commit moves the "sync" to storage backend, by adding new Watch method. The Watch method will return a channel for the caller. Everytime something happens inside the storage, we notify the caller by sending a message to this channel. --- internal/databroker/config_source_test.go | 6 +- internal/databroker/helper_no_redis.go | 7 ++ internal/databroker/helper_redis.go | 17 ++++ internal/databroker/server.go | 69 ++++++++-------- internal/databroker/server_test.go | 5 +- internal/{databroker => signal}/signal.go | 11 ++- pkg/storage/inmemory/inmemory.go | 18 +++++ pkg/storage/redis/redis.go | 96 ++++++++++++++++++++++- pkg/storage/redis/redis_test.go | 14 +++- pkg/storage/storage.go | 5 ++ 10 files changed, 204 insertions(+), 44 deletions(-) create mode 100644 internal/databroker/helper_no_redis.go create mode 100644 internal/databroker/helper_redis.go rename internal/{databroker => signal}/signal.go (83%) diff --git a/internal/databroker/config_source_test.go b/internal/databroker/config_source_test.go index 8f33c24bb..c1861041d 100644 --- a/internal/databroker/config_source_test.go +++ b/internal/databroker/config_source_test.go @@ -26,9 +26,9 @@ func TestConfigSource(t *testing.T) { } defer li.Close() - db := New() + dataBrokerServer := newTestServer() srv := grpc.NewServer() - databroker.RegisterDataBrokerServiceServer(srv, db) + databroker.RegisterDataBrokerServiceServer(srv, dataBrokerServer) go func() { _ = srv.Serve(li) }() cfgs := make(chan *config.Config, 10) @@ -52,7 +52,7 @@ func TestConfigSource(t *testing.T) { }, }, }) - _, _ = db.Set(ctx, &databroker.SetRequest{ + _, _ = dataBrokerServer.Set(ctx, &databroker.SetRequest{ Type: configTypeURL, Id: "1", Data: data, diff --git a/internal/databroker/helper_no_redis.go b/internal/databroker/helper_no_redis.go new file mode 100644 index 000000000..04d422b83 --- /dev/null +++ b/internal/databroker/helper_no_redis.go @@ -0,0 +1,7 @@ +// +build !redis + +package databroker + +func newTestServer() *Server { + return New() +} diff --git a/internal/databroker/helper_redis.go b/internal/databroker/helper_redis.go new file mode 100644 index 000000000..20536600b --- /dev/null +++ b/internal/databroker/helper_redis.go @@ -0,0 +1,17 @@ +// +build redis + +package databroker + +import ( + "os" + + "github.com/pomerium/pomerium/pkg/storage/redis" +) + +func newTestServer() *Server { + address := ":6379" + if redisURL := os.Getenv("REDIS_URL"); redisURL != "" { + address = redisURL + } + return New(WithStorageType(redis.Name), WithStorageConnectionString(address)) +} diff --git a/internal/databroker/server.go b/internal/databroker/server.go index 3e1356765..e243ac42b 100644 --- a/internal/databroker/server.go +++ b/internal/databroker/server.go @@ -18,6 +18,7 @@ import ( "google.golang.org/protobuf/types/known/emptypb" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/signal" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" @@ -36,9 +37,9 @@ type Server struct { cfg *serverConfig log zerolog.Logger - mu sync.RWMutex - byType map[string]storage.Backend - onchange *Signal + mu sync.RWMutex + byType map[string]storage.Backend + onTypechange *signal.Signal } // New creates a new server. @@ -49,8 +50,8 @@ func New(options ...ServerOption) *Server { cfg: cfg, log: log.With().Str("service", "databroker").Logger(), - byType: make(map[string]storage.Backend), - onchange: NewSignal(), + byType: make(map[string]storage.Backend), + onTypechange: signal.New(), } srv.initVersion() @@ -110,8 +111,6 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (* Str("id", req.GetId()). Msg("delete") - defer srv.onchange.Broadcast() - db, err := srv.getDB(req.GetType()) if err != nil { return nil, err @@ -182,8 +181,6 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr Str("id", req.GetId()). Msg("set") - defer srv.onchange.Broadcast() - db, err := srv.getDB(req.GetType()) if err != nil { return nil, err @@ -201,6 +198,24 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr }, nil } +func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage.Backend, stream databroker.DataBrokerService_SyncServer) error { + updated, err := db.List(ctx, *recordVersion) + if err != nil { + return err + } + if len(updated) == 0 { + return nil + } + sort.Slice(updated, func(i, j int) bool { + return updated[i].Version < updated[j].Version + }) + *recordVersion = updated[len(updated)-1].Version + return stream.Send(&databroker.SyncResponse{ + ServerVersion: srv.version, + Records: updated, + }) +} + // Sync streams updates for the given record type. func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error { _, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync") @@ -222,30 +237,20 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke return err } - ch := srv.onchange.Bind() - defer srv.onchange.Unbind(ch) - for { - updated, _ := db.List(context.Background(), recordVersion) - if len(updated) > 0 { - sort.Slice(updated, func(i, j int) bool { - return updated[i].Version < updated[j].Version - }) - recordVersion = updated[len(updated)-1].Version - err := stream.Send(&databroker.SyncResponse{ - ServerVersion: srv.version, - Records: updated, - }) - if err != nil { - return err - } - } + ctx := stream.Context() + ch := db.Watch(ctx) - select { - case <-stream.Context().Done(): - return stream.Context().Err() - case <-ch: + // Do first sync, so we won't missed anything. + if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil { + return err + } + + for range ch { + if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil { + return err } } + return nil } // GetTypes returns all the known record types. @@ -272,8 +277,8 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer srv.log.Info(). Msg("sync types") - ch := srv.onchange.Bind() - defer srv.onchange.Unbind(ch) + ch := srv.onTypechange.Bind() + defer srv.onTypechange.Unbind(ch) var prev []string for { diff --git a/internal/databroker/server_test.go b/internal/databroker/server_test.go index 8a013e2f7..e57effc08 100644 --- a/internal/databroker/server_test.go +++ b/internal/databroker/server_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/signal" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" ) @@ -20,8 +21,8 @@ func newServer(cfg *serverConfig) *Server { cfg: cfg, log: log.With().Str("service", "databroker").Logger(), - byType: make(map[string]storage.Backend), - onchange: NewSignal(), + byType: make(map[string]storage.Backend), + onTypechange: signal.New(), } } diff --git a/internal/databroker/signal.go b/internal/signal/signal.go similarity index 83% rename from internal/databroker/signal.go rename to internal/signal/signal.go index 14436871a..f460a3725 100644 --- a/internal/databroker/signal.go +++ b/internal/signal/signal.go @@ -1,6 +1,9 @@ -package databroker +// Package signal provides mechanism for notifying multiple listeners when something happened. +package signal -import "sync" +import ( + "sync" +) // A Signal is used to let multiple listeners know when something happened. type Signal struct { @@ -8,8 +11,8 @@ type Signal struct { chs map[chan struct{}]struct{} } -// NewSignal creates a new Signal. -func NewSignal() *Signal { +// New creates a new Signal. +func New() *Signal { return &Signal{ chs: make(map[chan struct{}]struct{}), } diff --git a/pkg/storage/inmemory/inmemory.go b/pkg/storage/inmemory/inmemory.go index 76010a354..f4a043ff4 100644 --- a/pkg/storage/inmemory/inmemory.go +++ b/pkg/storage/inmemory/inmemory.go @@ -14,6 +14,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/anypb" + "github.com/pomerium/pomerium/internal/signal" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" ) @@ -49,14 +50,17 @@ type DB struct { byID *btree.BTree byVersion *btree.BTree deletedIDs []string + onchange *signal.Signal } // NewDB creates a new in-memory database for the given record type. func NewDB(recordType string, btreeDegree int) *DB { + s := signal.New() return &DB{ recordType: recordType, byID: btree.New(btreeDegree), byVersion: btree.New(btreeDegree), + onchange: s, } } @@ -81,6 +85,7 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) { // Delete marks a record as deleted. func (db *DB) Delete(_ context.Context, id string) error { + defer db.onchange.Broadcast() db.replaceOrInsert(id, func(record *databroker.Record) { record.DeletedAt = ptypes.TimestampNow() db.deletedIDs = append(db.deletedIDs, id) @@ -122,12 +127,25 @@ func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record // Put replaces or inserts a record in the db. func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error { + defer db.onchange.Broadcast() db.replaceOrInsert(id, func(record *databroker.Record) { record.Data = data }) return nil } +// Watch returns the underlying signal.Signal binding channel to the caller. +// Then the caller can listen to the channel for detecting changes. +func (db *DB) Watch(ctx context.Context) chan struct{} { + ch := db.onchange.Bind() + go func() { + <-ctx.Done() + close(ch) + db.onchange.Unbind(ch) + }() + return ch +} + func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) { db.mu.Lock() defer db.mu.Unlock() diff --git a/pkg/storage/redis/redis.go b/pkg/storage/redis/redis.go index e266a58c7..bcb879938 100644 --- a/pkg/storage/redis/redis.go +++ b/pkg/storage/redis/redis.go @@ -4,20 +4,24 @@ package redis import ( "context" "fmt" + "net" "strconv" "time" + "github.com/cenkalti/backoff/v4" "github.com/golang/protobuf/proto" "github.com/golang/protobuf/ptypes" "github.com/gomodule/redigo/redis" "google.golang.org/protobuf/types/known/anypb" + "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" ) // Name is the storage type name for redis backend. const Name = "redis" +const watchAction = "zadd" var _ storage.Backend = (*DB)(nil) @@ -58,8 +62,8 @@ func New(address, recordType string, deletePermanentAfter int64) (*DB, error) { }, deletePermanentlyAfter: deletePermanentAfter, recordType: recordType, - versionSet: "version_set", - deletedSet: "deleted_set", + versionSet: recordType + "_version_set", + deletedSet: recordType + "_deleted_set", lastVersionKey: recordType + "_last_version", } return db, nil @@ -208,6 +212,94 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) { } } +// doNotify receives event from redis and signal the channel that something happenned. +func doNotify(ctx context.Context, psc *redis.PubSubConn, ch chan struct{}) error { + switch v := psc.ReceiveWithTimeout(time.Second).(type) { + case redis.Message: + log.Debug().Str("action", string(v.Data)).Msg("got redis message") + if string(v.Data) != watchAction { + return nil + } + select { + case <-ctx.Done(): + log.Warn().Err(ctx.Err()).Msg("unable to notify channel") + return ctx.Err() + case ch <- struct{}{}: + } + case error: + log.Debug().Err(v).Msg("redis subscribe error") + return v + } + return nil +} + +// doNotifyLoop tries to run doNotify forever. +// +// Because redis.PubSubConn does not support context, so it will block until it receives event, we can not use +// context to signal it stops. We mitigate this case by using PubSubConn.ReceiveWithTimeout. In case of timeout +// occurred, we return a nil error, so the caller of doNotifyLoop will re-create new connection to start new loop. +func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}, psc *redis.PubSubConn, eb *backoff.ExponentialBackOff) error { + for { + err, ok := doNotify(ctx, psc, ch).(net.Error) + if !ok && err != nil { + log.Error().Err(ctx.Err()).Msg("failed to notify channel") + return err + } + if ok && err.Timeout() { + select { + case <-ctx.Done(): + return ctx.Err() + case <-time.After(eb.NextBackOff()): + } + return nil + } + } +} + +// watchLoop runs the doNotifyLoop forever. +// +// If doNotifyLoop returns a nil error, watchLoop re-create the PubSubConn and start new iteration. +func (db *DB) watchLoop(ctx context.Context, ch chan struct{}) { + var psConn redis.Conn + eb := backoff.NewExponentialBackOff() + for { + psConn = db.pool.Get() + psc := redis.PubSubConn{Conn: psConn} + if err := psc.PSubscribe("__keyspace*__:" + db.versionSet); err != nil { + log.Error().Err(err).Msg("failed to subscribe to version set channel") + psConn.Close() + return + } + if err := db.doNotifyLoop(ctx, ch, &psc, eb); err != nil { + psConn.Close() + return + } + } +} + +// Watch returns a channel to the caller, when there is a change to the version set, +// sending message to the channel to notify the caller. +func (db *DB) Watch(ctx context.Context) chan struct{} { + ch := make(chan struct{}) + go func() { + c := db.pool.Get() + defer func() { + close(ch) + }() + + // Setup notifications, we only care about changes to db.version_set. + if _, err := c.Do("CONFIG", "SET", "notify-keyspace-events", "Kz"); err != nil { + log.Error().Err(err).Msg("failed to setup redis notification") + c.Close() + return + } + c.Close() + db.watchLoop(ctx, ch) + }() + + return ch +} + func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) ([]*databroker.Record, error) { c := db.pool.Get() defer c.Close() diff --git a/pkg/storage/redis/redis_test.go b/pkg/storage/redis/redis_test.go index c70ce8bad..8c8e66b4c 100644 --- a/pkg/storage/redis/redis_test.go +++ b/pkg/storage/redis/redis_test.go @@ -25,7 +25,8 @@ func cleanup(c redis.Conn, db *DB, t *testing.T) { } func TestDB(t *testing.T) { - ctx := context.Background() + ctx, cancelFunc := context.WithCancel(context.Background()) + defer cancelFunc() address := ":6379" if redisURL := os.Getenv("REDIS_URL"); redisURL != "" { address = redisURL @@ -41,6 +42,8 @@ func TestDB(t *testing.T) { _, err = c.Do("DEL", db.lastVersionKey) require.NoError(t, err) + ch := db.Watch(ctx) + t.Run("get missing record", func(t *testing.T) { record, err := db.Get(ctx, id) assert.Error(t, err) @@ -109,4 +112,13 @@ func TestDB(t *testing.T) { assert.NoError(t, err) assert.Len(t, records, 0) }) + + expectedNumEvents := 14 + actualNumEvents := 0 + for range ch { + actualNumEvents++ + if actualNumEvents == expectedNumEvents { + cancelFunc() + } + } } diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index 97e7b2d99..a05fff749 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -29,4 +29,9 @@ type Backend interface { // ClearDeleted is used clear marked delete records. ClearDeleted(ctx context.Context, cutoff time.Time) + + // Watch returns a channel to the caller. The channel is used to notify + // about changes that happen in storage. When ctx is finished, Watch will close + // the channel. + Watch(ctx context.Context) chan struct{} }