pkg/storage: introduce storage.Backend Watch method (#1135)

Currently, we're doing "sync" in databroker server. If we're going to
support multiple databroker servers instance, this mechanism won't work.

This commit moves the "sync" to storage backend, by adding new Watch
method. The Watch method will return a channel for the caller. Everytime
something happens inside the storage, we notify the caller by sending a
message to this channel.
This commit is contained in:
Cuong Manh Le 2020-07-27 21:10:47 +07:00 committed by GitHub
parent d9711c8055
commit a7bd2caae9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 204 additions and 44 deletions

View file

@ -26,9 +26,9 @@ func TestConfigSource(t *testing.T) {
} }
defer li.Close() defer li.Close()
db := New() dataBrokerServer := newTestServer()
srv := grpc.NewServer() srv := grpc.NewServer()
databroker.RegisterDataBrokerServiceServer(srv, db) databroker.RegisterDataBrokerServiceServer(srv, dataBrokerServer)
go func() { _ = srv.Serve(li) }() go func() { _ = srv.Serve(li) }()
cfgs := make(chan *config.Config, 10) cfgs := make(chan *config.Config, 10)
@ -52,7 +52,7 @@ func TestConfigSource(t *testing.T) {
}, },
}, },
}) })
_, _ = db.Set(ctx, &databroker.SetRequest{ _, _ = dataBrokerServer.Set(ctx, &databroker.SetRequest{
Type: configTypeURL, Type: configTypeURL,
Id: "1", Id: "1",
Data: data, Data: data,

View file

@ -0,0 +1,7 @@
// +build !redis
package databroker
func newTestServer() *Server {
return New()
}

View file

@ -0,0 +1,17 @@
// +build redis
package databroker
import (
"os"
"github.com/pomerium/pomerium/pkg/storage/redis"
)
func newTestServer() *Server {
address := ":6379"
if redisURL := os.Getenv("REDIS_URL"); redisURL != "" {
address = redisURL
}
return New(WithStorageType(redis.Name), WithStorageConnectionString(address))
}

View file

@ -18,6 +18,7 @@ import (
"google.golang.org/protobuf/types/known/emptypb" "google.golang.org/protobuf/types/known/emptypb"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
@ -38,7 +39,7 @@ type Server struct {
mu sync.RWMutex mu sync.RWMutex
byType map[string]storage.Backend byType map[string]storage.Backend
onchange *Signal onTypechange *signal.Signal
} }
// New creates a new server. // New creates a new server.
@ -50,7 +51,7 @@ func New(options ...ServerOption) *Server {
log: log.With().Str("service", "databroker").Logger(), log: log.With().Str("service", "databroker").Logger(),
byType: make(map[string]storage.Backend), byType: make(map[string]storage.Backend),
onchange: NewSignal(), onTypechange: signal.New(),
} }
srv.initVersion() srv.initVersion()
@ -110,8 +111,6 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*
Str("id", req.GetId()). Str("id", req.GetId()).
Msg("delete") Msg("delete")
defer srv.onchange.Broadcast()
db, err := srv.getDB(req.GetType()) db, err := srv.getDB(req.GetType())
if err != nil { if err != nil {
return nil, err return nil, err
@ -182,8 +181,6 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
Str("id", req.GetId()). Str("id", req.GetId()).
Msg("set") Msg("set")
defer srv.onchange.Broadcast()
db, err := srv.getDB(req.GetType()) db, err := srv.getDB(req.GetType())
if err != nil { if err != nil {
return nil, err return nil, err
@ -201,6 +198,24 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
}, nil }, nil
} }
func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage.Backend, stream databroker.DataBrokerService_SyncServer) error {
updated, err := db.List(ctx, *recordVersion)
if err != nil {
return err
}
if len(updated) == 0 {
return nil
}
sort.Slice(updated, func(i, j int) bool {
return updated[i].Version < updated[j].Version
})
*recordVersion = updated[len(updated)-1].Version
return stream.Send(&databroker.SyncResponse{
ServerVersion: srv.version,
Records: updated,
})
}
// Sync streams updates for the given record type. // Sync streams updates for the given record type.
func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error { func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error {
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync") _, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync")
@ -222,30 +237,20 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
return err return err
} }
ch := srv.onchange.Bind() ctx := stream.Context()
defer srv.onchange.Unbind(ch) ch := db.Watch(ctx)
for {
updated, _ := db.List(context.Background(), recordVersion) // Do first sync, so we won't missed anything.
if len(updated) > 0 { if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil {
sort.Slice(updated, func(i, j int) bool { return err
return updated[i].Version < updated[j].Version }
})
recordVersion = updated[len(updated)-1].Version for range ch {
err := stream.Send(&databroker.SyncResponse{ if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil {
ServerVersion: srv.version,
Records: updated,
})
if err != nil {
return err return err
} }
} }
return nil
select {
case <-stream.Context().Done():
return stream.Context().Err()
case <-ch:
}
}
} }
// GetTypes returns all the known record types. // GetTypes returns all the known record types.
@ -272,8 +277,8 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer
srv.log.Info(). srv.log.Info().
Msg("sync types") Msg("sync types")
ch := srv.onchange.Bind() ch := srv.onTypechange.Bind()
defer srv.onchange.Unbind(ch) defer srv.onTypechange.Unbind(ch)
var prev []string var prev []string
for { for {

View file

@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
) )
@ -21,7 +22,7 @@ func newServer(cfg *serverConfig) *Server {
log: log.With().Str("service", "databroker").Logger(), log: log.With().Str("service", "databroker").Logger(),
byType: make(map[string]storage.Backend), byType: make(map[string]storage.Backend),
onchange: NewSignal(), onTypechange: signal.New(),
} }
} }

