databroker: remove redis storage backend (#4699)

Remove the Redis databroker backend. According to
https://www.pomerium.com/docs/internals/data-storage#redis it has been
discouraged since Pomerium v0.18.

Update the config options validation to return an error if "redis" is 
set as the databroker storage backend type.
This commit is contained in:
Kenneth Jenkins 2023-11-02 11:53:25 -07:00 committed by GitHub
parent 47890e9ee1
commit 4f648e9ac1
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 12 additions and 1964 deletions

View file

@ -245,7 +245,7 @@ type Options struct {
DataBrokerURLStrings []string `mapstructure:"databroker_service_urls" yaml:"databroker_service_urls,omitempty"`
DataBrokerInternalURLString string `mapstructure:"databroker_internal_service_url" yaml:"databroker_internal_service_url,omitempty"`
// DataBrokerStorageType is the storage backend type that databroker will use.
// Supported type: memory, redis
// Supported type: memory, postgres
DataBrokerStorageType string `mapstructure:"databroker_storage_type" yaml:"databroker_storage_type,omitempty"`
// DataBrokerStorageConnectionString is the data source name for storage backend.
DataBrokerStorageConnectionString string `mapstructure:"databroker_storage_connection_string" yaml:"databroker_storage_connection_string,omitempty"`
@ -584,7 +584,9 @@ func (o *Options) Validate() error {
switch o.DataBrokerStorageType {
case StorageInMemoryName:
case StorageRedisName, StoragePostgresName:
case StorageRedisName:
return errors.New("config: redis databroker storage backend is no longer supported")
case StoragePostgresName:
if o.DataBrokerStorageConnectionString == "" {
return errors.New("config: missing databroker storage backend dsn")
}

View file

@ -58,8 +58,10 @@ func Test_Validate(t *testing.T) {
badPolicyFile.PolicyFile = "file"
invalidStorageType := testOptions()
invalidStorageType.DataBrokerStorageType = "foo"
redisStorageType := testOptions()
redisStorageType.DataBrokerStorageType = "redis"
missingStorageDSN := testOptions()
missingStorageDSN.DataBrokerStorageType = "redis"
missingStorageDSN.DataBrokerStorageType = "postgres"
badSignoutRedirectURL := testOptions()
badSignoutRedirectURL.SignOutRedirectURLString = "--"
badCookieSettings := testOptions()
@ -77,6 +79,7 @@ func Test_Validate(t *testing.T) {
{"missing shared secret but all service", badSecretAllServices, false},
{"policy file specified", badPolicyFile, true},
{"invalid databroker storage type", invalidStorageType, true},
{"redis databroker storage type", redisStorageType, true},
{"missing databroker storage dsn", missingStorageDSN, true},
{"invalid signout redirect url", badSignoutRedirectURL, true},
{"CookieSameSite none with CookieSecure fale", badCookieSettings, true},

View file

@ -9,7 +9,6 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/registry"
"github.com/pomerium/pomerium/internal/registry/inmemory"
"github.com/pomerium/pomerium/internal/registry/redis"
"github.com/pomerium/pomerium/internal/telemetry/trace"
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
"github.com/pomerium/pomerium/pkg/storage"
@ -110,16 +109,6 @@ func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interfac
case config.StorageInMemoryName:
log.Info(ctx).Msg("using in-memory registry")
return inmemory.New(ctx, srv.cfg.registryTTL), nil
case config.StorageRedisName:
log.Info(ctx).Msg("using redis registry")
r, err := redis.New(
srv.cfg.storageConnectionString,
redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)),
)
if err != nil {
return nil, fmt.Errorf("failed to create new redis registry: %w", err)
}
return r, nil
}
return nil, fmt.Errorf("unsupported registry type: %s", srv.cfg.storageType)

View file

@ -3,7 +3,6 @@ package databroker
import (
"context"
"crypto/tls"
"errors"
"fmt"
"strings"
@ -19,12 +18,10 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/registry"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
"github.com/pomerium/pomerium/pkg/storage/inmemory"
"github.com/pomerium/pomerium/pkg/storage/postgres"
"github.com/pomerium/pomerium/pkg/storage/redis"
)
// Server implements the databroker service using an in memory database.
@ -426,39 +423,8 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
case config.StoragePostgresName:
log.Info(ctx).Msg("using postgres store")
backend = postgres.New(srv.cfg.storageConnectionString)
case config.StorageRedisName:
log.Info(ctx).Msg("using redis store")
backend, err = redis.New(
srv.cfg.storageConnectionString,
redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)),
)
if err != nil {
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
}
if srv.cfg.secret != nil {
backend, err = storage.NewEncryptedBackend(srv.cfg.secret, backend)
if err != nil {
return nil, err
}
}
default:
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
}
return backend, nil
}
func (srv *Server) getTLSConfigLocked(ctx context.Context) *tls.Config {
caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile)
if err != nil {
log.Warn(ctx).Err(err).Msg("failed to read databroker CA file")
}
tlsConfig := &tls.Config{
RootCAs: caCertPool,
//nolint: gosec
InsecureSkipVerify: srv.cfg.storageCertSkipVerify,
}
if srv.cfg.storageCertificate != nil {
tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate}
}
return tlsConfig
}

View file

