mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-21 21:17:13 +02:00
databroker: remove redis storage backend (#4699)
Remove the Redis databroker backend. According to https://www.pomerium.com/docs/internals/data-storage#redis it has been discouraged since Pomerium v0.18. Update the config options validation to return an error if "redis" is set as the databroker storage backend type.
This commit is contained in:
parent
47890e9ee1
commit
4f648e9ac1
17 changed files with 12 additions and 1964 deletions
|
@ -245,7 +245,7 @@ type Options struct {
|
||||||
DataBrokerURLStrings []string `mapstructure:"databroker_service_urls" yaml:"databroker_service_urls,omitempty"`
|
DataBrokerURLStrings []string `mapstructure:"databroker_service_urls" yaml:"databroker_service_urls,omitempty"`
|
||||||
DataBrokerInternalURLString string `mapstructure:"databroker_internal_service_url" yaml:"databroker_internal_service_url,omitempty"`
|
DataBrokerInternalURLString string `mapstructure:"databroker_internal_service_url" yaml:"databroker_internal_service_url,omitempty"`
|
||||||
// DataBrokerStorageType is the storage backend type that databroker will use.
|
// DataBrokerStorageType is the storage backend type that databroker will use.
|
||||||
// Supported type: memory, redis
|
// Supported type: memory, postgres
|
||||||
DataBrokerStorageType string `mapstructure:"databroker_storage_type" yaml:"databroker_storage_type,omitempty"`
|
DataBrokerStorageType string `mapstructure:"databroker_storage_type" yaml:"databroker_storage_type,omitempty"`
|
||||||
// DataBrokerStorageConnectionString is the data source name for storage backend.
|
// DataBrokerStorageConnectionString is the data source name for storage backend.
|
||||||
DataBrokerStorageConnectionString string `mapstructure:"databroker_storage_connection_string" yaml:"databroker_storage_connection_string,omitempty"`
|
DataBrokerStorageConnectionString string `mapstructure:"databroker_storage_connection_string" yaml:"databroker_storage_connection_string,omitempty"`
|
||||||
|
@ -584,7 +584,9 @@ func (o *Options) Validate() error {
|
||||||
|
|
||||||
switch o.DataBrokerStorageType {
|
switch o.DataBrokerStorageType {
|
||||||
case StorageInMemoryName:
|
case StorageInMemoryName:
|
||||||
case StorageRedisName, StoragePostgresName:
|
case StorageRedisName:
|
||||||
|
return errors.New("config: redis databroker storage backend is no longer supported")
|
||||||
|
case StoragePostgresName:
|
||||||
if o.DataBrokerStorageConnectionString == "" {
|
if o.DataBrokerStorageConnectionString == "" {
|
||||||
return errors.New("config: missing databroker storage backend dsn")
|
return errors.New("config: missing databroker storage backend dsn")
|
||||||
}
|
}
|
||||||
|
|
|
@ -58,8 +58,10 @@ func Test_Validate(t *testing.T) {
|
||||||
badPolicyFile.PolicyFile = "file"
|
badPolicyFile.PolicyFile = "file"
|
||||||
invalidStorageType := testOptions()
|
invalidStorageType := testOptions()
|
||||||
invalidStorageType.DataBrokerStorageType = "foo"
|
invalidStorageType.DataBrokerStorageType = "foo"
|
||||||
|
redisStorageType := testOptions()
|
||||||
|
redisStorageType.DataBrokerStorageType = "redis"
|
||||||
missingStorageDSN := testOptions()
|
missingStorageDSN := testOptions()
|
||||||
missingStorageDSN.DataBrokerStorageType = "redis"
|
missingStorageDSN.DataBrokerStorageType = "postgres"
|
||||||
badSignoutRedirectURL := testOptions()
|
badSignoutRedirectURL := testOptions()
|
||||||
badSignoutRedirectURL.SignOutRedirectURLString = "--"
|
badSignoutRedirectURL.SignOutRedirectURLString = "--"
|
||||||
badCookieSettings := testOptions()
|
badCookieSettings := testOptions()
|
||||||
|
@ -77,6 +79,7 @@ func Test_Validate(t *testing.T) {
|
||||||
{"missing shared secret but all service", badSecretAllServices, false},
|
{"missing shared secret but all service", badSecretAllServices, false},
|
||||||
{"policy file specified", badPolicyFile, true},
|
{"policy file specified", badPolicyFile, true},
|
||||||
{"invalid databroker storage type", invalidStorageType, true},
|
{"invalid databroker storage type", invalidStorageType, true},
|
||||||
|
{"redis databroker storage type", redisStorageType, true},
|
||||||
{"missing databroker storage dsn", missingStorageDSN, true},
|
{"missing databroker storage dsn", missingStorageDSN, true},
|
||||||
{"invalid signout redirect url", badSignoutRedirectURL, true},
|
{"invalid signout redirect url", badSignoutRedirectURL, true},
|
||||||
{"CookieSameSite none with CookieSecure fale", badCookieSettings, true},
|
{"CookieSameSite none with CookieSecure fale", badCookieSettings, true},
|
||||||
|
|
|
@ -9,7 +9,6 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/registry"
|
"github.com/pomerium/pomerium/internal/registry"
|
||||||
"github.com/pomerium/pomerium/internal/registry/inmemory"
|
"github.com/pomerium/pomerium/internal/registry/inmemory"
|
||||||
"github.com/pomerium/pomerium/internal/registry/redis"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
|
@ -110,16 +109,6 @@ func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interfac
|
||||||
case config.StorageInMemoryName:
|
case config.StorageInMemoryName:
|
||||||
log.Info(ctx).Msg("using in-memory registry")
|
log.Info(ctx).Msg("using in-memory registry")
|
||||||
return inmemory.New(ctx, srv.cfg.registryTTL), nil
|
return inmemory.New(ctx, srv.cfg.registryTTL), nil
|
||||||
case config.StorageRedisName:
|
|
||||||
log.Info(ctx).Msg("using redis registry")
|
|
||||||
r, err := redis.New(
|
|
||||||
srv.cfg.storageConnectionString,
|
|
||||||
redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create new redis registry: %w", err)
|
|
||||||
}
|
|
||||||
return r, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, fmt.Errorf("unsupported registry type: %s", srv.cfg.storageType)
|
return nil, fmt.Errorf("unsupported registry type: %s", srv.cfg.storageType)
|
||||||
|
|
|
@ -3,7 +3,6 @@ package databroker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -19,12 +18,10 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/registry"
|
"github.com/pomerium/pomerium/internal/registry"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
"github.com/pomerium/pomerium/pkg/storage"
|
||||||
"github.com/pomerium/pomerium/pkg/storage/inmemory"
|
"github.com/pomerium/pomerium/pkg/storage/inmemory"
|
||||||
"github.com/pomerium/pomerium/pkg/storage/postgres"
|
"github.com/pomerium/pomerium/pkg/storage/postgres"
|
||||||
"github.com/pomerium/pomerium/pkg/storage/redis"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Server implements the databroker service using an in memory database.
|
// Server implements the databroker service using an in memory database.
|
||||||
|
@ -426,39 +423,8 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
|
||||||
case config.StoragePostgresName:
|
case config.StoragePostgresName:
|
||||||
log.Info(ctx).Msg("using postgres store")
|
log.Info(ctx).Msg("using postgres store")
|
||||||
backend = postgres.New(srv.cfg.storageConnectionString)
|
backend = postgres.New(srv.cfg.storageConnectionString)
|
||||||
case config.StorageRedisName:
|
|
||||||
log.Info(ctx).Msg("using redis store")
|
|
||||||
backend, err = redis.New(
|
|
||||||
srv.cfg.storageConnectionString,
|
|
||||||
redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
|
|
||||||
}
|
|
||||||
if srv.cfg.secret != nil {
|
|
||||||
backend, err = storage.NewEncryptedBackend(srv.cfg.secret, backend)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
|
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
|
||||||
}
|
}
|
||||||
return backend, nil
|
return backend, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) getTLSConfigLocked(ctx context.Context) *tls.Config {
|
|
||||||
caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn(ctx).Err(err).Msg("failed to read databroker CA file")
|
|
||||||
}
|
|
||||||
tlsConfig := &tls.Config{
|
|
||||||
RootCAs: caCertPool,
|
|
||||||
//nolint: gosec
|
|
||||||
InsecureSkipVerify: srv.cfg.storageCertSkipVerify,
|
|
||||||
}
|
|
||||||
if srv.cfg.storageCertificate != nil {
|
|
||||||
tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate}
|
|
||||||
}
|
|
||||||
return tlsConfig
|
|
||||||
}
|
|
||||||
|
|
|
@ -22,7 +22,6 @@ import (
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
"google.golang.org/protobuf/types/known/timestamppb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/testutil"
|
"github.com/pomerium/pomerium/internal/testutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
|
@ -287,12 +286,11 @@ func TestServerInvalidStorage(t *testing.T) {
|
||||||
_ = assert.Error(t, err) && assert.Contains(t, err.Error(), "unsupported storage type")
|
_ = assert.Error(t, err) && assert.Contains(t, err.Error(), "unsupported storage type")
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestServerRedis(t *testing.T) {
|
func TestServerPostgres(t *testing.T) {
|
||||||
testutil.WithTestRedis(false, func(rawURL string) error {
|
testutil.WithTestPostgres(func(dsn string) error {
|
||||||
srv := newServer(&serverConfig{
|
srv := newServer(&serverConfig{
|
||||||
storageType: "redis",
|
storageType: "postgres",
|
||||||
storageConnectionString: rawURL,
|
storageConnectionString: dsn,
|
||||||
secret: cryptutil.NewKey(),
|
|
||||||
})
|
})
|
||||||
|
|
||||||
s := new(session.Session)
|
s := new(session.Session)
|
||||||
|
|
|
@ -1,20 +0,0 @@
|
||||||
// Package lua contains lua source code.
|
|
||||||
package lua
|
|
||||||
|
|
||||||
import (
|
|
||||||
"embed"
|
|
||||||
)
|
|
||||||
|
|
||||||
//go:embed registry.lua
|
|
||||||
var fs embed.FS
|
|
||||||
|
|
||||||
// Registry is the registry lua script
|
|
||||||
var Registry string
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
bs, err := fs.ReadFile("registry.lua")
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
Registry = string(bs)
|
|
||||||
}
|
|
|
@ -1,28 +0,0 @@
|
||||||
-- ARGV = [current time in seconds, ttl in seconds, services ...]
|
|
||||||
local current_time = ARGV[1]
|
|
||||||
local ttl = ARGV[2]
|
|
||||||
local changed = false
|
|
||||||
|
|
||||||
-- update the service list
|
|
||||||
for i = 3, #ARGV, 1 do
|
|
||||||
redis.call('HSET', KEYS[1], ARGV[i], current_time + ttl)
|
|
||||||
changed = true
|
|
||||||
end
|
|
||||||
|
|
||||||
-- retrieve all the services, removing any that have expired
|
|
||||||
local svcs = {}
|
|
||||||
local kvs = redis.call('HGETALL', KEYS[1])
|
|
||||||
for i = 1, #kvs, 2 do
|
|
||||||
if kvs[i + 1] < current_time then
|
|
||||||
redis.call('HDEL', KEYS[1], kvs[i])
|
|
||||||
changed = true
|
|
||||||
else
|
|
||||||
table.insert(svcs, kvs[i])
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
if changed then
|
|
||||||
redis.call('PUBLISH', KEYS[2], current_time)
|
|
||||||
end
|
|
||||||
|
|
||||||
return svcs
|
|
|
@ -1,48 +0,0 @@
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
const defaultTTL = time.Second * 30
|
|
||||||
|
|
||||||
type config struct {
|
|
||||||
tls *tls.Config
|
|
||||||
ttl time.Duration
|
|
||||||
getNow func() time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
// An Option modifies the config..
|
|
||||||
type Option func(*config)
|
|
||||||
|
|
||||||
// WithGetNow sets the time.Now function in the config.
|
|
||||||
func WithGetNow(getNow func() time.Time) Option {
|
|
||||||
return func(cfg *config) {
|
|
||||||
cfg.getNow = getNow
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithTLSConfig sets the tls.Config in the config.
|
|
||||||
func WithTLSConfig(tlsConfig *tls.Config) Option {
|
|
||||||
return func(cfg *config) {
|
|
||||||
cfg.tls = tlsConfig
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithTTL sets the ttl in the config.
|
|
||||||
func WithTTL(ttl time.Duration) Option {
|
|
||||||
return func(cfg *config) {
|
|
||||||
cfg.ttl = ttl
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getConfig(options ...Option) *config {
|
|
||||||
cfg := new(config)
|
|
||||||
WithGetNow(time.Now)(cfg)
|
|
||||||
WithTTL(defaultTTL)(cfg)
|
|
||||||
for _, o := range options {
|
|
||||||
o(cfg)
|
|
||||||
}
|
|
||||||
return cfg
|
|
||||||
}
|
|
|
@ -1,250 +0,0 @@
|
||||||
// Package redis implements a registry in redis.
|
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
"github.com/pomerium/pomerium/internal/redisutil"
|
|
||||||
"github.com/pomerium/pomerium/internal/registry"
|
|
||||||
"github.com/pomerium/pomerium/internal/registry/redis/lua"
|
|
||||||
"github.com/pomerium/pomerium/internal/signal"
|
|
||||||
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
registryKey = redisutil.KeyPrefix + "registry"
|
|
||||||
registryUpdateKey = redisutil.KeyPrefix + "registry_changed_ch"
|
|
||||||
|
|
||||||
pollInterval = time.Second * 30
|
|
||||||
)
|
|
||||||
|
|
||||||
type impl struct {
|
|
||||||
cfg *config
|
|
||||||
|
|
||||||
client redis.UniversalClient
|
|
||||||
onChange *signal.Signal
|
|
||||||
|
|
||||||
closeOnce sync.Once
|
|
||||||
closed chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new registry implementation backend by redis.
|
|
||||||
func New(rawURL string, options ...Option) (registry.Interface, error) {
|
|
||||||
cfg := getConfig(options...)
|
|
||||||
|
|
||||||
client, err := redisutil.NewClientFromURL(rawURL, cfg.tls)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
i := &impl{
|
|
||||||
cfg: cfg,
|
|
||||||
client: client,
|
|
||||||
onChange: signal.New(),
|
|
||||||
closed: make(chan struct{}),
|
|
||||||
}
|
|
||||||
go i.listenForChanges(context.Background())
|
|
||||||
return i, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
|
|
||||||
_, err := i.runReport(ctx, req.GetServices())
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return ®istrypb.RegisterResponse{
|
|
||||||
CallBackAfter: durationpb.New(i.cfg.ttl / 2),
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
|
|
||||||
all, err := i.runReport(ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
include := map[registrypb.ServiceKind]struct{}{}
|
|
||||||
for _, kind := range req.GetKinds() {
|
|
||||||
include[kind] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
filtered := make([]*registrypb.Service, 0, len(all))
|
|
||||||
for _, svc := range all {
|
|
||||||
if _, ok := include[svc.GetKind()]; !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
filtered = append(filtered, svc)
|
|
||||||
}
|
|
||||||
|
|
||||||
sort.Slice(filtered, func(i, j int) bool {
|
|
||||||
{
|
|
||||||
iv, jv := filtered[i].GetKind(), filtered[j].GetKind()
|
|
||||||
switch {
|
|
||||||
case iv < jv:
|
|
||||||
return true
|
|
||||||
case jv < iv:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
iv, jv := filtered[i].GetEndpoint(), filtered[j].GetEndpoint()
|
|
||||||
switch {
|
|
||||||
case iv < jv:
|
|
||||||
return true
|
|
||||||
case jv < iv:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
})
|
|
||||||
|
|
||||||
return ®istrypb.ServiceList{
|
|
||||||
Services: filtered,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
|
|
||||||
// listen for changes
|
|
||||||
ch := i.onChange.Bind()
|
|
||||||
defer i.onChange.Unbind(ch)
|
|
||||||
|
|
||||||
// force a check periodically
|
|
||||||
poll := time.NewTicker(pollInterval)
|
|
||||||
defer poll.Stop()
|
|
||||||
|
|
||||||
var prev *registrypb.ServiceList
|
|
||||||
for {
|
|
||||||
// retrieve the most recent list of services
|
|
||||||
lst, err := i.List(stream.Context(), req)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// only send a new list if something changed
|
|
||||||
if !proto.Equal(prev, lst) {
|
|
||||||
err = stream.Send(lst)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
prev = lst
|
|
||||||
|
|
||||||
// wait for an update
|
|
||||||
select {
|
|
||||||
case <-i.closed:
|
|
||||||
return nil
|
|
||||||
case <-stream.Context().Done():
|
|
||||||
return stream.Context().Err()
|
|
||||||
case <-ch:
|
|
||||||
case <-poll.C:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) Close() error {
|
|
||||||
var err error
|
|
||||||
i.closeOnce.Do(func() {
|
|
||||||
err = i.client.Close()
|
|
||||||
close(i.closed)
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) listenForChanges(ctx context.Context) {
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
go func() {
|
|
||||||
<-i.closed
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
bo := backoff.NewExponentialBackOff()
|
|
||||||
bo.MaxElapsedTime = 0
|
|
||||||
|
|
||||||
outer:
|
|
||||||
for {
|
|
||||||
pubsub := i.client.Subscribe(ctx, registryUpdateKey)
|
|
||||||
for {
|
|
||||||
msg, err := pubsub.Receive(ctx)
|
|
||||||
if err != nil {
|
|
||||||
_ = pubsub.Close()
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-time.After(bo.NextBackOff()):
|
|
||||||
}
|
|
||||||
continue outer
|
|
||||||
}
|
|
||||||
bo.Reset()
|
|
||||||
|
|
||||||
switch msg.(type) {
|
|
||||||
case *redis.Message:
|
|
||||||
i.onChange.Broadcast(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) runReport(ctx context.Context, updates []*registrypb.Service) ([]*registrypb.Service, error) {
|
|
||||||
args := []interface{}{
|
|
||||||
i.cfg.getNow().UnixNano() / int64(time.Millisecond), // current_time
|
|
||||||
i.cfg.ttl.Milliseconds(), // ttl
|
|
||||||
}
|
|
||||||
for _, svc := range updates {
|
|
||||||
args = append(args, i.getRegistryHashKey(svc))
|
|
||||||
}
|
|
||||||
res, err := i.client.Eval(ctx, lua.Registry, []string{registryKey, registryUpdateKey}, args...).Result()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if values, ok := res.([]interface{}); ok {
|
|
||||||
var all []*registrypb.Service
|
|
||||||
for _, value := range values {
|
|
||||||
svc, err := i.getServiceFromRegistryHashKey(fmt.Sprint(value))
|
|
||||||
if err != nil {
|
|
||||||
log.Warn(ctx).Err(err).Msg("redis: invalid service")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
all = append(all, svc)
|
|
||||||
}
|
|
||||||
return all, nil
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) getServiceFromRegistryHashKey(key string) (*registrypb.Service, error) {
|
|
||||||
idx := strings.Index(key, "|")
|
|
||||||
if idx == -1 {
|
|
||||||
return nil, fmt.Errorf("redis: invalid service entry in hash: %s", key)
|
|
||||||
}
|
|
||||||
|
|
||||||
svcKindStr := key[:idx]
|
|
||||||
svcEndpointStr := key[idx+1:]
|
|
||||||
|
|
||||||
svcKind, ok := registrypb.ServiceKind_value[svcKindStr]
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("redis: unknown service kind: %s", svcKindStr)
|
|
||||||
}
|
|
||||||
|
|
||||||
svc := ®istrypb.Service{
|
|
||||||
Kind: registrypb.ServiceKind(svcKind),
|
|
||||||
Endpoint: svcEndpointStr,
|
|
||||||
}
|
|
||||||
return svc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i *impl) getRegistryHashKey(svc *registrypb.Service) string {
|
|
||||||
return svc.GetKind().String() + "|" + svc.GetEndpoint()
|
|
||||||
}
|
|
|
@ -1,196 +0,0 @@
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
"google.golang.org/grpc"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/testutil"
|
|
||||||
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestReport(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
tm := time.Now()
|
|
||||||
|
|
||||||
i, err := New(rawURL,
|
|
||||||
WithGetNow(func() time.Time {
|
|
||||||
return tm
|
|
||||||
}),
|
|
||||||
WithTTL(time.Second*10))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { _ = i.Close() }()
|
|
||||||
|
|
||||||
_, err = i.Report(ctx, ®istrypb.RegisterRequest{
|
|
||||||
Services: []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// move forward 5 seconds
|
|
||||||
tm = tm.Add(time.Second * 5)
|
|
||||||
_, err = i.Report(ctx, ®istrypb.RegisterRequest{
|
|
||||||
Services: []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
lst, err := i.List(ctx, ®istrypb.ListRequest{
|
|
||||||
Kinds: []registrypb.ServiceKind{
|
|
||||||
registrypb.ServiceKind_AUTHORIZE,
|
|
||||||
registrypb.ServiceKind_PROXY,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
|
||||||
}, lst.GetServices(), "should list selected services")
|
|
||||||
|
|
||||||
// move forward 6 seconds
|
|
||||||
tm = tm.Add(time.Second * 6)
|
|
||||||
lst, err = i.List(ctx, ®istrypb.ListRequest{
|
|
||||||
Kinds: []registrypb.ServiceKind{
|
|
||||||
registrypb.ServiceKind_AUTHORIZE,
|
|
||||||
registrypb.ServiceKind_PROXY,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Equal(t, []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
|
||||||
}, lst.GetServices(), "should expire old services")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWatch(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*15)
|
|
||||||
defer clearTimeout()
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
tm := time.Now()
|
|
||||||
i, err := New(rawURL,
|
|
||||||
WithGetNow(func() time.Time {
|
|
||||||
return tm
|
|
||||||
}),
|
|
||||||
WithTTL(time.Second*10))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
li, err := net.Listen("tcp", "127.0.0.1:0")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer li.Close()
|
|
||||||
|
|
||||||
srv := grpc.NewServer()
|
|
||||||
registrypb.RegisterRegistryServer(srv, i)
|
|
||||||
eg, ctx := errgroup.WithContext(ctx)
|
|
||||||
eg.Go(func() error {
|
|
||||||
<-ctx.Done()
|
|
||||||
srv.Stop()
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
eg.Go(func() error {
|
|
||||||
return srv.Serve(li)
|
|
||||||
})
|
|
||||||
eg.Go(func() error {
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
cc, err := grpc.Dial(li.Addr().String(), grpc.WithInsecure())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
client := registrypb.NewRegistryClient(cc)
|
|
||||||
|
|
||||||
// store the initial services
|
|
||||||
_, err = client.Report(ctx, ®istrypb.RegisterRequest{
|
|
||||||
Services: []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, err := client.Watch(ctx, ®istrypb.ListRequest{
|
|
||||||
Kinds: []registrypb.ServiceKind{
|
|
||||||
registrypb.ServiceKind_AUTHORIZE,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() { _ = stream.CloseSend() }()
|
|
||||||
|
|
||||||
lst, err := stream.Recv()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
assert.Equal(t, []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
|
|
||||||
}, lst.GetServices())
|
|
||||||
|
|
||||||
// update authenticate
|
|
||||||
_, err = client.Report(ctx, ®istrypb.RegisterRequest{
|
|
||||||
Services: []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// add an authorize
|
|
||||||
_, err = client.Report(ctx, ®istrypb.RegisterRequest{
|
|
||||||
Services: []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
lst, err = stream.Recv()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
assert.Equal(t, []*registrypb.Service{
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
|
|
||||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"},
|
|
||||||
}, lst.GetServices())
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
require.NoError(t, eg.Wait())
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
}
|
|
|
@ -1,204 +0,0 @@
|
||||||
package storage
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"crypto/cipher"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
|
||||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
type encryptedRecordStream struct {
|
|
||||||
underlying RecordStream
|
|
||||||
backend *encryptedBackend
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedRecordStream) Close() error {
|
|
||||||
return e.underlying.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedRecordStream) Next(wait bool) bool {
|
|
||||||
return e.underlying.Next(wait)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedRecordStream) Record() *databroker.Record {
|
|
||||||
r := e.underlying.Record()
|
|
||||||
if r != nil {
|
|
||||||
var err error
|
|
||||||
r, err = e.backend.decryptRecord(r)
|
|
||||||
if err != nil {
|
|
||||||
e.err = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedRecordStream) Err() error {
|
|
||||||
if e.err == nil {
|
|
||||||
e.err = e.underlying.Err()
|
|
||||||
}
|
|
||||||
return e.err
|
|
||||||
}
|
|
||||||
|
|
||||||
type encryptedBackend struct {
|
|
||||||
underlying Backend
|
|
||||||
cipher cipher.AEAD
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewEncryptedBackend creates a new encrypted backend.
|
|
||||||
func NewEncryptedBackend(secret []byte, underlying Backend) (Backend, error) {
|
|
||||||
c, err := cryptutil.NewAEADCipher(secret)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &encryptedBackend{
|
|
||||||
underlying: underlying,
|
|
||||||
cipher: c,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedBackend) Close() error {
|
|
||||||
return e.underlying.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedBackend) Get(ctx context.Context, recordType, id string) (*databroker.Record, error) {
|
|
||||||
record, err := e.underlying.Get(ctx, recordType, id)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
record, err = e.decryptRecord(record)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return record, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedBackend) GetOptions(ctx context.Context, recordType string) (*databroker.Options, error) {
|
|
||||||
return e.underlying.GetOptions(ctx, recordType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedBackend) Lease(ctx context.Context, leaseName, leaseID string, ttl time.Duration) (bool, error) {
|
|
||||||
return e.underlying.Lease(ctx, leaseName, leaseID, ttl)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedBackend) ListTypes(ctx context.Context) ([]string, error) {
|
|
||||||
return e.underlying.ListTypes(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedBackend) Put(ctx context.Context, records []*databroker.Record) (uint64, error) {
|
|
||||||
encryptedRecords := make([]*databroker.Record, len(records))
|
|
||||||
for i, record := range records {
|
|
||||||
encrypted, err := e.encrypt(record.GetData())
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
newRecord := proto.Clone(record).(*databroker.Record)
|
|
||||||
newRecord.Data = encrypted
|
|
||||||
encryptedRecords[i] = newRecord
|
|
||||||
}
|
|
||||||
|
|
||||||
serverVersion, err := e.underlying.Put(ctx, encryptedRecords)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, record := range records {
|
|
||||||
record.ModifiedAt = encryptedRecords[i].ModifiedAt
|
|
||||||
record.Version = encryptedRecords[i].Version
|
|
||||||
}
|
|
||||||
|
|
||||||
return serverVersion, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedBackend) SetOptions(ctx context.Context, recordType string, options *databroker.Options) error {
|
|
||||||
return e.underlying.SetOptions(ctx, recordType, options)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedBackend) Sync(ctx context.Context, recordType string, serverVersion, recordVersion uint64) (RecordStream, error) {
|
|
||||||
stream, err := e.underlying.Sync(ctx, recordType, serverVersion, recordVersion)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &encryptedRecordStream{
|
|
||||||
underlying: stream,
|
|
||||||
backend: e,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedBackend) SyncLatest(
|
|
||||||
ctx context.Context,
|
|
||||||
recordType string,
|
|
||||||
filter FilterExpression,
|
|
||||||
) (serverVersion, recordVersion uint64, stream RecordStream, err error) {
|
|
||||||
serverVersion, recordVersion, stream, err = e.underlying.SyncLatest(ctx, recordType, filter)
|
|
||||||
if err != nil {
|
|
||||||
return serverVersion, recordVersion, nil, err
|
|
||||||
}
|
|
||||||
return serverVersion, recordVersion, &encryptedRecordStream{
|
|
||||||
underlying: stream,
|
|
||||||
backend: e,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedBackend) decryptRecord(in *databroker.Record) (out *databroker.Record, err error) {
|
|
||||||
data, err := e.decrypt(in.Data)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
// Create a new record so that we don't re-use any internal state
|
|
||||||
return &databroker.Record{
|
|
||||||
Version: in.Version,
|
|
||||||
Type: in.Type,
|
|
||||||
Id: in.Id,
|
|
||||||
Data: data,
|
|
||||||
ModifiedAt: in.ModifiedAt,
|
|
||||||
DeletedAt: in.DeletedAt,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedBackend) decrypt(in *anypb.Any) (out *anypb.Any, err error) {
|
|
||||||
if in == nil {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var encrypted wrapperspb.BytesValue
|
|
||||||
err = in.UnmarshalTo(&encrypted)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
plaintext, err := cryptutil.Decrypt(e.cipher, encrypted.Value, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
out = new(anypb.Any)
|
|
||||||
err = proto.Unmarshal(plaintext, out)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *encryptedBackend) encrypt(in *anypb.Any) (out *anypb.Any, err error) {
|
|
||||||
plaintext, err := proto.Marshal(in)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
encrypted := cryptutil.Encrypt(e.cipher, plaintext, nil)
|
|
||||||
out = protoutil.NewAny(&wrapperspb.BytesValue{
|
|
||||||
Value: encrypted,
|
|
||||||
})
|
|
||||||
return out, nil
|
|
||||||
}
|
|
|
@ -1,75 +0,0 @@
|
||||||
package storage
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestEncryptedBackend(t *testing.T) {
|
|
||||||
ctx := context.Background()
|
|
||||||
|
|
||||||
m := map[string]*anypb.Any{}
|
|
||||||
backend := &mockBackend{
|
|
||||||
put: func(ctx context.Context, records []*databroker.Record) (uint64, error) {
|
|
||||||
for _, record := range records {
|
|
||||||
record.ModifiedAt = timestamppb.Now()
|
|
||||||
record.Version++
|
|
||||||
m[record.GetId()] = record.GetData()
|
|
||||||
}
|
|
||||||
return 0, nil
|
|
||||||
},
|
|
||||||
get: func(ctx context.Context, recordType, id string) (*databroker.Record, error) {
|
|
||||||
data, ok := m[id]
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.New("not found")
|
|
||||||
}
|
|
||||||
return &databroker.Record{
|
|
||||||
Id: id,
|
|
||||||
Data: data,
|
|
||||||
Version: 1,
|
|
||||||
ModifiedAt: timestamppb.Now(),
|
|
||||||
}, nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
e, err := NewEncryptedBackend(cryptutil.NewKey(), backend)
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
data := protoutil.NewAny(wrapperspb.String("HELLO WORLD"))
|
|
||||||
|
|
||||||
rec := &databroker.Record{
|
|
||||||
Type: "",
|
|
||||||
Id: "TEST-1",
|
|
||||||
Data: data,
|
|
||||||
}
|
|
||||||
_, err = e.Put(ctx, []*databroker.Record{rec})
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if assert.NotNil(t, m["TEST-1"], "key should be set") {
|
|
||||||
assert.NotEqual(t, data.TypeUrl, m["TEST-1"].TypeUrl, "encrypted data should be a bytes type")
|
|
||||||
assert.NotEqual(t, data.Value, m["TEST-1"].Value, "value should be encrypted")
|
|
||||||
assert.NotNil(t, rec.ModifiedAt)
|
|
||||||
assert.NotZero(t, rec.Version)
|
|
||||||
}
|
|
||||||
|
|
||||||
record, err := e.Get(ctx, "", "TEST-1")
|
|
||||||
if !assert.NoError(t, err) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
assert.Equal(t, data.TypeUrl, record.Data.TypeUrl, "type should be preserved")
|
|
||||||
assert.Equal(t, data.Value, record.Data.Value, "value should be preserved")
|
|
||||||
assert.NotEqual(t, data.TypeUrl, record.Type, "record type should be preserved")
|
|
||||||
}
|
|
|
@ -1,31 +0,0 @@
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
|
|
||||||
pomeriumconfig "github.com/pomerium/pomerium/config"
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
|
||||||
)
|
|
||||||
|
|
||||||
type logger struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l logger) Printf(ctx context.Context, format string, v ...interface{}) {
|
|
||||||
log.Info(ctx).Str("service", "redis").Msgf(format, v...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
redis.SetLogger(logger{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func recordOperation(ctx context.Context, startTime time.Time, operation string, err error) {
|
|
||||||
metrics.RecordStorageOperation(ctx, &metrics.StorageOperationTags{
|
|
||||||
Operation: operation,
|
|
||||||
Error: err,
|
|
||||||
Backend: pomeriumconfig.StorageRedisName,
|
|
||||||
}, time.Since(startTime))
|
|
||||||
}
|
|
|
@ -1,37 +0,0 @@
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/tls"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type config struct {
|
|
||||||
tls *tls.Config
|
|
||||||
expiry time.Duration
|
|
||||||
}
|
|
||||||
|
|
||||||
// Option customizes a Backend.
|
|
||||||
type Option func(*config)
|
|
||||||
|
|
||||||
// WithTLSConfig sets the tls.Config which Backend uses.
|
|
||||||
func WithTLSConfig(tlsConfig *tls.Config) Option {
|
|
||||||
return func(cfg *config) {
|
|
||||||
cfg.tls = tlsConfig
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// WithExpiry sets the expiry for changes.
|
|
||||||
func WithExpiry(expiry time.Duration) Option {
|
|
||||||
return func(cfg *config) {
|
|
||||||
cfg.expiry = expiry
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getConfig(options ...Option) *config {
|
|
||||||
cfg := new(config)
|
|
||||||
WithExpiry(time.Hour * 24)(cfg)
|
|
||||||
for _, o := range options {
|
|
||||||
o(cfg)
|
|
||||||
}
|
|
||||||
return cfg
|
|
||||||
}
|
|
|
@ -1,537 +0,0 @@
|
||||||
// Package redis implements the storage.Backend interface for redis.
|
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"sort"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
"github.com/pomerium/pomerium/internal/redisutil"
|
|
||||||
"github.com/pomerium/pomerium/internal/signal"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
maxTransactionRetries = 100
|
|
||||||
watchPollInterval = 30 * time.Second
|
|
||||||
|
|
||||||
// we rely on transactions in redis, so all redis-cluster keys need to be
|
|
||||||
// on the same node. Using a `hash tag` gives us this capability.
|
|
||||||
serverVersionKey = redisutil.KeyPrefix + "server_version"
|
|
||||||
lastVersionKey = redisutil.KeyPrefix + "last_version"
|
|
||||||
lastVersionChKey = redisutil.KeyPrefix + "last_version_ch"
|
|
||||||
recordHashKey = redisutil.KeyPrefix + "records"
|
|
||||||
recordTypesSetKey = redisutil.KeyPrefix + "record_types"
|
|
||||||
changesSetKey = redisutil.KeyPrefix + "changes"
|
|
||||||
optionsKey = redisutil.KeyPrefix + "options"
|
|
||||||
|
|
||||||
recordTypeChangesKeyTpl = redisutil.KeyPrefix + "changes.%s"
|
|
||||||
leaseKeyTpl = redisutil.KeyPrefix + "lease.%s"
|
|
||||||
)
|
|
||||||
|
|
||||||
// custom errors
|
|
||||||
var (
|
|
||||||
ErrExceededMaxRetries = errors.New("redis: transaction reached maximum number of retries")
|
|
||||||
)
|
|
||||||
|
|
||||||
// Backend implements the storage.Backend on top of redis.
|
|
||||||
//
|
|
||||||
// What's stored:
|
|
||||||
//
|
|
||||||
// - last_version: an integer recordVersion number
|
|
||||||
// - last_version_ch: a PubSub channel for recordVersion number updates
|
|
||||||
// - records: a Hash of records. The hash key is {recordType}/{recordID}, the hash value the protobuf record.
|
|
||||||
// - changes: a Sorted Set of all the changes. The score is the recordVersion number, the member the protobuf record.
|
|
||||||
// - options: a Hash of options. The hash key is {recordType}, the hash value the protobuf options.
|
|
||||||
// - changes.{recordType}: a Sorted Set of the changes for a record type. The score is the current time,
|
|
||||||
// the value the record id.
|
|
||||||
//
|
|
||||||
// Records stored in these keys are typically encrypted.
|
|
||||||
type Backend struct {
|
|
||||||
cfg *config
|
|
||||||
|
|
||||||
client redis.UniversalClient
|
|
||||||
onChange *signal.Signal
|
|
||||||
|
|
||||||
closeOnce sync.Once
|
|
||||||
closed chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new redis storage backend.
|
|
||||||
func New(rawURL string, options ...Option) (*Backend, error) {
|
|
||||||
ctx := context.TODO()
|
|
||||||
cfg := getConfig(options...)
|
|
||||||
backend := &Backend{
|
|
||||||
cfg: cfg,
|
|
||||||
closed: make(chan struct{}),
|
|
||||||
onChange: signal.New(),
|
|
||||||
}
|
|
||||||
var err error
|
|
||||||
backend.client, err = redisutil.NewClientFromURL(rawURL, backend.cfg.tls)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
metrics.AddRedisMetrics(backend.client.PoolStats)
|
|
||||||
go backend.listenForVersionChanges(ctx)
|
|
||||||
if cfg.expiry != 0 {
|
|
||||||
go func() {
|
|
||||||
ticker := time.NewTicker(time.Minute)
|
|
||||||
defer ticker.Stop()
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-backend.closed:
|
|
||||||
return
|
|
||||||
case <-ticker.C:
|
|
||||||
}
|
|
||||||
|
|
||||||
backend.removeChangesBefore(ctx, time.Now().Add(-cfg.expiry))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
return backend, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Close closes the underlying redis connection and any watchers.
|
|
||||||
func (backend *Backend) Close() error {
|
|
||||||
var err error
|
|
||||||
backend.closeOnce.Do(func() {
|
|
||||||
err = backend.client.Close()
|
|
||||||
close(backend.closed)
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get gets a record from redis.
|
|
||||||
func (backend *Backend) Get(ctx context.Context, recordType, id string) (_ *databroker.Record, err error) {
|
|
||||||
_, span := trace.StartSpan(ctx, "databroker.redis.Get")
|
|
||||||
defer span.End()
|
|
||||||
defer func(start time.Time) { recordOperation(ctx, start, "get", err) }(time.Now())
|
|
||||||
|
|
||||||
key, field := getHashKey(recordType, id)
|
|
||||||
cmd := backend.client.HGet(ctx, key, field)
|
|
||||||
raw, err := cmd.Result()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
return nil, storage.ErrNotFound
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var record databroker.Record
|
|
||||||
err = proto.Unmarshal([]byte(raw), &record)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &record, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetOptions gets the options for the given record type.
|
|
||||||
func (backend *Backend) GetOptions(ctx context.Context, recordType string) (*databroker.Options, error) {
|
|
||||||
raw, err := backend.client.HGet(ctx, optionsKey, recordType).Result()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
// treat no options as an empty set of options
|
|
||||||
return new(databroker.Options), nil
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var options databroker.Options
|
|
||||||
err = proto.Unmarshal([]byte(raw), &options)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return &options, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lease acquires or renews a lease.
|
|
||||||
func (backend *Backend) Lease(ctx context.Context, leaseName, leaseID string, ttl time.Duration) (bool, error) {
|
|
||||||
acquired := false
|
|
||||||
key := getLeaseKey(leaseName)
|
|
||||||
err := backend.client.Watch(ctx, func(tx *redis.Tx) error {
|
|
||||||
currentID, err := tx.Get(ctx, key).Result()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
// lease hasn't been set yet
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
} else if leaseID != currentID {
|
|
||||||
// lease has already been taken
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = tx.Pipelined(ctx, func(p redis.Pipeliner) error {
|
|
||||||
if ttl <= 0 {
|
|
||||||
p.Del(ctx, key)
|
|
||||||
} else {
|
|
||||||
p.Set(ctx, key, leaseID, ttl)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
acquired = ttl > 0
|
|
||||||
return nil
|
|
||||||
}, key)
|
|
||||||
// if the transaction failed someone else must've acquired the lease
|
|
||||||
if errors.Is(err, redis.TxFailedErr) {
|
|
||||||
acquired = false
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
return acquired, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ListTypes lists all the known record types.
|
|
||||||
func (backend *Backend) ListTypes(ctx context.Context) (types []string, err error) {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "databroker.redis.ListTypes")
|
|
||||||
defer span.End()
|
|
||||||
defer func(start time.Time) { recordOperation(ctx, start, "listTypes", err) }(time.Now())
|
|
||||||
|
|
||||||
cmd := backend.client.SMembers(ctx, recordTypesSetKey)
|
|
||||||
types, err = cmd.Result()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
sort.Strings(types)
|
|
||||||
return types, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put puts a record into redis.
|
|
||||||
func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (serverVersion uint64, err error) {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "databroker.redis.Put")
|
|
||||||
defer span.End()
|
|
||||||
defer func(start time.Time) { recordOperation(ctx, start, "put", err) }(time.Now())
|
|
||||||
|
|
||||||
serverVersion, err = backend.getOrCreateServerVersion(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return serverVersion, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = backend.put(ctx, records)
|
|
||||||
if err != nil {
|
|
||||||
return serverVersion, err
|
|
||||||
}
|
|
||||||
|
|
||||||
recordTypes := map[string]struct{}{}
|
|
||||||
for _, record := range records {
|
|
||||||
recordTypes[record.GetType()] = struct{}{}
|
|
||||||
}
|
|
||||||
for recordType := range recordTypes {
|
|
||||||
err = backend.enforceOptions(ctx, recordType)
|
|
||||||
if err != nil {
|
|
||||||
return serverVersion, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return serverVersion, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetOptions sets the options for the given record type.
|
|
||||||
func (backend *Backend) SetOptions(ctx context.Context, recordType string, options *databroker.Options) error {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "databroker.redis.SetOptions")
|
|
||||||
defer span.End()
|
|
||||||
|
|
||||||
bs, err := proto.Marshal(options)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// update the options in the hash set
|
|
||||||
err = backend.client.HSet(ctx, optionsKey, recordType, bs).Err()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// possibly re-enforce options
|
|
||||||
err = backend.enforceOptions(ctx, recordType)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sync returns a record stream of any records changed after the specified recordVersion.
|
|
||||||
func (backend *Backend) Sync(
|
|
||||||
ctx context.Context,
|
|
||||||
recordType string,
|
|
||||||
serverVersion, recordVersion uint64,
|
|
||||||
) (storage.RecordStream, error) {
|
|
||||||
return newSyncRecordStream(ctx, backend, recordType, serverVersion, recordVersion), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SyncLatest returns a record stream of all the records. Some records may be returned twice if the are updated while the
|
|
||||||
// stream is streaming.
|
|
||||||
func (backend *Backend) SyncLatest(
|
|
||||||
ctx context.Context,
|
|
||||||
recordType string,
|
|
||||||
expr storage.FilterExpression,
|
|
||||||
) (serverVersion, recordVersion uint64, stream storage.RecordStream, err error) {
|
|
||||||
serverVersion, err = backend.getOrCreateServerVersion(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return serverVersion, recordVersion, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
recordVersion, err = backend.client.Get(ctx, lastVersionKey).Uint64()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
// this happens if there are no records
|
|
||||||
} else if err != nil {
|
|
||||||
return serverVersion, recordVersion, nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
stream, err = newSyncLatestRecordStream(ctx, backend, recordType, expr)
|
|
||||||
return serverVersion, recordVersion, stream, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (backend *Backend) put(ctx context.Context, records []*databroker.Record) error {
|
|
||||||
return backend.incrementVersion(ctx,
|
|
||||||
func(tx *redis.Tx, version uint64) error {
|
|
||||||
for i, record := range records {
|
|
||||||
record.ModifiedAt = timestamppb.Now()
|
|
||||||
record.Version = version + uint64(i)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
func(p redis.Pipeliner, version uint64) error {
|
|
||||||
for i, record := range records {
|
|
||||||
bs, err := proto.Marshal(record)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
key, field := getHashKey(record.GetType(), record.GetId())
|
|
||||||
if record.DeletedAt != nil {
|
|
||||||
p.HDel(ctx, key, field)
|
|
||||||
} else {
|
|
||||||
p.HSet(ctx, key, field, bs)
|
|
||||||
p.ZAdd(ctx, getRecordTypeChangesKey(record.GetType()), &redis.Z{
|
|
||||||
Score: float64(record.GetModifiedAt().GetSeconds()) + float64(i)/float64(len(records)),
|
|
||||||
Member: record.GetId(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
p.ZAdd(ctx, changesSetKey, &redis.Z{
|
|
||||||
Score: float64(version) + float64(i),
|
|
||||||
Member: bs,
|
|
||||||
})
|
|
||||||
p.SAdd(ctx, recordTypesSetKey, record.GetType())
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// enforceOptions enforces the options for the given record type.
|
|
||||||
func (backend *Backend) enforceOptions(ctx context.Context, recordType string) error {
|
|
||||||
ctx, span := trace.StartSpan(ctx, "databroker.redis.enforceOptions")
|
|
||||||
defer span.End()
|
|
||||||
|
|
||||||
options, err := backend.GetOptions(ctx, recordType)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// nothing to do if capacity isn't set
|
|
||||||
if options.Capacity == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
key := getRecordTypeChangesKey(recordType)
|
|
||||||
|
|
||||||
// find oldest records that exceed the capacity
|
|
||||||
recordIDs, err := backend.client.ZRevRange(ctx, key, int64(*options.Capacity), -1).Result()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// for each record, delete it
|
|
||||||
for _, recordID := range recordIDs {
|
|
||||||
record, err := backend.Get(ctx, recordType, recordID)
|
|
||||||
if err == nil {
|
|
||||||
// mark the record as deleted and re-submit
|
|
||||||
record.DeletedAt = timestamppb.Now()
|
|
||||||
err = backend.put(ctx, []*databroker.Record{record})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
} else if errors.Is(err, storage.ErrNotFound) {
|
|
||||||
// ignore
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove the member from the collection
|
|
||||||
_, err = backend.client.ZRem(ctx, key, recordID).Result()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// incrementVersion increments the last recordVersion key, runs the code in `query`, then attempts to commit the code in
|
|
||||||
// `commit`. If the last recordVersion changes in the interim, we will retry the transaction.
|
|
||||||
func (backend *Backend) incrementVersion(ctx context.Context,
|
|
||||||
query func(tx *redis.Tx, recordVersion uint64) error,
|
|
||||||
commit func(p redis.Pipeliner, recordVersion uint64) error,
|
|
||||||
) error {
|
|
||||||
// code is modeled on https://pkg.go.dev/github.com/go-redis/redis/v8#example-Client.Watch
|
|
||||||
txf := func(tx *redis.Tx) error {
|
|
||||||
version, err := tx.Get(ctx, lastVersionKey).Uint64()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
version = 0
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
version++
|
|
||||||
|
|
||||||
err = query(tx, version)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// the `commit` code is run in a transaction so that the EXEC cmd will run for the original redis watch
|
|
||||||
_, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error {
|
|
||||||
err := commit(p, version)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
p.Set(ctx, lastVersionKey, version, 0)
|
|
||||||
p.Publish(ctx, lastVersionChKey, version)
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
bo := backoff.NewExponentialBackOff()
|
|
||||||
bo.MaxElapsedTime = 0
|
|
||||||
for i := 0; i < maxTransactionRetries; i++ {
|
|
||||||
err := backend.client.Watch(ctx, txf, lastVersionKey)
|
|
||||||
if errors.Is(err, redis.TxFailedErr) {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-time.After(bo.NextBackOff()):
|
|
||||||
}
|
|
||||||
continue // retry
|
|
||||||
} else if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil // tx was successful
|
|
||||||
}
|
|
||||||
|
|
||||||
return ErrExceededMaxRetries
|
|
||||||
}
|
|
||||||
|
|
||||||
func (backend *Backend) listenForVersionChanges(ctx context.Context) {
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
|
||||||
go func() {
|
|
||||||
<-backend.closed
|
|
||||||
cancel()
|
|
||||||
}()
|
|
||||||
|
|
||||||
bo := backoff.NewExponentialBackOff()
|
|
||||||
bo.MaxElapsedTime = 0
|
|
||||||
|
|
||||||
outer:
|
|
||||||
for {
|
|
||||||
pubsub := backend.client.Subscribe(ctx, lastVersionChKey)
|
|
||||||
for {
|
|
||||||
msg, err := pubsub.Receive(ctx)
|
|
||||||
if err != nil {
|
|
||||||
_ = pubsub.Close()
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return
|
|
||||||
case <-time.After(bo.NextBackOff()):
|
|
||||||
}
|
|
||||||
continue outer
|
|
||||||
}
|
|
||||||
bo.Reset()
|
|
||||||
|
|
||||||
switch msg.(type) {
|
|
||||||
case *redis.Message:
|
|
||||||
backend.onChange.Broadcast(ctx)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (backend *Backend) removeChangesBefore(ctx context.Context, cutoff time.Time) {
|
|
||||||
for {
|
|
||||||
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
|
|
||||||
Min: "-inf",
|
|
||||||
Max: "+inf",
|
|
||||||
Offset: 0,
|
|
||||||
Count: 1,
|
|
||||||
})
|
|
||||||
results, err := cmd.Result()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(ctx).Err(err).Msg("redis: error retrieving changes for expiration")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// nothing left to do
|
|
||||||
if len(results) == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var record databroker.Record
|
|
||||||
err = proto.Unmarshal([]byte(results[0]), &record)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
|
||||||
record.ModifiedAt = timestamppb.New(cutoff.Add(-time.Second)) // set the modified so will delete it
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the record's modified timestamp is after the cutoff, we're all done, so break
|
|
||||||
if record.GetModifiedAt().AsTime().After(cutoff) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove the record
|
|
||||||
err = backend.client.ZRem(ctx, changesSetKey, results[0]).Err()
|
|
||||||
if err != nil {
|
|
||||||
log.Error(ctx).Err(err).Msg("redis: error removing member")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (backend *Backend) getOrCreateServerVersion(ctx context.Context) (serverVersion uint64, err error) {
|
|
||||||
serverVersion, err = backend.client.Get(ctx, serverVersionKey).Uint64()
|
|
||||||
// if the server version hasn't been set yet, set it to a random value and immediately retrieve it
|
|
||||||
// this should properly handle a data race by only setting the key if it doesn't already exist
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
_, _ = backend.client.SetNX(ctx, serverVersionKey, cryptutil.NewRandomUInt64(), 0).Result()
|
|
||||||
serverVersion, err = backend.client.Get(ctx, serverVersionKey).Uint64()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return 0, fmt.Errorf("redis: error retrieving server version: %w", err)
|
|
||||||
}
|
|
||||||
return serverVersion, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func getLeaseKey(leaseName string) string {
|
|
||||||
return fmt.Sprintf(leaseKeyTpl, leaseName)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getRecordTypeChangesKey(recordType string) string {
|
|
||||||
return fmt.Sprintf(recordTypeChangesKeyTpl, recordType)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getHashKey(recordType, id string) (key, field string) {
|
|
||||||
return recordHashKey, fmt.Sprintf("%s/%s", recordType, id)
|
|
||||||
}
|
|
|
@ -1,312 +0,0 @@
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"golang.org/x/sync/errgroup"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/testutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestBackend(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
handler := func(t *testing.T, useTLS bool, rawURL string) error {
|
|
||||||
ctx := context.Background()
|
|
||||||
var opts []Option
|
|
||||||
if useTLS {
|
|
||||||
opts = append(opts, WithTLSConfig(testutil.RedisTLSConfig()))
|
|
||||||
}
|
|
||||||
backend, err := New(rawURL, opts...)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { _ = backend.Close() }()
|
|
||||||
|
|
||||||
serverVersion, err := backend.getOrCreateServerVersion(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
t.Run("get missing record", func(t *testing.T) {
|
|
||||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
|
||||||
require.Error(t, err)
|
|
||||||
assert.Nil(t, record)
|
|
||||||
})
|
|
||||||
t.Run("get record", func(t *testing.T) {
|
|
||||||
data := new(anypb.Any)
|
|
||||||
sv, err := backend.Put(ctx, []*databroker.Record{{
|
|
||||||
Type: "TYPE",
|
|
||||||
Id: "abcd",
|
|
||||||
Data: data,
|
|
||||||
}})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, serverVersion, sv)
|
|
||||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
|
||||||
require.NoError(t, err)
|
|
||||||
if assert.NotNil(t, record) {
|
|
||||||
assert.Equal(t, data, record.Data)
|
|
||||||
assert.Nil(t, record.DeletedAt)
|
|
||||||
assert.Equal(t, "abcd", record.Id)
|
|
||||||
assert.NotNil(t, record.ModifiedAt)
|
|
||||||
assert.Equal(t, "TYPE", record.Type)
|
|
||||||
assert.Equal(t, uint64(1), record.Version)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
t.Run("delete record", func(t *testing.T) {
|
|
||||||
sv, err := backend.Put(ctx, []*databroker.Record{{
|
|
||||||
Type: "TYPE",
|
|
||||||
Id: "abcd",
|
|
||||||
DeletedAt: timestamppb.Now(),
|
|
||||||
}})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, serverVersion, sv)
|
|
||||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Nil(t, record)
|
|
||||||
})
|
|
||||||
t.Run("list types", func(t *testing.T) {
|
|
||||||
types, err := backend.ListTypes(ctx)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []string{"TYPE"}, types)
|
|
||||||
})
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("no-tls", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
return handler(t, false, rawURL)
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("tls", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
require.NoError(t, testutil.WithTestRedis(true, func(rawURL string) error {
|
|
||||||
return handler(t, true, rawURL)
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
if runtime.GOOS == "linux" {
|
|
||||||
t.Run("cluster", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
require.NoError(t, testutil.WithTestRedisCluster(func(rawURL string) error {
|
|
||||||
return handler(t, false, rawURL)
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("sentinel", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
require.NoError(t, testutil.WithTestRedisSentinel(func(rawURL string) error {
|
|
||||||
return handler(t, false, rawURL)
|
|
||||||
}))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestChangeSignal(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
|
||||||
defer clearTimeout()
|
|
||||||
|
|
||||||
done := make(chan struct{})
|
|
||||||
var eg errgroup.Group
|
|
||||||
eg.Go(func() error {
|
|
||||||
backend, err := New(rawURL)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() { _ = backend.Close() }()
|
|
||||||
|
|
||||||
ch := backend.onChange.Bind()
|
|
||||||
defer backend.onChange.Unbind(ch)
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ch:
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// signal the second backend that we've received the change
|
|
||||||
close(done)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
})
|
|
||||||
eg.Go(func() error {
|
|
||||||
backend, err := New(rawURL)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() { _ = backend.Close() }()
|
|
||||||
|
|
||||||
// put a new value to trigger a change
|
|
||||||
for {
|
|
||||||
_, err = backend.Put(ctx, []*databroker.Record{{
|
|
||||||
Type: "TYPE",
|
|
||||||
Id: "ID",
|
|
||||||
}})
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return ctx.Err()
|
|
||||||
case <-done:
|
|
||||||
return nil
|
|
||||||
case <-time.After(time.Millisecond * 100):
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
assert.NoError(t, eg.Wait(), "expected signal to be fired when another backend triggers a change")
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestExpiry(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
backend, err := New(rawURL, WithExpiry(0))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { _ = backend.Close() }()
|
|
||||||
|
|
||||||
serverVersion, err := backend.getOrCreateServerVersion(ctx)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
_, err := backend.Put(ctx, []*databroker.Record{{
|
|
||||||
Type: "TYPE",
|
|
||||||
Id: fmt.Sprint(i),
|
|
||||||
}})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
stream, err := backend.Sync(ctx, "TYPE", serverVersion, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
var records []*databroker.Record
|
|
||||||
for stream.Next(false) {
|
|
||||||
records = append(records, stream.Record())
|
|
||||||
}
|
|
||||||
_ = stream.Close()
|
|
||||||
require.Len(t, records, 1000)
|
|
||||||
|
|
||||||
backend.removeChangesBefore(ctx, time.Now().Add(time.Second))
|
|
||||||
|
|
||||||
stream, err = backend.Sync(ctx, "TYPE", serverVersion, 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
records = nil
|
|
||||||
for stream.Next(false) {
|
|
||||||
records = append(records, stream.Record())
|
|
||||||
}
|
|
||||||
_ = stream.Close()
|
|
||||||
require.Len(t, records, 0)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestCapacity(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
|
||||||
defer clearTimeout()
|
|
||||||
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
backend, err := New(rawURL, WithExpiry(0))
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { _ = backend.Close() }()
|
|
||||||
|
|
||||||
err = backend.SetOptions(ctx, "EXAMPLE", &databroker.Options{
|
|
||||||
Capacity: proto.Uint64(3),
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
_, err = backend.Put(ctx, []*databroker.Record{{
|
|
||||||
Type: "EXAMPLE",
|
|
||||||
Id: fmt.Sprint(i),
|
|
||||||
}})
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, _, stream, err := backend.SyncLatest(ctx, "EXAMPLE", nil)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer stream.Close()
|
|
||||||
|
|
||||||
records, err := storage.RecordStreamToList(stream)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Len(t, records, 3)
|
|
||||||
|
|
||||||
var ids []string
|
|
||||||
for _, r := range records {
|
|
||||||
ids = append(ids, r.GetId())
|
|
||||||
}
|
|
||||||
assert.Equal(t, []string{"7", "8", "9"}, ids, "should contain recent records")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestLease(t *testing.T) {
|
|
||||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
|
||||||
t.Skip("Github action can not run docker on MacOS")
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
ctx := context.Background()
|
|
||||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
|
||||||
backend, err := New(rawURL)
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { _ = backend.Close() }()
|
|
||||||
|
|
||||||
{
|
|
||||||
ok, err := backend.Lease(ctx, "test", "a", time.Second*30)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, ok, "expected a to acquire the lease")
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ok, err := backend.Lease(ctx, "test", "b", time.Second*30)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, ok, "expected b to fail to acquire the lease")
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ok, err := backend.Lease(ctx, "test", "a", 0)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.False(t, ok, "expected a to clear the lease")
|
|
||||||
}
|
|
||||||
{
|
|
||||||
ok, err := backend.Lease(ctx, "test", "b", time.Second*30)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.True(t, ok, "expected b to to acquire the lease")
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}))
|
|
||||||
}
|
|
|
@ -1,172 +0,0 @@
|
||||||
package redis
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/go-redis/redis/v8"
|
|
||||||
"google.golang.org/protobuf/proto"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"github.com/pomerium/pomerium/pkg/storage"
|
|
||||||
)
|
|
||||||
|
|
||||||
func newSyncRecordStream(
|
|
||||||
ctx context.Context,
|
|
||||||
backend *Backend,
|
|
||||||
recordType string,
|
|
||||||
serverVersion uint64,
|
|
||||||
recordVersion uint64,
|
|
||||||
) storage.RecordStream {
|
|
||||||
changed := backend.onChange.Bind()
|
|
||||||
return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{
|
|
||||||
// 1. stream all record changes
|
|
||||||
func(ctx context.Context, block bool) (*databroker.Record, error) {
|
|
||||||
ticker := time.NewTicker(watchPollInterval)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
currentServerVersion, err := backend.getOrCreateServerVersion(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if serverVersion != currentServerVersion {
|
|
||||||
return nil, storage.ErrInvalidServerVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
record, err := nextChangedRecord(ctx, backend, recordType, &recordVersion)
|
|
||||||
if err == nil {
|
|
||||||
return record, nil
|
|
||||||
} else if !errors.Is(err, storage.ErrStreamDone) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !block {
|
|
||||||
return nil, storage.ErrStreamDone
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil, ctx.Err()
|
|
||||||
case <-ticker.C:
|
|
||||||
case <-changed:
|
|
||||||
}
|
|
||||||
}
|
|
||||||
},
|
|
||||||
}, func() {
|
|
||||||
backend.onChange.Unbind(changed)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func newSyncLatestRecordStream(
|
|
||||||
ctx context.Context,
|
|
||||||
backend *Backend,
|
|
||||||
recordType string,
|
|
||||||
expr storage.FilterExpression,
|
|
||||||
) (storage.RecordStream, error) {
|
|
||||||
filter, err := storage.RecordStreamFilterFromFilterExpression(expr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if recordType != "" {
|
|
||||||
filter = filter.And(func(record *databroker.Record) (keep bool) {
|
|
||||||
return record.GetType() == recordType
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
var cursor uint64
|
|
||||||
scannedOnce := false
|
|
||||||
var scannedRecords []*databroker.Record
|
|
||||||
generator := storage.FilteredRecordStreamGenerator(
|
|
||||||
func(ctx context.Context, block bool) (*databroker.Record, error) {
|
|
||||||
for {
|
|
||||||
if len(scannedRecords) > 0 {
|
|
||||||
record := scannedRecords[0]
|
|
||||||
scannedRecords = scannedRecords[1:]
|
|
||||||
return record, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// the cursor is reset to 0 after iteration is complete
|
|
||||||
if scannedOnce && cursor == 0 {
|
|
||||||
return nil, storage.ErrStreamDone
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
scannedRecords, err = nextScannedRecords(ctx, backend, &cursor)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
scannedOnce = true
|
|
||||||
}
|
|
||||||
},
|
|
||||||
filter,
|
|
||||||
)
|
|
||||||
|
|
||||||
return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{
|
|
||||||
generator,
|
|
||||||
}, nil), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func nextScannedRecords(ctx context.Context, backend *Backend, cursor *uint64) ([]*databroker.Record, error) {
|
|
||||||
var values []string
|
|
||||||
var err error
|
|
||||||
values, *cursor, err = backend.client.HScan(ctx, recordHashKey, *cursor, "", 0).Result()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
return nil, storage.ErrStreamDone
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if len(values) == 0 {
|
|
||||||
return nil, storage.ErrStreamDone
|
|
||||||
}
|
|
||||||
|
|
||||||
var records []*databroker.Record
|
|
||||||
for i := 1; i < len(values); i += 2 {
|
|
||||||
var record databroker.Record
|
|
||||||
err := proto.Unmarshal([]byte(values[i]), &record)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
records = append(records, &record)
|
|
||||||
}
|
|
||||||
return records, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func nextChangedRecord(ctx context.Context, backend *Backend, recordType string, recordVersion *uint64) (*databroker.Record, error) {
|
|
||||||
for {
|
|
||||||
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
|
|
||||||
Min: fmt.Sprintf("(%d", *recordVersion),
|
|
||||||
Max: "+inf",
|
|
||||||
Offset: 0,
|
|
||||||
Count: 1,
|
|
||||||
})
|
|
||||||
results, err := cmd.Result()
|
|
||||||
if errors.Is(err, redis.Nil) {
|
|
||||||
return nil, storage.ErrStreamDone
|
|
||||||
} else if err != nil {
|
|
||||||
return nil, err
|
|
||||||
} else if len(results) == 0 {
|
|
||||||
return nil, storage.ErrStreamDone
|
|
||||||
}
|
|
||||||
|
|
||||||
result := results[0]
|
|
||||||
var record databroker.Record
|
|
||||||
err = proto.Unmarshal([]byte(result), &record)
|
|
||||||
if err != nil {
|
|
||||||
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
|
||||||
*recordVersion++
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
*recordVersion = record.GetVersion()
|
|
||||||
if recordType != "" && record.GetType() != recordType {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return &record, nil
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Add table
Add a link
Reference in a new issue