View file

@ -1,6 +1,9 @@
package databroker // Package signal provides mechanism for notifying multiple listeners when something happened.
package signal
import "sync" import (
"sync"
)
// A Signal is used to let multiple listeners know when something happened. // A Signal is used to let multiple listeners know when something happened.
type Signal struct { type Signal struct {
@ -8,8 +11,8 @@ type Signal struct {
chs map[chan struct{}]struct{} chs map[chan struct{}]struct{}
} }
// NewSignal creates a new Signal. // New creates a new Signal.
func NewSignal() *Signal { func New() *Signal {
return &Signal{ return &Signal{
chs: make(map[chan struct{}]struct{}), chs: make(map[chan struct{}]struct{}),
} }

View file

@ -14,6 +14,7 @@ import (
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
"github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
) )
@ -49,14 +50,17 @@ type DB struct {
byID *btree.BTree byID *btree.BTree
byVersion *btree.BTree byVersion *btree.BTree
deletedIDs []string deletedIDs []string
onchange *signal.Signal
} }
// NewDB creates a new in-memory database for the given record type. // NewDB creates a new in-memory database for the given record type.
func NewDB(recordType string, btreeDegree int) *DB { func NewDB(recordType string, btreeDegree int) *DB {
s := signal.New()
return &DB{ return &DB{
recordType: recordType, recordType: recordType,
byID: btree.New(btreeDegree), byID: btree.New(btreeDegree),
byVersion: btree.New(btreeDegree), byVersion: btree.New(btreeDegree),
onchange: s,
} }
} }
@ -81,6 +85,7 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
// Delete marks a record as deleted. // Delete marks a record as deleted.
func (db *DB) Delete(_ context.Context, id string) error { func (db *DB) Delete(_ context.Context, id string) error {
defer db.onchange.Broadcast()
db.replaceOrInsert(id, func(record *databroker.Record) { db.replaceOrInsert(id, func(record *databroker.Record) {
record.DeletedAt = ptypes.TimestampNow() record.DeletedAt = ptypes.TimestampNow()
db.deletedIDs = append(db.deletedIDs, id) db.deletedIDs = append(db.deletedIDs, id)
@ -122,12 +127,25 @@ func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record
// Put replaces or inserts a record in the db. // Put replaces or inserts a record in the db.
func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error { func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
defer db.onchange.Broadcast()
db.replaceOrInsert(id, func(record *databroker.Record) { db.replaceOrInsert(id, func(record *databroker.Record) {
record.Data = data record.Data = data
}) })
return nil return nil
} }
// Watch returns the underlying signal.Signal binding channel to the caller.
// Then the caller can listen to the channel for detecting changes.
func (db *DB) Watch(ctx context.Context) chan struct{} {
ch := db.onchange.Bind()
go func() {
<-ctx.Done()
close(ch)
db.onchange.Unbind(ch)
}()
return ch
}
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) { func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
db.mu.Lock() db.mu.Lock()
defer db.mu.Unlock() defer db.mu.Unlock()

View file

@ -4,20 +4,24 @@ package redis
import ( import (
"context" "context"
"fmt" "fmt"
"net"
"strconv" "strconv"
"time" "time"
"github.com/cenkalti/backoff/v4"
"github.com/golang/protobuf/proto" "github.com/golang/protobuf/proto"
"github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes"
"github.com/gomodule/redigo/redis" "github.com/gomodule/redigo/redis"
"google.golang.org/protobuf/types/known/anypb" "google.golang.org/protobuf/types/known/anypb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage"
) )
// Name is the storage type name for redis backend. // Name is the storage type name for redis backend.
const Name = "redis" const Name = "redis"
const watchAction = "zadd"
var _ storage.Backend = (*DB)(nil) var _ storage.Backend = (*DB)(nil)
@ -58,8 +62,8 @@ func New(address, recordType string, deletePermanentAfter int64) (*DB, error) {
}, },
deletePermanentlyAfter: deletePermanentAfter, deletePermanentlyAfter: deletePermanentAfter,
recordType: recordType, recordType: recordType,
versionSet: "version_set", versionSet: recordType + "_version_set",
deletedSet: "deleted_set", deletedSet: recordType + "_deleted_set",
lastVersionKey: recordType + "_last_version", lastVersionKey: recordType + "_last_version",
} }
return db, nil return db, nil
@ -208,6 +212,94 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
} }
} }
// doNotify receives event from redis and signal the channel that something happenned.
func doNotify(ctx context.Context, psc *redis.PubSubConn, ch chan struct{}) error {
switch v := psc.ReceiveWithTimeout(time.Second).(type) {
case redis.Message:
log.Debug().Str("action", string(v.Data)).Msg("got redis message")
if string(v.Data) != watchAction {
return nil
}
select {
case <-ctx.Done():
log.Warn().Err(ctx.Err()).Msg("unable to notify channel")
return ctx.Err()
case ch <- struct{}{}:
}
case error:
log.Debug().Err(v).Msg("redis subscribe error")
return v
}
return nil
}
// doNotifyLoop tries to run doNotify forever.
//
// Because redis.PubSubConn does not support context, so it will block until it receives event, we can not use
// context to signal it stops. We mitigate this case by using PubSubConn.ReceiveWithTimeout. In case of timeout
// occurred, we return a nil error, so the caller of doNotifyLoop will re-create new connection to start new loop.
func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}, psc *redis.PubSubConn, eb *backoff.ExponentialBackOff) error {
for {
err, ok := doNotify(ctx, psc, ch).(net.Error)
if !ok && err != nil {
log.Error().Err(ctx.Err()).Msg("failed to notify channel")
return err
}
if ok && err.Timeout() {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(eb.NextBackOff()):
}
return nil
}
}
}
// watchLoop runs the doNotifyLoop forever.
//
// If doNotifyLoop returns a nil error, watchLoop re-create the PubSubConn and start new iteration.
func (db *DB) watchLoop(ctx context.Context, ch chan struct{}) {
var psConn redis.Conn
eb := backoff.NewExponentialBackOff()
for {
psConn = db.pool.Get()
psc := redis.PubSubConn{Conn: psConn}
if err := psc.PSubscribe("__keyspace*__:" + db.versionSet); err != nil {
log.Error().Err(err).Msg("failed to subscribe to version set channel")
psConn.Close()
return
}
if err := db.doNotifyLoop(ctx, ch, &psc, eb); err != nil {
psConn.Close()
return
}
}
}
// Watch returns a channel to the caller, when there is a change to the version set,
// sending message to the channel to notify the caller.
func (db *DB) Watch(ctx context.Context) chan struct{} {
ch := make(chan struct{})
go func() {
c := db.pool.Get()
defer func() {
close(ch)
}()
// Setup notifications, we only care about changes to db.version_set.
if _, err := c.Do("CONFIG", "SET", "notify-keyspace-events", "Kz"); err != nil {
log.Error().Err(err).Msg("failed to setup redis notification")
c.Close()
return
}
c.Close()
db.watchLoop(ctx, ch)
}()
return ch
}
func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) ([]*databroker.Record, error) { func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) ([]*databroker.Record, error) {
c := db.pool.Get() c := db.pool.Get()
defer c.Close() defer c.Close()

View file

@ -25,7 +25,8 @@ func cleanup(c redis.Conn, db *DB, t *testing.T) {
} }
func TestDB(t *testing.T) { func TestDB(t *testing.T) {
ctx := context.Background() ctx, cancelFunc := context.WithCancel(context.Background())
defer cancelFunc()
address := ":6379" address := ":6379"
if redisURL := os.Getenv("REDIS_URL"); redisURL != "" { if redisURL := os.Getenv("REDIS_URL"); redisURL != "" {
address = redisURL address = redisURL
@ -41,6 +42,8 @@ func TestDB(t *testing.T) {
_, err = c.Do("DEL", db.lastVersionKey) _, err = c.Do("DEL", db.lastVersionKey)
require.NoError(t, err) require.NoError(t, err)
ch := db.Watch(ctx)
t.Run("get missing record", func(t *testing.T) { t.Run("get missing record", func(t *testing.T) {
record, err := db.Get(ctx, id) record, err := db.Get(ctx, id)
assert.Error(t, err) assert.Error(t, err)
@ -109,4 +112,13 @@ func TestDB(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.Len(t, records, 0) assert.Len(t, records, 0)
}) })
expectedNumEvents := 14
actualNumEvents := 0
for range ch {
actualNumEvents++
if actualNumEvents == expectedNumEvents {
cancelFunc()
}
}
} }

View file

@ -29,4 +29,9 @@ type Backend interface {
// ClearDeleted is used clear marked delete records. // ClearDeleted is used clear marked delete records.
ClearDeleted(ctx context.Context, cutoff time.Time) ClearDeleted(ctx context.Context, cutoff time.Time)
// Watch returns a channel to the caller. The channel is used to notify
// about changes that happen in storage. When ctx is finished, Watch will close
// the channel.
Watch(ctx context.Context) chan struct{}
} }