From 4f648e9ac1d10dfb6b0a1d0e20f6764ae94bc156 Mon Sep 17 00:00:00 2001 From: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com> Date: Thu, 2 Nov 2023 11:53:25 -0700 Subject: [PATCH] databroker: remove redis storage backend (#4699) Remove the Redis databroker backend. According to https://www.pomerium.com/docs/internals/data-storage#redis it has been discouraged since Pomerium v0.18. Update the config options validation to return an error if "redis" is set as the databroker storage backend type. --- config/options.go | 6 +- config/options_test.go | 5 +- internal/databroker/registry.go | 11 - internal/databroker/server.go | 34 -- internal/databroker/server_test.go | 10 +- internal/registry/redis/lua/lua.go | 20 - internal/registry/redis/lua/registry.lua | 28 -- internal/registry/redis/option.go | 48 -- internal/registry/redis/redis.go | 250 ----------- internal/registry/redis/redis_test.go | 196 --------- pkg/storage/encrypted.go | 204 --------- pkg/storage/encrypted_test.go | 75 ---- pkg/storage/redis/observe.go | 31 -- pkg/storage/redis/option.go | 37 -- pkg/storage/redis/redis.go | 537 ----------------------- pkg/storage/redis/redis_test.go | 312 ------------- pkg/storage/redis/stream.go | 172 -------- 17 files changed, 12 insertions(+), 1964 deletions(-) delete mode 100644 internal/registry/redis/lua/lua.go delete mode 100644 internal/registry/redis/lua/registry.lua delete mode 100644 internal/registry/redis/option.go delete mode 100644 internal/registry/redis/redis.go delete mode 100644 internal/registry/redis/redis_test.go delete mode 100644 pkg/storage/encrypted.go delete mode 100644 pkg/storage/encrypted_test.go delete mode 100644 pkg/storage/redis/observe.go delete mode 100644 pkg/storage/redis/option.go delete mode 100644 pkg/storage/redis/redis.go delete mode 100644 pkg/storage/redis/redis_test.go delete mode 100644 pkg/storage/redis/stream.go diff --git a/config/options.go b/config/options.go index d651c5b1f..ac64303cd 100644 --- a/config/options.go +++ b/config/options.go @@ -245,7 +245,7 @@ type Options struct { DataBrokerURLStrings []string `mapstructure:"databroker_service_urls" yaml:"databroker_service_urls,omitempty"` DataBrokerInternalURLString string `mapstructure:"databroker_internal_service_url" yaml:"databroker_internal_service_url,omitempty"` // DataBrokerStorageType is the storage backend type that databroker will use. - // Supported type: memory, redis + // Supported type: memory, postgres DataBrokerStorageType string `mapstructure:"databroker_storage_type" yaml:"databroker_storage_type,omitempty"` // DataBrokerStorageConnectionString is the data source name for storage backend. DataBrokerStorageConnectionString string `mapstructure:"databroker_storage_connection_string" yaml:"databroker_storage_connection_string,omitempty"` @@ -584,7 +584,9 @@ func (o *Options) Validate() error { switch o.DataBrokerStorageType { case StorageInMemoryName: - case StorageRedisName, StoragePostgresName: + case StorageRedisName: + return errors.New("config: redis databroker storage backend is no longer supported") + case StoragePostgresName: if o.DataBrokerStorageConnectionString == "" { return errors.New("config: missing databroker storage backend dsn") } diff --git a/config/options_test.go b/config/options_test.go index e770cdce6..bcd7021f3 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -58,8 +58,10 @@ func Test_Validate(t *testing.T) { badPolicyFile.PolicyFile = "file" invalidStorageType := testOptions() invalidStorageType.DataBrokerStorageType = "foo" + redisStorageType := testOptions() + redisStorageType.DataBrokerStorageType = "redis" missingStorageDSN := testOptions() - missingStorageDSN.DataBrokerStorageType = "redis" + missingStorageDSN.DataBrokerStorageType = "postgres" badSignoutRedirectURL := testOptions() badSignoutRedirectURL.SignOutRedirectURLString = "--" badCookieSettings := testOptions() @@ -77,6 +79,7 @@ func Test_Validate(t *testing.T) { {"missing shared secret but all service", badSecretAllServices, false}, {"policy file specified", badPolicyFile, true}, {"invalid databroker storage type", invalidStorageType, true}, + {"redis databroker storage type", redisStorageType, true}, {"missing databroker storage dsn", missingStorageDSN, true}, {"invalid signout redirect url", badSignoutRedirectURL, true}, {"CookieSameSite none with CookieSecure fale", badCookieSettings, true}, diff --git a/internal/databroker/registry.go b/internal/databroker/registry.go index e2d3d4997..d259f6714 100644 --- a/internal/databroker/registry.go +++ b/internal/databroker/registry.go @@ -9,7 +9,6 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/registry" "github.com/pomerium/pomerium/internal/registry/inmemory" - "github.com/pomerium/pomerium/internal/registry/redis" "github.com/pomerium/pomerium/internal/telemetry/trace" registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" "github.com/pomerium/pomerium/pkg/storage" @@ -110,16 +109,6 @@ func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interfac case config.StorageInMemoryName: log.Info(ctx).Msg("using in-memory registry") return inmemory.New(ctx, srv.cfg.registryTTL), nil - case config.StorageRedisName: - log.Info(ctx).Msg("using redis registry") - r, err := redis.New( - srv.cfg.storageConnectionString, - redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)), - ) - if err != nil { - return nil, fmt.Errorf("failed to create new redis registry: %w", err) - } - return r, nil } return nil, fmt.Errorf("unsupported registry type: %s", srv.cfg.storageType) diff --git a/internal/databroker/server.go b/internal/databroker/server.go index 324e96cd1..98eda6b3d 100644 --- a/internal/databroker/server.go +++ b/internal/databroker/server.go @@ -3,7 +3,6 @@ package databroker import ( "context" - "crypto/tls" "errors" "fmt" "strings" @@ -19,12 +18,10 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/registry" "github.com/pomerium/pomerium/internal/telemetry/trace" - "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/storage" "github.com/pomerium/pomerium/pkg/storage/inmemory" "github.com/pomerium/pomerium/pkg/storage/postgres" - "github.com/pomerium/pomerium/pkg/storage/redis" ) // Server implements the databroker service using an in memory database. @@ -426,39 +423,8 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) { case config.StoragePostgresName: log.Info(ctx).Msg("using postgres store") backend = postgres.New(srv.cfg.storageConnectionString) - case config.StorageRedisName: - log.Info(ctx).Msg("using redis store") - backend, err = redis.New( - srv.cfg.storageConnectionString, - redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)), - ) - if err != nil { - return nil, fmt.Errorf("failed to create new redis storage: %w", err) - } - if srv.cfg.secret != nil { - backend, err = storage.NewEncryptedBackend(srv.cfg.secret, backend) - if err != nil { - return nil, err - } - } default: return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType) } return backend, nil } - -func (srv *Server) getTLSConfigLocked(ctx context.Context) *tls.Config { - caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile) - if err != nil { - log.Warn(ctx).Err(err).Msg("failed to read databroker CA file") - } - tlsConfig := &tls.Config{ - RootCAs: caCertPool, - //nolint: gosec - InsecureSkipVerify: srv.cfg.storageCertSkipVerify, - } - if srv.cfg.storageCertificate != nil { - tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate} - } - return tlsConfig -} diff --git a/internal/databroker/server_test.go b/internal/databroker/server_test.go index c73b18c9e..19ea8c3ef 100644 --- a/internal/databroker/server_test.go +++ b/internal/databroker/server_test.go @@ -22,7 +22,6 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" "github.com/pomerium/pomerium/internal/testutil" - "github.com/pomerium/pomerium/pkg/cryptutil" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/protoutil" @@ -287,12 +286,11 @@ func TestServerInvalidStorage(t *testing.T) { _ = assert.Error(t, err) && assert.Contains(t, err.Error(), "unsupported storage type") } -func TestServerRedis(t *testing.T) { - testutil.WithTestRedis(false, func(rawURL string) error { +func TestServerPostgres(t *testing.T) { + testutil.WithTestPostgres(func(dsn string) error { srv := newServer(&serverConfig{ - storageType: "redis", - storageConnectionString: rawURL, - secret: cryptutil.NewKey(), + storageType: "postgres", + storageConnectionString: dsn, }) s := new(session.Session) diff --git a/internal/registry/redis/lua/lua.go b/internal/registry/redis/lua/lua.go deleted file mode 100644 index 1ec7d5d40..000000000 --- a/internal/registry/redis/lua/lua.go +++ /dev/null @@ -1,20 +0,0 @@ -// Package lua contains lua source code. -package lua - -import ( - "embed" -) - -//go:embed registry.lua -var fs embed.FS - -// Registry is the registry lua script -var Registry string - -func init() { - bs, err := fs.ReadFile("registry.lua") - if err != nil { - panic(err) - } - Registry = string(bs) -} diff --git a/internal/registry/redis/lua/registry.lua b/internal/registry/redis/lua/registry.lua deleted file mode 100644 index a5527ae5b..000000000 --- a/internal/registry/redis/lua/registry.lua +++ /dev/null @@ -1,28 +0,0 @@ --- ARGV = [current time in seconds, ttl in seconds, services ...] -local current_time = ARGV[1] -local ttl = ARGV[2] -local changed = false - --- update the service list -for i = 3, #ARGV, 1 do - redis.call('HSET', KEYS[1], ARGV[i], current_time + ttl) - changed = true -end - --- retrieve all the services, removing any that have expired -local svcs = {} -local kvs = redis.call('HGETALL', KEYS[1]) -for i = 1, #kvs, 2 do - if kvs[i + 1] < current_time then - redis.call('HDEL', KEYS[1], kvs[i]) - changed = true - else - table.insert(svcs, kvs[i]) - end -end - -if changed then - redis.call('PUBLISH', KEYS[2], current_time) -end - -return svcs diff --git a/internal/registry/redis/option.go b/internal/registry/redis/option.go deleted file mode 100644 index 6559a8165..000000000 --- a/internal/registry/redis/option.go +++ /dev/null @@ -1,48 +0,0 @@ -package redis - -import ( - "crypto/tls" - "time" -) - -const defaultTTL = time.Second * 30 - -type config struct { - tls *tls.Config - ttl time.Duration - getNow func() time.Time -} - -// An Option modifies the config.. -type Option func(*config) - -// WithGetNow sets the time.Now function in the config. -func WithGetNow(getNow func() time.Time) Option { - return func(cfg *config) { - cfg.getNow = getNow - } -} - -// WithTLSConfig sets the tls.Config in the config. -func WithTLSConfig(tlsConfig *tls.Config) Option { - return func(cfg *config) { - cfg.tls = tlsConfig - } -} - -// WithTTL sets the ttl in the config. -func WithTTL(ttl time.Duration) Option { - return func(cfg *config) { - cfg.ttl = ttl - } -} - -func getConfig(options ...Option) *config { - cfg := new(config) - WithGetNow(time.Now)(cfg) - WithTTL(defaultTTL)(cfg) - for _, o := range options { - o(cfg) - } - return cfg -} diff --git a/internal/registry/redis/redis.go b/internal/registry/redis/redis.go deleted file mode 100644 index 9fee74aa7..000000000 --- a/internal/registry/redis/redis.go +++ /dev/null @@ -1,250 +0,0 @@ -// Package redis implements a registry in redis. -package redis - -import ( - "context" - "fmt" - "sort" - "strings" - "sync" - "time" - - "github.com/cenkalti/backoff/v4" - "github.com/go-redis/redis/v8" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/durationpb" - - "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/redisutil" - "github.com/pomerium/pomerium/internal/registry" - "github.com/pomerium/pomerium/internal/registry/redis/lua" - "github.com/pomerium/pomerium/internal/signal" - registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" -) - -const ( - registryKey = redisutil.KeyPrefix + "registry" - registryUpdateKey = redisutil.KeyPrefix + "registry_changed_ch" - - pollInterval = time.Second * 30 -) - -type impl struct { - cfg *config - - client redis.UniversalClient - onChange *signal.Signal - - closeOnce sync.Once - closed chan struct{} -} - -// New creates a new registry implementation backend by redis. -func New(rawURL string, options ...Option) (registry.Interface, error) { - cfg := getConfig(options...) - - client, err := redisutil.NewClientFromURL(rawURL, cfg.tls) - if err != nil { - return nil, err - } - - i := &impl{ - cfg: cfg, - client: client, - onChange: signal.New(), - closed: make(chan struct{}), - } - go i.listenForChanges(context.Background()) - return i, nil -} - -func (i *impl) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) { - _, err := i.runReport(ctx, req.GetServices()) - if err != nil { - return nil, err - } - return ®istrypb.RegisterResponse{ - CallBackAfter: durationpb.New(i.cfg.ttl / 2), - }, nil -} - -func (i *impl) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) { - all, err := i.runReport(ctx, nil) - if err != nil { - return nil, err - } - - include := map[registrypb.ServiceKind]struct{}{} - for _, kind := range req.GetKinds() { - include[kind] = struct{}{} - } - - filtered := make([]*registrypb.Service, 0, len(all)) - for _, svc := range all { - if _, ok := include[svc.GetKind()]; !ok { - continue - } - filtered = append(filtered, svc) - } - - sort.Slice(filtered, func(i, j int) bool { - { - iv, jv := filtered[i].GetKind(), filtered[j].GetKind() - switch { - case iv < jv: - return true - case jv < iv: - return false - } - } - - { - iv, jv := filtered[i].GetEndpoint(), filtered[j].GetEndpoint() - switch { - case iv < jv: - return true - case jv < iv: - return false - } - } - - return false - }) - - return ®istrypb.ServiceList{ - Services: filtered, - }, nil -} - -func (i *impl) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error { - // listen for changes - ch := i.onChange.Bind() - defer i.onChange.Unbind(ch) - - // force a check periodically - poll := time.NewTicker(pollInterval) - defer poll.Stop() - - var prev *registrypb.ServiceList - for { - // retrieve the most recent list of services - lst, err := i.List(stream.Context(), req) - if err != nil { - return err - } - - // only send a new list if something changed - if !proto.Equal(prev, lst) { - err = stream.Send(lst) - if err != nil { - return err - } - } - prev = lst - - // wait for an update - select { - case <-i.closed: - return nil - case <-stream.Context().Done(): - return stream.Context().Err() - case <-ch: - case <-poll.C: - } - } -} - -func (i *impl) Close() error { - var err error - i.closeOnce.Do(func() { - err = i.client.Close() - close(i.closed) - }) - return err -} - -func (i *impl) listenForChanges(ctx context.Context) { - ctx, cancel := context.WithCancel(ctx) - go func() { - <-i.closed - cancel() - }() - - bo := backoff.NewExponentialBackOff() - bo.MaxElapsedTime = 0 - -outer: - for { - pubsub := i.client.Subscribe(ctx, registryUpdateKey) - for { - msg, err := pubsub.Receive(ctx) - if err != nil { - _ = pubsub.Close() - select { - case <-ctx.Done(): - return - case <-time.After(bo.NextBackOff()): - } - continue outer - } - bo.Reset() - - switch msg.(type) { - case *redis.Message: - i.onChange.Broadcast(ctx) - } - } - } -} - -func (i *impl) runReport(ctx context.Context, updates []*registrypb.Service) ([]*registrypb.Service, error) { - args := []interface{}{ - i.cfg.getNow().UnixNano() / int64(time.Millisecond), // current_time - i.cfg.ttl.Milliseconds(), // ttl - } - for _, svc := range updates { - args = append(args, i.getRegistryHashKey(svc)) - } - res, err := i.client.Eval(ctx, lua.Registry, []string{registryKey, registryUpdateKey}, args...).Result() - if err != nil { - return nil, err - } - if values, ok := res.([]interface{}); ok { - var all []*registrypb.Service - for _, value := range values { - svc, err := i.getServiceFromRegistryHashKey(fmt.Sprint(value)) - if err != nil { - log.Warn(ctx).Err(err).Msg("redis: invalid service") - continue - } - all = append(all, svc) - } - return all, nil - } - return nil, nil -} - -func (i *impl) getServiceFromRegistryHashKey(key string) (*registrypb.Service, error) { - idx := strings.Index(key, "|") - if idx == -1 { - return nil, fmt.Errorf("redis: invalid service entry in hash: %s", key) - } - - svcKindStr := key[:idx] - svcEndpointStr := key[idx+1:] - - svcKind, ok := registrypb.ServiceKind_value[svcKindStr] - if !ok { - return nil, fmt.Errorf("redis: unknown service kind: %s", svcKindStr) - } - - svc := ®istrypb.Service{ - Kind: registrypb.ServiceKind(svcKind), - Endpoint: svcEndpointStr, - } - return svc, nil -} - -func (i *impl) getRegistryHashKey(svc *registrypb.Service) string { - return svc.GetKind().String() + "|" + svc.GetEndpoint() -} diff --git a/internal/registry/redis/redis_test.go b/internal/registry/redis/redis_test.go deleted file mode 100644 index 4092278ab..000000000 --- a/internal/registry/redis/redis_test.go +++ /dev/null @@ -1,196 +0,0 @@ -package redis - -import ( - "context" - "net" - "os" - "runtime" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/sync/errgroup" - "google.golang.org/grpc" - - "github.com/pomerium/pomerium/internal/testutil" - registrypb "github.com/pomerium/pomerium/pkg/grpc/registry" -) - -func TestReport(t *testing.T) { - if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" { - t.Skip("Github action can not run docker on MacOS") - } - - ctx := context.Background() - require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error { - tm := time.Now() - - i, err := New(rawURL, - WithGetNow(func() time.Time { - return tm - }), - WithTTL(time.Second*10)) - require.NoError(t, err) - defer func() { _ = i.Close() }() - - _, err = i.Report(ctx, ®istrypb.RegisterRequest{ - Services: []*registrypb.Service{ - {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"}, - {Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"}, - {Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"}, - }, - }) - require.NoError(t, err) - - // move forward 5 seconds - tm = tm.Add(time.Second * 5) - _, err = i.Report(ctx, ®istrypb.RegisterRequest{ - Services: []*registrypb.Service{ - {Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"}, - {Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"}, - }, - }) - require.NoError(t, err) - - lst, err := i.List(ctx, ®istrypb.ListRequest{ - Kinds: []registrypb.ServiceKind{ - registrypb.ServiceKind_AUTHORIZE, - registrypb.ServiceKind_PROXY, - }, - }) - require.NoError(t, err) - assert.Equal(t, []*registrypb.Service{ - {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"}, - {Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"}, - }, lst.GetServices(), "should list selected services") - - // move forward 6 seconds - tm = tm.Add(time.Second * 6) - lst, err = i.List(ctx, ®istrypb.ListRequest{ - Kinds: []registrypb.ServiceKind{ - registrypb.ServiceKind_AUTHORIZE, - registrypb.ServiceKind_PROXY, - }, - }) - require.NoError(t, err) - assert.Equal(t, []*registrypb.Service{ - {Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"}, - }, lst.GetServices(), "should expire old services") - - return nil - })) -} - -func TestWatch(t *testing.T) { - if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" { - t.Skip("Github action can not run docker on MacOS") - } - - require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error { - ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*15) - defer clearTimeout() - - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - tm := time.Now() - i, err := New(rawURL, - WithGetNow(func() time.Time { - return tm - }), - WithTTL(time.Second*10)) - require.NoError(t, err) - - li, err := net.Listen("tcp", "127.0.0.1:0") - require.NoError(t, err) - defer li.Close() - - srv := grpc.NewServer() - registrypb.RegisterRegistryServer(srv, i) - eg, ctx := errgroup.WithContext(ctx) - eg.Go(func() error { - <-ctx.Done() - srv.Stop() - return nil - }) - eg.Go(func() error { - return srv.Serve(li) - }) - eg.Go(func() error { - defer cancel() - - cc, err := grpc.Dial(li.Addr().String(), grpc.WithInsecure()) - if err != nil { - return err - } - - client := registrypb.NewRegistryClient(cc) - - // store the initial services - _, err = client.Report(ctx, ®istrypb.RegisterRequest{ - Services: []*registrypb.Service{ - {Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"}, - {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"}, - {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"}, - }, - }) - if err != nil { - return err - } - - stream, err := client.Watch(ctx, ®istrypb.ListRequest{ - Kinds: []registrypb.ServiceKind{ - registrypb.ServiceKind_AUTHORIZE, - }, - }) - if err != nil { - return err - } - defer func() { _ = stream.CloseSend() }() - - lst, err := stream.Recv() - if err != nil { - return err - } - assert.Equal(t, []*registrypb.Service{ - {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"}, - {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"}, - }, lst.GetServices()) - - // update authenticate - _, err = client.Report(ctx, ®istrypb.RegisterRequest{ - Services: []*registrypb.Service{ - {Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"}, - }, - }) - if err != nil { - return err - } - - // add an authorize - _, err = client.Report(ctx, ®istrypb.RegisterRequest{ - Services: []*registrypb.Service{ - {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"}, - }, - }) - if err != nil { - return err - } - - lst, err = stream.Recv() - if err != nil { - return err - } - assert.Equal(t, []*registrypb.Service{ - {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"}, - {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"}, - {Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"}, - }, lst.GetServices()) - - return nil - }) - require.NoError(t, eg.Wait()) - return nil - })) -} diff --git a/pkg/storage/encrypted.go b/pkg/storage/encrypted.go deleted file mode 100644 index 6ab9e8235..000000000 --- a/pkg/storage/encrypted.go +++ /dev/null @@ -1,204 +0,0 @@ -package storage - -import ( - "context" - "crypto/cipher" - "time" - - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/anypb" - "google.golang.org/protobuf/types/known/wrapperspb" - - "github.com/pomerium/pomerium/pkg/cryptutil" - "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/protoutil" -) - -type encryptedRecordStream struct { - underlying RecordStream - backend *encryptedBackend - err error -} - -func (e *encryptedRecordStream) Close() error { - return e.underlying.Close() -} - -func (e *encryptedRecordStream) Next(wait bool) bool { - return e.underlying.Next(wait) -} - -func (e *encryptedRecordStream) Record() *databroker.Record { - r := e.underlying.Record() - if r != nil { - var err error - r, err = e.backend.decryptRecord(r) - if err != nil { - e.err = err - } - } - return r -} - -func (e *encryptedRecordStream) Err() error { - if e.err == nil { - e.err = e.underlying.Err() - } - return e.err -} - -type encryptedBackend struct { - underlying Backend - cipher cipher.AEAD -} - -// NewEncryptedBackend creates a new encrypted backend. -func NewEncryptedBackend(secret []byte, underlying Backend) (Backend, error) { - c, err := cryptutil.NewAEADCipher(secret) - if err != nil { - return nil, err - } - - return &encryptedBackend{ - underlying: underlying, - cipher: c, - }, nil -} - -func (e *encryptedBackend) Close() error { - return e.underlying.Close() -} - -func (e *encryptedBackend) Get(ctx context.Context, recordType, id string) (*databroker.Record, error) { - record, err := e.underlying.Get(ctx, recordType, id) - if err != nil { - return nil, err - } - record, err = e.decryptRecord(record) - if err != nil { - return nil, err - } - return record, nil -} - -func (e *encryptedBackend) GetOptions(ctx context.Context, recordType string) (*databroker.Options, error) { - return e.underlying.GetOptions(ctx, recordType) -} - -func (e *encryptedBackend) Lease(ctx context.Context, leaseName, leaseID string, ttl time.Duration) (bool, error) { - return e.underlying.Lease(ctx, leaseName, leaseID, ttl) -} - -func (e *encryptedBackend) ListTypes(ctx context.Context) ([]string, error) { - return e.underlying.ListTypes(ctx) -} - -func (e *encryptedBackend) Put(ctx context.Context, records []*databroker.Record) (uint64, error) { - encryptedRecords := make([]*databroker.Record, len(records)) - for i, record := range records { - encrypted, err := e.encrypt(record.GetData()) - if err != nil { - return 0, err - } - - newRecord := proto.Clone(record).(*databroker.Record) - newRecord.Data = encrypted - encryptedRecords[i] = newRecord - } - - serverVersion, err := e.underlying.Put(ctx, encryptedRecords) - if err != nil { - return 0, err - } - - for i, record := range records { - record.ModifiedAt = encryptedRecords[i].ModifiedAt - record.Version = encryptedRecords[i].Version - } - - return serverVersion, nil -} - -func (e *encryptedBackend) SetOptions(ctx context.Context, recordType string, options *databroker.Options) error { - return e.underlying.SetOptions(ctx, recordType, options) -} - -func (e *encryptedBackend) Sync(ctx context.Context, recordType string, serverVersion, recordVersion uint64) (RecordStream, error) { - stream, err := e.underlying.Sync(ctx, recordType, serverVersion, recordVersion) - if err != nil { - return nil, err - } - return &encryptedRecordStream{ - underlying: stream, - backend: e, - }, nil -} - -func (e *encryptedBackend) SyncLatest( - ctx context.Context, - recordType string, - filter FilterExpression, -) (serverVersion, recordVersion uint64, stream RecordStream, err error) { - serverVersion, recordVersion, stream, err = e.underlying.SyncLatest(ctx, recordType, filter) - if err != nil { - return serverVersion, recordVersion, nil, err - } - return serverVersion, recordVersion, &encryptedRecordStream{ - underlying: stream, - backend: e, - }, nil -} - -func (e *encryptedBackend) decryptRecord(in *databroker.Record) (out *databroker.Record, err error) { - data, err := e.decrypt(in.Data) - if err != nil { - return nil, err - } - // Create a new record so that we don't re-use any internal state - return &databroker.Record{ - Version: in.Version, - Type: in.Type, - Id: in.Id, - Data: data, - ModifiedAt: in.ModifiedAt, - DeletedAt: in.DeletedAt, - }, nil -} - -func (e *encryptedBackend) decrypt(in *anypb.Any) (out *anypb.Any, err error) { - if in == nil { - return nil, nil - } - - var encrypted wrapperspb.BytesValue - err = in.UnmarshalTo(&encrypted) - if err != nil { - return nil, err - } - - plaintext, err := cryptutil.Decrypt(e.cipher, encrypted.Value, nil) - if err != nil { - return nil, err - } - - out = new(anypb.Any) - err = proto.Unmarshal(plaintext, out) - if err != nil { - return nil, err - } - - return out, nil -} - -func (e *encryptedBackend) encrypt(in *anypb.Any) (out *anypb.Any, err error) { - plaintext, err := proto.Marshal(in) - if err != nil { - return nil, err - } - - encrypted := cryptutil.Encrypt(e.cipher, plaintext, nil) - out = protoutil.NewAny(&wrapperspb.BytesValue{ - Value: encrypted, - }) - return out, nil -} diff --git a/pkg/storage/encrypted_test.go b/pkg/storage/encrypted_test.go deleted file mode 100644 index 6fb285a25..000000000 --- a/pkg/storage/encrypted_test.go +++ /dev/null @@ -1,75 +0,0 @@ -package storage - -import ( - "context" - "errors" - "testing" - - "github.com/stretchr/testify/assert" - "google.golang.org/protobuf/types/known/anypb" - "google.golang.org/protobuf/types/known/timestamppb" - "google.golang.org/protobuf/types/known/wrapperspb" - - "github.com/pomerium/pomerium/pkg/cryptutil" - "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/protoutil" -) - -func TestEncryptedBackend(t *testing.T) { - ctx := context.Background() - - m := map[string]*anypb.Any{} - backend := &mockBackend{ - put: func(ctx context.Context, records []*databroker.Record) (uint64, error) { - for _, record := range records { - record.ModifiedAt = timestamppb.Now() - record.Version++ - m[record.GetId()] = record.GetData() - } - return 0, nil - }, - get: func(ctx context.Context, recordType, id string) (*databroker.Record, error) { - data, ok := m[id] - if !ok { - return nil, errors.New("not found") - } - return &databroker.Record{ - Id: id, - Data: data, - Version: 1, - ModifiedAt: timestamppb.Now(), - }, nil - }, - } - - e, err := NewEncryptedBackend(cryptutil.NewKey(), backend) - if !assert.NoError(t, err) { - return - } - - data := protoutil.NewAny(wrapperspb.String("HELLO WORLD")) - - rec := &databroker.Record{ - Type: "", - Id: "TEST-1", - Data: data, - } - _, err = e.Put(ctx, []*databroker.Record{rec}) - if !assert.NoError(t, err) { - return - } - if assert.NotNil(t, m["TEST-1"], "key should be set") { - assert.NotEqual(t, data.TypeUrl, m["TEST-1"].TypeUrl, "encrypted data should be a bytes type") - assert.NotEqual(t, data.Value, m["TEST-1"].Value, "value should be encrypted") - assert.NotNil(t, rec.ModifiedAt) - assert.NotZero(t, rec.Version) - } - - record, err := e.Get(ctx, "", "TEST-1") - if !assert.NoError(t, err) { - return - } - assert.Equal(t, data.TypeUrl, record.Data.TypeUrl, "type should be preserved") - assert.Equal(t, data.Value, record.Data.Value, "value should be preserved") - assert.NotEqual(t, data.TypeUrl, record.Type, "record type should be preserved") -} diff --git a/pkg/storage/redis/observe.go b/pkg/storage/redis/observe.go deleted file mode 100644 index ad1495c89..000000000 --- a/pkg/storage/redis/observe.go +++ /dev/null @@ -1,31 +0,0 @@ -package redis - -import ( - "context" - "time" - - "github.com/go-redis/redis/v8" - - pomeriumconfig "github.com/pomerium/pomerium/config" - "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/telemetry/metrics" -) - -type logger struct { -} - -func (l logger) Printf(ctx context.Context, format string, v ...interface{}) { - log.Info(ctx).Str("service", "redis").Msgf(format, v...) -} - -func init() { - redis.SetLogger(logger{}) -} - -func recordOperation(ctx context.Context, startTime time.Time, operation string, err error) { - metrics.RecordStorageOperation(ctx, &metrics.StorageOperationTags{ - Operation: operation, - Error: err, - Backend: pomeriumconfig.StorageRedisName, - }, time.Since(startTime)) -} diff --git a/pkg/storage/redis/option.go b/pkg/storage/redis/option.go deleted file mode 100644 index 5197bc072..000000000 --- a/pkg/storage/redis/option.go +++ /dev/null @@ -1,37 +0,0 @@ -package redis - -import ( - "crypto/tls" - "time" -) - -type config struct { - tls *tls.Config - expiry time.Duration -} - -// Option customizes a Backend. -type Option func(*config) - -// WithTLSConfig sets the tls.Config which Backend uses. -func WithTLSConfig(tlsConfig *tls.Config) Option { - return func(cfg *config) { - cfg.tls = tlsConfig - } -} - -// WithExpiry sets the expiry for changes. -func WithExpiry(expiry time.Duration) Option { - return func(cfg *config) { - cfg.expiry = expiry - } -} - -func getConfig(options ...Option) *config { - cfg := new(config) - WithExpiry(time.Hour * 24)(cfg) - for _, o := range options { - o(cfg) - } - return cfg -} diff --git a/pkg/storage/redis/redis.go b/pkg/storage/redis/redis.go deleted file mode 100644 index 2a4554b67..000000000 --- a/pkg/storage/redis/redis.go +++ /dev/null @@ -1,537 +0,0 @@ -// Package redis implements the storage.Backend interface for redis. -package redis - -import ( - "context" - "errors" - "fmt" - "sort" - "sync" - "time" - - "github.com/cenkalti/backoff/v4" - "github.com/go-redis/redis/v8" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/timestamppb" - - "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/internal/redisutil" - "github.com/pomerium/pomerium/internal/signal" - "github.com/pomerium/pomerium/internal/telemetry/metrics" - "github.com/pomerium/pomerium/internal/telemetry/trace" - "github.com/pomerium/pomerium/pkg/cryptutil" - "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/storage" -) - -const ( - maxTransactionRetries = 100 - watchPollInterval = 30 * time.Second - - // we rely on transactions in redis, so all redis-cluster keys need to be - // on the same node. Using a `hash tag` gives us this capability. - serverVersionKey = redisutil.KeyPrefix + "server_version" - lastVersionKey = redisutil.KeyPrefix + "last_version" - lastVersionChKey = redisutil.KeyPrefix + "last_version_ch" - recordHashKey = redisutil.KeyPrefix + "records" - recordTypesSetKey = redisutil.KeyPrefix + "record_types" - changesSetKey = redisutil.KeyPrefix + "changes" - optionsKey = redisutil.KeyPrefix + "options" - - recordTypeChangesKeyTpl = redisutil.KeyPrefix + "changes.%s" - leaseKeyTpl = redisutil.KeyPrefix + "lease.%s" -) - -// custom errors -var ( - ErrExceededMaxRetries = errors.New("redis: transaction reached maximum number of retries") -) - -// Backend implements the storage.Backend on top of redis. -// -// What's stored: -// -// - last_version: an integer recordVersion number -// - last_version_ch: a PubSub channel for recordVersion number updates -// - records: a Hash of records. The hash key is {recordType}/{recordID}, the hash value the protobuf record. -// - changes: a Sorted Set of all the changes. The score is the recordVersion number, the member the protobuf record. -// - options: a Hash of options. The hash key is {recordType}, the hash value the protobuf options. -// - changes.{recordType}: a Sorted Set of the changes for a record type. The score is the current time, -// the value the record id. -// -// Records stored in these keys are typically encrypted. -type Backend struct { - cfg *config - - client redis.UniversalClient - onChange *signal.Signal - - closeOnce sync.Once - closed chan struct{} -} - -// New creates a new redis storage backend. -func New(rawURL string, options ...Option) (*Backend, error) { - ctx := context.TODO() - cfg := getConfig(options...) - backend := &Backend{ - cfg: cfg, - closed: make(chan struct{}), - onChange: signal.New(), - } - var err error - backend.client, err = redisutil.NewClientFromURL(rawURL, backend.cfg.tls) - if err != nil { - return nil, err - } - metrics.AddRedisMetrics(backend.client.PoolStats) - go backend.listenForVersionChanges(ctx) - if cfg.expiry != 0 { - go func() { - ticker := time.NewTicker(time.Minute) - defer ticker.Stop() - for { - select { - case <-backend.closed: - return - case <-ticker.C: - } - - backend.removeChangesBefore(ctx, time.Now().Add(-cfg.expiry)) - } - }() - } - return backend, nil -} - -// Close closes the underlying redis connection and any watchers. -func (backend *Backend) Close() error { - var err error - backend.closeOnce.Do(func() { - err = backend.client.Close() - close(backend.closed) - }) - return err -} - -// Get gets a record from redis. -func (backend *Backend) Get(ctx context.Context, recordType, id string) (_ *databroker.Record, err error) { - _, span := trace.StartSpan(ctx, "databroker.redis.Get") - defer span.End() - defer func(start time.Time) { recordOperation(ctx, start, "get", err) }(time.Now()) - - key, field := getHashKey(recordType, id) - cmd := backend.client.HGet(ctx, key, field) - raw, err := cmd.Result() - if errors.Is(err, redis.Nil) { - return nil, storage.ErrNotFound - } else if err != nil { - return nil, err - } - - var record databroker.Record - err = proto.Unmarshal([]byte(raw), &record) - if err != nil { - return nil, err - } - - return &record, nil -} - -// GetOptions gets the options for the given record type. -func (backend *Backend) GetOptions(ctx context.Context, recordType string) (*databroker.Options, error) { - raw, err := backend.client.HGet(ctx, optionsKey, recordType).Result() - if errors.Is(err, redis.Nil) { - // treat no options as an empty set of options - return new(databroker.Options), nil - } else if err != nil { - return nil, err - } - - var options databroker.Options - err = proto.Unmarshal([]byte(raw), &options) - if err != nil { - return nil, err - } - - return &options, nil -} - -// Lease acquires or renews a lease. -func (backend *Backend) Lease(ctx context.Context, leaseName, leaseID string, ttl time.Duration) (bool, error) { - acquired := false - key := getLeaseKey(leaseName) - err := backend.client.Watch(ctx, func(tx *redis.Tx) error { - currentID, err := tx.Get(ctx, key).Result() - if errors.Is(err, redis.Nil) { - // lease hasn't been set yet - } else if err != nil { - return err - } else if leaseID != currentID { - // lease has already been taken - return nil - } - - _, err = tx.Pipelined(ctx, func(p redis.Pipeliner) error { - if ttl <= 0 { - p.Del(ctx, key) - } else { - p.Set(ctx, key, leaseID, ttl) - } - return nil - }) - if err != nil { - return err - } - acquired = ttl > 0 - return nil - }, key) - // if the transaction failed someone else must've acquired the lease - if errors.Is(err, redis.TxFailedErr) { - acquired = false - err = nil - } - return acquired, err -} - -// ListTypes lists all the known record types. -func (backend *Backend) ListTypes(ctx context.Context) (types []string, err error) { - ctx, span := trace.StartSpan(ctx, "databroker.redis.ListTypes") - defer span.End() - defer func(start time.Time) { recordOperation(ctx, start, "listTypes", err) }(time.Now()) - - cmd := backend.client.SMembers(ctx, recordTypesSetKey) - types, err = cmd.Result() - if err != nil { - return nil, err - } - sort.Strings(types) - return types, nil -} - -// Put puts a record into redis. -func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (serverVersion uint64, err error) { - ctx, span := trace.StartSpan(ctx, "databroker.redis.Put") - defer span.End() - defer func(start time.Time) { recordOperation(ctx, start, "put", err) }(time.Now()) - - serverVersion, err = backend.getOrCreateServerVersion(ctx) - if err != nil { - return serverVersion, err - } - - err = backend.put(ctx, records) - if err != nil { - return serverVersion, err - } - - recordTypes := map[string]struct{}{} - for _, record := range records { - recordTypes[record.GetType()] = struct{}{} - } - for recordType := range recordTypes { - err = backend.enforceOptions(ctx, recordType) - if err != nil { - return serverVersion, err - } - } - - return serverVersion, nil -} - -// SetOptions sets the options for the given record type. -func (backend *Backend) SetOptions(ctx context.Context, recordType string, options *databroker.Options) error { - ctx, span := trace.StartSpan(ctx, "databroker.redis.SetOptions") - defer span.End() - - bs, err := proto.Marshal(options) - if err != nil { - return err - } - - // update the options in the hash set - err = backend.client.HSet(ctx, optionsKey, recordType, bs).Err() - if err != nil { - return err - } - - // possibly re-enforce options - err = backend.enforceOptions(ctx, recordType) - if err != nil { - return err - } - - return nil -} - -// Sync returns a record stream of any records changed after the specified recordVersion. -func (backend *Backend) Sync( - ctx context.Context, - recordType string, - serverVersion, recordVersion uint64, -) (storage.RecordStream, error) { - return newSyncRecordStream(ctx, backend, recordType, serverVersion, recordVersion), nil -} - -// SyncLatest returns a record stream of all the records. Some records may be returned twice if the are updated while the -// stream is streaming. -func (backend *Backend) SyncLatest( - ctx context.Context, - recordType string, - expr storage.FilterExpression, -) (serverVersion, recordVersion uint64, stream storage.RecordStream, err error) { - serverVersion, err = backend.getOrCreateServerVersion(ctx) - if err != nil { - return serverVersion, recordVersion, nil, err - } - - recordVersion, err = backend.client.Get(ctx, lastVersionKey).Uint64() - if errors.Is(err, redis.Nil) { - // this happens if there are no records - } else if err != nil { - return serverVersion, recordVersion, nil, err - } - - stream, err = newSyncLatestRecordStream(ctx, backend, recordType, expr) - return serverVersion, recordVersion, stream, err -} - -func (backend *Backend) put(ctx context.Context, records []*databroker.Record) error { - return backend.incrementVersion(ctx, - func(tx *redis.Tx, version uint64) error { - for i, record := range records { - record.ModifiedAt = timestamppb.Now() - record.Version = version + uint64(i) - } - return nil - }, - func(p redis.Pipeliner, version uint64) error { - for i, record := range records { - bs, err := proto.Marshal(record) - if err != nil { - return err - } - - key, field := getHashKey(record.GetType(), record.GetId()) - if record.DeletedAt != nil { - p.HDel(ctx, key, field) - } else { - p.HSet(ctx, key, field, bs) - p.ZAdd(ctx, getRecordTypeChangesKey(record.GetType()), &redis.Z{ - Score: float64(record.GetModifiedAt().GetSeconds()) + float64(i)/float64(len(records)), - Member: record.GetId(), - }) - } - p.ZAdd(ctx, changesSetKey, &redis.Z{ - Score: float64(version) + float64(i), - Member: bs, - }) - p.SAdd(ctx, recordTypesSetKey, record.GetType()) - } - return nil - }) -} - -// enforceOptions enforces the options for the given record type. -func (backend *Backend) enforceOptions(ctx context.Context, recordType string) error { - ctx, span := trace.StartSpan(ctx, "databroker.redis.enforceOptions") - defer span.End() - - options, err := backend.GetOptions(ctx, recordType) - if err != nil { - return err - } - - // nothing to do if capacity isn't set - if options.Capacity == nil { - return nil - } - - key := getRecordTypeChangesKey(recordType) - - // find oldest records that exceed the capacity - recordIDs, err := backend.client.ZRevRange(ctx, key, int64(*options.Capacity), -1).Result() - if err != nil { - return err - } - - // for each record, delete it - for _, recordID := range recordIDs { - record, err := backend.Get(ctx, recordType, recordID) - if err == nil { - // mark the record as deleted and re-submit - record.DeletedAt = timestamppb.Now() - err = backend.put(ctx, []*databroker.Record{record}) - if err != nil { - return err - } - } else if errors.Is(err, storage.ErrNotFound) { - // ignore - } else if err != nil { - return err - } - - // remove the member from the collection - _, err = backend.client.ZRem(ctx, key, recordID).Result() - if err != nil { - return err - } - } - - return nil -} - -// incrementVersion increments the last recordVersion key, runs the code in `query`, then attempts to commit the code in -// `commit`. If the last recordVersion changes in the interim, we will retry the transaction. -func (backend *Backend) incrementVersion(ctx context.Context, - query func(tx *redis.Tx, recordVersion uint64) error, - commit func(p redis.Pipeliner, recordVersion uint64) error, -) error { - // code is modeled on https://pkg.go.dev/github.com/go-redis/redis/v8#example-Client.Watch - txf := func(tx *redis.Tx) error { - version, err := tx.Get(ctx, lastVersionKey).Uint64() - if errors.Is(err, redis.Nil) { - version = 0 - } else if err != nil { - return err - } - version++ - - err = query(tx, version) - if err != nil { - return err - } - - // the `commit` code is run in a transaction so that the EXEC cmd will run for the original redis watch - _, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error { - err := commit(p, version) - if err != nil { - return err - } - p.Set(ctx, lastVersionKey, version, 0) - p.Publish(ctx, lastVersionChKey, version) - return nil - }) - return err - } - - bo := backoff.NewExponentialBackOff() - bo.MaxElapsedTime = 0 - for i := 0; i < maxTransactionRetries; i++ { - err := backend.client.Watch(ctx, txf, lastVersionKey) - if errors.Is(err, redis.TxFailedErr) { - select { - case <-ctx.Done(): - return ctx.Err() - case <-time.After(bo.NextBackOff()): - } - continue // retry - } else if err != nil { - return err - } - - return nil // tx was successful - } - - return ErrExceededMaxRetries -} - -func (backend *Backend) listenForVersionChanges(ctx context.Context) { - ctx, cancel := context.WithCancel(ctx) - go func() { - <-backend.closed - cancel() - }() - - bo := backoff.NewExponentialBackOff() - bo.MaxElapsedTime = 0 - -outer: - for { - pubsub := backend.client.Subscribe(ctx, lastVersionChKey) - for { - msg, err := pubsub.Receive(ctx) - if err != nil { - _ = pubsub.Close() - select { - case <-ctx.Done(): - return - case <-time.After(bo.NextBackOff()): - } - continue outer - } - bo.Reset() - - switch msg.(type) { - case *redis.Message: - backend.onChange.Broadcast(ctx) - } - } - } -} - -func (backend *Backend) removeChangesBefore(ctx context.Context, cutoff time.Time) { - for { - cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{ - Min: "-inf", - Max: "+inf", - Offset: 0, - Count: 1, - }) - results, err := cmd.Result() - if err != nil { - log.Error(ctx).Err(err).Msg("redis: error retrieving changes for expiration") - return - } - - // nothing left to do - if len(results) == 0 { - return - } - - var record databroker.Record - err = proto.Unmarshal([]byte(results[0]), &record) - if err != nil { - log.Warn(ctx).Err(err).Msg("redis: invalid record detected") - record.ModifiedAt = timestamppb.New(cutoff.Add(-time.Second)) // set the modified so will delete it - } - - // if the record's modified timestamp is after the cutoff, we're all done, so break - if record.GetModifiedAt().AsTime().After(cutoff) { - break - } - - // remove the record - err = backend.client.ZRem(ctx, changesSetKey, results[0]).Err() - if err != nil { - log.Error(ctx).Err(err).Msg("redis: error removing member") - return - } - } -} - -func (backend *Backend) getOrCreateServerVersion(ctx context.Context) (serverVersion uint64, err error) { - serverVersion, err = backend.client.Get(ctx, serverVersionKey).Uint64() - // if the server version hasn't been set yet, set it to a random value and immediately retrieve it - // this should properly handle a data race by only setting the key if it doesn't already exist - if errors.Is(err, redis.Nil) { - _, _ = backend.client.SetNX(ctx, serverVersionKey, cryptutil.NewRandomUInt64(), 0).Result() - serverVersion, err = backend.client.Get(ctx, serverVersionKey).Uint64() - } - if err != nil { - return 0, fmt.Errorf("redis: error retrieving server version: %w", err) - } - return serverVersion, err -} - -func getLeaseKey(leaseName string) string { - return fmt.Sprintf(leaseKeyTpl, leaseName) -} - -func getRecordTypeChangesKey(recordType string) string { - return fmt.Sprintf(recordTypeChangesKeyTpl, recordType) -} - -func getHashKey(recordType, id string) (key, field string) { - return recordHashKey, fmt.Sprintf("%s/%s", recordType, id) -} diff --git a/pkg/storage/redis/redis_test.go b/pkg/storage/redis/redis_test.go deleted file mode 100644 index fb7ca2d36..000000000 --- a/pkg/storage/redis/redis_test.go +++ /dev/null @@ -1,312 +0,0 @@ -package redis - -import ( - "context" - "fmt" - "os" - "runtime" - "testing" - "time" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/sync/errgroup" - "google.golang.org/protobuf/proto" - "google.golang.org/protobuf/types/known/anypb" - "google.golang.org/protobuf/types/known/timestamppb" - - "github.com/pomerium/pomerium/internal/testutil" - "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/storage" -) - -func TestBackend(t *testing.T) { - if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" { - t.Skip("Github action can not run docker on MacOS") - } - - handler := func(t *testing.T, useTLS bool, rawURL string) error { - ctx := context.Background() - var opts []Option - if useTLS { - opts = append(opts, WithTLSConfig(testutil.RedisTLSConfig())) - } - backend, err := New(rawURL, opts...) - require.NoError(t, err) - defer func() { _ = backend.Close() }() - - serverVersion, err := backend.getOrCreateServerVersion(ctx) - require.NoError(t, err) - - t.Run("get missing record", func(t *testing.T) { - record, err := backend.Get(ctx, "TYPE", "abcd") - require.Error(t, err) - assert.Nil(t, record) - }) - t.Run("get record", func(t *testing.T) { - data := new(anypb.Any) - sv, err := backend.Put(ctx, []*databroker.Record{{ - Type: "TYPE", - Id: "abcd", - Data: data, - }}) - assert.NoError(t, err) - assert.Equal(t, serverVersion, sv) - record, err := backend.Get(ctx, "TYPE", "abcd") - require.NoError(t, err) - if assert.NotNil(t, record) { - assert.Equal(t, data, record.Data) - assert.Nil(t, record.DeletedAt) - assert.Equal(t, "abcd", record.Id) - assert.NotNil(t, record.ModifiedAt) - assert.Equal(t, "TYPE", record.Type) - assert.Equal(t, uint64(1), record.Version) - } - }) - t.Run("delete record", func(t *testing.T) { - sv, err := backend.Put(ctx, []*databroker.Record{{ - Type: "TYPE", - Id: "abcd", - DeletedAt: timestamppb.Now(), - }}) - assert.NoError(t, err) - assert.Equal(t, serverVersion, sv) - record, err := backend.Get(ctx, "TYPE", "abcd") - assert.Error(t, err) - assert.Nil(t, record) - }) - t.Run("list types", func(t *testing.T) { - types, err := backend.ListTypes(ctx) - assert.NoError(t, err) - assert.Equal(t, []string{"TYPE"}, types) - }) - return nil - } - - t.Run("no-tls", func(t *testing.T) { - t.Parallel() - require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error { - return handler(t, false, rawURL) - })) - }) - - t.Run("tls", func(t *testing.T) { - t.Parallel() - require.NoError(t, testutil.WithTestRedis(true, func(rawURL string) error { - return handler(t, true, rawURL) - })) - }) - - if runtime.GOOS == "linux" { - t.Run("cluster", func(t *testing.T) { - t.Parallel() - require.NoError(t, testutil.WithTestRedisCluster(func(rawURL string) error { - return handler(t, false, rawURL) - })) - }) - - t.Run("sentinel", func(t *testing.T) { - t.Parallel() - require.NoError(t, testutil.WithTestRedisSentinel(func(rawURL string) error { - return handler(t, false, rawURL) - })) - }) - } -} - -func TestChangeSignal(t *testing.T) { - if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" { - t.Skip("Github action can not run docker on MacOS") - } - - t.Parallel() - - ctx := context.Background() - require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error { - ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30) - defer clearTimeout() - - done := make(chan struct{}) - var eg errgroup.Group - eg.Go(func() error { - backend, err := New(rawURL) - if err != nil { - return err - } - defer func() { _ = backend.Close() }() - - ch := backend.onChange.Bind() - defer backend.onChange.Unbind(ch) - - select { - case <-ch: - case <-ctx.Done(): - return ctx.Err() - } - - // signal the second backend that we've received the change - close(done) - - return nil - }) - eg.Go(func() error { - backend, err := New(rawURL) - if err != nil { - return err - } - defer func() { _ = backend.Close() }() - - // put a new value to trigger a change - for { - _, err = backend.Put(ctx, []*databroker.Record{{ - Type: "TYPE", - Id: "ID", - }}) - if err != nil { - return err - } - - select { - case <-ctx.Done(): - return ctx.Err() - case <-done: - return nil - case <-time.After(time.Millisecond * 100): - } - } - }) - assert.NoError(t, eg.Wait(), "expected signal to be fired when another backend triggers a change") - return nil - })) -} - -func TestExpiry(t *testing.T) { - if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" { - t.Skip("Github action can not run docker on MacOS") - } - - t.Parallel() - - ctx := context.Background() - require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error { - backend, err := New(rawURL, WithExpiry(0)) - require.NoError(t, err) - defer func() { _ = backend.Close() }() - - serverVersion, err := backend.getOrCreateServerVersion(ctx) - require.NoError(t, err) - - for i := 0; i < 1000; i++ { - _, err := backend.Put(ctx, []*databroker.Record{{ - Type: "TYPE", - Id: fmt.Sprint(i), - }}) - assert.NoError(t, err) - } - stream, err := backend.Sync(ctx, "TYPE", serverVersion, 0) - require.NoError(t, err) - var records []*databroker.Record - for stream.Next(false) { - records = append(records, stream.Record()) - } - _ = stream.Close() - require.Len(t, records, 1000) - - backend.removeChangesBefore(ctx, time.Now().Add(time.Second)) - - stream, err = backend.Sync(ctx, "TYPE", serverVersion, 0) - require.NoError(t, err) - records = nil - for stream.Next(false) { - records = append(records, stream.Record()) - } - _ = stream.Close() - require.Len(t, records, 0) - - return nil - })) -} - -func TestCapacity(t *testing.T) { - if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" { - t.Skip("Github action can not run docker on MacOS") - } - - t.Parallel() - - ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) - defer clearTimeout() - - require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error { - backend, err := New(rawURL, WithExpiry(0)) - require.NoError(t, err) - defer func() { _ = backend.Close() }() - - err = backend.SetOptions(ctx, "EXAMPLE", &databroker.Options{ - Capacity: proto.Uint64(3), - }) - require.NoError(t, err) - - for i := 0; i < 10; i++ { - _, err = backend.Put(ctx, []*databroker.Record{{ - Type: "EXAMPLE", - Id: fmt.Sprint(i), - }}) - require.NoError(t, err) - } - - _, _, stream, err := backend.SyncLatest(ctx, "EXAMPLE", nil) - require.NoError(t, err) - defer stream.Close() - - records, err := storage.RecordStreamToList(stream) - require.NoError(t, err) - assert.Len(t, records, 3) - - var ids []string - for _, r := range records { - ids = append(ids, r.GetId()) - } - assert.Equal(t, []string{"7", "8", "9"}, ids, "should contain recent records") - - return nil - })) -} - -func TestLease(t *testing.T) { - if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" { - t.Skip("Github action can not run docker on MacOS") - } - - t.Parallel() - - ctx := context.Background() - require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error { - backend, err := New(rawURL) - require.NoError(t, err) - defer func() { _ = backend.Close() }() - - { - ok, err := backend.Lease(ctx, "test", "a", time.Second*30) - require.NoError(t, err) - assert.True(t, ok, "expected a to acquire the lease") - } - { - ok, err := backend.Lease(ctx, "test", "b", time.Second*30) - require.NoError(t, err) - assert.False(t, ok, "expected b to fail to acquire the lease") - } - { - ok, err := backend.Lease(ctx, "test", "a", 0) - require.NoError(t, err) - assert.False(t, ok, "expected a to clear the lease") - } - { - ok, err := backend.Lease(ctx, "test", "b", time.Second*30) - require.NoError(t, err) - assert.True(t, ok, "expected b to to acquire the lease") - } - - return nil - })) -} diff --git a/pkg/storage/redis/stream.go b/pkg/storage/redis/stream.go deleted file mode 100644 index c5a79d95c..000000000 --- a/pkg/storage/redis/stream.go +++ /dev/null @@ -1,172 +0,0 @@ -package redis - -import ( - "context" - "errors" - "fmt" - "time" - - "github.com/go-redis/redis/v8" - "google.golang.org/protobuf/proto" - - "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/pkg/grpc/databroker" - "github.com/pomerium/pomerium/pkg/storage" -) - -func newSyncRecordStream( - ctx context.Context, - backend *Backend, - recordType string, - serverVersion uint64, - recordVersion uint64, -) storage.RecordStream { - changed := backend.onChange.Bind() - return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{ - // 1. stream all record changes - func(ctx context.Context, block bool) (*databroker.Record, error) { - ticker := time.NewTicker(watchPollInterval) - defer ticker.Stop() - - for { - currentServerVersion, err := backend.getOrCreateServerVersion(ctx) - if err != nil { - return nil, err - } - if serverVersion != currentServerVersion { - return nil, storage.ErrInvalidServerVersion - } - - record, err := nextChangedRecord(ctx, backend, recordType, &recordVersion) - if err == nil { - return record, nil - } else if !errors.Is(err, storage.ErrStreamDone) { - return nil, err - } - - if !block { - return nil, storage.ErrStreamDone - } - - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-ticker.C: - case <-changed: - } - } - }, - }, func() { - backend.onChange.Unbind(changed) - }) -} - -func newSyncLatestRecordStream( - ctx context.Context, - backend *Backend, - recordType string, - expr storage.FilterExpression, -) (storage.RecordStream, error) { - filter, err := storage.RecordStreamFilterFromFilterExpression(expr) - if err != nil { - return nil, err - } - if recordType != "" { - filter = filter.And(func(record *databroker.Record) (keep bool) { - return record.GetType() == recordType - }) - } - - var cursor uint64 - scannedOnce := false - var scannedRecords []*databroker.Record - generator := storage.FilteredRecordStreamGenerator( - func(ctx context.Context, block bool) (*databroker.Record, error) { - for { - if len(scannedRecords) > 0 { - record := scannedRecords[0] - scannedRecords = scannedRecords[1:] - return record, nil - } - - // the cursor is reset to 0 after iteration is complete - if scannedOnce && cursor == 0 { - return nil, storage.ErrStreamDone - } - - var err error - scannedRecords, err = nextScannedRecords(ctx, backend, &cursor) - if err != nil { - return nil, err - } - - scannedOnce = true - } - }, - filter, - ) - - return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{ - generator, - }, nil), nil -} - -func nextScannedRecords(ctx context.Context, backend *Backend, cursor *uint64) ([]*databroker.Record, error) { - var values []string - var err error - values, *cursor, err = backend.client.HScan(ctx, recordHashKey, *cursor, "", 0).Result() - if errors.Is(err, redis.Nil) { - return nil, storage.ErrStreamDone - } else if err != nil { - return nil, err - } else if len(values) == 0 { - return nil, storage.ErrStreamDone - } - - var records []*databroker.Record - for i := 1; i < len(values); i += 2 { - var record databroker.Record - err := proto.Unmarshal([]byte(values[i]), &record) - if err != nil { - log.Warn(ctx).Err(err).Msg("redis: invalid record detected") - continue - } - records = append(records, &record) - } - return records, nil -} - -func nextChangedRecord(ctx context.Context, backend *Backend, recordType string, recordVersion *uint64) (*databroker.Record, error) { - for { - cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{ - Min: fmt.Sprintf("(%d", *recordVersion), - Max: "+inf", - Offset: 0, - Count: 1, - }) - results, err := cmd.Result() - if errors.Is(err, redis.Nil) { - return nil, storage.ErrStreamDone - } else if err != nil { - return nil, err - } else if len(results) == 0 { - return nil, storage.ErrStreamDone - } - - result := results[0] - var record databroker.Record - err = proto.Unmarshal([]byte(result), &record) - if err != nil { - log.Warn(ctx).Err(err).Msg("redis: invalid record detected") - *recordVersion++ - continue - } - - *recordVersion = record.GetVersion() - if recordType != "" && record.GetType() != recordType { - continue - } - - return &record, nil - } -}