@ -22,7 +22,6 @@ import (
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/protoutil"
@ -287,12 +286,11 @@ func TestServerInvalidStorage(t *testing.T) {
_ = assert.Error(t, err) && assert.Contains(t, err.Error(), "unsupported storage type")
}
func TestServerRedis(t *testing.T) {
testutil.WithTestRedis(false, func(rawURL string) error {
func TestServerPostgres(t *testing.T) {
testutil.WithTestPostgres(func(dsn string) error {
srv := newServer(&serverConfig{
storageType: "redis",
storageConnectionString: rawURL,
secret: cryptutil.NewKey(),
storageType: "postgres",
storageConnectionString: dsn,
})
s := new(session.Session)

View file

@ -1,20 +0,0 @@
// Package lua contains lua source code.
package lua
import (
"embed"
)
//go:embed registry.lua
var fs embed.FS
// Registry is the registry lua script
var Registry string
func init() {
bs, err := fs.ReadFile("registry.lua")
if err != nil {
panic(err)
}
Registry = string(bs)
}

View file

@ -1,28 +0,0 @@
-- ARGV = [current time in seconds, ttl in seconds, services ...]
local current_time = ARGV[1]
local ttl = ARGV[2]
local changed = false
-- update the service list
for i = 3, #ARGV, 1 do
redis.call('HSET', KEYS[1], ARGV[i], current_time + ttl)
changed = true
end
-- retrieve all the services, removing any that have expired
local svcs = {}
local kvs = redis.call('HGETALL', KEYS[1])
for i = 1, #kvs, 2 do
if kvs[i + 1] < current_time then
redis.call('HDEL', KEYS[1], kvs[i])
changed = true
else
table.insert(svcs, kvs[i])
end
end
if changed then
redis.call('PUBLISH', KEYS[2], current_time)
end
return svcs

View file

@ -1,48 +0,0 @@
package redis
import (
"crypto/tls"
"time"
)
const defaultTTL = time.Second * 30
type config struct {
tls *tls.Config
ttl time.Duration
getNow func() time.Time
}
// An Option modifies the config..
type Option func(*config)
// WithGetNow sets the time.Now function in the config.
func WithGetNow(getNow func() time.Time) Option {
return func(cfg *config) {
cfg.getNow = getNow
}
}
// WithTLSConfig sets the tls.Config in the config.
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cfg *config) {
cfg.tls = tlsConfig
}
}
// WithTTL sets the ttl in the config.
func WithTTL(ttl time.Duration) Option {
return func(cfg *config) {
cfg.ttl = ttl
}
}
func getConfig(options ...Option) *config {
cfg := new(config)
WithGetNow(time.Now)(cfg)
WithTTL(defaultTTL)(cfg)
for _, o := range options {
o(cfg)
}
return cfg
}

View file

@ -1,250 +0,0 @@
// Package redis implements a registry in redis.
package redis
import (
"context"
"fmt"
"sort"
"strings"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/go-redis/redis/v8"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/redisutil"
"github.com/pomerium/pomerium/internal/registry"
"github.com/pomerium/pomerium/internal/registry/redis/lua"
"github.com/pomerium/pomerium/internal/signal"
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
)
const (
registryKey = redisutil.KeyPrefix + "registry"
registryUpdateKey = redisutil.KeyPrefix + "registry_changed_ch"
pollInterval = time.Second * 30
)
type impl struct {
cfg *config
client redis.UniversalClient
onChange *signal.Signal
closeOnce sync.Once
closed chan struct{}
}
// New creates a new registry implementation backend by redis.
func New(rawURL string, options ...Option) (registry.Interface, error) {
cfg := getConfig(options...)
client, err := redisutil.NewClientFromURL(rawURL, cfg.tls)
if err != nil {
return nil, err
}
i := &impl{
cfg: cfg,
client: client,
onChange: signal.New(),
closed: make(chan struct{}),
}
go i.listenForChanges(context.Background())
return i, nil
}
func (i *impl) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
_, err := i.runReport(ctx, req.GetServices())
if err != nil {
return nil, err
}
return &registrypb.RegisterResponse{
CallBackAfter: durationpb.New(i.cfg.ttl / 2),
}, nil
}
func (i *impl) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
all, err := i.runReport(ctx, nil)
if err != nil {
return nil, err
}
include := map[registrypb.ServiceKind]struct{}{}
for _, kind := range req.GetKinds() {
include[kind] = struct{}{}
}
filtered := make([]*registrypb.Service, 0, len(all))
for _, svc := range all {
if _, ok := include[svc.GetKind()]; !ok {
continue
}
filtered = append(filtered, svc)
}
sort.Slice(filtered, func(i, j int) bool {
{
iv, jv := filtered[i].GetKind(), filtered[j].GetKind()
switch {
case iv < jv:
return true
case jv < iv:
return false
}
}
{
iv, jv := filtered[i].GetEndpoint(), filtered[j].GetEndpoint()
switch {
case iv < jv:
return true
case jv < iv:
return false
}
}
return false
})
return &registrypb.ServiceList{
Services: filtered,
}, nil
}
func (i *impl) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
// listen for changes
ch := i.onChange.Bind()
defer i.onChange.Unbind(ch)
// force a check periodically
poll := time.NewTicker(pollInterval)
defer poll.Stop()
var prev *registrypb.ServiceList
for {
// retrieve the most recent list of services
lst, err := i.List(stream.Context(), req)
if err != nil {
return err
}
// only send a new list if something changed
if !proto.Equal(prev, lst) {
err = stream.Send(lst)
if err != nil {
return err
}
}
prev = lst
// wait for an update
select {
case <-i.closed:
return nil
case <-stream.Context().Done():
return stream.Context().Err()
case <-ch:
case <-poll.C:
}
}
}
func (i *impl) Close() error {
var err error
i.closeOnce.Do(func() {
err = i.client.Close()
close(i.closed)
})
return err
}
func (i *impl) listenForChanges(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
go func() {
<-i.closed
cancel()
}()
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
outer:
for {
pubsub := i.client.Subscribe(ctx, registryUpdateKey)
for {
msg, err := pubsub.Receive(ctx)
if err != nil {
_ = pubsub.Close()
select {
case <-ctx.Done():
return
case <-time.After(bo.NextBackOff()):
}
continue outer
}
bo.Reset()
switch msg.(type) {
case *redis.Message:
i.onChange.Broadcast(ctx)
}
}
}
}
func (i *impl) runReport(ctx context.Context, updates []*registrypb.Service) ([]*registrypb.Service, error) {
args := []interface{}{
i.cfg.getNow().UnixNano() / int64(time.Millisecond), // current_time
i.cfg.ttl.Milliseconds(), // ttl
}
for _, svc := range updates {
args = append(args, i.getRegistryHashKey(svc))
}
res, err := i.client.Eval(ctx, lua.Registry, []string{registryKey, registryUpdateKey}, args...).Result()
if err != nil {
return nil, err
}
if values, ok := res.([]interface{}); ok {
var all []*registrypb.Service
for _, value := range values {
svc, err := i.getServiceFromRegistryHashKey(fmt.Sprint(value))
if err != nil {
log.Warn(ctx).Err(err).Msg("redis: invalid service")
continue
}
all = append(all, svc)
}
return all, nil
}
return nil, nil
}
func (i *impl) getServiceFromRegistryHashKey(key string) (*registrypb.Service, error) {
idx := strings.Index(key, "|")
if idx == -1 {
return nil, fmt.Errorf("redis: invalid service entry in hash: %s", key)
}
svcKindStr := key[:idx]
svcEndpointStr := key[idx+1:]
svcKind, ok := registrypb.ServiceKind_value[svcKindStr]
if !ok {
return nil, fmt.Errorf("redis: unknown service kind: %s", svcKindStr)
}
svc := &registrypb.Service{
Kind: registrypb.ServiceKind(svcKind),
Endpoint: svcEndpointStr,
}
return svc, nil
}
func (i *impl) getRegistryHashKey(svc *registrypb.Service) string {
return svc.GetKind().String() + "|" + svc.GetEndpoint()
}

View file

@ -1,196 +0,0 @@
package redis
import (
"context"
"net"
"os"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/internal/testutil"
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
)
func TestReport(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
ctx := context.Background()
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
tm := time.Now()
i, err := New(rawURL,
WithGetNow(func() time.Time {
return tm
}),
WithTTL(time.Second*10))
require.NoError(t, err)
defer func() { _ = i.Close() }()
_, err = i.Report(ctx, &registrypb.RegisterRequest{
Services: []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"},
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"},
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
},
})
require.NoError(t, err)
// move forward 5 seconds
tm = tm.Add(time.Second * 5)
_, err = i.Report(ctx, &registrypb.RegisterRequest{
Services: []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"},
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
},
})
require.NoError(t, err)
lst, err := i.List(ctx, &registrypb.ListRequest{
Kinds: []registrypb.ServiceKind{
registrypb.ServiceKind_AUTHORIZE,
registrypb.ServiceKind_PROXY,
},
})
require.NoError(t, err)
assert.Equal(t, []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"},
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
}, lst.GetServices(), "should list selected services")
// move forward 6 seconds
tm = tm.Add(time.Second * 6)
lst, err = i.List(ctx, &registrypb.ListRequest{
Kinds: []registrypb.ServiceKind{
registrypb.ServiceKind_AUTHORIZE,
registrypb.ServiceKind_PROXY,
},
})
require.NoError(t, err)
assert.Equal(t, []*registrypb.Service{
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
}, lst.GetServices(), "should expire old services")
return nil
}))
}
func TestWatch(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*15)
defer clearTimeout()
ctx, cancel := context.WithCancel(ctx)
defer cancel()
tm := time.Now()
i, err := New(rawURL,
WithGetNow(func() time.Time {
return tm
}),
WithTTL(time.Second*10))
require.NoError(t, err)
li, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
defer li.Close()
srv := grpc.NewServer()
registrypb.RegisterRegistryServer(srv, i)
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
<-ctx.Done()
srv.Stop()
return nil
})
eg.Go(func() error {
return srv.Serve(li)
})
eg.Go(func() error {
defer cancel()
cc, err := grpc.Dial(li.Addr().String(), grpc.WithInsecure())
if err != nil {
return err
}
client := registrypb.NewRegistryClient(cc)
// store the initial services
_, err = client.Report(ctx, &registrypb.RegisterRequest{
Services: []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"},
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
},
})
if err != nil {
return err
}
stream, err := client.Watch(ctx, &registrypb.ListRequest{
Kinds: []registrypb.ServiceKind{
registrypb.ServiceKind_AUTHORIZE,
},
})
if err != nil {
return err
}
defer func() { _ = stream.CloseSend() }()
lst, err := stream.Recv()
if err != nil {
return err
}
assert.Equal(t, []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
}, lst.GetServices())
// update authenticate
_, err = client.Report(ctx, &registrypb.RegisterRequest{
Services: []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"},
},
})
if err != nil {
return err
}
// add an authorize
_, err = client.Report(ctx, &registrypb.RegisterRequest{
Services: []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"},
},
})
if err != nil {
return err
}
lst, err = stream.Recv()
if err != nil {
return err
}
assert.Equal(t, []*registrypb.Service{
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"},
}, lst.GetServices())
return nil
})
require.NoError(t, eg.Wait())
return nil
}))
}

View file

@ -1,204 +0,0 @@
package storage
import (
"context"
"crypto/cipher"
"time"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/wrapperspb"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
type encryptedRecordStream struct {
underlying RecordStream
backend *encryptedBackend
err error
}
func (e *encryptedRecordStream) Close() error {
return e.underlying.Close()
}
func (e *encryptedRecordStream) Next(wait bool) bool {
return e.underlying.Next(wait)
}
func (e *encryptedRecordStream) Record() *databroker.Record {
r := e.underlying.Record()
if r != nil {
var err error
r, err = e.backend.decryptRecord(r)
if err != nil {
e.err = err
}
}
return r
}
func (e *encryptedRecordStream) Err() error {
if e.err == nil {
e.err = e.underlying.Err()
}
return e.err
}
type encryptedBackend struct {
underlying Backend
cipher cipher.AEAD
}
// NewEncryptedBackend creates a new encrypted backend.
func NewEncryptedBackend(secret []byte, underlying Backend) (Backend, error) {
c, err := cryptutil.NewAEADCipher(secret)
if err != nil {
return nil, err
}
return &encryptedBackend{
underlying: underlying,
cipher: c,
}, nil
}
func (e *encryptedBackend) Close() error {
return e.underlying.Close()
}
func (e *encryptedBackend) Get(ctx context.Context, recordType, id string) (*databroker.Record, error) {
record, err := e.underlying.Get(ctx, recordType, id)
if err != nil {
return nil, err
}
record, err = e.decryptRecord(record)
if err != nil {
return nil, err
}
return record, nil
}
func (e *encryptedBackend) GetOptions(ctx context.Context, recordType string) (*databroker.Options, error) {
return e.underlying.GetOptions(ctx, recordType)
}
func (e *encryptedBackend) Lease(ctx context.Context, leaseName, leaseID string, ttl time.Duration) (bool, error) {
return e.underlying.Lease(ctx, leaseName, leaseID, ttl)
}
func (e *encryptedBackend) ListTypes(ctx context.Context) ([]string, error) {
return e.underlying.ListTypes(ctx)
}
func (e *encryptedBackend) Put(ctx context.Context, records []*databroker.Record) (uint64, error) {
encryptedRecords := make([]*databroker.Record, len(records))
for i, record := range records {
encrypted, err := e.encrypt(record.GetData())
if err != nil {
return 0, err
}
newRecord := proto.Clone(record).(*databroker.Record)
newRecord.Data = encrypted
encryptedRecords[i] = newRecord
}
serverVersion, err := e.underlying.Put(ctx, encryptedRecords)
if err != nil {
return 0, err
}
for i, record := range records {
record.ModifiedAt = encryptedRecords[i].ModifiedAt
record.Version = encryptedRecords[i].Version
}
return serverVersion, nil
}
func (e *encryptedBackend) SetOptions(ctx context.Context, recordType string, options *databroker.Options) error {
return e.underlying.SetOptions(ctx, recordType, options)
}
func (e *encryptedBackend) Sync(ctx context.Context, recordType string, serverVersion, recordVersion uint64) (RecordStream, error) {
stream, err := e.underlying.Sync(ctx, recordType, serverVersion, recordVersion)
if err != nil {
return nil, err
}
return &encryptedRecordStream{
underlying: stream,
backend: e,
}, nil
}
func (e *encryptedBackend) SyncLatest(
ctx context.Context,
recordType string,
filter FilterExpression,
) (serverVersion, recordVersion uint64, stream RecordStream, err error) {
serverVersion, recordVersion, stream, err = e.underlying.SyncLatest(ctx, recordType, filter)
if err != nil {
return serverVersion, recordVersion, nil, err
}
return serverVersion, recordVersion, &encryptedRecordStream{
underlying: stream,
backend: e,
}, nil
}
func (e *encryptedBackend) decryptRecord(in *databroker.Record) (out *databroker.Record, err error) {
data, err := e.decrypt(in.Data)
if err != nil {
return nil, err
}
// Create a new record so that we don't re-use any internal state
return &databroker.Record{
Version: in.Version,
Type: in.Type,
Id: in.Id,
Data: data,
ModifiedAt: in.ModifiedAt,
DeletedAt: in.DeletedAt,
}, nil
}
func (e *encryptedBackend) decrypt(in *anypb.Any) (out *anypb.Any, err error) {
if in == nil {
return nil, nil
}
var encrypted wrapperspb.BytesValue
err = in.UnmarshalTo(&encrypted)
if err != nil {
return nil, err
}
plaintext, err := cryptutil.Decrypt(e.cipher, encrypted.Value, nil)
if err != nil {
return nil, err
}
out = new(anypb.Any)
err = proto.Unmarshal(plaintext, out)
if err != nil {
return nil, err
}
return out, nil
}
func (e *encryptedBackend) encrypt(in *anypb.Any) (out *anypb.Any, err error) {
plaintext, err := proto.Marshal(in)
if err != nil {
return nil, err
}
encrypted := cryptutil.Encrypt(e.cipher, plaintext, nil)
out = protoutil.NewAny(&wrapperspb.BytesValue{
Value: encrypted,
})
return out, nil
}

View file

@ -1,75 +0,0 @@
package storage
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb"
"google.golang.org/protobuf/types/known/wrapperspb"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/protoutil"
)
func TestEncryptedBackend(t *testing.T) {
ctx := context.Background()
m := map[string]*anypb.Any{}
backend := &mockBackend{
put: func(ctx context.Context, records []*databroker.Record) (uint64, error) {
for _, record := range records {
record.ModifiedAt = timestamppb.Now()
record.Version++
m[record.GetId()] = record.GetData()
}
return 0, nil
},
get: func(ctx context.Context, recordType, id string) (*databroker.Record, error) {
data, ok := m[id]
if !ok {
return nil, errors.New("not found")
}
return &databroker.Record{
Id: id,
Data: data,
Version: 1,
ModifiedAt: timestamppb.Now(),
}, nil
},
}
e, err := NewEncryptedBackend(cryptutil.NewKey(), backend)
if !assert.NoError(t, err) {
return
}
data := protoutil.NewAny(wrapperspb.String("HELLO WORLD"))
rec := &databroker.Record{
Type: "",
Id: "TEST-1",
Data: data,
}
_, err = e.Put(ctx, []*databroker.Record{rec})
if !assert.NoError(t, err) {
return
}
if assert.NotNil(t, m["TEST-1"], "key should be set") {
assert.NotEqual(t, data.TypeUrl, m["TEST-1"].TypeUrl, "encrypted data should be a bytes type")
assert.NotEqual(t, data.Value, m["TEST-1"].Value, "value should be encrypted")
assert.NotNil(t, rec.ModifiedAt)
assert.NotZero(t, rec.Version)
}
record, err := e.Get(ctx, "", "TEST-1")
if !assert.NoError(t, err) {
return
}
assert.Equal(t, data.TypeUrl, record.Data.TypeUrl, "type should be preserved")
assert.Equal(t, data.Value, record.Data.Value, "value should be preserved")
assert.NotEqual(t, data.TypeUrl, record.Type, "record type should be preserved")
}

View file

@ -1,31 +0,0 @@
package redis
import (
"context"
"time"
"github.com/go-redis/redis/v8"
pomeriumconfig "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
)
type logger struct {
}
func (l logger) Printf(ctx context.Context, format string, v ...interface{}) {
log.Info(ctx).Str("service", "redis").Msgf(format, v...)
}
func init() {
redis.SetLogger(logger{})
}
func recordOperation(ctx context.Context, startTime time.Time, operation string, err error) {
metrics.RecordStorageOperation(ctx, &metrics.StorageOperationTags{
Operation: operation,
Error: err,
Backend: pomeriumconfig.StorageRedisName,
}, time.Since(startTime))
}

View file

@ -1,37 +0,0 @@
package redis
import (
"crypto/tls"
"time"
)
type config struct {
tls *tls.Config
expiry time.Duration
}
// Option customizes a Backend.
type Option func(*config)
// WithTLSConfig sets the tls.Config which Backend uses.
func WithTLSConfig(tlsConfig *tls.Config) Option {
return func(cfg *config) {
cfg.tls = tlsConfig
}
}
// WithExpiry sets the expiry for changes.
func WithExpiry(expiry time.Duration) Option {
return func(cfg *config) {
cfg.expiry = expiry
}
}
func getConfig(options ...Option) *config {
cfg := new(config)
WithExpiry(time.Hour * 24)(cfg)
for _, o := range options {
o(cfg)
}
return cfg
}

