mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-21 13:07:13 +02:00
pkg/storage: introduce storage.Backend Watch method (#1135)
Currently, we're doing "sync" in databroker server. If we're going to support multiple databroker servers instance, this mechanism won't work. This commit moves the "sync" to storage backend, by adding new Watch method. The Watch method will return a channel for the caller. Everytime something happens inside the storage, we notify the caller by sending a message to this channel.
This commit is contained in:
parent
d9711c8055
commit
a7bd2caae9
10 changed files with 204 additions and 44 deletions
|
@ -26,9 +26,9 @@ func TestConfigSource(t *testing.T) {
|
|||
}
|
||||
defer li.Close()
|
||||
|
||||
db := New()
|
||||
dataBrokerServer := newTestServer()
|
||||
srv := grpc.NewServer()
|
||||
databroker.RegisterDataBrokerServiceServer(srv, db)
|
||||
databroker.RegisterDataBrokerServiceServer(srv, dataBrokerServer)
|
||||
go func() { _ = srv.Serve(li) }()
|
||||
|
||||
cfgs := make(chan *config.Config, 10)
|
||||
|
@ -52,7 +52,7 @@ func TestConfigSource(t *testing.T) {
|
|||
},
|
||||
},
|
||||
})
|
||||
_, _ = db.Set(ctx, &databroker.SetRequest{
|
||||
_, _ = dataBrokerServer.Set(ctx, &databroker.SetRequest{
|
||||
Type: configTypeURL,
|
||||
Id: "1",
|
||||
Data: data,
|
||||
|
|
7
internal/databroker/helper_no_redis.go
Normal file
7
internal/databroker/helper_no_redis.go
Normal file
|
@ -0,0 +1,7 @@
|
|||
// +build !redis
|
||||
|
||||
package databroker
|
||||
|
||||
func newTestServer() *Server {
|
||||
return New()
|
||||
}
|
17
internal/databroker/helper_redis.go
Normal file
17
internal/databroker/helper_redis.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
// +build redis
|
||||
|
||||
package databroker
|
||||
|
||||
import (
|
||||
"os"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/storage/redis"
|
||||
)
|
||||
|
||||
func newTestServer() *Server {
|
||||
address := ":6379"
|
||||
if redisURL := os.Getenv("REDIS_URL"); redisURL != "" {
|
||||
address = redisURL
|
||||
}
|
||||
return New(WithStorageType(redis.Name), WithStorageConnectionString(address))
|
||||
}
|
|
@ -18,6 +18,7 @@ import (
|
|||
"google.golang.org/protobuf/types/known/emptypb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
|
@ -38,7 +39,7 @@ type Server struct {
|
|||
|
||||
mu sync.RWMutex
|
||||
byType map[string]storage.Backend
|
||||
onchange *Signal
|
||||
onTypechange *signal.Signal
|
||||
}
|
||||
|
||||
// New creates a new server.
|
||||
|
@ -50,7 +51,7 @@ func New(options ...ServerOption) *Server {
|
|||
log: log.With().Str("service", "databroker").Logger(),
|
||||
|
||||
byType: make(map[string]storage.Backend),
|
||||
onchange: NewSignal(),
|
||||
onTypechange: signal.New(),
|
||||
}
|
||||
srv.initVersion()
|
||||
|
||||
|
@ -110,8 +111,6 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*
|
|||
Str("id", req.GetId()).
|
||||
Msg("delete")
|
||||
|
||||
defer srv.onchange.Broadcast()
|
||||
|
||||
db, err := srv.getDB(req.GetType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -182,8 +181,6 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
|||
Str("id", req.GetId()).
|
||||
Msg("set")
|
||||
|
||||
defer srv.onchange.Broadcast()
|
||||
|
||||
db, err := srv.getDB(req.GetType())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -201,6 +198,24 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage.Backend, stream databroker.DataBrokerService_SyncServer) error {
|
||||
updated, err := db.List(ctx, *recordVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(updated) == 0 {
|
||||
return nil
|
||||
}
|
||||
sort.Slice(updated, func(i, j int) bool {
|
||||
return updated[i].Version < updated[j].Version
|
||||
})
|
||||
*recordVersion = updated[len(updated)-1].Version
|
||||
return stream.Send(&databroker.SyncResponse{
|
||||
ServerVersion: srv.version,
|
||||
Records: updated,
|
||||
})
|
||||
}
|
||||
|
||||
// Sync streams updates for the given record type.
|
||||
func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBrokerService_SyncServer) error {
|
||||
_, span := trace.StartSpan(stream.Context(), "databroker.grpc.Sync")
|
||||
|
@ -222,30 +237,20 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
|||
return err
|
||||
}
|
||||
|
||||
ch := srv.onchange.Bind()
|
||||
defer srv.onchange.Unbind(ch)
|
||||
for {
|
||||
updated, _ := db.List(context.Background(), recordVersion)
|
||||
if len(updated) > 0 {
|
||||
sort.Slice(updated, func(i, j int) bool {
|
||||
return updated[i].Version < updated[j].Version
|
||||
})
|
||||
recordVersion = updated[len(updated)-1].Version
|
||||
err := stream.Send(&databroker.SyncResponse{
|
||||
ServerVersion: srv.version,
|
||||
Records: updated,
|
||||
})
|
||||
if err != nil {
|
||||
ctx := stream.Context()
|
||||
ch := db.Watch(ctx)
|
||||
|
||||
// Do first sync, so we won't missed anything.
|
||||
if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for range ch {
|
||||
if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
select {
|
||||
case <-stream.Context().Done():
|
||||
return stream.Context().Err()
|
||||
case <-ch:
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTypes returns all the known record types.
|
||||
|
@ -272,8 +277,8 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer
|
|||
srv.log.Info().
|
||||
Msg("sync types")
|
||||
|
||||
ch := srv.onchange.Bind()
|
||||
defer srv.onchange.Unbind(ch)
|
||||
ch := srv.onTypechange.Bind()
|
||||
defer srv.onTypechange.Unbind(ch)
|
||||
|
||||
var prev []string
|
||||
for {
|
||||
|
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
@ -21,7 +22,7 @@ func newServer(cfg *serverConfig) *Server {
|
|||
log: log.With().Str("service", "databroker").Logger(),
|
||||
|
||||
byType: make(map[string]storage.Backend),
|
||||
onchange: NewSignal(),
|
||||
onTypechange: signal.New(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
package databroker
|
||||
// Package signal provides mechanism for notifying multiple listeners when something happened.
|
||||
package signal
|
||||
|
||||
import "sync"
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A Signal is used to let multiple listeners know when something happened.
|
||||
type Signal struct {
|
||||
|
@ -8,8 +11,8 @@ type Signal struct {
|
|||
chs map[chan struct{}]struct{}
|
||||
}
|
||||
|
||||
// NewSignal creates a new Signal.
|
||||
func NewSignal() *Signal {
|
||||
// New creates a new Signal.
|
||||
func New() *Signal {
|
||||
return &Signal{
|
||||
chs: make(map[chan struct{}]struct{}),
|
||||
}
|
|
@ -14,6 +14,7 @@ import (
|
|||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
@ -49,14 +50,17 @@ type DB struct {
|
|||
byID *btree.BTree
|
||||
byVersion *btree.BTree
|
||||
deletedIDs []string
|
||||
onchange *signal.Signal
|
||||
}
|
||||
|
||||
// NewDB creates a new in-memory database for the given record type.
|
||||
func NewDB(recordType string, btreeDegree int) *DB {
|
||||
s := signal.New()
|
||||
return &DB{
|
||||
recordType: recordType,
|
||||
byID: btree.New(btreeDegree),
|
||||
byVersion: btree.New(btreeDegree),
|
||||
onchange: s,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -81,6 +85,7 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
|
|||
|
||||
// Delete marks a record as deleted.
|
||||
func (db *DB) Delete(_ context.Context, id string) error {
|
||||
defer db.onchange.Broadcast()
|
||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||
record.DeletedAt = ptypes.TimestampNow()
|
||||
db.deletedIDs = append(db.deletedIDs, id)
|
||||
|
@ -122,12 +127,25 @@ func (db *DB) List(_ context.Context, sinceVersion string) ([]*databroker.Record
|
|||
|
||||
// Put replaces or inserts a record in the db.
|
||||
func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
|
||||
defer db.onchange.Broadcast()
|
||||
db.replaceOrInsert(id, func(record *databroker.Record) {
|
||||
record.Data = data
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Watch returns the underlying signal.Signal binding channel to the caller.
|
||||
// Then the caller can listen to the channel for detecting changes.
|
||||
func (db *DB) Watch(ctx context.Context) chan struct{} {
|
||||
ch := db.onchange.Bind()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
close(ch)
|
||||
db.onchange.Unbind(ch)
|
||||
}()
|
||||
return ch
|
||||
}
|
||||
|
||||
func (db *DB) replaceOrInsert(id string, f func(record *databroker.Record)) {
|
||||
db.mu.Lock()
|
||||
defer db.mu.Unlock()
|
||||
|
|
|
@ -4,20 +4,24 @@ package redis
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/gomodule/redigo/redis"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
// Name is the storage type name for redis backend.
|
||||
const Name = "redis"
|
||||
const watchAction = "zadd"
|
||||
|
||||
var _ storage.Backend = (*DB)(nil)
|
||||
|
||||
|
@ -58,8 +62,8 @@ func New(address, recordType string, deletePermanentAfter int64) (*DB, error) {
|
|||
},
|
||||
deletePermanentlyAfter: deletePermanentAfter,
|
||||
recordType: recordType,
|
||||
versionSet: "version_set",
|
||||
deletedSet: "deleted_set",
|
||||
versionSet: recordType + "_version_set",
|
||||
deletedSet: recordType + "_deleted_set",
|
||||
lastVersionKey: recordType + "_last_version",
|
||||
}
|
||||
return db, nil
|
||||
|
@ -208,6 +212,94 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
|
|||
}
|
||||
}
|
||||
|
||||
// doNotify receives event from redis and signal the channel that something happenned.
|
||||
func doNotify(ctx context.Context, psc *redis.PubSubConn, ch chan struct{}) error {
|
||||
switch v := psc.ReceiveWithTimeout(time.Second).(type) {
|
||||
case redis.Message:
|
||||
log.Debug().Str("action", string(v.Data)).Msg("got redis message")
|
||||
if string(v.Data) != watchAction {
|
||||
return nil
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Warn().Err(ctx.Err()).Msg("unable to notify channel")
|
||||
return ctx.Err()
|
||||
case ch <- struct{}{}:
|
||||
}
|
||||
case error:
|
||||
log.Debug().Err(v).Msg("redis subscribe error")
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// doNotifyLoop tries to run doNotify forever.
|
||||
//
|
||||
// Because redis.PubSubConn does not support context, so it will block until it receives event, we can not use
|
||||
// context to signal it stops. We mitigate this case by using PubSubConn.ReceiveWithTimeout. In case of timeout
|
||||
// occurred, we return a nil error, so the caller of doNotifyLoop will re-create new connection to start new loop.
|
||||
func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}, psc *redis.PubSubConn, eb *backoff.ExponentialBackOff) error {
|
||||
for {
|
||||
err, ok := doNotify(ctx, psc, ch).(net.Error)
|
||||
if !ok && err != nil {
|
||||
log.Error().Err(ctx.Err()).Msg("failed to notify channel")
|
||||
return err
|
||||
}
|
||||
if ok && err.Timeout() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(eb.NextBackOff()):
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// watchLoop runs the doNotifyLoop forever.
|
||||
//
|
||||
// If doNotifyLoop returns a nil error, watchLoop re-create the PubSubConn and start new iteration.
|
||||
func (db *DB) watchLoop(ctx context.Context, ch chan struct{}) {
|
||||
var psConn redis.Conn
|
||||
eb := backoff.NewExponentialBackOff()
|
||||
for {
|
||||
psConn = db.pool.Get()
|
||||
psc := redis.PubSubConn{Conn: psConn}
|
||||
if err := psc.PSubscribe("__keyspace*__:" + db.versionSet); err != nil {
|
||||
log.Error().Err(err).Msg("failed to subscribe to version set channel")
|
||||
psConn.Close()
|
||||
return
|
||||
}
|
||||
if err := db.doNotifyLoop(ctx, ch, &psc, eb); err != nil {
|
||||
psConn.Close()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Watch returns a channel to the caller, when there is a change to the version set,
|
||||
// sending message to the channel to notify the caller.
|
||||
func (db *DB) Watch(ctx context.Context) chan struct{} {
|
||||
ch := make(chan struct{})
|
||||
go func() {
|
||||
c := db.pool.Get()
|
||||
defer func() {
|
||||
close(ch)
|
||||
}()
|
||||
|
||||
// Setup notifications, we only care about changes to db.version_set.
|
||||
if _, err := c.Do("CONFIG", "SET", "notify-keyspace-events", "Kz"); err != nil {
|
||||
log.Error().Err(err).Msg("failed to setup redis notification")
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
c.Close()
|
||||
db.watchLoop(ctx, ch)
|
||||
}()
|
||||
|
||||
return ch
|
||||
}
|
||||
|
||||
func (db *DB) getAll(_ context.Context, filter func(record *databroker.Record) bool) ([]*databroker.Record, error) {
|
||||
c := db.pool.Get()
|
||||
defer c.Close()
|
||||
|
|
|
@ -25,7 +25,8 @@ func cleanup(c redis.Conn, db *DB, t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDB(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ctx, cancelFunc := context.WithCancel(context.Background())
|
||||
defer cancelFunc()
|
||||
address := ":6379"
|
||||
if redisURL := os.Getenv("REDIS_URL"); redisURL != "" {
|
||||
address = redisURL
|
||||
|
@ -41,6 +42,8 @@ func TestDB(t *testing.T) {
|
|||
_, err = c.Do("DEL", db.lastVersionKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
ch := db.Watch(ctx)
|
||||
|
||||
t.Run("get missing record", func(t *testing.T) {
|
||||
record, err := db.Get(ctx, id)
|
||||
assert.Error(t, err)
|
||||
|
@ -109,4 +112,13 @@ func TestDB(t *testing.T) {
|
|||
assert.NoError(t, err)
|
||||
assert.Len(t, records, 0)
|
||||
})
|
||||
|
||||
expectedNumEvents := 14
|
||||
actualNumEvents := 0
|
||||
for range ch {
|
||||
actualNumEvents++
|
||||
if actualNumEvents == expectedNumEvents {
|
||||
cancelFunc()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -29,4 +29,9 @@ type Backend interface {
|
|||
|
||||
// ClearDeleted is used clear marked delete records.
|
||||
ClearDeleted(ctx context.Context, cutoff time.Time)
|
||||
|
||||
// Watch returns a channel to the caller. The channel is used to notify
|
||||
// about changes that happen in storage. When ctx is finished, Watch will close
|
||||
// the channel.
|
||||
Watch(ctx context.Context) chan struct{}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue