mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-04 02:18:42 +02:00
databroker: remove redis storage backend
Remove the Redis databroker backend. According to https://www.pomerium.com/docs/internals/data-storage#redis it has been discouraged since Pomerium v0.18.
This commit is contained in:
parent
fd8cb18c44
commit
c56042f51f
13 changed files with 4 additions and 1664 deletions
|
@ -9,7 +9,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/registry"
|
"github.com/pomerium/pomerium/internal/registry"
|
||||||
"github.com/pomerium/pomerium/internal/registry/inmemory"
|
"github.com/pomerium/pomerium/internal/registry/inmemory"
|
||||||
"github.com/pomerium/pomerium/internal/registry/redis"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
@ -110,16 +109,6 @@ func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interfac
|
||||||
case config.StorageInMemoryName:
|
case config.StorageInMemoryName:
|
||||||
log.Info(ctx).Msg("using in-memory registry")
|
log.Info(ctx).Msg("using in-memory registry")
|
||||||
return inmemory.New(ctx, srv.cfg.registryTTL), nil
|
return inmemory.New(ctx, srv.cfg.registryTTL), nil
|
||||||
case config.StorageRedisName:
|
|
||||||
log.Info(ctx).Msg("using redis registry")
|
|
||||||
r, err := redis.New(
|
|
||||||
srv.cfg.storageConnectionString,
|
|
||||||
redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create new redis registry: %w", err)
|
|
||||||
}
|
|
||||||
return r, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unsupported registry type: %s", srv.cfg.storageType)
|
return nil, fmt.Errorf("unsupported registry type: %s", srv.cfg.storageType)
|
||||||
|
|
|
@ -24,7 +24,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
"github.com/pomerium/pomerium/pkg/storage/inmemory"
|
"github.com/pomerium/pomerium/pkg/storage/inmemory"
|
||||||
"github.com/pomerium/pomerium/pkg/storage/postgres"
|
"github.com/pomerium/pomerium/pkg/storage/postgres"
|
||||||
"github.com/pomerium/pomerium/pkg/storage/redis"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server implements the databroker service using an in memory database.
|
// Server implements the databroker service using an in memory database.
|
||||||
|
@ -426,21 +425,6 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
|
||||||
case config.StoragePostgresName:
|
case config.StoragePostgresName:
|
||||||
log.Info(ctx).Msg("using postgres store")
|
log.Info(ctx).Msg("using postgres store")
|
||||||
backend = postgres.New(srv.cfg.storageConnectionString)
|
backend = postgres.New(srv.cfg.storageConnectionString)
|
||||||
case config.StorageRedisName:
|
|
||||||
log.Info(ctx).Msg("using redis store")
|
|
||||||
backend, err = redis.New(
|
|
||||||
srv.cfg.storageConnectionString,
|
|
||||||
redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
|
|
||||||
}
|
|
||||||
if srv.cfg.secret != nil {
|
|
||||||
backend, err = storage.NewEncryptedBackend(srv.cfg.secret, backend)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
|
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,7 +22,6 @@ import (
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/testutil"
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
@ -287,12 +286,11 @@ func TestServerInvalidStorage(t *testing.T) {
|
||||||
_ = assert.Error(t, err) && assert.Contains(t, err.Error(), "unsupported storage type")
|
_ = assert.Error(t, err) && assert.Contains(t, err.Error(), "unsupported storage type")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerRedis(t *testing.T) {
|
func TestServerPostgres(t *testing.T) {
|
||||||
testutil.WithTestRedis(false, func(rawURL string) error {
|
testutil.WithTestPostgres(func(dsn string) error {
|
||||||
srv := newServer(&serverConfig{
|
srv := newServer(&serverConfig{
|
||||||
storageType: "redis",
|
storageType: "postgres",
|
||||||
storageConnectionString: rawURL,
|
storageConnectionString: dsn,
|
||||||
secret: cryptutil.NewKey(),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
s := new(session.Session)
|
s := new(session.Session)
|
||||||
|
|
|
@ -1,20 +0,0 @@
|
||||||
// Package lua contains lua source code.
|
|
||||||
package lua
|
|
||||||
|
|
||||||
import (
|
|
||||||
"embed"
|
|
||||||
)
|
|
||||||
|
|
||||||
//go:embed registry.lua
|
|
||||||
var fs embed.FS
|
|
||||||
|
|
||||||
// Registry is the registry lua script
|
|
||||||
var Registry string
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
bs, err := fs.ReadFile("registry.lua")
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
Registry = string(bs)
|
|
||||||
}
|
|
|
@ -1,28 +0,0 @@
|
||||||
-- ARGV = [current time in seconds, ttl in seconds, services ...]
|
|
||||||
local current_time = ARGV[1]
|
|
||||||
local ttl = ARGV[2]
|
|
||||||
local changed = false
|
|
||||||
|
|
||||||
-- update the service list
|
|
||||||
for i = 3, #ARGV, 1 do
|
|
||||||
redis.call('HSET', KEYS[1], ARGV[i], current_time + ttl)
|
|
||||||
changed = true
|
|
||||||
end
|
|
||||||
|
|
||||||
-- retrieve all the services, removing any that have expired
|
|
||||||
local svcs = {}
|
|
||||||
local kvs = redis.call('HGETALL', KEYS[1])
|
|
||||||
for i = 1, #kvs, 2 do
|
|
||||||
if kvs[i + 1] < current_time then
|
|
||||||
redis.call('HDEL', KEYS[1], kvs[i])
|
|
||||||
changed = true
|
|
||||||
else
|
|
||||||
table.insert(svcs, kvs[i])
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
if changed then
|
|
||||||
redis.call('PUBLISH', KEYS[2], current_time)
|
|
||||||
end
|
|
||||||
|
|
||||||
return svcs
|
|
|
@ -1,48 +0,0 @@
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const defaultTTL = time.Second * 30
|
|
||||||
|
|
||||||
type config struct {
|
|
||||||
tls *tls.Config
|
|
||||||
ttl time.Duration
|
|
||||||
getNow func() time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// An Option modifies the config..
|
|
||||||
type Option func(*config)
|
|
||||||
|
|
||||||
// WithGetNow sets the time.Now function in the config.
|
|
||||||
func WithGetNow(getNow func() time.Time) Option {
|
|
||||||
return func(cfg *config) {
|
|
||||||
cfg.getNow = getNow
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithTLSConfig sets the tls.Config in the config.
|
|
||||||
func WithTLSConfig(tlsConfig *tls.Config) Option {
|
|
||||||
return func(cfg *config) {
|
|
||||||
cfg.tls = tlsConfig
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithTTL sets the ttl in the config.
|
|
||||||
func WithTTL(ttl time.Duration) Option {
|
|
||||||
return func(cfg *config) {
|
|
||||||
cfg.ttl = ttl
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getConfig(options ...Option) *config {
|
|
||||||
cfg := new(config)
|
|
||||||
WithGetNow(time.Now)(cfg)
|
|
||||||
WithTTL(defaultTTL)(cfg)
|
|
||||||
for _, o := range options {
|
|
||||||
o(cfg)
|
|
||||||
}
|
|
||||||
return cfg
|
|
||||||
}
|
|
|
@ -1,250 +0,0 @@
|
||||||
// Package redis implements a registry in redis.
|
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
"github.com/pomerium/pomerium/internal/redisutil"
|
|
||||||
"github.com/pomerium/pomerium/internal/registry"
|
|
||||||
"github.com/pomerium/pomerium/internal/registry/redis/lua"
|
|
||||||
"github.com/pomerium/pomerium/internal/signal"
|
|
||||||
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
registryKey = redisutil.KeyPrefix + "registry"
|
|
||||||
registryUpdateKey = redisutil.KeyPrefix + "registry_changed_ch"
|
|
||||||
|
|
||||||
pollInterval = time.Second * 30
|
|
||||||
)
|
|
||||||
|
|
||||||
type impl struct {
|
|
||||||
cfg *config
|
|
||||||
|
|
||||||
client redis.UniversalClient
|
|
||||||
onChange *signal.Signal
|
|
||||||
|
|
||||||
closeOnce sync.Once
|
|
||||||
closed chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new registry implementation backend by redis.
|
|
||||||
func New(rawURL string, options ...Option) (registry.Interface, error) {
|
|
||||||
cfg := getConfig(options...)
|
|
||||||
|
|
||||||
client, err := redisutil.NewClientFromURL(rawURL, cfg.tls)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
i := &impl{
|
|
||||||
cfg: cfg,
|
|
||||||
client: client,
|
|
||||||
onChange: signal.New(),
|
|
||||||
closed: make(chan struct{}),
|
|
||||||
}
|
|
||||||
go i.listenForChanges(context.Background())
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
|
|
||||||
_, err := i.runReport(ctx, req.GetServices())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return ®istrypb.RegisterResponse{
|
|
||||||
CallBackAfter: durationpb.New(i.cfg.ttl / 2),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
|
|
||||||
all, err := i.runReport(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
include := map[registrypb.ServiceKind]struct{}{}
|
|
||||||
for _, kind := range req.GetKinds() {
|
|
||||||
include[kind] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
filtered := make([]*registrypb.Service, 0, len(all))
|
|
||||||
for _, svc := range all {
|
|
||||||
if _, ok := include[svc.GetKind()]; !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
filtered = append(filtered, svc)
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(filtered, func(i, j int) bool {
|
|
||||||
{
|
|
||||||
iv, jv := filtered[i].GetKind(), filtered[j].GetKind()
|
|
||||||
switch {
|
|
||||||
case iv < jv:
|
|
||||||
return true
|
|
||||||
case jv < iv:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
iv, jv := filtered[i].GetEndpoint(), filtered[j].GetEndpoint()
|
|
||||||
switch {
|
|
||||||
case iv < jv:
|
|
||||||
return true
|
|
||||||
case jv < iv:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
})
|
|
||||||
|
|
||||||
return ®istrypb.ServiceList{
|
|
||||||
Services: filtered,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
|
|
||||||
// listen for changes
|
|
||||||
ch := i.onChange.Bind()
|
|
||||||
defer i.onChange.Unbind(ch)
|
|
||||||
|
|
||||||
// force a check periodically
|
|
||||||
poll := time.NewTicker(pollInterval)
|
|
||||||
defer poll.Stop()
|
|
||||||
|
|
||||||
var prev *registrypb.ServiceList
|
|
||||||
for {
|
|
||||||
// retrieve the most recent list of services
|
|
||||||
lst, err := i.List(stream.Context(), req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// only send a new list if something changed
|
|
||||||
if !proto.Equal(prev, lst) {
|
|
||||||
err = stream.Send(lst)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
prev = lst
|
|
||||||
|
|
||||||
// wait for an update
|
|
||||||
select {
|
|
||||||
case <-i.closed:
|
|
||||||
return nil
|
|
||||||
case <-stream.Context().Done():
|
|
||||||
return stream.Context().Err()
|
|
||||||
case <-ch:
|
|
||||||
case <-poll.C:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) Close() error {
|
|
||||||
var err error
|
|
||||||
i.closeOnce.Do(func() {
|
|
||||||
err = i.client.Close()
|
|
||||||
close(i.closed)
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) listenForChanges(ctx context.Context) {
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
go func() {
|
|
||||||
<-i.closed
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
bo := backoff.NewExponentialBackOff()
|
|
||||||
bo.MaxElapsedTime = 0
|
|
||||||
|
|
||||||
outer:
|
|
||||||
for {
|
|
||||||
pubsub := i.client.Subscribe(ctx, registryUpdateKey)
|
|
||||||
for {
|
|
||||||
msg, err := pubsub.Receive(ctx)
|
|
||||||
if err != nil {
|
|
||||||
_ = pubsub.Close()
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-time.After(bo.NextBackOff()):
|
|
||||||
}
|
|
||||||
continue outer
|
|
||||||
}
|
|
||||||
bo.Reset()
|
|
||||||
|
|
||||||
switch msg.(type) {
|
|
||||||
case *redis.Message:
|
|
||||||
i.onChange.Broadcast(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) runReport(ctx context.Context, updates []*registrypb.Service) ([]*registrypb.Service, error) {
|
|
||||||
args := []interface{}{
|
|
||||||
i.cfg.getNow().UnixNano() / int64(time.Millisecond), // current_time
|
|
||||||
i.cfg.ttl.Milliseconds(), // ttl
|
|
||||||
}
|
|
||||||
for _, svc := range updates {
|
|
||||||
args = append(args, i.getRegistryHashKey(svc))
|
|
||||||
}
|
|
||||||
res, err := i.client.Eval(ctx, lua.Registry, []string{registryKey, registryUpdateKey}, args...).Result()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if values, ok := res.([]interface{}); ok {
|
|
||||||
var all []*registrypb.Service
|
|
||||||
for _, value := range values {
|
|
||||||
svc, err := i.getServiceFromRegistryHashKey(fmt.Sprint(value))
|
|
||||||
if err != nil {
|
|
||||||
log.Warn(ctx).Err(err).Msg("redis: invalid service")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
all = append(all, svc)
|
|
||||||
}
|
|
||||||
return all, nil
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) getServiceFromRegistryHashKey(key string) (*registrypb.Service, error) {
|
|
||||||
idx := strings.Index(key, "|")
|
|
||||||
if idx == -1 {
|
|
||||||
return nil, fmt.Errorf("redis: invalid service entry in hash: %s", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
svcKindStr := key[:idx]
|
|
||||||
svcEndpointStr := key[idx+1:]
|
|
||||||
|
|
||||||
svcKind, ok := registrypb.ServiceKind_value[svcKindStr]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("redis: unknown service kind: %s", svcKindStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
svc := ®istrypb.Service{
|
|
||||||
Kind: registrypb.ServiceKind(svcKind),
|
|
||||||
Endpoint: svcEndpointStr,
|
|
||||||
}
|
|
||||||
return svc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) getRegistryHashKey(svc *registrypb.Service) string {
|
|
||||||
return svc.GetKind().String() + "|" + svc.GetEndpoint()
|
|
||||||
}
|
|
|
@ -1,196 +0,0 @@
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/testutil"
|
|
||||||
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestReport(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
tm := time.Now()
|
|
||||||
|
|
||||||
i, err := New(rawURL,
|
|
||||||
WithGetNow(func() time.Time {
|
|
||||||
return tm
|
|
||||||
}),
|
|
||||||
WithTTL(time.Second*10))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { _ = i.Close() }()
|
|
||||||
|
|
||||||
_, err = i.Report(ctx, ®istrypb.RegisterRequest{
|
|
||||||
Services: []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// move forward 5 seconds
|
|
||||||
tm = tm.Add(time.Second * 5)
|
|
||||||
_, err = i.Report(ctx, ®istrypb.RegisterRequest{
|
|
||||||
Services: []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
lst, err := i.List(ctx, ®istrypb.ListRequest{
|
|
||||||
Kinds: []registrypb.ServiceKind{
|
|
||||||
registrypb.ServiceKind_AUTHORIZE,
|
|
||||||
registrypb.ServiceKind_PROXY,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
|
||||||
}, lst.GetServices(), "should list selected services")
|
|
||||||
|
|
||||||
// move forward 6 seconds
|
|
||||||
tm = tm.Add(time.Second * 6)
|
|
||||||
lst, err = i.List(ctx, ®istrypb.ListRequest{
|
|
||||||
Kinds: []registrypb.ServiceKind{
|
|
||||||
registrypb.ServiceKind_AUTHORIZE,
|
|
||||||
registrypb.ServiceKind_PROXY,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
|
||||||
}, lst.GetServices(), "should expire old services")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWatch(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*15)
|
|
||||||
defer clearTimeout()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
tm := time.Now()
|
|
||||||
i, err := New(rawURL,
|
|
||||||
WithGetNow(func() time.Time {
|
|
||||||
return tm
|
|
||||||
}),
|
|
||||||
WithTTL(time.Second*10))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
li, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer li.Close()
|
|
||||||
|
|
||||||
srv := grpc.NewServer()
|
|
||||||
registrypb.RegisterRegistryServer(srv, i)
|
|
||||||
eg, ctx := errgroup.WithContext(ctx)
|
|
||||||
eg.Go(func() error {
|
|
||||||
<-ctx.Done()
|
|
||||||
srv.Stop()
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
eg.Go(func() error {
|
|
||||||
return srv.Serve(li)
|
|
||||||
})
|
|
||||||
eg.Go(func() error {
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
cc, err := grpc.Dial(li.Addr().String(), grpc.WithInsecure())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
client := registrypb.NewRegistryClient(cc)
|
|
||||||
|
|
||||||
// store the initial services
|
|
||||||
_, err = client.Report(ctx, ®istrypb.RegisterRequest{
|
|
||||||
Services: []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, err := client.Watch(ctx, ®istrypb.ListRequest{
|
|
||||||
Kinds: []registrypb.ServiceKind{
|
|
||||||
registrypb.ServiceKind_AUTHORIZE,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() { _ = stream.CloseSend() }()
|
|
||||||
|
|
||||||
lst, err := stream.Recv()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
assert.Equal(t, []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
|
|
||||||
}, lst.GetServices())
|
|
||||||
|
|
||||||
// update authenticate
|
|
||||||
_, err = client.Report(ctx, ®istrypb.RegisterRequest{
|
|
||||||
Services: []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// add an authorize
|
|
||||||
_, err = client.Report(ctx, ®istrypb.RegisterRequest{
|
|
||||||
Services: []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
lst, err = stream.Recv()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
assert.Equal(t, []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"},
|
|
||||||
}, lst.GetServices())
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
require.NoError(t, eg.Wait())
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
}
|
|
|
@ -1,31 +0,0 @@
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
|
|
||||||
pomeriumconfig "github.com/pomerium/pomerium/config"
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
|
||||||
)
|
|
||||||
|
|
||||||
type logger struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l logger) Printf(ctx context.Context, format string, v ...interface{}) {
|
|
||||||
log.Info(ctx).Str("service", "redis").Msgf(format, v...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
redis.SetLogger(logger{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func recordOperation(ctx context.Context, startTime time.Time, operation string, err error) {
|
|
||||||
metrics.RecordStorageOperation(ctx, &metrics.StorageOperationTags{
|
|
||||||
Operation: operation,
|
|
||||||
Error: err,
|
|
||||||
Backend: pomeriumconfig.StorageRedisName,
|
|
||||||
}, time.Since(startTime))
|
|
||||||
}
|
|
|
@ -1,37 +0,0 @@
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type config struct {
|
|
||||||
tls *tls.Config
|
|
||||||
expiry time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// Option customizes a Backend.
|
|
||||||
type Option func(*config)
|
|
||||||
|
|
||||||
// WithTLSConfig sets the tls.Config which Backend uses.
|
|
||||||
func WithTLSConfig(tlsConfig *tls.Config) Option {
|
|
||||||
return func(cfg *config) {
|
|
||||||
cfg.tls = tlsConfig
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithExpiry sets the expiry for changes.
|
|
||||||
func WithExpiry(expiry time.Duration) Option {
|
|
||||||
return func(cfg *config) {
|
|
||||||
cfg.expiry = expiry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getConfig(options ...Option) *config {
|
|
||||||
cfg := new(config)
|
|
||||||
WithExpiry(time.Hour * 24)(cfg)
|
|
||||||
for _, o := range options {
|
|
||||||
o(cfg)
|
|
||||||
}
|
|
||||||
return cfg
|
|
||||||
}
|
|
|
@ -1,537 +0,0 @@
|
||||||
// Package redis implements the storage.Backend interface for redis.
|
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"sort"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
"github.com/pomerium/pomerium/internal/redisutil"
|
|
||||||
"github.com/pomerium/pomerium/internal/signal"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
maxTransactionRetries = 100
|
|
||||||
watchPollInterval = 30 * time.Second
|
|
||||||
|
|
||||||
// we rely on transactions in redis, so all redis-cluster keys need to be
|
|
||||||
// on the same node. Using a `hash tag` gives us this capability.
|
|
||||||
serverVersionKey = redisutil.KeyPrefix + "server_version"
|
|
||||||
lastVersionKey = redisutil.KeyPrefix + "last_version"
|
|
||||||
lastVersionChKey = redisutil.KeyPrefix + "last_version_ch"
|
|
||||||
recordHashKey = redisutil.KeyPrefix + "records"
|
|
||||||
recordTypesSetKey = redisutil.KeyPrefix + "record_types"
|
|
||||||
changesSetKey = redisutil.KeyPrefix + "changes"
|
|
||||||
optionsKey = redisutil.KeyPrefix + "options"
|
|
||||||
|
|
||||||
recordTypeChangesKeyTpl = redisutil.KeyPrefix + "changes.%s"
|
|
||||||
leaseKeyTpl = redisutil.KeyPrefix + "lease.%s"
|
|
||||||
)
|
|
||||||
|
|
||||||
// custom errors
|
|
||||||
var (
|
|
||||||
ErrExceededMaxRetries = errors.New("redis: transaction reached maximum number of retries")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Backend implements the storage.Backend on top of redis.
|
|
||||||
//
|
|
||||||
// What's stored:
|
|
||||||
//
|
|
||||||
// - last_version: an integer recordVersion number
|
|
||||||
// - last_version_ch: a PubSub channel for recordVersion number updates
|
|
||||||
// - records: a Hash of records. The hash key is {recordType}/{recordID}, the hash value the protobuf record.
|
|
||||||
// - changes: a Sorted Set of all the changes. The score is the recordVersion number, the member the protobuf record.
|
|
||||||
// - options: a Hash of options. The hash key is {recordType}, the hash value the protobuf options.
|
|
||||||
// - changes.{recordType}: a Sorted Set of the changes for a record type. The score is the current time,
|
|
||||||
// the value the record id.
|
|
||||||
//
|
|
||||||
// Records stored in these keys are typically encrypted.
|
|
||||||
type Backend struct {
|
|
||||||
cfg *config
|
|
||||||
|
|
||||||
client redis.UniversalClient
|
|
||||||
onChange *signal.Signal
|
|
||||||
|
|
||||||
closeOnce sync.Once
|
|
||||||
closed chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new redis storage backend.
|
|
||||||
func New(rawURL string, options ...Option) (*Backend, error) {
|
|
||||||
ctx := context.TODO()
|
|
||||||
cfg := getConfig(options...)
|
|
||||||
backend := &Backend{
|
|
||||||
cfg: cfg,
|
|
||||||
closed: make(chan struct{}),
|
|
||||||
onChange: signal.New(),
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
backend.client, err = redisutil.NewClientFromURL(rawURL, backend.cfg.tls)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
metrics.AddRedisMetrics(backend.client.PoolStats)
|
|
||||||
go backend.listenForVersionChanges(ctx)
|
|
||||||
if cfg.expiry != 0 {
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(time.Minute)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-backend.closed:
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
}
|
|
||||||
|
|
||||||
backend.removeChangesBefore(ctx, time.Now().Add(-cfg.expiry))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
return backend, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the underlying redis connection and any watchers.
|
|
||||||
func (backend *Backend) Close() error {
|
|
||||||
var err error
|
|
||||||
backend.closeOnce.Do(func() {
|
|
||||||
err = backend.client.Close()
|
|
||||||
close(backend.closed)
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get gets a record from redis.
|
|
||||||
func (backend *Backend) Get(ctx context.Context, recordType, id string) (_ *databroker.Record, err error) {
|
|
||||||
_, span := trace.StartSpan(ctx, "databroker.redis.Get")
|
|
||||||
defer span.End()
|
|
||||||
defer func(start time.Time) { recordOperation(ctx, start, "get", err) }(time.Now())
|
|
||||||
|
|
||||||
key, field := getHashKey(recordType, id)
|
|
||||||
cmd := backend.client.HGet(ctx, key, field)
|
|
||||||
raw, err := cmd.Result()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
return nil, storage.ErrNotFound
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var record databroker.Record
|
|
||||||
err = proto.Unmarshal([]byte(raw), &record)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &record, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOptions gets the options for the given record type.
|
|
||||||
func (backend *Backend) GetOptions(ctx context.Context, recordType string) (*databroker.Options, error) {
|
|
||||||
raw, err := backend.client.HGet(ctx, optionsKey, recordType).Result()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
// treat no options as an empty set of options
|
|
||||||
return new(databroker.Options), nil
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var options databroker.Options
|
|
||||||
err = proto.Unmarshal([]byte(raw), &options)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &options, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lease acquires or renews a lease.
|
|
||||||
func (backend *Backend) Lease(ctx context.Context, leaseName, leaseID string, ttl time.Duration) (bool, error) {
|
|
||||||
acquired := false
|
|
||||||
key := getLeaseKey(leaseName)
|
|
||||||
err := backend.client.Watch(ctx, func(tx *redis.Tx) error {
|
|
||||||
currentID, err := tx.Get(ctx, key).Result()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
// lease hasn't been set yet
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
} else if leaseID != currentID {
|
|
||||||
// lease has already been taken
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = tx.Pipelined(ctx, func(p redis.Pipeliner) error {
|
|
||||||
if ttl <= 0 {
|
|
||||||
p.Del(ctx, key)
|
|
||||||
} else {
|
|
||||||
p.Set(ctx, key, leaseID, ttl)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
acquired = ttl > 0
|
|
||||||
return nil
|
|
||||||
}, key)
|
|
||||||
// if the transaction failed someone else must've acquired the lease
|
|
||||||
if errors.Is(err, redis.TxFailedErr) {
|
|
||||||
acquired = false
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
return acquired, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListTypes lists all the known record types.
|
|
||||||
func (backend *Backend) ListTypes(ctx context.Context) (types []string, err error) {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "databroker.redis.ListTypes")
|
|
||||||
defer span.End()
|
|
||||||
defer func(start time.Time) { recordOperation(ctx, start, "listTypes", err) }(time.Now())
|
|
||||||
|
|
||||||
cmd := backend.client.SMembers(ctx, recordTypesSetKey)
|
|
||||||
types, err = cmd.Result()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sort.Strings(types)
|
|
||||||
return types, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put puts a record into redis.
|
|
||||||
func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (serverVersion uint64, err error) {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "databroker.redis.Put")
|
|
||||||
defer span.End()
|
|
||||||
defer func(start time.Time) { recordOperation(ctx, start, "put", err) }(time.Now())
|
|
||||||
|
|
||||||
serverVersion, err = backend.getOrCreateServerVersion(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return serverVersion, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = backend.put(ctx, records)
|
|
||||||
if err != nil {
|
|
||||||
return serverVersion, err
|
|
||||||
}
|
|
||||||
|
|
||||||
recordTypes := map[string]struct{}{}
|
|
||||||
for _, record := range records {
|
|
||||||
recordTypes[record.GetType()] = struct{}{}
|
|
||||||
}
|
|
||||||
for recordType := range recordTypes {
|
|
||||||
err = backend.enforceOptions(ctx, recordType)
|
|
||||||
if err != nil {
|
|
||||||
return serverVersion, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return serverVersion, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetOptions sets the options for the given record type.
|
|
||||||
func (backend *Backend) SetOptions(ctx context.Context, recordType string, options *databroker.Options) error {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "databroker.redis.SetOptions")
|
|
||||||
defer span.End()
|
|
||||||
|
|
||||||
bs, err := proto.Marshal(options)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// update the options in the hash set
|
|
||||||
err = backend.client.HSet(ctx, optionsKey, recordType, bs).Err()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// possibly re-enforce options
|
|
||||||
err = backend.enforceOptions(ctx, recordType)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sync returns a record stream of any records changed after the specified recordVersion.
|
|
||||||
func (backend *Backend) Sync(
|
|
||||||
ctx context.Context,
|
|
||||||
recordType string,
|
|
||||||
serverVersion, recordVersion uint64,
|
|
||||||
) (storage.RecordStream, error) {
|
|
||||||
return newSyncRecordStream(ctx, backend, recordType, serverVersion, recordVersion), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SyncLatest returns a record stream of all the records. Some records may be returned twice if the are updated while the
|
|
||||||
// stream is streaming.
|
|
||||||
func (backend *Backend) SyncLatest(
|
|
||||||
ctx context.Context,
|
|
||||||
recordType string,
|
|
||||||
expr storage.FilterExpression,
|
|
||||||
) (serverVersion, recordVersion uint64, stream storage.RecordStream, err error) {
|
|
||||||
serverVersion, err = backend.getOrCreateServerVersion(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return serverVersion, recordVersion, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
recordVersion, err = backend.client.Get(ctx, lastVersionKey).Uint64()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
// this happens if there are no records
|
|
||||||
} else if err != nil {
|
|
||||||
return serverVersion, recordVersion, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, err = newSyncLatestRecordStream(ctx, backend, recordType, expr)
|
|
||||||
return serverVersion, recordVersion, stream, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (backend *Backend) put(ctx context.Context, records []*databroker.Record) error {
|
|
||||||
return backend.incrementVersion(ctx,
|
|
||||||
func(tx *redis.Tx, version uint64) error {
|
|
||||||
for i, record := range records {
|
|
||||||
record.ModifiedAt = timestamppb.Now()
|
|
||||||
record.Version = version + uint64(i)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
func(p redis.Pipeliner, version uint64) error {
|
|
||||||
for i, record := range records {
|
|
||||||
bs, err := proto.Marshal(record)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
key, field := getHashKey(record.GetType(), record.GetId())
|
|
||||||
if record.DeletedAt != nil {
|
|
||||||
p.HDel(ctx, key, field)
|
|
||||||
} else {
|
|
||||||
p.HSet(ctx, key, field, bs)
|
|
||||||
p.ZAdd(ctx, getRecordTypeChangesKey(record.GetType()), &redis.Z{
|
|
||||||
Score: float64(record.GetModifiedAt().GetSeconds()) + float64(i)/float64(len(records)),
|
|
||||||
Member: record.GetId(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
p.ZAdd(ctx, changesSetKey, &redis.Z{
|
|
||||||
Score: float64(version) + float64(i),
|
|
||||||
Member: bs,
|
|
||||||
})
|
|
||||||
p.SAdd(ctx, recordTypesSetKey, record.GetType())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// enforceOptions enforces the options for the given record type.
|
|
||||||
func (backend *Backend) enforceOptions(ctx context.Context, recordType string) error {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "databroker.redis.enforceOptions")
|
|
||||||
defer span.End()
|
|
||||||
|
|
||||||
options, err := backend.GetOptions(ctx, recordType)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// nothing to do if capacity isn't set
|
|
||||||
if options.Capacity == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
key := getRecordTypeChangesKey(recordType)
|
|
||||||
|
|
||||||
// find oldest records that exceed the capacity
|
|
||||||
recordIDs, err := backend.client.ZRevRange(ctx, key, int64(*options.Capacity), -1).Result()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// for each record, delete it
|
|
||||||
for _, recordID := range recordIDs {
|
|
||||||
record, err := backend.Get(ctx, recordType, recordID)
|
|
||||||
if err == nil {
|
|
||||||
// mark the record as deleted and re-submit
|
|
||||||
record.DeletedAt = timestamppb.Now()
|
|
||||||
err = backend.put(ctx, []*databroker.Record{record})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else if errors.Is(err, storage.ErrNotFound) {
|
|
||||||
// ignore
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove the member from the collection
|
|
||||||
_, err = backend.client.ZRem(ctx, key, recordID).Result()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// incrementVersion increments the last recordVersion key, runs the code in `query`, then attempts to commit the code in
|
|
||||||
// `commit`. If the last recordVersion changes in the interim, we will retry the transaction.
|
|
||||||
func (backend *Backend) incrementVersion(ctx context.Context,
|
|
||||||
query func(tx *redis.Tx, recordVersion uint64) error,
|
|
||||||
commit func(p redis.Pipeliner, recordVersion uint64) error,
|
|
||||||
) error {
|
|
||||||
// code is modeled on https://pkg.go.dev/github.com/go-redis/redis/v8#example-Client.Watch
|
|
||||||
txf := func(tx *redis.Tx) error {
|
|
||||||
version, err := tx.Get(ctx, lastVersionKey).Uint64()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
version = 0
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
version++
|
|
||||||
|
|
||||||
err = query(tx, version)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// the `commit` code is run in a transaction so that the EXEC cmd will run for the original redis watch
|
|
||||||
_, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error {
|
|
||||||
err := commit(p, version)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
p.Set(ctx, lastVersionKey, version, 0)
|
|
||||||
p.Publish(ctx, lastVersionChKey, version)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
bo := backoff.NewExponentialBackOff()
|
|
||||||
bo.MaxElapsedTime = 0
|
|
||||||
for i := 0; i < maxTransactionRetries; i++ {
|
|
||||||
err := backend.client.Watch(ctx, txf, lastVersionKey)
|
|
||||||
if errors.Is(err, redis.TxFailedErr) {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-time.After(bo.NextBackOff()):
|
|
||||||
}
|
|
||||||
continue // retry
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil // tx was successful
|
|
||||||
}
|
|
||||||
|
|
||||||
return ErrExceededMaxRetries
|
|
||||||
}
|
|
||||||
|
|
||||||
func (backend *Backend) listenForVersionChanges(ctx context.Context) {
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
go func() {
|
|
||||||
<-backend.closed
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
bo := backoff.NewExponentialBackOff()
|
|
||||||
bo.MaxElapsedTime = 0
|
|
||||||
|
|
||||||
outer:
|
|
||||||
for {
|
|
||||||
pubsub := backend.client.Subscribe(ctx, lastVersionChKey)
|
|
||||||
for {
|
|
||||||
msg, err := pubsub.Receive(ctx)
|
|
||||||
if err != nil {
|
|
||||||
_ = pubsub.Close()
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-time.After(bo.NextBackOff()):
|
|
||||||
}
|
|
||||||
continue outer
|
|
||||||
}
|
|
||||||
bo.Reset()
|
|
||||||
|
|
||||||
switch msg.(type) {
|
|
||||||
case *redis.Message:
|
|
||||||
backend.onChange.Broadcast(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (backend *Backend) removeChangesBefore(ctx context.Context, cutoff time.Time) {
|
|
||||||
for {
|
|
||||||
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
|
|
||||||
Min: "-inf",
|
|
||||||
Max: "+inf",
|
|
||||||
Offset: 0,
|
|
||||||
Count: 1,
|
|
||||||
})
|
|
||||||
results, err := cmd.Result()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(ctx).Err(err).Msg("redis: error retrieving changes for expiration")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// nothing left to do
|
|
||||||
if len(results) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var record databroker.Record
|
|
||||||
err = proto.Unmarshal([]byte(results[0]), &record)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
|
||||||
record.ModifiedAt = timestamppb.New(cutoff.Add(-time.Second)) // set the modified so will delete it
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the record's modified timestamp is after the cutoff, we're all done, so break
|
|
||||||
if record.GetModifiedAt().AsTime().After(cutoff) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove the record
|
|
||||||
err = backend.client.ZRem(ctx, changesSetKey, results[0]).Err()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(ctx).Err(err).Msg("redis: error removing member")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (backend *Backend) getOrCreateServerVersion(ctx context.Context) (serverVersion uint64, err error) {
|
|
||||||
serverVersion, err = backend.client.Get(ctx, serverVersionKey).Uint64()
|
|
||||||
// if the server version hasn't been set yet, set it to a random value and immediately retrieve it
|
|
||||||
// this should properly handle a data race by only setting the key if it doesn't already exist
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
_, _ = backend.client.SetNX(ctx, serverVersionKey, cryptutil.NewRandomUInt64(), 0).Result()
|
|
||||||
serverVersion, err = backend.client.Get(ctx, serverVersionKey).Uint64()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("redis: error retrieving server version: %w", err)
|
|
||||||
}
|
|
||||||
return serverVersion, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func getLeaseKey(leaseName string) string {
|
|
||||||
return fmt.Sprintf(leaseKeyTpl, leaseName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRecordTypeChangesKey(recordType string) string {
|
|
||||||
return fmt.Sprintf(recordTypeChangesKeyTpl, recordType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getHashKey(recordType, id string) (key, field string) {
|
|
||||||
return recordHashKey, fmt.Sprintf("%s/%s", recordType, id)
|
|
||||||
}
|
|
|
@ -1,312 +0,0 @@
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/testutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBackend(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := func(t *testing.T, useTLS bool, rawURL string) error {
|
|
||||||
ctx := context.Background()
|
|
||||||
var opts []Option
|
|
||||||
if useTLS {
|
|
||||||
opts = append(opts, WithTLSConfig(testutil.RedisTLSConfig()))
|
|
||||||
}
|
|
||||||
backend, err := New(rawURL, opts...)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { _ = backend.Close() }()
|
|
||||||
|
|
||||||
serverVersion, err := backend.getOrCreateServerVersion(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
t.Run("get missing record", func(t *testing.T) {
|
|
||||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Nil(t, record)
|
|
||||||
})
|
|
||||||
t.Run("get record", func(t *testing.T) {
|
|
||||||
data := new(anypb.Any)
|
|
||||||
sv, err := backend.Put(ctx, []*databroker.Record{{
|
|
||||||
Type: "TYPE",
|
|
||||||
Id: "abcd",
|
|
||||||
Data: data,
|
|
||||||
}})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, serverVersion, sv)
|
|
||||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
|
||||||
require.NoError(t, err)
|
|
||||||
if assert.NotNil(t, record) {
|
|
||||||
assert.Equal(t, data, record.Data)
|
|
||||||
assert.Nil(t, record.DeletedAt)
|
|
||||||
assert.Equal(t, "abcd", record.Id)
|
|
||||||
assert.NotNil(t, record.ModifiedAt)
|
|
||||||
assert.Equal(t, "TYPE", record.Type)
|
|
||||||
assert.Equal(t, uint64(1), record.Version)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Run("delete record", func(t *testing.T) {
|
|
||||||
sv, err := backend.Put(ctx, []*databroker.Record{{
|
|
||||||
Type: "TYPE",
|
|
||||||
Id: "abcd",
|
|
||||||
DeletedAt: timestamppb.Now(),
|
|
||||||
}})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, serverVersion, sv)
|
|
||||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Nil(t, record)
|
|
||||||
})
|
|
||||||
t.Run("list types", func(t *testing.T) {
|
|
||||||
types, err := backend.ListTypes(ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{"TYPE"}, types)
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("no-tls", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
return handler(t, false, rawURL)
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("tls", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
require.NoError(t, testutil.WithTestRedis(true, func(rawURL string) error {
|
|
||||||
return handler(t, true, rawURL)
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
t.Run("cluster", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
require.NoError(t, testutil.WithTestRedisCluster(func(rawURL string) error {
|
|
||||||
return handler(t, false, rawURL)
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("sentinel", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
require.NoError(t, testutil.WithTestRedisSentinel(func(rawURL string) error {
|
|
||||||
return handler(t, false, rawURL)
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChangeSignal(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
||||||
defer clearTimeout()
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
var eg errgroup.Group
|
|
||||||
eg.Go(func() error {
|
|
||||||
backend, err := New(rawURL)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() { _ = backend.Close() }()
|
|
||||||
|
|
||||||
ch := backend.onChange.Bind()
|
|
||||||
defer backend.onChange.Unbind(ch)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ch:
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// signal the second backend that we've received the change
|
|
||||||
close(done)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
eg.Go(func() error {
|
|
||||||
backend, err := New(rawURL)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() { _ = backend.Close() }()
|
|
||||||
|
|
||||||
// put a new value to trigger a change
|
|
||||||
for {
|
|
||||||
_, err = backend.Put(ctx, []*databroker.Record{{
|
|
||||||
Type: "TYPE",
|
|
||||||
Id: "ID",
|
|
||||||
}})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-done:
|
|
||||||
return nil
|
|
||||||
case <-time.After(time.Millisecond * 100):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
assert.NoError(t, eg.Wait(), "expected signal to be fired when another backend triggers a change")
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExpiry(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
backend, err := New(rawURL, WithExpiry(0))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { _ = backend.Close() }()
|
|
||||||
|
|
||||||
serverVersion, err := backend.getOrCreateServerVersion(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
_, err := backend.Put(ctx, []*databroker.Record{{
|
|
||||||
Type: "TYPE",
|
|
||||||
Id: fmt.Sprint(i),
|
|
||||||
}})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
stream, err := backend.Sync(ctx, "TYPE", serverVersion, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
var records []*databroker.Record
|
|
||||||
for stream.Next(false) {
|
|
||||||
records = append(records, stream.Record())
|
|
||||||
}
|
|
||||||
_ = stream.Close()
|
|
||||||
require.Len(t, records, 1000)
|
|
||||||
|
|
||||||
backend.removeChangesBefore(ctx, time.Now().Add(time.Second))
|
|
||||||
|
|
||||||
stream, err = backend.Sync(ctx, "TYPE", serverVersion, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
records = nil
|
|
||||||
for stream.Next(false) {
|
|
||||||
records = append(records, stream.Record())
|
|
||||||
}
|
|
||||||
_ = stream.Close()
|
|
||||||
require.Len(t, records, 0)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCapacity(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
|
||||||
defer clearTimeout()
|
|
||||||
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
backend, err := New(rawURL, WithExpiry(0))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { _ = backend.Close() }()
|
|
||||||
|
|
||||||
err = backend.SetOptions(ctx, "EXAMPLE", &databroker.Options{
|
|
||||||
Capacity: proto.Uint64(3),
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
_, err = backend.Put(ctx, []*databroker.Record{{
|
|
||||||
Type: "EXAMPLE",
|
|
||||||
Id: fmt.Sprint(i),
|
|
||||||
}})
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, stream, err := backend.SyncLatest(ctx, "EXAMPLE", nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer stream.Close()
|
|
||||||
|
|
||||||
records, err := storage.RecordStreamToList(stream)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Len(t, records, 3)
|
|
||||||
|
|
||||||
var ids []string
|
|
||||||
for _, r := range records {
|
|
||||||
ids = append(ids, r.GetId())
|
|
||||||
}
|
|
||||||
assert.Equal(t, []string{"7", "8", "9"}, ids, "should contain recent records")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLease(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
backend, err := New(rawURL)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { _ = backend.Close() }()
|
|
||||||
|
|
||||||
{
|
|
||||||
ok, err := backend.Lease(ctx, "test", "a", time.Second*30)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, ok, "expected a to acquire the lease")
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ok, err := backend.Lease(ctx, "test", "b", time.Second*30)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, ok, "expected b to fail to acquire the lease")
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ok, err := backend.Lease(ctx, "test", "a", 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, ok, "expected a to clear the lease")
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ok, err := backend.Lease(ctx, "test", "b", time.Second*30)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, ok, "expected b to to acquire the lease")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
}
|
|
|
@ -1,172 +0,0 @@
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newSyncRecordStream(
|
|
||||||
ctx context.Context,
|
|
||||||
backend *Backend,
|
|
||||||
recordType string,
|
|
||||||
serverVersion uint64,
|
|
||||||
recordVersion uint64,
|
|
||||||
) storage.RecordStream {
|
|
||||||
changed := backend.onChange.Bind()
|
|
||||||
return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{
|
|
||||||
// 1. stream all record changes
|
|
||||||
func(ctx context.Context, block bool) (*databroker.Record, error) {
|
|
||||||
ticker := time.NewTicker(watchPollInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
currentServerVersion, err := backend.getOrCreateServerVersion(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if serverVersion != currentServerVersion {
|
|
||||||
return nil, storage.ErrInvalidServerVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
record, err := nextChangedRecord(ctx, backend, recordType, &recordVersion)
|
|
||||||
if err == nil {
|
|
||||||
return record, nil
|
|
||||||
} else if !errors.Is(err, storage.ErrStreamDone) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !block {
|
|
||||||
return nil, storage.ErrStreamDone
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case <-ticker.C:
|
|
||||||
case <-changed:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}, func() {
|
|
||||||
backend.onChange.Unbind(changed)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSyncLatestRecordStream(
|
|
||||||
ctx context.Context,
|
|
||||||
backend *Backend,
|
|
||||||
recordType string,
|
|
||||||
expr storage.FilterExpression,
|
|
||||||
) (storage.RecordStream, error) {
|
|
||||||
filter, err := storage.RecordStreamFilterFromFilterExpression(expr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if recordType != "" {
|
|
||||||
filter = filter.And(func(record *databroker.Record) (keep bool) {
|
|
||||||
return record.GetType() == recordType
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
var cursor uint64
|
|
||||||
scannedOnce := false
|
|
||||||
var scannedRecords []*databroker.Record
|
|
||||||
generator := storage.FilteredRecordStreamGenerator(
|
|
||||||
func(ctx context.Context, block bool) (*databroker.Record, error) {
|
|
||||||
for {
|
|
||||||
if len(scannedRecords) > 0 {
|
|
||||||
record := scannedRecords[0]
|
|
||||||
scannedRecords = scannedRecords[1:]
|
|
||||||
return record, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// the cursor is reset to 0 after iteration is complete
|
|
||||||
if scannedOnce && cursor == 0 {
|
|
||||||
return nil, storage.ErrStreamDone
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
scannedRecords, err = nextScannedRecords(ctx, backend, &cursor)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
scannedOnce = true
|
|
||||||
}
|
|
||||||
},
|
|
||||||
filter,
|
|
||||||
)
|
|
||||||
|
|
||||||
return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{
|
|
||||||
generator,
|
|
||||||
}, nil), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func nextScannedRecords(ctx context.Context, backend *Backend, cursor *uint64) ([]*databroker.Record, error) {
|
|
||||||
var values []string
|
|
||||||
var err error
|
|
||||||
values, *cursor, err = backend.client.HScan(ctx, recordHashKey, *cursor, "", 0).Result()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
return nil, storage.ErrStreamDone
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if len(values) == 0 {
|
|
||||||
return nil, storage.ErrStreamDone
|
|
||||||
}
|
|
||||||
|
|
||||||
var records []*databroker.Record
|
|
||||||
for i := 1; i < len(values); i += 2 {
|
|
||||||
var record databroker.Record
|
|
||||||
err := proto.Unmarshal([]byte(values[i]), &record)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
records = append(records, &record)
|
|
||||||
}
|
|
||||||
return records, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func nextChangedRecord(ctx context.Context, backend *Backend, recordType string, recordVersion *uint64) (*databroker.Record, error) {
|
|
||||||
for {
|
|
||||||
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
|
|
||||||
Min: fmt.Sprintf("(%d", *recordVersion),
|
|
||||||
Max: "+inf",
|
|
||||||
Offset: 0,
|
|
||||||
Count: 1,
|
|
||||||
})
|
|
||||||
results, err := cmd.Result()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
return nil, storage.ErrStreamDone
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if len(results) == 0 {
|
|
||||||
return nil, storage.ErrStreamDone
|
|
||||||
}
|
|
||||||
|
|
||||||
result := results[0]
|
|
||||||
var record databroker.Record
|
|
||||||
err = proto.Unmarshal([]byte(result), &record)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
|
||||||
*recordVersion++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
*recordVersion = record.GetVersion()
|
|
||||||
if recordType != "" && record.GetType() != recordType {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return &record, nil
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Add table
Add a link
Reference in a new issue