View file

@ -1,537 +0,0 @@
// Package redis implements the storage.Backend interface for redis.
package redis
import (
"context"
"errors"
"fmt"
"sort"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/go-redis/redis/v8"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/redisutil"
"github.com/pomerium/pomerium/internal/signal"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
const (
maxTransactionRetries = 100
watchPollInterval = 30 * time.Second
// we rely on transactions in redis, so all redis-cluster keys need to be
// on the same node. Using a `hash tag` gives us this capability.
serverVersionKey = redisutil.KeyPrefix + "server_version"
lastVersionKey = redisutil.KeyPrefix + "last_version"
lastVersionChKey = redisutil.KeyPrefix + "last_version_ch"
recordHashKey = redisutil.KeyPrefix + "records"
recordTypesSetKey = redisutil.KeyPrefix + "record_types"
changesSetKey = redisutil.KeyPrefix + "changes"
optionsKey = redisutil.KeyPrefix + "options"
recordTypeChangesKeyTpl = redisutil.KeyPrefix + "changes.%s"
leaseKeyTpl = redisutil.KeyPrefix + "lease.%s"
)
// custom errors
var (
ErrExceededMaxRetries = errors.New("redis: transaction reached maximum number of retries")
)
// Backend implements the storage.Backend on top of redis.
//
// What's stored:
//
// - last_version: an integer recordVersion number
// - last_version_ch: a PubSub channel for recordVersion number updates
// - records: a Hash of records. The hash key is {recordType}/{recordID}, the hash value the protobuf record.
// - changes: a Sorted Set of all the changes. The score is the recordVersion number, the member the protobuf record.
// - options: a Hash of options. The hash key is {recordType}, the hash value the protobuf options.
// - changes.{recordType}: a Sorted Set of the changes for a record type. The score is the current time,
// the value the record id.
//
// Records stored in these keys are typically encrypted.
type Backend struct {
cfg *config
client redis.UniversalClient
onChange *signal.Signal
closeOnce sync.Once
closed chan struct{}
}
// New creates a new redis storage backend.
func New(rawURL string, options ...Option) (*Backend, error) {
ctx := context.TODO()
cfg := getConfig(options...)
backend := &Backend{
cfg: cfg,
closed: make(chan struct{}),
onChange: signal.New(),
}
var err error
backend.client, err = redisutil.NewClientFromURL(rawURL, backend.cfg.tls)
if err != nil {
return nil, err
}
metrics.AddRedisMetrics(backend.client.PoolStats)
go backend.listenForVersionChanges(ctx)
if cfg.expiry != 0 {
go func() {
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for {
select {
case <-backend.closed:
return
case <-ticker.C:
}
backend.removeChangesBefore(ctx, time.Now().Add(-cfg.expiry))
}
}()
}
return backend, nil
}
// Close closes the underlying redis connection and any watchers.
func (backend *Backend) Close() error {
var err error
backend.closeOnce.Do(func() {
err = backend.client.Close()
close(backend.closed)
})
return err
}
// Get gets a record from redis.
func (backend *Backend) Get(ctx context.Context, recordType, id string) (_ *databroker.Record, err error) {
_, span := trace.StartSpan(ctx, "databroker.redis.Get")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "get", err) }(time.Now())
key, field := getHashKey(recordType, id)
cmd := backend.client.HGet(ctx, key, field)
raw, err := cmd.Result()
if errors.Is(err, redis.Nil) {
return nil, storage.ErrNotFound
} else if err != nil {
return nil, err
}
var record databroker.Record
err = proto.Unmarshal([]byte(raw), &record)
if err != nil {
return nil, err
}
return &record, nil
}
// GetOptions gets the options for the given record type.
func (backend *Backend) GetOptions(ctx context.Context, recordType string) (*databroker.Options, error) {
raw, err := backend.client.HGet(ctx, optionsKey, recordType).Result()
if errors.Is(err, redis.Nil) {
// treat no options as an empty set of options
return new(databroker.Options), nil
} else if err != nil {
return nil, err
}
var options databroker.Options
err = proto.Unmarshal([]byte(raw), &options)
if err != nil {
return nil, err
}
return &options, nil
}
// Lease acquires or renews a lease.
func (backend *Backend) Lease(ctx context.Context, leaseName, leaseID string, ttl time.Duration) (bool, error) {
acquired := false
key := getLeaseKey(leaseName)
err := backend.client.Watch(ctx, func(tx *redis.Tx) error {
currentID, err := tx.Get(ctx, key).Result()
if errors.Is(err, redis.Nil) {
// lease hasn't been set yet
} else if err != nil {
return err
} else if leaseID != currentID {
// lease has already been taken
return nil
}
_, err = tx.Pipelined(ctx, func(p redis.Pipeliner) error {
if ttl <= 0 {
p.Del(ctx, key)
} else {
p.Set(ctx, key, leaseID, ttl)
}
return nil
})
if err != nil {
return err
}
acquired = ttl > 0
return nil
}, key)
// if the transaction failed someone else must've acquired the lease
if errors.Is(err, redis.TxFailedErr) {
acquired = false
err = nil
}
return acquired, err
}
// ListTypes lists all the known record types.
func (backend *Backend) ListTypes(ctx context.Context) (types []string, err error) {
ctx, span := trace.StartSpan(ctx, "databroker.redis.ListTypes")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "listTypes", err) }(time.Now())
cmd := backend.client.SMembers(ctx, recordTypesSetKey)
types, err = cmd.Result()
if err != nil {
return nil, err
}
sort.Strings(types)
return types, nil
}
// Put puts a record into redis.
func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (serverVersion uint64, err error) {
ctx, span := trace.StartSpan(ctx, "databroker.redis.Put")
defer span.End()
defer func(start time.Time) { recordOperation(ctx, start, "put", err) }(time.Now())
serverVersion, err = backend.getOrCreateServerVersion(ctx)
if err != nil {
return serverVersion, err
}
err = backend.put(ctx, records)
if err != nil {
return serverVersion, err
}
recordTypes := map[string]struct{}{}
for _, record := range records {
recordTypes[record.GetType()] = struct{}{}
}
for recordType := range recordTypes {
err = backend.enforceOptions(ctx, recordType)
if err != nil {
return serverVersion, err
}
}
return serverVersion, nil
}
// SetOptions sets the options for the given record type.
func (backend *Backend) SetOptions(ctx context.Context, recordType string, options *databroker.Options) error {
ctx, span := trace.StartSpan(ctx, "databroker.redis.SetOptions")
defer span.End()
bs, err := proto.Marshal(options)
if err != nil {
return err
}
// update the options in the hash set
err = backend.client.HSet(ctx, optionsKey, recordType, bs).Err()
if err != nil {
return err
}
// possibly re-enforce options
err = backend.enforceOptions(ctx, recordType)
if err != nil {
return err
}
return nil
}
// Sync returns a record stream of any records changed after the specified recordVersion.
func (backend *Backend) Sync(
ctx context.Context,
recordType string,
serverVersion, recordVersion uint64,
) (storage.RecordStream, error) {
return newSyncRecordStream(ctx, backend, recordType, serverVersion, recordVersion), nil
}
// SyncLatest returns a record stream of all the records. Some records may be returned twice if the are updated while the
// stream is streaming.
func (backend *Backend) SyncLatest(
ctx context.Context,
recordType string,
expr storage.FilterExpression,
) (serverVersion, recordVersion uint64, stream storage.RecordStream, err error) {
serverVersion, err = backend.getOrCreateServerVersion(ctx)
if err != nil {
return serverVersion, recordVersion, nil, err
}
recordVersion, err = backend.client.Get(ctx, lastVersionKey).Uint64()
if errors.Is(err, redis.Nil) {
// this happens if there are no records
} else if err != nil {
return serverVersion, recordVersion, nil, err
}
stream, err = newSyncLatestRecordStream(ctx, backend, recordType, expr)
return serverVersion, recordVersion, stream, err
}
func (backend *Backend) put(ctx context.Context, records []*databroker.Record) error {
return backend.incrementVersion(ctx,
func(tx *redis.Tx, version uint64) error {
for i, record := range records {
record.ModifiedAt = timestamppb.Now()
record.Version = version + uint64(i)
}
return nil
},
func(p redis.Pipeliner, version uint64) error {
for i, record := range records {
bs, err := proto.Marshal(record)
if err != nil {
return err
}
key, field := getHashKey(record.GetType(), record.GetId())
if record.DeletedAt != nil {
p.HDel(ctx, key, field)
} else {
p.HSet(ctx, key, field, bs)
p.ZAdd(ctx, getRecordTypeChangesKey(record.GetType()), &redis.Z{
Score: float64(record.GetModifiedAt().GetSeconds()) + float64(i)/float64(len(records)),
Member: record.GetId(),
})
}
p.ZAdd(ctx, changesSetKey, &redis.Z{
Score: float64(version) + float64(i),
Member: bs,
})
p.SAdd(ctx, recordTypesSetKey, record.GetType())
}
return nil
})
}
// enforceOptions enforces the options for the given record type.
func (backend *Backend) enforceOptions(ctx context.Context, recordType string) error {
ctx, span := trace.StartSpan(ctx, "databroker.redis.enforceOptions")
defer span.End()
options, err := backend.GetOptions(ctx, recordType)
if err != nil {
return err
}
// nothing to do if capacity isn't set
if options.Capacity == nil {
return nil
}
key := getRecordTypeChangesKey(recordType)
// find oldest records that exceed the capacity
recordIDs, err := backend.client.ZRevRange(ctx, key, int64(*options.Capacity), -1).Result()
if err != nil {
return err
}
// for each record, delete it
for _, recordID := range recordIDs {
record, err := backend.Get(ctx, recordType, recordID)
if err == nil {
// mark the record as deleted and re-submit
record.DeletedAt = timestamppb.Now()
err = backend.put(ctx, []*databroker.Record{record})
if err != nil {
return err
}
} else if errors.Is(err, storage.ErrNotFound) {
// ignore
} else if err != nil {
return err
}
// remove the member from the collection
_, err = backend.client.ZRem(ctx, key, recordID).Result()
if err != nil {
return err
}
}
return nil
}
// incrementVersion increments the last recordVersion key, runs the code in `query`, then attempts to commit the code in
// `commit`. If the last recordVersion changes in the interim, we will retry the transaction.
func (backend *Backend) incrementVersion(ctx context.Context,
query func(tx *redis.Tx, recordVersion uint64) error,
commit func(p redis.Pipeliner, recordVersion uint64) error,
) error {
// code is modeled on https://pkg.go.dev/github.com/go-redis/redis/v8#example-Client.Watch
txf := func(tx *redis.Tx) error {
version, err := tx.Get(ctx, lastVersionKey).Uint64()
if errors.Is(err, redis.Nil) {
version = 0
} else if err != nil {
return err
}
version++
err = query(tx, version)
if err != nil {
return err
}
// the `commit` code is run in a transaction so that the EXEC cmd will run for the original redis watch
_, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error {
err := commit(p, version)
if err != nil {
return err
}
p.Set(ctx, lastVersionKey, version, 0)
p.Publish(ctx, lastVersionChKey, version)
return nil
})
return err
}
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
for i := 0; i < maxTransactionRetries; i++ {
err := backend.client.Watch(ctx, txf, lastVersionKey)
if errors.Is(err, redis.TxFailedErr) {
select {
case <-ctx.Done():
return ctx.Err()
case <-time.After(bo.NextBackOff()):
}
continue // retry
} else if err != nil {
return err
}
return nil // tx was successful
}
return ErrExceededMaxRetries
}
func (backend *Backend) listenForVersionChanges(ctx context.Context) {
ctx, cancel := context.WithCancel(ctx)
go func() {
<-backend.closed
cancel()
}()
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
outer:
for {
pubsub := backend.client.Subscribe(ctx, lastVersionChKey)
for {
msg, err := pubsub.Receive(ctx)
if err != nil {
_ = pubsub.Close()
select {
case <-ctx.Done():
return
case <-time.After(bo.NextBackOff()):
}
continue outer
}
bo.Reset()
switch msg.(type) {
case *redis.Message:
backend.onChange.Broadcast(ctx)
}
}
}
}
func (backend *Backend) removeChangesBefore(ctx context.Context, cutoff time.Time) {
for {
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
Min: "-inf",
Max: "+inf",
Offset: 0,
Count: 1,
})
results, err := cmd.Result()
if err != nil {
log.Error(ctx).Err(err).Msg("redis: error retrieving changes for expiration")
return
}
// nothing left to do
if len(results) == 0 {
return
}
var record databroker.Record
err = proto.Unmarshal([]byte(results[0]), &record)
if err != nil {
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
record.ModifiedAt = timestamppb.New(cutoff.Add(-time.Second)) // set the modified so will delete it
}
// if the record's modified timestamp is after the cutoff, we're all done, so break
if record.GetModifiedAt().AsTime().After(cutoff) {
break
}
// remove the record
err = backend.client.ZRem(ctx, changesSetKey, results[0]).Err()
if err != nil {
log.Error(ctx).Err(err).Msg("redis: error removing member")
return
}
}
}
func (backend *Backend) getOrCreateServerVersion(ctx context.Context) (serverVersion uint64, err error) {
serverVersion, err = backend.client.Get(ctx, serverVersionKey).Uint64()
// if the server version hasn't been set yet, set it to a random value and immediately retrieve it
// this should properly handle a data race by only setting the key if it doesn't already exist
if errors.Is(err, redis.Nil) {
_, _ = backend.client.SetNX(ctx, serverVersionKey, cryptutil.NewRandomUInt64(), 0).Result()
serverVersion, err = backend.client.Get(ctx, serverVersionKey).Uint64()
}
if err != nil {
return 0, fmt.Errorf("redis: error retrieving server version: %w", err)
}
return serverVersion, err
}
func getLeaseKey(leaseName string) string {
return fmt.Sprintf(leaseKeyTpl, leaseName)
}
func getRecordTypeChangesKey(recordType string) string {
return fmt.Sprintf(recordTypeChangesKeyTpl, recordType)
}
func getHashKey(recordType, id string) (key, field string) {
return recordHashKey, fmt.Sprintf("%s/%s", recordType, id)
}

View file

@ -1,312 +0,0 @@
package redis
import (
"context"
"fmt"
"os"
"runtime"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sync/errgroup"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/anypb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
func TestBackend(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
handler := func(t *testing.T, useTLS bool, rawURL string) error {
ctx := context.Background()
var opts []Option
if useTLS {
opts = append(opts, WithTLSConfig(testutil.RedisTLSConfig()))
}
backend, err := New(rawURL, opts...)
require.NoError(t, err)
defer func() { _ = backend.Close() }()
serverVersion, err := backend.getOrCreateServerVersion(ctx)
require.NoError(t, err)
t.Run("get missing record", func(t *testing.T) {
record, err := backend.Get(ctx, "TYPE", "abcd")
require.Error(t, err)
assert.Nil(t, record)
})
t.Run("get record", func(t *testing.T) {
data := new(anypb.Any)
sv, err := backend.Put(ctx, []*databroker.Record{{
Type: "TYPE",
Id: "abcd",
Data: data,
}})
assert.NoError(t, err)
assert.Equal(t, serverVersion, sv)
record, err := backend.Get(ctx, "TYPE", "abcd")
require.NoError(t, err)
if assert.NotNil(t, record) {
assert.Equal(t, data, record.Data)
assert.Nil(t, record.DeletedAt)
assert.Equal(t, "abcd", record.Id)
assert.NotNil(t, record.ModifiedAt)
assert.Equal(t, "TYPE", record.Type)
assert.Equal(t, uint64(1), record.Version)
}
})
t.Run("delete record", func(t *testing.T) {
sv, err := backend.Put(ctx, []*databroker.Record{{
Type: "TYPE",
Id: "abcd",
DeletedAt: timestamppb.Now(),
}})
assert.NoError(t, err)
assert.Equal(t, serverVersion, sv)
record, err := backend.Get(ctx, "TYPE", "abcd")
assert.Error(t, err)
assert.Nil(t, record)
})
t.Run("list types", func(t *testing.T) {
types, err := backend.ListTypes(ctx)
assert.NoError(t, err)
assert.Equal(t, []string{"TYPE"}, types)
})
return nil
}
t.Run("no-tls", func(t *testing.T) {
t.Parallel()
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
return handler(t, false, rawURL)
}))
})
t.Run("tls", func(t *testing.T) {
t.Parallel()
require.NoError(t, testutil.WithTestRedis(true, func(rawURL string) error {
return handler(t, true, rawURL)
}))
})
if runtime.GOOS == "linux" {
t.Run("cluster", func(t *testing.T) {
t.Parallel()
require.NoError(t, testutil.WithTestRedisCluster(func(rawURL string) error {
return handler(t, false, rawURL)
}))
})
t.Run("sentinel", func(t *testing.T) {
t.Parallel()
require.NoError(t, testutil.WithTestRedisSentinel(func(rawURL string) error {
return handler(t, false, rawURL)
}))
})
}
}
func TestChangeSignal(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
t.Parallel()
ctx := context.Background()
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
defer clearTimeout()
done := make(chan struct{})
var eg errgroup.Group
eg.Go(func() error {
backend, err := New(rawURL)
if err != nil {
return err
}
defer func() { _ = backend.Close() }()
ch := backend.onChange.Bind()
defer backend.onChange.Unbind(ch)
select {
case <-ch:
case <-ctx.Done():
return ctx.Err()
}
// signal the second backend that we've received the change
close(done)
return nil
})
eg.Go(func() error {
backend, err := New(rawURL)
if err != nil {
return err
}
defer func() { _ = backend.Close() }()
// put a new value to trigger a change
for {
_, err = backend.Put(ctx, []*databroker.Record{{
Type: "TYPE",
Id: "ID",
}})
if err != nil {
return err
}
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
return nil
case <-time.After(time.Millisecond * 100):
}
}
})
assert.NoError(t, eg.Wait(), "expected signal to be fired when another backend triggers a change")
return nil
}))
}
func TestExpiry(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
t.Parallel()
ctx := context.Background()
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
backend, err := New(rawURL, WithExpiry(0))
require.NoError(t, err)
defer func() { _ = backend.Close() }()
serverVersion, err := backend.getOrCreateServerVersion(ctx)
require.NoError(t, err)
for i := 0; i < 1000; i++ {
_, err := backend.Put(ctx, []*databroker.Record{{
Type: "TYPE",
Id: fmt.Sprint(i),
}})
assert.NoError(t, err)
}
stream, err := backend.Sync(ctx, "TYPE", serverVersion, 0)
require.NoError(t, err)
var records []*databroker.Record
for stream.Next(false) {
records = append(records, stream.Record())
}
_ = stream.Close()
require.Len(t, records, 1000)
backend.removeChangesBefore(ctx, time.Now().Add(time.Second))
stream, err = backend.Sync(ctx, "TYPE", serverVersion, 0)
require.NoError(t, err)
records = nil
for stream.Next(false) {
records = append(records, stream.Record())
}
_ = stream.Close()
require.Len(t, records, 0)
return nil
}))
}
func TestCapacity(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
t.Parallel()
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
defer clearTimeout()
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
backend, err := New(rawURL, WithExpiry(0))
require.NoError(t, err)
defer func() { _ = backend.Close() }()
err = backend.SetOptions(ctx, "EXAMPLE", &databroker.Options{
Capacity: proto.Uint64(3),
})
require.NoError(t, err)
for i := 0; i < 10; i++ {
_, err = backend.Put(ctx, []*databroker.Record{{
Type: "EXAMPLE",
Id: fmt.Sprint(i),
}})
require.NoError(t, err)
}
_, _, stream, err := backend.SyncLatest(ctx, "EXAMPLE", nil)
require.NoError(t, err)
defer stream.Close()
records, err := storage.RecordStreamToList(stream)
require.NoError(t, err)
assert.Len(t, records, 3)
var ids []string
for _, r := range records {
ids = append(ids, r.GetId())
}
assert.Equal(t, []string{"7", "8", "9"}, ids, "should contain recent records")
return nil
}))
}
func TestLease(t *testing.T) {
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
t.Skip("Github action can not run docker on MacOS")
}
t.Parallel()
ctx := context.Background()
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
backend, err := New(rawURL)
require.NoError(t, err)
defer func() { _ = backend.Close() }()
{
ok, err := backend.Lease(ctx, "test", "a", time.Second*30)
require.NoError(t, err)
assert.True(t, ok, "expected a to acquire the lease")
}
{
ok, err := backend.Lease(ctx, "test", "b", time.Second*30)
require.NoError(t, err)
assert.False(t, ok, "expected b to fail to acquire the lease")
}
{
ok, err := backend.Lease(ctx, "test", "a", 0)
require.NoError(t, err)
assert.False(t, ok, "expected a to clear the lease")
}
{
ok, err := backend.Lease(ctx, "test", "b", time.Second*30)
require.NoError(t, err)
assert.True(t, ok, "expected b to to acquire the lease")
}
return nil
}))
}

View file

@ -1,172 +0,0 @@
package redis
import (
"context"
"errors"
"fmt"
"time"
"github.com/go-redis/redis/v8"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/storage"
)
func newSyncRecordStream(
ctx context.Context,
backend *Backend,
recordType string,
serverVersion uint64,
recordVersion uint64,
) storage.RecordStream {
changed := backend.onChange.Bind()
return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{
// 1. stream all record changes
func(ctx context.Context, block bool) (*databroker.Record, error) {
ticker := time.NewTicker(watchPollInterval)
defer ticker.Stop()
for {
currentServerVersion, err := backend.getOrCreateServerVersion(ctx)
if err != nil {
return nil, err
}
if serverVersion != currentServerVersion {
return nil, storage.ErrInvalidServerVersion
}
record, err := nextChangedRecord(ctx, backend, recordType, &recordVersion)
if err == nil {
return record, nil
} else if !errors.Is(err, storage.ErrStreamDone) {
return nil, err
}
if !block {
return nil, storage.ErrStreamDone
}
select {
case <-ctx.Done():
return nil, ctx.Err()
case <-ticker.C:
case <-changed:
}
}
},
}, func() {
backend.onChange.Unbind(changed)
})
}
func newSyncLatestRecordStream(
ctx context.Context,
backend *Backend,
recordType string,
expr storage.FilterExpression,
) (storage.RecordStream, error) {
filter, err := storage.RecordStreamFilterFromFilterExpression(expr)
if err != nil {
return nil, err
}
if recordType != "" {
filter = filter.And(func(record *databroker.Record) (keep bool) {
return record.GetType() == recordType
})
}
var cursor uint64
scannedOnce := false
var scannedRecords []*databroker.Record
generator := storage.FilteredRecordStreamGenerator(
func(ctx context.Context, block bool) (*databroker.Record, error) {
for {
if len(scannedRecords) > 0 {
record := scannedRecords[0]
scannedRecords = scannedRecords[1:]
return record, nil
}
// the cursor is reset to 0 after iteration is complete
if scannedOnce && cursor == 0 {
return nil, storage.ErrStreamDone
}
var err error
scannedRecords, err = nextScannedRecords(ctx, backend, &cursor)
if err != nil {
return nil, err
}
scannedOnce = true
}
},
filter,
)
return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{
generator,
}, nil), nil
}
func nextScannedRecords(ctx context.Context, backend *Backend, cursor *uint64) ([]*databroker.Record, error) {
var values []string
var err error
values, *cursor, err = backend.client.HScan(ctx, recordHashKey, *cursor, "", 0).Result()
if errors.Is(err, redis.Nil) {
return nil, storage.ErrStreamDone
} else if err != nil {
return nil, err
} else if len(values) == 0 {
return nil, storage.ErrStreamDone
}
var records []*databroker.Record
for i := 1; i < len(values); i += 2 {
var record databroker.Record
err := proto.Unmarshal([]byte(values[i]), &record)
if err != nil {
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
continue
}
records = append(records, &record)
}
return records, nil
}
func nextChangedRecord(ctx context.Context, backend *Backend, recordType string, recordVersion *uint64) (*databroker.Record, error) {
for {
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
Min: fmt.Sprintf("(%d", *recordVersion),
Max: "+inf",
Offset: 0,
Count: 1,
})
results, err := cmd.Result()
if errors.Is(err, redis.Nil) {
return nil, storage.ErrStreamDone
} else if err != nil {
return nil, err
} else if len(results) == 0 {
return nil, storage.ErrStreamDone
}
result := results[0]
var record databroker.Record
err = proto.Unmarshal([]byte(result), &record)
if err != nil {
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
*recordVersion++
continue
}
*recordVersion = record.GetVersion()
if recordType != "" && record.GetType() != recordType {
continue
}
return &record, nil
}
}