mirror of
https://github.com/pomerium/pomerium.git
synced 2025-07-17 08:38:15 +02:00
Merge branch 'main' into wasaga/databroker-recordset
This commit is contained in:
commit
ba077f6b9e
32 changed files with 1137 additions and 2334 deletions
|
@ -245,7 +245,7 @@ type Options struct {
|
|||
DataBrokerURLStrings []string `mapstructure:"databroker_service_urls" yaml:"databroker_service_urls,omitempty"`
|
||||
DataBrokerInternalURLString string `mapstructure:"databroker_internal_service_url" yaml:"databroker_internal_service_url,omitempty"`
|
||||
// DataBrokerStorageType is the storage backend type that databroker will use.
|
||||
// Supported type: memory, redis
|
||||
// Supported type: memory, postgres
|
||||
DataBrokerStorageType string `mapstructure:"databroker_storage_type" yaml:"databroker_storage_type,omitempty"`
|
||||
// DataBrokerStorageConnectionString is the data source name for storage backend.
|
||||
DataBrokerStorageConnectionString string `mapstructure:"databroker_storage_connection_string" yaml:"databroker_storage_connection_string,omitempty"`
|
||||
|
@ -584,7 +584,9 @@ func (o *Options) Validate() error {
|
|||
|
||||
switch o.DataBrokerStorageType {
|
||||
case StorageInMemoryName:
|
||||
case StorageRedisName, StoragePostgresName:
|
||||
case StorageRedisName:
|
||||
return errors.New("config: redis databroker storage backend is no longer supported")
|
||||
case StoragePostgresName:
|
||||
if o.DataBrokerStorageConnectionString == "" {
|
||||
return errors.New("config: missing databroker storage backend dsn")
|
||||
}
|
||||
|
|
|
@ -58,8 +58,10 @@ func Test_Validate(t *testing.T) {
|
|||
badPolicyFile.PolicyFile = "file"
|
||||
invalidStorageType := testOptions()
|
||||
invalidStorageType.DataBrokerStorageType = "foo"
|
||||
redisStorageType := testOptions()
|
||||
redisStorageType.DataBrokerStorageType = "redis"
|
||||
missingStorageDSN := testOptions()
|
||||
missingStorageDSN.DataBrokerStorageType = "redis"
|
||||
missingStorageDSN.DataBrokerStorageType = "postgres"
|
||||
badSignoutRedirectURL := testOptions()
|
||||
badSignoutRedirectURL.SignOutRedirectURLString = "--"
|
||||
badCookieSettings := testOptions()
|
||||
|
@ -77,6 +79,7 @@ func Test_Validate(t *testing.T) {
|
|||
{"missing shared secret but all service", badSecretAllServices, false},
|
||||
{"policy file specified", badPolicyFile, true},
|
||||
{"invalid databroker storage type", invalidStorageType, true},
|
||||
{"redis databroker storage type", redisStorageType, true},
|
||||
{"missing databroker storage dsn", missingStorageDSN, true},
|
||||
{"invalid signout redirect url", badSignoutRedirectURL, true},
|
||||
{"CookieSameSite none with CookieSecure fale", badCookieSettings, true},
|
||||
|
|
|
@ -93,6 +93,13 @@ func (srv *dataBrokerServer) Put(ctx context.Context, req *databrokerpb.PutReque
|
|||
return srv.server.Put(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) Patch(ctx context.Context, req *databrokerpb.PatchRequest) (*databrokerpb.PatchResponse, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return srv.server.Patch(ctx, req)
|
||||
}
|
||||
|
||||
func (srv *dataBrokerServer) ReleaseLease(ctx context.Context, req *databrokerpb.ReleaseLeaseRequest) (*emptypb.Empty, error) {
|
||||
if err := grpcutil.RequireSignedJWT(ctx, srv.sharedKey.Load()); err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -9,7 +9,6 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/registry"
|
||||
"github.com/pomerium/pomerium/internal/registry/inmemory"
|
||||
"github.com/pomerium/pomerium/internal/registry/redis"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
|
@ -110,16 +109,6 @@ func (srv *Server) newRegistryLocked(backend storage.Backend) (registry.Interfac
|
|||
case config.StorageInMemoryName:
|
||||
log.Info(ctx).Msg("using in-memory registry")
|
||||
return inmemory.New(ctx, srv.cfg.registryTTL), nil
|
||||
case config.StorageRedisName:
|
||||
log.Info(ctx).Msg("using redis registry")
|
||||
r, err := redis.New(
|
||||
srv.cfg.storageConnectionString,
|
||||
redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create new redis registry: %w", err)
|
||||
}
|
||||
return r, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unsupported registry type: %s", srv.cfg.storageType)
|
||||
|
|
|
@ -3,7 +3,6 @@ package databroker
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
@ -19,12 +18,10 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/registry"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/pkg/storage/inmemory"
|
||||
"github.com/pomerium/pomerium/pkg/storage/postgres"
|
||||
"github.com/pomerium/pomerium/pkg/storage/redis"
|
||||
)
|
||||
|
||||
// Server implements the databroker service using an in memory database.
|
||||
|
@ -237,6 +234,45 @@ func (srv *Server) Put(ctx context.Context, req *databroker.PutRequest) (*databr
|
|||
return res, nil
|
||||
}
|
||||
|
||||
// Patch updates specific fields of an existing record.
|
||||
func (srv *Server) Patch(ctx context.Context, req *databroker.PatchRequest) (*databroker.PatchResponse, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "databroker.grpc.Patch")
|
||||
defer span.End()
|
||||
|
||||
records := req.GetRecords()
|
||||
if len(records) == 1 {
|
||||
log.Info(ctx).
|
||||
Str("record-type", records[0].GetType()).
|
||||
Str("record-id", records[0].GetId()).
|
||||
Msg("patch")
|
||||
} else {
|
||||
var recordType string
|
||||
for _, record := range records {
|
||||
recordType = record.GetType()
|
||||
}
|
||||
log.Info(ctx).
|
||||
Int("record-count", len(records)).
|
||||
Str("record-type", recordType).
|
||||
Msg("patch")
|
||||
}
|
||||
|
||||
db, err := srv.getBackend()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
serverVersion, patchedRecords, err := db.Patch(ctx, records, req.GetFieldMask())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res := &databroker.PatchResponse{
|
||||
ServerVersion: serverVersion,
|
||||
Records: patchedRecords,
|
||||
}
|
||||
|
||||
return res, nil
|
||||
}
|
||||
|
||||
// ReleaseLease releases a lease.
|
||||
func (srv *Server) ReleaseLease(ctx context.Context, req *databroker.ReleaseLeaseRequest) (*emptypb.Empty, error) {
|
||||
ctx, span := trace.StartSpan(ctx, "databroker.grpc.ReleaseLease")
|
||||
|
@ -426,39 +462,8 @@ func (srv *Server) newBackendLocked() (backend storage.Backend, err error) {
|
|||
case config.StoragePostgresName:
|
||||
log.Info(ctx).Msg("using postgres store")
|
||||
backend = postgres.New(srv.cfg.storageConnectionString)
|
||||
case config.StorageRedisName:
|
||||
log.Info(ctx).Msg("using redis store")
|
||||
backend, err = redis.New(
|
||||
srv.cfg.storageConnectionString,
|
||||
redis.WithTLSConfig(srv.getTLSConfigLocked(ctx)),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
|
||||
}
|
||||
if srv.cfg.secret != nil {
|
||||
backend, err = storage.NewEncryptedBackend(srv.cfg.secret, backend)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported storage type: %s", srv.cfg.storageType)
|
||||
}
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
func (srv *Server) getTLSConfigLocked(ctx context.Context) *tls.Config {
|
||||
caCertPool, err := cryptutil.GetCertPool("", srv.cfg.storageCAFile)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Err(err).Msg("failed to read databroker CA file")
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: caCertPool,
|
||||
//nolint: gosec
|
||||
InsecureSkipVerify: srv.cfg.storageCertSkipVerify,
|
||||
}
|
||||
if srv.cfg.storageCertificate != nil {
|
||||
tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate}
|
||||
}
|
||||
return tlsConfig
|
||||
}
|
||||
|
|
|
@ -18,11 +18,11 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
|
@ -85,6 +85,58 @@ func TestServer_Get(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestServer_Patch(t *testing.T) {
|
||||
cfg := newServerConfig()
|
||||
srv := newServer(cfg)
|
||||
|
||||
s := &session.Session{
|
||||
Id: "1",
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token"},
|
||||
}
|
||||
data := protoutil.NewAny(s)
|
||||
_, err := srv.Put(context.Background(), &databroker.PutRequest{
|
||||
Records: []*databroker.Record{{
|
||||
Type: data.TypeUrl,
|
||||
Id: s.Id,
|
||||
Data: data,
|
||||
}},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
fm, err := fieldmaskpb.New(s, "accessed_at")
|
||||
require.NoError(t, err)
|
||||
|
||||
now := timestamppb.Now()
|
||||
s.AccessedAt = now
|
||||
s.OauthToken.AccessToken = "access-token-field-ignored"
|
||||
data = protoutil.NewAny(s)
|
||||
patchResponse, err := srv.Patch(context.Background(), &databroker.PatchRequest{
|
||||
Records: []*databroker.Record{{
|
||||
Type: data.TypeUrl,
|
||||
Id: s.Id,
|
||||
Data: data,
|
||||
}},
|
||||
FieldMask: fm,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoEqual(t, protoutil.NewAny(&session.Session{
|
||||
Id: "1",
|
||||
AccessedAt: now,
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token"},
|
||||
}), patchResponse.GetRecord().GetData())
|
||||
|
||||
getResponse, err := srv.Get(context.Background(), &databroker.GetRequest{
|
||||
Type: data.TypeUrl,
|
||||
Id: s.Id,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoEqual(t, protoutil.NewAny(&session.Session{
|
||||
Id: "1",
|
||||
AccessedAt: now,
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token"},
|
||||
}), getResponse.GetRecord().GetData())
|
||||
}
|
||||
|
||||
func TestServer_Options(t *testing.T) {
|
||||
cfg := newServerConfig()
|
||||
srv := newServer(cfg)
|
||||
|
@ -287,12 +339,11 @@ func TestServerInvalidStorage(t *testing.T) {
|
|||
_ = assert.Error(t, err) && assert.Contains(t, err.Error(), "unsupported storage type")
|
||||
}
|
||||
|
||||
func TestServerRedis(t *testing.T) {
|
||||
testutil.WithTestRedis(false, func(rawURL string) error {
|
||||
func TestServerPostgres(t *testing.T) {
|
||||
testutil.WithTestPostgres(func(dsn string) error {
|
||||
srv := newServer(&serverConfig{
|
||||
storageType: "redis",
|
||||
storageConnectionString: rawURL,
|
||||
secret: cryptutil.NewKey(),
|
||||
storageType: "postgres",
|
||||
storageConnectionString: dsn,
|
||||
})
|
||||
|
||||
s := new(session.Session)
|
||||
|
|
|
@ -14,13 +14,6 @@ type (
|
|||
// A Handle represents a listener.
|
||||
Handle string
|
||||
|
||||
addListenerEvent[T any] struct {
|
||||
listener Listener[T]
|
||||
handle Handle
|
||||
}
|
||||
removeListenerEvent[T any] struct {
|
||||
handle Handle
|
||||
}
|
||||
dispatchEvent[T any] struct {
|
||||
ctx context.Context
|
||||
event T
|
||||
|
@ -36,131 +29,123 @@ type (
|
|||
//
|
||||
// Target is safe to use in its zero state.
|
||||
//
|
||||
// The first time any method of Target is called a background goroutine is started that handles
|
||||
// any requests and maintains the state of the listeners. Each listener also starts a
|
||||
// separate goroutine so that all listeners can be invoked concurrently.
|
||||
// Each listener is run in its own goroutine.
|
||||
//
|
||||
// The channels to the main goroutine and to the listener goroutines have a size of 1 so typically
|
||||
// methods and dispatches will return immediately. However a slow listener will cause the next event
|
||||
// dispatch to block. This is the opposite behavior from Manager.
|
||||
// A slow listener will cause the next event dispatch to block. This is the
|
||||
// opposite behavior from Manager.
|
||||
//
|
||||
// Close will cancel all the goroutines. Subsequent calls to AddListener, RemoveListener, Close and
|
||||
// Dispatch are no-ops.
|
||||
// Close will remove and cancel all listeners.
|
||||
type Target[T any] struct {
|
||||
initOnce sync.Once
|
||||
ctx context.Context
|
||||
cancel context.CancelCauseFunc
|
||||
addListenerCh chan addListenerEvent[T]
|
||||
removeListenerCh chan removeListenerEvent[T]
|
||||
dispatchCh chan dispatchEvent[T]
|
||||
listeners map[Handle]chan dispatchEvent[T]
|
||||
mu sync.RWMutex
|
||||
listeners map[Handle]targetListener[T]
|
||||
}
|
||||
|
||||
// AddListener adds a listener to the target.
|
||||
func (t *Target[T]) AddListener(listener Listener[T]) Handle {
|
||||
t.init()
|
||||
|
||||
// using a handle is necessary because you can't use a function as a map key.
|
||||
handle := Handle(uuid.NewString())
|
||||
h := Handle(uuid.NewString())
|
||||
tl := newTargetListener(listener)
|
||||
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
case t.addListenerCh <- addListenerEvent[T]{listener, handle}:
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
if t.listeners == nil {
|
||||
t.listeners = make(map[Handle]targetListener[T])
|
||||
}
|
||||
|
||||
return handle
|
||||
t.listeners[h] = tl
|
||||
return h
|
||||
}
|
||||
|
||||
// Close closes the event target. This can be called multiple times safely.
|
||||
// Once closed the target cannot be used.
|
||||
func (t *Target[T]) Close() {
|
||||
t.init()
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
t.cancel(errors.New("target closed"))
|
||||
for _, tl := range t.listeners {
|
||||
tl.close()
|
||||
}
|
||||
t.listeners = nil
|
||||
}
|
||||
|
||||
// Dispatch dispatches an event to all listeners.
|
||||
func (t *Target[T]) Dispatch(ctx context.Context, evt T) {
|
||||
t.init()
|
||||
// store all the listeners in a slice so we don't hold the lock while dispatching
|
||||
var tls []targetListener[T]
|
||||
t.mu.RLock()
|
||||
tls = make([]targetListener[T], 0, len(t.listeners))
|
||||
for _, tl := range t.listeners {
|
||||
tls = append(tls, tl)
|
||||
}
|
||||
t.mu.RUnlock()
|
||||
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
case t.dispatchCh <- dispatchEvent[T]{ctx: ctx, event: evt}:
|
||||
// Because we're outside of the lock it's possible we may dispatch to a listener
|
||||
// that's been removed if Dispatch and RemoveListener are called from separate
|
||||
// goroutines. There should be no possibility of a deadlock however.
|
||||
|
||||
for _, tl := range tls {
|
||||
tl.dispatch(dispatchEvent[T]{ctx: ctx, event: evt})
|
||||
}
|
||||
}
|
||||
|
||||
// RemoveListener removes a listener from the target.
|
||||
func (t *Target[T]) RemoveListener(handle Handle) {
|
||||
t.init()
|
||||
t.mu.Lock()
|
||||
defer t.mu.Unlock()
|
||||
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
case t.removeListenerCh <- removeListenerEvent[T]{handle}:
|
||||
if t.listeners == nil {
|
||||
t.listeners = make(map[Handle]targetListener[T])
|
||||
}
|
||||
}
|
||||
|
||||
func (t *Target[T]) init() {
|
||||
t.initOnce.Do(func() {
|
||||
t.ctx, t.cancel = context.WithCancelCause(context.Background())
|
||||
t.addListenerCh = make(chan addListenerEvent[T], 1)
|
||||
t.removeListenerCh = make(chan removeListenerEvent[T], 1)
|
||||
t.dispatchCh = make(chan dispatchEvent[T], 1)
|
||||
t.listeners = map[Handle]chan dispatchEvent[T]{}
|
||||
go t.run()
|
||||
})
|
||||
}
|
||||
|
||||
func (t *Target[T]) run() {
|
||||
// listen for add/remove/dispatch events and call functions
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
return
|
||||
case evt := <-t.addListenerCh:
|
||||
t.addListener(evt.listener, evt.handle)
|
||||
case evt := <-t.removeListenerCh:
|
||||
t.removeListener(evt.handle)
|
||||
case evt := <-t.dispatchCh:
|
||||
t.dispatch(evt.ctx, evt.event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// these functions are not thread-safe. They are intended to be called only by "run".
|
||||
|
||||
func (t *Target[T]) addListener(listener Listener[T], handle Handle) {
|
||||
ch := make(chan dispatchEvent[T], 1)
|
||||
t.listeners[handle] = ch
|
||||
// start a goroutine to send events to the listener
|
||||
go func() {
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
case evt := <-ch:
|
||||
listener(evt.ctx, evt.event)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (t *Target[T]) removeListener(handle Handle) {
|
||||
ch, ok := t.listeners[handle]
|
||||
tl, ok := t.listeners[handle]
|
||||
if !ok {
|
||||
// nothing to do since the listener doesn't exist
|
||||
return
|
||||
}
|
||||
// close the channel to kill the goroutine
|
||||
close(ch)
|
||||
|
||||
tl.close()
|
||||
delete(t.listeners, handle)
|
||||
}
|
||||
|
||||
func (t *Target[T]) dispatch(ctx context.Context, evt T) {
|
||||
// loop over all the listeners and send the event to them
|
||||
for _, ch := range t.listeners {
|
||||
// A targetListener starts a goroutine that pulls events from "ch" and
|
||||
// calls the listener for each event.
|
||||
//
|
||||
// The goroutine is stopped when ".close()" is called. We don't rely
|
||||
// on closing "ch" because sending to a closed channel results in a
|
||||
// panic. Instead we signal closing via "ctx.Done()".
|
||||
type targetListener[T any] struct {
|
||||
ctx context.Context
|
||||
cancel context.CancelCauseFunc
|
||||
ch chan dispatchEvent[T]
|
||||
listener Listener[T]
|
||||
}
|
||||
|
||||
func newTargetListener[T any](listener Listener[T]) targetListener[T] {
|
||||
li := targetListener[T]{}
|
||||
li.ctx, li.cancel = context.WithCancelCause(context.Background())
|
||||
li.ch = make(chan dispatchEvent[T])
|
||||
li.listener = listener
|
||||
go li.run()
|
||||
return li
|
||||
}
|
||||
|
||||
func (li targetListener[T]) close() {
|
||||
li.cancel(errors.New("events target listener closed"))
|
||||
}
|
||||
|
||||
func (li targetListener[T]) dispatch(evt dispatchEvent[T]) {
|
||||
select {
|
||||
case <-li.ctx.Done():
|
||||
case li.ch <- evt:
|
||||
}
|
||||
}
|
||||
|
||||
func (li targetListener[T]) run() {
|
||||
for {
|
||||
select {
|
||||
case <-t.ctx.Done():
|
||||
case <-li.ctx.Done():
|
||||
return
|
||||
case ch <- dispatchEvent[T]{ctx: ctx, event: evt}:
|
||||
case evt := <-li.ch:
|
||||
li.listener(evt.ctx, evt.event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,20 +0,0 @@
|
|||
// Package lua contains lua source code.
|
||||
package lua
|
||||
|
||||
import (
|
||||
"embed"
|
||||
)
|
||||
|
||||
//go:embed registry.lua
|
||||
var fs embed.FS
|
||||
|
||||
// Registry is the registry lua script
|
||||
var Registry string
|
||||
|
||||
func init() {
|
||||
bs, err := fs.ReadFile("registry.lua")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
Registry = string(bs)
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
-- ARGV = [current time in seconds, ttl in seconds, services ...]
|
||||
local current_time = ARGV[1]
|
||||
local ttl = ARGV[2]
|
||||
local changed = false
|
||||
|
||||
-- update the service list
|
||||
for i = 3, #ARGV, 1 do
|
||||
redis.call('HSET', KEYS[1], ARGV[i], current_time + ttl)
|
||||
changed = true
|
||||
end
|
||||
|
||||
-- retrieve all the services, removing any that have expired
|
||||
local svcs = {}
|
||||
local kvs = redis.call('HGETALL', KEYS[1])
|
||||
for i = 1, #kvs, 2 do
|
||||
if kvs[i + 1] < current_time then
|
||||
redis.call('HDEL', KEYS[1], kvs[i])
|
||||
changed = true
|
||||
else
|
||||
table.insert(svcs, kvs[i])
|
||||
end
|
||||
end
|
||||
|
||||
if changed then
|
||||
redis.call('PUBLISH', KEYS[2], current_time)
|
||||
end
|
||||
|
||||
return svcs
|
|
@ -1,48 +0,0 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
)
|
||||
|
||||
const defaultTTL = time.Second * 30
|
||||
|
||||
type config struct {
|
||||
tls *tls.Config
|
||||
ttl time.Duration
|
||||
getNow func() time.Time
|
||||
}
|
||||
|
||||
// An Option modifies the config..
|
||||
type Option func(*config)
|
||||
|
||||
// WithGetNow sets the time.Now function in the config.
|
||||
func WithGetNow(getNow func() time.Time) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.getNow = getNow
|
||||
}
|
||||
}
|
||||
|
||||
// WithTLSConfig sets the tls.Config in the config.
|
||||
func WithTLSConfig(tlsConfig *tls.Config) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.tls = tlsConfig
|
||||
}
|
||||
}
|
||||
|
||||
// WithTTL sets the ttl in the config.
|
||||
func WithTTL(ttl time.Duration) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.ttl = ttl
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithGetNow(time.Now)(cfg)
|
||||
WithTTL(defaultTTL)(cfg)
|
||||
for _, o := range options {
|
||||
o(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
|
@ -1,250 +0,0 @@
|
|||
// Package redis implements a registry in redis.
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/redisutil"
|
||||
"github.com/pomerium/pomerium/internal/registry"
|
||||
"github.com/pomerium/pomerium/internal/registry/redis/lua"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||
)
|
||||
|
||||
const (
|
||||
registryKey = redisutil.KeyPrefix + "registry"
|
||||
registryUpdateKey = redisutil.KeyPrefix + "registry_changed_ch"
|
||||
|
||||
pollInterval = time.Second * 30
|
||||
)
|
||||
|
||||
type impl struct {
|
||||
cfg *config
|
||||
|
||||
client redis.UniversalClient
|
||||
onChange *signal.Signal
|
||||
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// New creates a new registry implementation backend by redis.
|
||||
func New(rawURL string, options ...Option) (registry.Interface, error) {
|
||||
cfg := getConfig(options...)
|
||||
|
||||
client, err := redisutil.NewClientFromURL(rawURL, cfg.tls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
i := &impl{
|
||||
cfg: cfg,
|
||||
client: client,
|
||||
onChange: signal.New(),
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
go i.listenForChanges(context.Background())
|
||||
return i, nil
|
||||
}
|
||||
|
||||
func (i *impl) Report(ctx context.Context, req *registrypb.RegisterRequest) (*registrypb.RegisterResponse, error) {
|
||||
_, err := i.runReport(ctx, req.GetServices())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ®istrypb.RegisterResponse{
|
||||
CallBackAfter: durationpb.New(i.cfg.ttl / 2),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (i *impl) List(ctx context.Context, req *registrypb.ListRequest) (*registrypb.ServiceList, error) {
|
||||
all, err := i.runReport(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
include := map[registrypb.ServiceKind]struct{}{}
|
||||
for _, kind := range req.GetKinds() {
|
||||
include[kind] = struct{}{}
|
||||
}
|
||||
|
||||
filtered := make([]*registrypb.Service, 0, len(all))
|
||||
for _, svc := range all {
|
||||
if _, ok := include[svc.GetKind()]; !ok {
|
||||
continue
|
||||
}
|
||||
filtered = append(filtered, svc)
|
||||
}
|
||||
|
||||
sort.Slice(filtered, func(i, j int) bool {
|
||||
{
|
||||
iv, jv := filtered[i].GetKind(), filtered[j].GetKind()
|
||||
switch {
|
||||
case iv < jv:
|
||||
return true
|
||||
case jv < iv:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
iv, jv := filtered[i].GetEndpoint(), filtered[j].GetEndpoint()
|
||||
switch {
|
||||
case iv < jv:
|
||||
return true
|
||||
case jv < iv:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
})
|
||||
|
||||
return ®istrypb.ServiceList{
|
||||
Services: filtered,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (i *impl) Watch(req *registrypb.ListRequest, stream registrypb.Registry_WatchServer) error {
|
||||
// listen for changes
|
||||
ch := i.onChange.Bind()
|
||||
defer i.onChange.Unbind(ch)
|
||||
|
||||
// force a check periodically
|
||||
poll := time.NewTicker(pollInterval)
|
||||
defer poll.Stop()
|
||||
|
||||
var prev *registrypb.ServiceList
|
||||
for {
|
||||
// retrieve the most recent list of services
|
||||
lst, err := i.List(stream.Context(), req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// only send a new list if something changed
|
||||
if !proto.Equal(prev, lst) {
|
||||
err = stream.Send(lst)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
prev = lst
|
||||
|
||||
// wait for an update
|
||||
select {
|
||||
case <-i.closed:
|
||||
return nil
|
||||
case <-stream.Context().Done():
|
||||
return stream.Context().Err()
|
||||
case <-ch:
|
||||
case <-poll.C:
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *impl) Close() error {
|
||||
var err error
|
||||
i.closeOnce.Do(func() {
|
||||
err = i.client.Close()
|
||||
close(i.closed)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (i *impl) listenForChanges(ctx context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
<-i.closed
|
||||
cancel()
|
||||
}()
|
||||
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0
|
||||
|
||||
outer:
|
||||
for {
|
||||
pubsub := i.client.Subscribe(ctx, registryUpdateKey)
|
||||
for {
|
||||
msg, err := pubsub.Receive(ctx)
|
||||
if err != nil {
|
||||
_ = pubsub.Close()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(bo.NextBackOff()):
|
||||
}
|
||||
continue outer
|
||||
}
|
||||
bo.Reset()
|
||||
|
||||
switch msg.(type) {
|
||||
case *redis.Message:
|
||||
i.onChange.Broadcast(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *impl) runReport(ctx context.Context, updates []*registrypb.Service) ([]*registrypb.Service, error) {
|
||||
args := []interface{}{
|
||||
i.cfg.getNow().UnixNano() / int64(time.Millisecond), // current_time
|
||||
i.cfg.ttl.Milliseconds(), // ttl
|
||||
}
|
||||
for _, svc := range updates {
|
||||
args = append(args, i.getRegistryHashKey(svc))
|
||||
}
|
||||
res, err := i.client.Eval(ctx, lua.Registry, []string{registryKey, registryUpdateKey}, args...).Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if values, ok := res.([]interface{}); ok {
|
||||
var all []*registrypb.Service
|
||||
for _, value := range values {
|
||||
svc, err := i.getServiceFromRegistryHashKey(fmt.Sprint(value))
|
||||
if err != nil {
|
||||
log.Warn(ctx).Err(err).Msg("redis: invalid service")
|
||||
continue
|
||||
}
|
||||
all = append(all, svc)
|
||||
}
|
||||
return all, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (i *impl) getServiceFromRegistryHashKey(key string) (*registrypb.Service, error) {
|
||||
idx := strings.Index(key, "|")
|
||||
if idx == -1 {
|
||||
return nil, fmt.Errorf("redis: invalid service entry in hash: %s", key)
|
||||
}
|
||||
|
||||
svcKindStr := key[:idx]
|
||||
svcEndpointStr := key[idx+1:]
|
||||
|
||||
svcKind, ok := registrypb.ServiceKind_value[svcKindStr]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("redis: unknown service kind: %s", svcKindStr)
|
||||
}
|
||||
|
||||
svc := ®istrypb.Service{
|
||||
Kind: registrypb.ServiceKind(svcKind),
|
||||
Endpoint: svcEndpointStr,
|
||||
}
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
func (i *impl) getRegistryHashKey(svc *registrypb.Service) string {
|
||||
return svc.GetKind().String() + "|" + svc.GetEndpoint()
|
||||
}
|
|
@ -1,196 +0,0 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
registrypb "github.com/pomerium/pomerium/pkg/grpc/registry"
|
||||
)
|
||||
|
||||
func TestReport(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
||||
t.Skip("Github action can not run docker on MacOS")
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
||||
tm := time.Now()
|
||||
|
||||
i, err := New(rawURL,
|
||||
WithGetNow(func() time.Time {
|
||||
return tm
|
||||
}),
|
||||
WithTTL(time.Second*10))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = i.Close() }()
|
||||
|
||||
_, err = i.Report(ctx, ®istrypb.RegisterRequest{
|
||||
Services: []*registrypb.Service{
|
||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"},
|
||||
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"},
|
||||
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// move forward 5 seconds
|
||||
tm = tm.Add(time.Second * 5)
|
||||
_, err = i.Report(ctx, ®istrypb.RegisterRequest{
|
||||
Services: []*registrypb.Service{
|
||||
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "https://authenticate.example.com"},
|
||||
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
lst, err := i.List(ctx, ®istrypb.ListRequest{
|
||||
Kinds: []registrypb.ServiceKind{
|
||||
registrypb.ServiceKind_AUTHORIZE,
|
||||
registrypb.ServiceKind_PROXY,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []*registrypb.Service{
|
||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "https://authorize.example.com"},
|
||||
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
||||
}, lst.GetServices(), "should list selected services")
|
||||
|
||||
// move forward 6 seconds
|
||||
tm = tm.Add(time.Second * 6)
|
||||
lst, err = i.List(ctx, ®istrypb.ListRequest{
|
||||
Kinds: []registrypb.ServiceKind{
|
||||
registrypb.ServiceKind_AUTHORIZE,
|
||||
registrypb.ServiceKind_PROXY,
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []*registrypb.Service{
|
||||
{Kind: registrypb.ServiceKind_PROXY, Endpoint: "https://proxy.example.com"},
|
||||
}, lst.GetServices(), "should expire old services")
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestWatch(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
||||
t.Skip("Github action can not run docker on MacOS")
|
||||
}
|
||||
|
||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*15)
|
||||
defer clearTimeout()
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
tm := time.Now()
|
||||
i, err := New(rawURL,
|
||||
WithGetNow(func() time.Time {
|
||||
return tm
|
||||
}),
|
||||
WithTTL(time.Second*10))
|
||||
require.NoError(t, err)
|
||||
|
||||
li, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer li.Close()
|
||||
|
||||
srv := grpc.NewServer()
|
||||
registrypb.RegisterRegistryServer(srv, i)
|
||||
eg, ctx := errgroup.WithContext(ctx)
|
||||
eg.Go(func() error {
|
||||
<-ctx.Done()
|
||||
srv.Stop()
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
return srv.Serve(li)
|
||||
})
|
||||
eg.Go(func() error {
|
||||
defer cancel()
|
||||
|
||||
cc, err := grpc.Dial(li.Addr().String(), grpc.WithInsecure())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client := registrypb.NewRegistryClient(cc)
|
||||
|
||||
// store the initial services
|
||||
_, err = client.Report(ctx, ®istrypb.RegisterRequest{
|
||||
Services: []*registrypb.Service{
|
||||
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"},
|
||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
|
||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stream, err := client.Watch(ctx, ®istrypb.ListRequest{
|
||||
Kinds: []registrypb.ServiceKind{
|
||||
registrypb.ServiceKind_AUTHORIZE,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = stream.CloseSend() }()
|
||||
|
||||
lst, err := stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
assert.Equal(t, []*registrypb.Service{
|
||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
|
||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
|
||||
}, lst.GetServices())
|
||||
|
||||
// update authenticate
|
||||
_, err = client.Report(ctx, ®istrypb.RegisterRequest{
|
||||
Services: []*registrypb.Service{
|
||||
{Kind: registrypb.ServiceKind_AUTHENTICATE, Endpoint: "http://authenticate1.example.com"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// add an authorize
|
||||
_, err = client.Report(ctx, ®istrypb.RegisterRequest{
|
||||
Services: []*registrypb.Service{
|
||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
lst, err = stream.Recv()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
assert.Equal(t, []*registrypb.Service{
|
||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize1.example.com"},
|
||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize2.example.com"},
|
||||
{Kind: registrypb.ServiceKind_AUTHORIZE, Endpoint: "http://authorize3.example.com"},
|
||||
}, lst.GetServices())
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, eg.Wait())
|
||||
return nil
|
||||
}))
|
||||
}
|
|
@ -150,6 +150,15 @@ func (x *PutResponse) GetRecord() *Record {
|
|||
return records[0]
|
||||
}
|
||||
|
||||
// GetRecord gets the first record, or nil if there are none.
|
||||
func (x *PatchResponse) GetRecord() *Record {
|
||||
records := x.GetRecords()
|
||||
if len(records) == 0 {
|
||||
return nil
|
||||
}
|
||||
return records[0]
|
||||
}
|
||||
|
||||
// SetFilterByID sets the filter to an id.
|
||||
func (x *QueryRequest) SetFilterByID(id string) {
|
||||
x.Filter = &structpb.Struct{Fields: map[string]*structpb.Value{
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -6,6 +6,7 @@ option go_package = "github.com/pomerium/pomerium/pkg/grpc/databroker";
|
|||
import "google/protobuf/any.proto";
|
||||
import "google/protobuf/duration.proto";
|
||||
import "google/protobuf/empty.proto";
|
||||
import "google/protobuf/field_mask.proto";
|
||||
import "google/protobuf/struct.proto";
|
||||
import "google/protobuf/timestamp.proto";
|
||||
|
||||
|
@ -60,6 +61,15 @@ message PutResponse {
|
|||
repeated Record records = 2;
|
||||
}
|
||||
|
||||
message PatchRequest {
|
||||
repeated Record records = 1;
|
||||
google.protobuf.FieldMask field_mask = 2;
|
||||
}
|
||||
message PatchResponse {
|
||||
uint64 server_version = 1;
|
||||
repeated Record records = 2;
|
||||
}
|
||||
|
||||
message SetOptionsRequest {
|
||||
string type = 1;
|
||||
Options options = 2;
|
||||
|
@ -114,6 +124,8 @@ service DataBrokerService {
|
|||
rpc ListTypes(google.protobuf.Empty) returns (ListTypesResponse);
|
||||
// Put saves a record.
|
||||
rpc Put(PutRequest) returns (PutResponse);
|
||||
// Patch updates specific fields of an existing record.
|
||||
rpc Patch(PatchRequest) returns (PatchResponse);
|
||||
// Query queries for records.
|
||||
rpc Query(QueryRequest) returns (QueryResponse);
|
||||
// ReleaseLease releases a distributed mutex lease.
|
||||
|
|
|
@ -133,6 +133,26 @@ func (mr *MockDataBrokerServiceClientMockRecorder) ListTypes(ctx, in interface{}
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTypes", reflect.TypeOf((*MockDataBrokerServiceClient)(nil).ListTypes), varargs...)
|
||||
}
|
||||
|
||||
// Patch mocks base method.
|
||||
func (m *MockDataBrokerServiceClient) Patch(ctx context.Context, in *databroker.PatchRequest, opts ...grpc.CallOption) (*databroker.PatchResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{ctx, in}
|
||||
for _, a := range opts {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Patch", varargs...)
|
||||
ret0, _ := ret[0].(*databroker.PatchResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Patch indicates an expected call of Patch.
|
||||
func (mr *MockDataBrokerServiceClientMockRecorder) Patch(ctx, in interface{}, opts ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{ctx, in}, opts...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Patch", reflect.TypeOf((*MockDataBrokerServiceClient)(nil).Patch), varargs...)
|
||||
}
|
||||
|
||||
// Put mocks base method.
|
||||
func (m *MockDataBrokerServiceClient) Put(ctx context.Context, in *databroker.PutRequest, opts ...grpc.CallOption) (*databroker.PutResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
@ -587,6 +607,21 @@ func (mr *MockDataBrokerServiceServerMockRecorder) ListTypes(arg0, arg1 interfac
|
|||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ListTypes", reflect.TypeOf((*MockDataBrokerServiceServer)(nil).ListTypes), arg0, arg1)
|
||||
}
|
||||
|
||||
// Patch mocks base method.
|
||||
func (m *MockDataBrokerServiceServer) Patch(arg0 context.Context, arg1 *databroker.PatchRequest) (*databroker.PatchResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Patch", arg0, arg1)
|
||||
ret0, _ := ret[0].(*databroker.PatchResponse)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Patch indicates an expected call of Patch.
|
||||
func (mr *MockDataBrokerServiceServerMockRecorder) Patch(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Patch", reflect.TypeOf((*MockDataBrokerServiceServer)(nil).Patch), arg0, arg1)
|
||||
}
|
||||
|
||||
// Put mocks base method.
|
||||
func (m *MockDataBrokerServiceServer) Put(arg0 context.Context, arg1 *databroker.PutRequest) (*databroker.PutResponse, error) {
|
||||
m.ctrl.T.Helper()
|
||||
|
|
|
@ -1,204 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/cipher"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
type encryptedRecordStream struct {
|
||||
underlying RecordStream
|
||||
backend *encryptedBackend
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *encryptedRecordStream) Close() error {
|
||||
return e.underlying.Close()
|
||||
}
|
||||
|
||||
func (e *encryptedRecordStream) Next(wait bool) bool {
|
||||
return e.underlying.Next(wait)
|
||||
}
|
||||
|
||||
func (e *encryptedRecordStream) Record() *databroker.Record {
|
||||
r := e.underlying.Record()
|
||||
if r != nil {
|
||||
var err error
|
||||
r, err = e.backend.decryptRecord(r)
|
||||
if err != nil {
|
||||
e.err = err
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
func (e *encryptedRecordStream) Err() error {
|
||||
if e.err == nil {
|
||||
e.err = e.underlying.Err()
|
||||
}
|
||||
return e.err
|
||||
}
|
||||
|
||||
type encryptedBackend struct {
|
||||
underlying Backend
|
||||
cipher cipher.AEAD
|
||||
}
|
||||
|
||||
// NewEncryptedBackend creates a new encrypted backend.
|
||||
func NewEncryptedBackend(secret []byte, underlying Backend) (Backend, error) {
|
||||
c, err := cryptutil.NewAEADCipher(secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &encryptedBackend{
|
||||
underlying: underlying,
|
||||
cipher: c,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) Close() error {
|
||||
return e.underlying.Close()
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) Get(ctx context.Context, recordType, id string) (*databroker.Record, error) {
|
||||
record, err := e.underlying.Get(ctx, recordType, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
record, err = e.decryptRecord(record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return record, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) GetOptions(ctx context.Context, recordType string) (*databroker.Options, error) {
|
||||
return e.underlying.GetOptions(ctx, recordType)
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) Lease(ctx context.Context, leaseName, leaseID string, ttl time.Duration) (bool, error) {
|
||||
return e.underlying.Lease(ctx, leaseName, leaseID, ttl)
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) ListTypes(ctx context.Context) ([]string, error) {
|
||||
return e.underlying.ListTypes(ctx)
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) Put(ctx context.Context, records []*databroker.Record) (uint64, error) {
|
||||
encryptedRecords := make([]*databroker.Record, len(records))
|
||||
for i, record := range records {
|
||||
encrypted, err := e.encrypt(record.GetData())
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
newRecord := proto.Clone(record).(*databroker.Record)
|
||||
newRecord.Data = encrypted
|
||||
encryptedRecords[i] = newRecord
|
||||
}
|
||||
|
||||
serverVersion, err := e.underlying.Put(ctx, encryptedRecords)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
for i, record := range records {
|
||||
record.ModifiedAt = encryptedRecords[i].ModifiedAt
|
||||
record.Version = encryptedRecords[i].Version
|
||||
}
|
||||
|
||||
return serverVersion, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) SetOptions(ctx context.Context, recordType string, options *databroker.Options) error {
|
||||
return e.underlying.SetOptions(ctx, recordType, options)
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) Sync(ctx context.Context, recordType string, serverVersion, recordVersion uint64) (RecordStream, error) {
|
||||
stream, err := e.underlying.Sync(ctx, recordType, serverVersion, recordVersion)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &encryptedRecordStream{
|
||||
underlying: stream,
|
||||
backend: e,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) SyncLatest(
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
filter FilterExpression,
|
||||
) (serverVersion, recordVersion uint64, stream RecordStream, err error) {
|
||||
serverVersion, recordVersion, stream, err = e.underlying.SyncLatest(ctx, recordType, filter)
|
||||
if err != nil {
|
||||
return serverVersion, recordVersion, nil, err
|
||||
}
|
||||
return serverVersion, recordVersion, &encryptedRecordStream{
|
||||
underlying: stream,
|
||||
backend: e,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) decryptRecord(in *databroker.Record) (out *databroker.Record, err error) {
|
||||
data, err := e.decrypt(in.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Create a new record so that we don't re-use any internal state
|
||||
return &databroker.Record{
|
||||
Version: in.Version,
|
||||
Type: in.Type,
|
||||
Id: in.Id,
|
||||
Data: data,
|
||||
ModifiedAt: in.ModifiedAt,
|
||||
DeletedAt: in.DeletedAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) decrypt(in *anypb.Any) (out *anypb.Any, err error) {
|
||||
if in == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
var encrypted wrapperspb.BytesValue
|
||||
err = in.UnmarshalTo(&encrypted)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
plaintext, err := cryptutil.Decrypt(e.cipher, encrypted.Value, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out = new(anypb.Any)
|
||||
err = proto.Unmarshal(plaintext, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func (e *encryptedBackend) encrypt(in *anypb.Any) (out *anypb.Any, err error) {
|
||||
plaintext, err := proto.Marshal(in)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
encrypted := cryptutil.Encrypt(e.cipher, plaintext, nil)
|
||||
out = protoutil.NewAny(&wrapperspb.BytesValue{
|
||||
Value: encrypted,
|
||||
})
|
||||
return out, nil
|
||||
}
|
|
@ -1,75 +0,0 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
func TestEncryptedBackend(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
m := map[string]*anypb.Any{}
|
||||
backend := &mockBackend{
|
||||
put: func(ctx context.Context, records []*databroker.Record) (uint64, error) {
|
||||
for _, record := range records {
|
||||
record.ModifiedAt = timestamppb.Now()
|
||||
record.Version++
|
||||
m[record.GetId()] = record.GetData()
|
||||
}
|
||||
return 0, nil
|
||||
},
|
||||
get: func(ctx context.Context, recordType, id string) (*databroker.Record, error) {
|
||||
data, ok := m[id]
|
||||
if !ok {
|
||||
return nil, errors.New("not found")
|
||||
}
|
||||
return &databroker.Record{
|
||||
Id: id,
|
||||
Data: data,
|
||||
Version: 1,
|
||||
ModifiedAt: timestamppb.Now(),
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
e, err := NewEncryptedBackend(cryptutil.NewKey(), backend)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
||||
data := protoutil.NewAny(wrapperspb.String("HELLO WORLD"))
|
||||
|
||||
rec := &databroker.Record{
|
||||
Type: "",
|
||||
Id: "TEST-1",
|
||||
Data: data,
|
||||
}
|
||||
_, err = e.Put(ctx, []*databroker.Record{rec})
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
if assert.NotNil(t, m["TEST-1"], "key should be set") {
|
||||
assert.NotEqual(t, data.TypeUrl, m["TEST-1"].TypeUrl, "encrypted data should be a bytes type")
|
||||
assert.NotEqual(t, data.Value, m["TEST-1"].Value, "value should be encrypted")
|
||||
assert.NotNil(t, rec.ModifiedAt)
|
||||
assert.NotZero(t, rec.Version)
|
||||
}
|
||||
|
||||
record, err := e.Get(ctx, "", "TEST-1")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.Equal(t, data.TypeUrl, record.Data.TypeUrl, "type should be preserved")
|
||||
assert.Equal(t, data.Value, record.Data.Value, "value should be preserved")
|
||||
assert.NotEqual(t, data.TypeUrl, record.Type, "record type should be preserved")
|
||||
}
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/rs/zerolog"
|
||||
"golang.org/x/exp/maps"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -130,18 +131,25 @@ func (backend *Backend) Close() error {
|
|||
func (backend *Backend) Get(_ context.Context, recordType, id string) (*databroker.Record, error) {
|
||||
backend.mu.RLock()
|
||||
defer backend.mu.RUnlock()
|
||||
if record := backend.get(recordType, id); record != nil {
|
||||
return record, nil
|
||||
}
|
||||
return nil, storage.ErrNotFound
|
||||
}
|
||||
|
||||
// get gets a record from the in-memory store, assuming the RWMutex is held.
|
||||
func (backend *Backend) get(recordType, id string) *databroker.Record {
|
||||
records := backend.lookup[recordType]
|
||||
if records == nil {
|
||||
return nil, storage.ErrNotFound
|
||||
return nil
|
||||
}
|
||||
|
||||
record := records.Get(id)
|
||||
if record == nil {
|
||||
return nil, storage.ErrNotFound
|
||||
return nil
|
||||
}
|
||||
|
||||
return dup(record), nil
|
||||
return dup(record)
|
||||
}
|
||||
|
||||
// GetOptions returns the options for a type in the in-memory store.
|
||||
|
@ -216,19 +224,7 @@ func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (
|
|||
Str("db_type", record.Type)
|
||||
})
|
||||
|
||||
backend.recordChange(record)
|
||||
|
||||
c, ok := backend.lookup[record.GetType()]
|
||||
if !ok {
|
||||
c = NewRecordCollection()
|
||||
backend.lookup[record.GetType()] = c
|
||||
}
|
||||
|
||||
if record.GetDeletedAt() != nil {
|
||||
c.Delete(record.GetId())
|
||||
} else {
|
||||
c.Put(dup(record))
|
||||
}
|
||||
backend.update(record)
|
||||
|
||||
recordTypes[record.GetType()] = struct{}{}
|
||||
}
|
||||
|
@ -239,6 +235,68 @@ func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (
|
|||
return backend.serverVersion, nil
|
||||
}
|
||||
|
||||
// update stores a record into the in-memory store, assuming the RWMutex is held.
|
||||
func (backend *Backend) update(record *databroker.Record) {
|
||||
backend.recordChange(record)
|
||||
|
||||
c, ok := backend.lookup[record.GetType()]
|
||||
if !ok {
|
||||
c = NewRecordCollection()
|
||||
backend.lookup[record.GetType()] = c
|
||||
}
|
||||
|
||||
if record.GetDeletedAt() != nil {
|
||||
c.Delete(record.GetId())
|
||||
} else {
|
||||
c.Put(dup(record))
|
||||
}
|
||||
}
|
||||
|
||||
// Patch updates the specified fields of existing record(s).
|
||||
func (backend *Backend) Patch(
|
||||
ctx context.Context, records []*databroker.Record, fields *fieldmaskpb.FieldMask,
|
||||
) (serverVersion uint64, patchedRecords []*databroker.Record, err error) {
|
||||
backend.mu.Lock()
|
||||
defer backend.mu.Unlock()
|
||||
defer backend.onChange.Broadcast(ctx)
|
||||
|
||||
serverVersion = backend.serverVersion
|
||||
patchedRecords = make([]*databroker.Record, 0, len(records))
|
||||
|
||||
for _, record := range records {
|
||||
err = backend.patch(record, fields)
|
||||
if storage.IsNotFound(err) {
|
||||
// Skip any record that does not currently exist.
|
||||
continue
|
||||
} else if err != nil {
|
||||
return
|
||||
}
|
||||
patchedRecords = append(patchedRecords, record)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// patch updates the specified fields of an existing record, assuming the RWMutex is held.
|
||||
func (backend *Backend) patch(record *databroker.Record, fields *fieldmaskpb.FieldMask) error {
|
||||
if record == nil {
|
||||
return fmt.Errorf("cannot patch using a nil record")
|
||||
}
|
||||
|
||||
existing := backend.get(record.GetType(), record.GetId())
|
||||
if existing == nil {
|
||||
return storage.ErrNotFound
|
||||
}
|
||||
|
||||
if err := storage.PatchRecord(existing, record, fields); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
backend.update(record)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetOptions sets the options for a type in the in-memory store.
|
||||
func (backend *Backend) SetOptions(_ context.Context, recordType string, options *databroker.Options) error {
|
||||
backend.mu.Lock()
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/pkg/storage/storagetest"
|
||||
)
|
||||
|
||||
func TestBackend(t *testing.T) {
|
||||
|
@ -72,6 +73,9 @@ func TestBackend(t *testing.T) {
|
|||
assert.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("patch", func(t *testing.T) {
|
||||
storagetest.TestBackendPatch(t, ctx, backend)
|
||||
})
|
||||
}
|
||||
|
||||
func TestExpiry(t *testing.T) {
|
||||
|
|
36
pkg/storage/patch.go
Normal file
36
pkg/storage/patch.go
Normal file
|
@ -0,0 +1,36 @@
|
|||
package storage
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
// PatchRecord extracts the data from existing and record, updates the existing
|
||||
// data subject to the provided field mask, and stores the result back into
|
||||
// record. The existing record is not modified.
|
||||
func PatchRecord(existing, record *databroker.Record, fields *fieldmaskpb.FieldMask) error {
|
||||
dst, err := existing.GetData().UnmarshalNew()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not unmarshal existing record data: %w", err)
|
||||
}
|
||||
|
||||
src, err := record.GetData().UnmarshalNew()
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not unmarshal new record data: %w", err)
|
||||
}
|
||||
|
||||
if err := protoutil.OverwriteMasked(dst, src, fields); err != nil {
|
||||
return fmt.Errorf("cannot patch record: %w", err)
|
||||
}
|
||||
|
||||
record.Data, err = anypb.New(dst)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not marshal new record data: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
45
pkg/storage/patch_test.go
Normal file
45
pkg/storage/patch_test.go
Normal file
|
@ -0,0 +1,45 @@
|
|||
package storage_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func TestPatchRecord(t *testing.T) {
|
||||
tm := timestamppb.New(time.Date(2023, 10, 31, 12, 0, 0, 0, time.UTC))
|
||||
|
||||
s1 := &session.Session{Id: "session-id"}
|
||||
a1, _ := anypb.New(s1)
|
||||
r1 := &databroker.Record{Data: a1}
|
||||
|
||||
s2 := &session.Session{Id: "new-session-id", AccessedAt: tm}
|
||||
a2, _ := anypb.New(s2)
|
||||
r2 := &databroker.Record{Data: a2}
|
||||
|
||||
originalR1 := proto.Clone(r1).(*databroker.Record)
|
||||
|
||||
m, _ := fieldmaskpb.New(&session.Session{}, "accessed_at")
|
||||
|
||||
storage.PatchRecord(r1, r2, m)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"data": {
|
||||
"@type": "type.googleapis.com/session.Session",
|
||||
"accessedAt": "2023-10-31T12:00:00Z",
|
||||
"id": "session-id"
|
||||
}
|
||||
}`, r2)
|
||||
|
||||
// The existing record should not be modified.
|
||||
testutil.AssertProtoEqual(t, originalR1, r1)
|
||||
}
|
|
@ -10,6 +10,7 @@ import (
|
|||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
|
@ -140,7 +141,7 @@ func (backend *Backend) Get(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
return getRecord(ctx, conn, recordType, recordID)
|
||||
return getRecord(ctx, conn, recordType, recordID, lockModeNone)
|
||||
}
|
||||
|
||||
// GetOptions returns the options for the given record type.
|
||||
|
@ -239,6 +240,42 @@ func (backend *Backend) Put(
|
|||
return serverVersion, err
|
||||
}
|
||||
|
||||
// Patch updates specific fields of existing records in Postgres.
|
||||
func (backend *Backend) Patch(
|
||||
ctx context.Context,
|
||||
records []*databroker.Record,
|
||||
fields *fieldmaskpb.FieldMask,
|
||||
) (uint64, []*databroker.Record, error) {
|
||||
ctx, cancel := contextutil.Merge(ctx, backend.closeCtx)
|
||||
defer cancel()
|
||||
|
||||
serverVersion, pool, err := backend.init(ctx)
|
||||
if err != nil {
|
||||
return serverVersion, nil, err
|
||||
}
|
||||
|
||||
patchedRecords := make([]*databroker.Record, 0, len(records))
|
||||
|
||||
now := timestamppb.Now()
|
||||
|
||||
for _, record := range records {
|
||||
record = dup(record)
|
||||
record.ModifiedAt = now
|
||||
err := patchRecord(ctx, pool, record, fields)
|
||||
if storage.IsNotFound(err) {
|
||||
continue
|
||||
} else if err != nil {
|
||||
err = fmt.Errorf("storage/postgres: error patching record %q of type %q: %w",
|
||||
record.GetId(), record.GetType(), err)
|
||||
return serverVersion, patchedRecords, err
|
||||
}
|
||||
patchedRecords = append(patchedRecords, record)
|
||||
}
|
||||
|
||||
err = signalRecordChange(ctx, pool)
|
||||
return serverVersion, patchedRecords, err
|
||||
}
|
||||
|
||||
// SetOptions sets the options for the given record type.
|
||||
func (backend *Backend) SetOptions(
|
||||
ctx context.Context,
|
||||
|
|
|
@ -18,6 +18,7 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
"github.com/pomerium/pomerium/pkg/storage/storagetest"
|
||||
)
|
||||
|
||||
const maxWait = time.Minute * 10
|
||||
|
@ -188,6 +189,10 @@ func TestBackend(t *testing.T) {
|
|||
assert.Equal(t, []string{"capacity-test", "latest-test", "sync-test", "test-1", "unknown"}, types)
|
||||
})
|
||||
|
||||
t.Run("patch", func(t *testing.T) {
|
||||
storagetest.TestBackendPatch(t, ctx, backend)
|
||||
})
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
|
|
@ -12,10 +12,12 @@ import (
|
|||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/reflect/protoregistry"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
|
@ -160,15 +162,24 @@ func getOptions(ctx context.Context, q querier, recordType string) (*databroker.
|
|||
return options, nil
|
||||
}
|
||||
|
||||
func getRecord(ctx context.Context, q querier, recordType, recordID string) (*databroker.Record, error) {
|
||||
type lockMode string
|
||||
|
||||
const (
|
||||
lockModeNone lockMode = ""
|
||||
lockModeUpdate lockMode = "FOR UPDATE"
|
||||
)
|
||||
|
||||
func getRecord(
|
||||
ctx context.Context, q querier, recordType, recordID string, lockMode lockMode,
|
||||
) (*databroker.Record, error) {
|
||||
var version uint64
|
||||
var data []byte
|
||||
var modifiedAt pgtype.Timestamptz
|
||||
err := q.QueryRow(ctx, `
|
||||
SELECT version, data, modified_at
|
||||
FROM `+schemaName+`.`+recordsTableName+`
|
||||
WHERE type=$1 AND id=$2
|
||||
`, recordType, recordID).Scan(&version, &data, &modifiedAt)
|
||||
WHERE type=$1 AND id=$2 `+string(lockMode),
|
||||
recordType, recordID).Scan(&version, &data, &modifiedAt)
|
||||
if isNotFound(err) {
|
||||
return nil, storage.ErrNotFound
|
||||
} else if err != nil {
|
||||
|
@ -378,6 +389,34 @@ func putRecordAndChange(ctx context.Context, q querier, record *databroker.Recor
|
|||
return nil
|
||||
}
|
||||
|
||||
// patchRecord updates specific fields of an existing record.
|
||||
func patchRecord(
|
||||
ctx context.Context, p *pgxpool.Pool, record *databroker.Record, fields *fieldmaskpb.FieldMask,
|
||||
) error {
|
||||
tx, err := p.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = tx.Rollback(ctx) }()
|
||||
|
||||
existing, err := getRecord(ctx, tx, record.GetType(), record.GetId(), lockModeUpdate)
|
||||
if isNotFound(err) {
|
||||
return storage.ErrNotFound
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := storage.PatchRecord(existing, record, fields); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := putRecordAndChange(ctx, tx, record); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func putService(ctx context.Context, q querier, svc *registry.Service, expiresAt time.Time) error {
|
||||
query := `
|
||||
INSERT INTO ` + schemaName + `.` + servicesTableName + ` (kind, endpoint, expires_at)
|
||||
|
|
|
@ -1,31 +0,0 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
|
||||
pomeriumconfig "github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
)
|
||||
|
||||
type logger struct {
|
||||
}
|
||||
|
||||
func (l logger) Printf(ctx context.Context, format string, v ...interface{}) {
|
||||
log.Info(ctx).Str("service", "redis").Msgf(format, v...)
|
||||
}
|
||||
|
||||
func init() {
|
||||
redis.SetLogger(logger{})
|
||||
}
|
||||
|
||||
func recordOperation(ctx context.Context, startTime time.Time, operation string, err error) {
|
||||
metrics.RecordStorageOperation(ctx, &metrics.StorageOperationTags{
|
||||
Operation: operation,
|
||||
Error: err,
|
||||
Backend: pomeriumconfig.StorageRedisName,
|
||||
}, time.Since(startTime))
|
||||
}
|
|
@ -1,37 +0,0 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"time"
|
||||
)
|
||||
|
||||
type config struct {
|
||||
tls *tls.Config
|
||||
expiry time.Duration
|
||||
}
|
||||
|
||||
// Option customizes a Backend.
|
||||
type Option func(*config)
|
||||
|
||||
// WithTLSConfig sets the tls.Config which Backend uses.
|
||||
func WithTLSConfig(tlsConfig *tls.Config) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.tls = tlsConfig
|
||||
}
|
||||
}
|
||||
|
||||
// WithExpiry sets the expiry for changes.
|
||||
func WithExpiry(expiry time.Duration) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.expiry = expiry
|
||||
}
|
||||
}
|
||||
|
||||
func getConfig(options ...Option) *config {
|
||||
cfg := new(config)
|
||||
WithExpiry(time.Hour * 24)(cfg)
|
||||
for _, o := range options {
|
||||
o(cfg)
|
||||
}
|
||||
return cfg
|
||||
}
|
|
@ -1,537 +0,0 @@
|
|||
// Package redis implements the storage.Backend interface for redis.
|
||||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v4"
|
||||
"github.com/go-redis/redis/v8"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/redisutil"
|
||||
"github.com/pomerium/pomerium/internal/signal"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
const (
|
||||
maxTransactionRetries = 100
|
||||
watchPollInterval = 30 * time.Second
|
||||
|
||||
// we rely on transactions in redis, so all redis-cluster keys need to be
|
||||
// on the same node. Using a `hash tag` gives us this capability.
|
||||
serverVersionKey = redisutil.KeyPrefix + "server_version"
|
||||
lastVersionKey = redisutil.KeyPrefix + "last_version"
|
||||
lastVersionChKey = redisutil.KeyPrefix + "last_version_ch"
|
||||
recordHashKey = redisutil.KeyPrefix + "records"
|
||||
recordTypesSetKey = redisutil.KeyPrefix + "record_types"
|
||||
changesSetKey = redisutil.KeyPrefix + "changes"
|
||||
optionsKey = redisutil.KeyPrefix + "options"
|
||||
|
||||
recordTypeChangesKeyTpl = redisutil.KeyPrefix + "changes.%s"
|
||||
leaseKeyTpl = redisutil.KeyPrefix + "lease.%s"
|
||||
)
|
||||
|
||||
// custom errors
|
||||
var (
|
||||
ErrExceededMaxRetries = errors.New("redis: transaction reached maximum number of retries")
|
||||
)
|
||||
|
||||
// Backend implements the storage.Backend on top of redis.
|
||||
//
|
||||
// What's stored:
|
||||
//
|
||||
// - last_version: an integer recordVersion number
|
||||
// - last_version_ch: a PubSub channel for recordVersion number updates
|
||||
// - records: a Hash of records. The hash key is {recordType}/{recordID}, the hash value the protobuf record.
|
||||
// - changes: a Sorted Set of all the changes. The score is the recordVersion number, the member the protobuf record.
|
||||
// - options: a Hash of options. The hash key is {recordType}, the hash value the protobuf options.
|
||||
// - changes.{recordType}: a Sorted Set of the changes for a record type. The score is the current time,
|
||||
// the value the record id.
|
||||
//
|
||||
// Records stored in these keys are typically encrypted.
|
||||
type Backend struct {
|
||||
cfg *config
|
||||
|
||||
client redis.UniversalClient
|
||||
onChange *signal.Signal
|
||||
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// New creates a new redis storage backend.
|
||||
func New(rawURL string, options ...Option) (*Backend, error) {
|
||||
ctx := context.TODO()
|
||||
cfg := getConfig(options...)
|
||||
backend := &Backend{
|
||||
cfg: cfg,
|
||||
closed: make(chan struct{}),
|
||||
onChange: signal.New(),
|
||||
}
|
||||
var err error
|
||||
backend.client, err = redisutil.NewClientFromURL(rawURL, backend.cfg.tls)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
metrics.AddRedisMetrics(backend.client.PoolStats)
|
||||
go backend.listenForVersionChanges(ctx)
|
||||
if cfg.expiry != 0 {
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-backend.closed:
|
||||
return
|
||||
case <-ticker.C:
|
||||
}
|
||||
|
||||
backend.removeChangesBefore(ctx, time.Now().Add(-cfg.expiry))
|
||||
}
|
||||
}()
|
||||
}
|
||||
return backend, nil
|
||||
}
|
||||
|
||||
// Close closes the underlying redis connection and any watchers.
|
||||
func (backend *Backend) Close() error {
|
||||
var err error
|
||||
backend.closeOnce.Do(func() {
|
||||
err = backend.client.Close()
|
||||
close(backend.closed)
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
// Get gets a record from redis.
|
||||
func (backend *Backend) Get(ctx context.Context, recordType, id string) (_ *databroker.Record, err error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.redis.Get")
|
||||
defer span.End()
|
||||
defer func(start time.Time) { recordOperation(ctx, start, "get", err) }(time.Now())
|
||||
|
||||
key, field := getHashKey(recordType, id)
|
||||
cmd := backend.client.HGet(ctx, key, field)
|
||||
raw, err := cmd.Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, storage.ErrNotFound
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var record databroker.Record
|
||||
err = proto.Unmarshal([]byte(raw), &record)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &record, nil
|
||||
}
|
||||
|
||||
// GetOptions gets the options for the given record type.
|
||||
func (backend *Backend) GetOptions(ctx context.Context, recordType string) (*databroker.Options, error) {
|
||||
raw, err := backend.client.HGet(ctx, optionsKey, recordType).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
// treat no options as an empty set of options
|
||||
return new(databroker.Options), nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var options databroker.Options
|
||||
err = proto.Unmarshal([]byte(raw), &options)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &options, nil
|
||||
}
|
||||
|
||||
// Lease acquires or renews a lease.
|
||||
func (backend *Backend) Lease(ctx context.Context, leaseName, leaseID string, ttl time.Duration) (bool, error) {
|
||||
acquired := false
|
||||
key := getLeaseKey(leaseName)
|
||||
err := backend.client.Watch(ctx, func(tx *redis.Tx) error {
|
||||
currentID, err := tx.Get(ctx, key).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
// lease hasn't been set yet
|
||||
} else if err != nil {
|
||||
return err
|
||||
} else if leaseID != currentID {
|
||||
// lease has already been taken
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = tx.Pipelined(ctx, func(p redis.Pipeliner) error {
|
||||
if ttl <= 0 {
|
||||
p.Del(ctx, key)
|
||||
} else {
|
||||
p.Set(ctx, key, leaseID, ttl)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
acquired = ttl > 0
|
||||
return nil
|
||||
}, key)
|
||||
// if the transaction failed someone else must've acquired the lease
|
||||
if errors.Is(err, redis.TxFailedErr) {
|
||||
acquired = false
|
||||
err = nil
|
||||
}
|
||||
return acquired, err
|
||||
}
|
||||
|
||||
// ListTypes lists all the known record types.
|
||||
func (backend *Backend) ListTypes(ctx context.Context) (types []string, err error) {
|
||||
ctx, span := trace.StartSpan(ctx, "databroker.redis.ListTypes")
|
||||
defer span.End()
|
||||
defer func(start time.Time) { recordOperation(ctx, start, "listTypes", err) }(time.Now())
|
||||
|
||||
cmd := backend.client.SMembers(ctx, recordTypesSetKey)
|
||||
types, err = cmd.Result()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sort.Strings(types)
|
||||
return types, nil
|
||||
}
|
||||
|
||||
// Put puts a record into redis.
|
||||
func (backend *Backend) Put(ctx context.Context, records []*databroker.Record) (serverVersion uint64, err error) {
|
||||
ctx, span := trace.StartSpan(ctx, "databroker.redis.Put")
|
||||
defer span.End()
|
||||
defer func(start time.Time) { recordOperation(ctx, start, "put", err) }(time.Now())
|
||||
|
||||
serverVersion, err = backend.getOrCreateServerVersion(ctx)
|
||||
if err != nil {
|
||||
return serverVersion, err
|
||||
}
|
||||
|
||||
err = backend.put(ctx, records)
|
||||
if err != nil {
|
||||
return serverVersion, err
|
||||
}
|
||||
|
||||
recordTypes := map[string]struct{}{}
|
||||
for _, record := range records {
|
||||
recordTypes[record.GetType()] = struct{}{}
|
||||
}
|
||||
for recordType := range recordTypes {
|
||||
err = backend.enforceOptions(ctx, recordType)
|
||||
if err != nil {
|
||||
return serverVersion, err
|
||||
}
|
||||
}
|
||||
|
||||
return serverVersion, nil
|
||||
}
|
||||
|
||||
// SetOptions sets the options for the given record type.
|
||||
func (backend *Backend) SetOptions(ctx context.Context, recordType string, options *databroker.Options) error {
|
||||
ctx, span := trace.StartSpan(ctx, "databroker.redis.SetOptions")
|
||||
defer span.End()
|
||||
|
||||
bs, err := proto.Marshal(options)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// update the options in the hash set
|
||||
err = backend.client.HSet(ctx, optionsKey, recordType, bs).Err()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// possibly re-enforce options
|
||||
err = backend.enforceOptions(ctx, recordType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Sync returns a record stream of any records changed after the specified recordVersion.
|
||||
func (backend *Backend) Sync(
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
serverVersion, recordVersion uint64,
|
||||
) (storage.RecordStream, error) {
|
||||
return newSyncRecordStream(ctx, backend, recordType, serverVersion, recordVersion), nil
|
||||
}
|
||||
|
||||
// SyncLatest returns a record stream of all the records. Some records may be returned twice if the are updated while the
|
||||
// stream is streaming.
|
||||
func (backend *Backend) SyncLatest(
|
||||
ctx context.Context,
|
||||
recordType string,
|
||||
expr storage.FilterExpression,
|
||||
) (serverVersion, recordVersion uint64, stream storage.RecordStream, err error) {
|
||||
serverVersion, err = backend.getOrCreateServerVersion(ctx)
|
||||
if err != nil {
|
||||
return serverVersion, recordVersion, nil, err
|
||||
}
|
||||
|
||||
recordVersion, err = backend.client.Get(ctx, lastVersionKey).Uint64()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
// this happens if there are no records
|
||||
} else if err != nil {
|
||||
return serverVersion, recordVersion, nil, err
|
||||
}
|
||||
|
||||
stream, err = newSyncLatestRecordStream(ctx, backend, recordType, expr)
|
||||
return serverVersion, recordVersion, stream, err
|
||||
}
|
||||
|
||||
func (backend *Backend) put(ctx context.Context, records []*databroker.Record) error {
|
||||
return backend.incrementVersion(ctx,
|
||||
func(tx *redis.Tx, version uint64) error {
|
||||
for i, record := range records {
|
||||
record.ModifiedAt = timestamppb.Now()
|
||||
record.Version = version + uint64(i)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
func(p redis.Pipeliner, version uint64) error {
|
||||
for i, record := range records {
|
||||
bs, err := proto.Marshal(record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key, field := getHashKey(record.GetType(), record.GetId())
|
||||
if record.DeletedAt != nil {
|
||||
p.HDel(ctx, key, field)
|
||||
} else {
|
||||
p.HSet(ctx, key, field, bs)
|
||||
p.ZAdd(ctx, getRecordTypeChangesKey(record.GetType()), &redis.Z{
|
||||
Score: float64(record.GetModifiedAt().GetSeconds()) + float64(i)/float64(len(records)),
|
||||
Member: record.GetId(),
|
||||
})
|
||||
}
|
||||
p.ZAdd(ctx, changesSetKey, &redis.Z{
|
||||
Score: float64(version) + float64(i),
|
||||
Member: bs,
|
||||
})
|
||||
p.SAdd(ctx, recordTypesSetKey, record.GetType())
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// enforceOptions enforces the options for the given record type.
|
||||
func (backend *Backend) enforceOptions(ctx context.Context, recordType string) error {
|
||||
ctx, span := trace.StartSpan(ctx, "databroker.redis.enforceOptions")
|
||||
defer span.End()
|
||||
|
||||
options, err := backend.GetOptions(ctx, recordType)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// nothing to do if capacity isn't set
|
||||
if options.Capacity == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
key := getRecordTypeChangesKey(recordType)
|
||||
|
||||
// find oldest records that exceed the capacity
|
||||
recordIDs, err := backend.client.ZRevRange(ctx, key, int64(*options.Capacity), -1).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// for each record, delete it
|
||||
for _, recordID := range recordIDs {
|
||||
record, err := backend.Get(ctx, recordType, recordID)
|
||||
if err == nil {
|
||||
// mark the record as deleted and re-submit
|
||||
record.DeletedAt = timestamppb.Now()
|
||||
err = backend.put(ctx, []*databroker.Record{record})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
} else if errors.Is(err, storage.ErrNotFound) {
|
||||
// ignore
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// remove the member from the collection
|
||||
_, err = backend.client.ZRem(ctx, key, recordID).Result()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// incrementVersion increments the last recordVersion key, runs the code in `query`, then attempts to commit the code in
|
||||
// `commit`. If the last recordVersion changes in the interim, we will retry the transaction.
|
||||
func (backend *Backend) incrementVersion(ctx context.Context,
|
||||
query func(tx *redis.Tx, recordVersion uint64) error,
|
||||
commit func(p redis.Pipeliner, recordVersion uint64) error,
|
||||
) error {
|
||||
// code is modeled on https://pkg.go.dev/github.com/go-redis/redis/v8#example-Client.Watch
|
||||
txf := func(tx *redis.Tx) error {
|
||||
version, err := tx.Get(ctx, lastVersionKey).Uint64()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
version = 0
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
version++
|
||||
|
||||
err = query(tx, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// the `commit` code is run in a transaction so that the EXEC cmd will run for the original redis watch
|
||||
_, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error {
|
||||
err := commit(p, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
p.Set(ctx, lastVersionKey, version, 0)
|
||||
p.Publish(ctx, lastVersionChKey, version)
|
||||
return nil
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0
|
||||
for i := 0; i < maxTransactionRetries; i++ {
|
||||
err := backend.client.Watch(ctx, txf, lastVersionKey)
|
||||
if errors.Is(err, redis.TxFailedErr) {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-time.After(bo.NextBackOff()):
|
||||
}
|
||||
continue // retry
|
||||
} else if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil // tx was successful
|
||||
}
|
||||
|
||||
return ErrExceededMaxRetries
|
||||
}
|
||||
|
||||
func (backend *Backend) listenForVersionChanges(ctx context.Context) {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
go func() {
|
||||
<-backend.closed
|
||||
cancel()
|
||||
}()
|
||||
|
||||
bo := backoff.NewExponentialBackOff()
|
||||
bo.MaxElapsedTime = 0
|
||||
|
||||
outer:
|
||||
for {
|
||||
pubsub := backend.client.Subscribe(ctx, lastVersionChKey)
|
||||
for {
|
||||
msg, err := pubsub.Receive(ctx)
|
||||
if err != nil {
|
||||
_ = pubsub.Close()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-time.After(bo.NextBackOff()):
|
||||
}
|
||||
continue outer
|
||||
}
|
||||
bo.Reset()
|
||||
|
||||
switch msg.(type) {
|
||||
case *redis.Message:
|
||||
backend.onChange.Broadcast(ctx)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (backend *Backend) removeChangesBefore(ctx context.Context, cutoff time.Time) {
|
||||
for {
|
||||
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
|
||||
Min: "-inf",
|
||||
Max: "+inf",
|
||||
Offset: 0,
|
||||
Count: 1,
|
||||
})
|
||||
results, err := cmd.Result()
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("redis: error retrieving changes for expiration")
|
||||
return
|
||||
}
|
||||
|
||||
// nothing left to do
|
||||
if len(results) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var record databroker.Record
|
||||
err = proto.Unmarshal([]byte(results[0]), &record)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
||||
record.ModifiedAt = timestamppb.New(cutoff.Add(-time.Second)) // set the modified so will delete it
|
||||
}
|
||||
|
||||
// if the record's modified timestamp is after the cutoff, we're all done, so break
|
||||
if record.GetModifiedAt().AsTime().After(cutoff) {
|
||||
break
|
||||
}
|
||||
|
||||
// remove the record
|
||||
err = backend.client.ZRem(ctx, changesSetKey, results[0]).Err()
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("redis: error removing member")
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (backend *Backend) getOrCreateServerVersion(ctx context.Context) (serverVersion uint64, err error) {
|
||||
serverVersion, err = backend.client.Get(ctx, serverVersionKey).Uint64()
|
||||
// if the server version hasn't been set yet, set it to a random value and immediately retrieve it
|
||||
// this should properly handle a data race by only setting the key if it doesn't already exist
|
||||
if errors.Is(err, redis.Nil) {
|
||||
_, _ = backend.client.SetNX(ctx, serverVersionKey, cryptutil.NewRandomUInt64(), 0).Result()
|
||||
serverVersion, err = backend.client.Get(ctx, serverVersionKey).Uint64()
|
||||
}
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("redis: error retrieving server version: %w", err)
|
||||
}
|
||||
return serverVersion, err
|
||||
}
|
||||
|
||||
func getLeaseKey(leaseName string) string {
|
||||
return fmt.Sprintf(leaseKeyTpl, leaseName)
|
||||
}
|
||||
|
||||
func getRecordTypeChangesKey(recordType string) string {
|
||||
return fmt.Sprintf(recordTypeChangesKeyTpl, recordType)
|
||||
}
|
||||
|
||||
func getHashKey(recordType, id string) (key, field string) {
|
||||
return recordHashKey, fmt.Sprintf("%s/%s", recordType, id)
|
||||
}
|
|
@ -1,312 +0,0 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func TestBackend(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
||||
t.Skip("Github action can not run docker on MacOS")
|
||||
}
|
||||
|
||||
handler := func(t *testing.T, useTLS bool, rawURL string) error {
|
||||
ctx := context.Background()
|
||||
var opts []Option
|
||||
if useTLS {
|
||||
opts = append(opts, WithTLSConfig(testutil.RedisTLSConfig()))
|
||||
}
|
||||
backend, err := New(rawURL, opts...)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = backend.Close() }()
|
||||
|
||||
serverVersion, err := backend.getOrCreateServerVersion(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("get missing record", func(t *testing.T) {
|
||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
||||
require.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("get record", func(t *testing.T) {
|
||||
data := new(anypb.Any)
|
||||
sv, err := backend.Put(ctx, []*databroker.Record{{
|
||||
Type: "TYPE",
|
||||
Id: "abcd",
|
||||
Data: data,
|
||||
}})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, serverVersion, sv)
|
||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
||||
require.NoError(t, err)
|
||||
if assert.NotNil(t, record) {
|
||||
assert.Equal(t, data, record.Data)
|
||||
assert.Nil(t, record.DeletedAt)
|
||||
assert.Equal(t, "abcd", record.Id)
|
||||
assert.NotNil(t, record.ModifiedAt)
|
||||
assert.Equal(t, "TYPE", record.Type)
|
||||
assert.Equal(t, uint64(1), record.Version)
|
||||
}
|
||||
})
|
||||
t.Run("delete record", func(t *testing.T) {
|
||||
sv, err := backend.Put(ctx, []*databroker.Record{{
|
||||
Type: "TYPE",
|
||||
Id: "abcd",
|
||||
DeletedAt: timestamppb.Now(),
|
||||
}})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, serverVersion, sv)
|
||||
record, err := backend.Get(ctx, "TYPE", "abcd")
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, record)
|
||||
})
|
||||
t.Run("list types", func(t *testing.T) {
|
||||
types, err := backend.ListTypes(ctx)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, []string{"TYPE"}, types)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
t.Run("no-tls", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
||||
return handler(t, false, rawURL)
|
||||
}))
|
||||
})
|
||||
|
||||
t.Run("tls", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.NoError(t, testutil.WithTestRedis(true, func(rawURL string) error {
|
||||
return handler(t, true, rawURL)
|
||||
}))
|
||||
})
|
||||
|
||||
if runtime.GOOS == "linux" {
|
||||
t.Run("cluster", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.NoError(t, testutil.WithTestRedisCluster(func(rawURL string) error {
|
||||
return handler(t, false, rawURL)
|
||||
}))
|
||||
})
|
||||
|
||||
t.Run("sentinel", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
require.NoError(t, testutil.WithTestRedisSentinel(func(rawURL string) error {
|
||||
return handler(t, false, rawURL)
|
||||
}))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestChangeSignal(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
||||
t.Skip("Github action can not run docker on MacOS")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
||||
ctx, clearTimeout := context.WithTimeout(ctx, time.Second*30)
|
||||
defer clearTimeout()
|
||||
|
||||
done := make(chan struct{})
|
||||
var eg errgroup.Group
|
||||
eg.Go(func() error {
|
||||
backend, err := New(rawURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = backend.Close() }()
|
||||
|
||||
ch := backend.onChange.Bind()
|
||||
defer backend.onChange.Unbind(ch)
|
||||
|
||||
select {
|
||||
case <-ch:
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
|
||||
// signal the second backend that we've received the change
|
||||
close(done)
|
||||
|
||||
return nil
|
||||
})
|
||||
eg.Go(func() error {
|
||||
backend, err := New(rawURL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = backend.Close() }()
|
||||
|
||||
// put a new value to trigger a change
|
||||
for {
|
||||
_, err = backend.Put(ctx, []*databroker.Record{{
|
||||
Type: "TYPE",
|
||||
Id: "ID",
|
||||
}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-done:
|
||||
return nil
|
||||
case <-time.After(time.Millisecond * 100):
|
||||
}
|
||||
}
|
||||
})
|
||||
assert.NoError(t, eg.Wait(), "expected signal to be fired when another backend triggers a change")
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestExpiry(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
||||
t.Skip("Github action can not run docker on MacOS")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
||||
backend, err := New(rawURL, WithExpiry(0))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = backend.Close() }()
|
||||
|
||||
serverVersion, err := backend.getOrCreateServerVersion(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < 1000; i++ {
|
||||
_, err := backend.Put(ctx, []*databroker.Record{{
|
||||
Type: "TYPE",
|
||||
Id: fmt.Sprint(i),
|
||||
}})
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
stream, err := backend.Sync(ctx, "TYPE", serverVersion, 0)
|
||||
require.NoError(t, err)
|
||||
var records []*databroker.Record
|
||||
for stream.Next(false) {
|
||||
records = append(records, stream.Record())
|
||||
}
|
||||
_ = stream.Close()
|
||||
require.Len(t, records, 1000)
|
||||
|
||||
backend.removeChangesBefore(ctx, time.Now().Add(time.Second))
|
||||
|
||||
stream, err = backend.Sync(ctx, "TYPE", serverVersion, 0)
|
||||
require.NoError(t, err)
|
||||
records = nil
|
||||
for stream.Next(false) {
|
||||
records = append(records, stream.Record())
|
||||
}
|
||||
_ = stream.Close()
|
||||
require.Len(t, records, 0)
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestCapacity(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
||||
t.Skip("Github action can not run docker on MacOS")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10)
|
||||
defer clearTimeout()
|
||||
|
||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
||||
backend, err := New(rawURL, WithExpiry(0))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = backend.Close() }()
|
||||
|
||||
err = backend.SetOptions(ctx, "EXAMPLE", &databroker.Options{
|
||||
Capacity: proto.Uint64(3),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
_, err = backend.Put(ctx, []*databroker.Record{{
|
||||
Type: "EXAMPLE",
|
||||
Id: fmt.Sprint(i),
|
||||
}})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
_, _, stream, err := backend.SyncLatest(ctx, "EXAMPLE", nil)
|
||||
require.NoError(t, err)
|
||||
defer stream.Close()
|
||||
|
||||
records, err := storage.RecordStreamToList(stream)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, records, 3)
|
||||
|
||||
var ids []string
|
||||
for _, r := range records {
|
||||
ids = append(ids, r.GetId())
|
||||
}
|
||||
assert.Equal(t, []string{"7", "8", "9"}, ids, "should contain recent records")
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
||||
|
||||
func TestLease(t *testing.T) {
|
||||
if os.Getenv("GITHUB_ACTION") != "" && runtime.GOOS == "darwin" {
|
||||
t.Skip("Github action can not run docker on MacOS")
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
require.NoError(t, testutil.WithTestRedis(false, func(rawURL string) error {
|
||||
backend, err := New(rawURL)
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = backend.Close() }()
|
||||
|
||||
{
|
||||
ok, err := backend.Lease(ctx, "test", "a", time.Second*30)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok, "expected a to acquire the lease")
|
||||
}
|
||||
{
|
||||
ok, err := backend.Lease(ctx, "test", "b", time.Second*30)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok, "expected b to fail to acquire the lease")
|
||||
}
|
||||
{
|
||||
ok, err := backend.Lease(ctx, "test", "a", 0)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, ok, "expected a to clear the lease")
|
||||
}
|
||||
{
|
||||
ok, err := backend.Lease(ctx, "test", "b", time.Second*30)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok, "expected b to to acquire the lease")
|
||||
}
|
||||
|
||||
return nil
|
||||
}))
|
||||
}
|
|
@ -1,172 +0,0 @@
|
|||
package redis
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/go-redis/redis/v8"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
func newSyncRecordStream(
|
||||
ctx context.Context,
|
||||
backend *Backend,
|
||||
recordType string,
|
||||
serverVersion uint64,
|
||||
recordVersion uint64,
|
||||
) storage.RecordStream {
|
||||
changed := backend.onChange.Bind()
|
||||
return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{
|
||||
// 1. stream all record changes
|
||||
func(ctx context.Context, block bool) (*databroker.Record, error) {
|
||||
ticker := time.NewTicker(watchPollInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
currentServerVersion, err := backend.getOrCreateServerVersion(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if serverVersion != currentServerVersion {
|
||||
return nil, storage.ErrInvalidServerVersion
|
||||
}
|
||||
|
||||
record, err := nextChangedRecord(ctx, backend, recordType, &recordVersion)
|
||||
if err == nil {
|
||||
return record, nil
|
||||
} else if !errors.Is(err, storage.ErrStreamDone) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !block {
|
||||
return nil, storage.ErrStreamDone
|
||||
}
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-ticker.C:
|
||||
case <-changed:
|
||||
}
|
||||
}
|
||||
},
|
||||
}, func() {
|
||||
backend.onChange.Unbind(changed)
|
||||
})
|
||||
}
|
||||
|
||||
func newSyncLatestRecordStream(
|
||||
ctx context.Context,
|
||||
backend *Backend,
|
||||
recordType string,
|
||||
expr storage.FilterExpression,
|
||||
) (storage.RecordStream, error) {
|
||||
filter, err := storage.RecordStreamFilterFromFilterExpression(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if recordType != "" {
|
||||
filter = filter.And(func(record *databroker.Record) (keep bool) {
|
||||
return record.GetType() == recordType
|
||||
})
|
||||
}
|
||||
|
||||
var cursor uint64
|
||||
scannedOnce := false
|
||||
var scannedRecords []*databroker.Record
|
||||
generator := storage.FilteredRecordStreamGenerator(
|
||||
func(ctx context.Context, block bool) (*databroker.Record, error) {
|
||||
for {
|
||||
if len(scannedRecords) > 0 {
|
||||
record := scannedRecords[0]
|
||||
scannedRecords = scannedRecords[1:]
|
||||
return record, nil
|
||||
}
|
||||
|
||||
// the cursor is reset to 0 after iteration is complete
|
||||
if scannedOnce && cursor == 0 {
|
||||
return nil, storage.ErrStreamDone
|
||||
}
|
||||
|
||||
var err error
|
||||
scannedRecords, err = nextScannedRecords(ctx, backend, &cursor)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
scannedOnce = true
|
||||
}
|
||||
},
|
||||
filter,
|
||||
)
|
||||
|
||||
return storage.NewRecordStream(ctx, backend.closed, []storage.RecordStreamGenerator{
|
||||
generator,
|
||||
}, nil), nil
|
||||
}
|
||||
|
||||
func nextScannedRecords(ctx context.Context, backend *Backend, cursor *uint64) ([]*databroker.Record, error) {
|
||||
var values []string
|
||||
var err error
|
||||
values, *cursor, err = backend.client.HScan(ctx, recordHashKey, *cursor, "", 0).Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, storage.ErrStreamDone
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
} else if len(values) == 0 {
|
||||
return nil, storage.ErrStreamDone
|
||||
}
|
||||
|
||||
var records []*databroker.Record
|
||||
for i := 1; i < len(values); i += 2 {
|
||||
var record databroker.Record
|
||||
err := proto.Unmarshal([]byte(values[i]), &record)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
||||
continue
|
||||
}
|
||||
records = append(records, &record)
|
||||
}
|
||||
return records, nil
|
||||
}
|
||||
|
||||
func nextChangedRecord(ctx context.Context, backend *Backend, recordType string, recordVersion *uint64) (*databroker.Record, error) {
|
||||
for {
|
||||
cmd := backend.client.ZRangeByScore(ctx, changesSetKey, &redis.ZRangeBy{
|
||||
Min: fmt.Sprintf("(%d", *recordVersion),
|
||||
Max: "+inf",
|
||||
Offset: 0,
|
||||
Count: 1,
|
||||
})
|
||||
results, err := cmd.Result()
|
||||
if errors.Is(err, redis.Nil) {
|
||||
return nil, storage.ErrStreamDone
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
} else if len(results) == 0 {
|
||||
return nil, storage.ErrStreamDone
|
||||
}
|
||||
|
||||
result := results[0]
|
||||
var record databroker.Record
|
||||
err = proto.Unmarshal([]byte(result), &record)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Err(err).Msg("redis: invalid record detected")
|
||||
*recordVersion++
|
||||
continue
|
||||
}
|
||||
|
||||
*recordVersion = record.GetVersion()
|
||||
if recordType != "" && record.GetType() != recordType {
|
||||
continue
|
||||
}
|
||||
|
||||
return &record, nil
|
||||
}
|
||||
}
|
|
@ -11,6 +11,7 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
"google.golang.org/protobuf/reflect/protoreflect"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
|
@ -37,6 +38,8 @@ type Backend interface {
|
|||
ListTypes(ctx context.Context) ([]string, error)
|
||||
// Put is used to insert or update records.
|
||||
Put(ctx context.Context, records []*databroker.Record) (serverVersion uint64, err error)
|
||||
// Patch is used to update specific fields of existing records.
|
||||
Patch(ctx context.Context, records []*databroker.Record, fields *fieldmaskpb.FieldMask) (serverVersion uint64, patchedRecords []*databroker.Record, err error)
|
||||
// SetOptions sets the options for a type.
|
||||
SetOptions(ctx context.Context, recordType string, options *databroker.Options) error
|
||||
// Sync syncs record changes after the specified version.
|
||||
|
|
187
pkg/storage/storagetest/storagetest.go
Normal file
187
pkg/storage/storagetest/storagetest.go
Normal file
|
@ -0,0 +1,187 @@
|
|||
// Package storagetest contains test cases for use in verifying the behavior of
|
||||
// a storage.Backend implementation.
|
||||
package storagetest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/testutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/session"
|
||||
"github.com/pomerium/pomerium/pkg/storage"
|
||||
)
|
||||
|
||||
// TestBackendPatch verifies the behavior of the backend Patch() method.
|
||||
func TestBackendPatch(t *testing.T, ctx context.Context, backend storage.Backend) { //nolint:revive
|
||||
mkRecord := func(s *session.Session) *databroker.Record {
|
||||
a, _ := anypb.New(s)
|
||||
return &databroker.Record{
|
||||
Type: a.TypeUrl,
|
||||
Id: s.Id,
|
||||
Data: a,
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("basic", func(t *testing.T) {
|
||||
// Populate an initial set of session records.
|
||||
s1 := &session.Session{
|
||||
Id: "session-1",
|
||||
IdToken: &session.IDToken{Issuer: "issuer-1"},
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token-1"},
|
||||
}
|
||||
s2 := &session.Session{
|
||||
Id: "session-2",
|
||||
IdToken: &session.IDToken{Issuer: "issuer-2"},
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token-2"},
|
||||
}
|
||||
s3 := &session.Session{
|
||||
Id: "session-3",
|
||||
IdToken: &session.IDToken{Issuer: "issuer-3"},
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token-3"},
|
||||
}
|
||||
initial := []*databroker.Record{mkRecord(s1), mkRecord(s2), mkRecord(s3)}
|
||||
|
||||
_, err := backend.Put(ctx, initial)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Now patch just the oauth_token field.
|
||||
u1 := &session.Session{
|
||||
Id: "session-1",
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token-1-new"},
|
||||
}
|
||||
u2 := &session.Session{
|
||||
Id: "session-4-does-not-exist",
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token-4-new"},
|
||||
}
|
||||
u3 := &session.Session{
|
||||
Id: "session-3",
|
||||
OauthToken: &session.OAuthToken{AccessToken: "access-token-3-new"},
|
||||
}
|
||||
|
||||
mask, _ := fieldmaskpb.New(&session.Session{}, "oauth_token")
|
||||
|
||||
_, updated, err := backend.Patch(
|
||||
ctx, []*databroker.Record{mkRecord(u1), mkRecord(u2), mkRecord(u3)}, mask)
|
||||
require.NoError(t, err)
|
||||
|
||||
// The OAuthToken message should be updated but the IDToken message should
|
||||
// be unchanged, as it was not included in the field mask. The results
|
||||
// should indicate that only two records were updated (one did not exist).
|
||||
assert.Equal(t, 2, len(updated))
|
||||
assert.Greater(t, updated[0].Version, initial[0].Version)
|
||||
assert.Greater(t, updated[1].Version, initial[2].Version)
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"@type": "type.googleapis.com/session.Session",
|
||||
"id": "session-1",
|
||||
"idToken": {
|
||||
"issuer": "issuer-1"
|
||||
},
|
||||
"oauthToken": {
|
||||
"accessToken": "access-token-1-new"
|
||||
}
|
||||
}`, updated[0].Data)
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"@type": "type.googleapis.com/session.Session",
|
||||
"id": "session-3",
|
||||
"idToken": {
|
||||
"issuer": "issuer-3"
|
||||
},
|
||||
"oauthToken": {
|
||||
"accessToken": "access-token-3-new"
|
||||
}
|
||||
}`, updated[1].Data)
|
||||
|
||||
// Verify that the updates will indeed be seen by a subsequent Get().
|
||||
// Note: first truncate the modified_at timestamps to 1 µs precision, as
|
||||
// that is the maximum precision supported by Postgres.
|
||||
r1, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-1")
|
||||
truncateTimestamps(updated[0].ModifiedAt, r1.ModifiedAt)
|
||||
testutil.AssertProtoEqual(t, updated[0], r1)
|
||||
r3, _ := backend.Get(ctx, "type.googleapis.com/session.Session", "session-3")
|
||||
truncateTimestamps(updated[1].ModifiedAt, r3.ModifiedAt)
|
||||
testutil.AssertProtoEqual(t, updated[1], r3)
|
||||
})
|
||||
|
||||
t.Run("concurrent", func(t *testing.T) {
|
||||
if n := gomaxprocs(); n < 2 {
|
||||
t.Skipf("skipping concurrent test (GOMAXPROCS = %d)", n)
|
||||
}
|
||||
|
||||
rs1 := make([]*databroker.Record, 1)
|
||||
rs2 := make([]*databroker.Record, 1)
|
||||
|
||||
s1 := session.Session{Id: "concurrent", OauthToken: &session.OAuthToken{}}
|
||||
s2 := session.Session{Id: "concurrent", OauthToken: &session.OAuthToken{}}
|
||||
|
||||
// Store an initial version of a session record.
|
||||
rs1[0] = mkRecord(&s1)
|
||||
_, err := backend.Put(ctx, rs1)
|
||||
require.NoError(t, err)
|
||||
|
||||
fmAccessToken, err := fieldmaskpb.New(&session.Session{}, "oauth_token.access_token")
|
||||
require.NoError(t, err)
|
||||
fmRefreshToken, err := fieldmaskpb.New(&session.Session{}, "oauth_token.refresh_token")
|
||||
require.NoError(t, err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Repeatedly make Patch calls to update the session from two separate
|
||||
// goroutines (one updating just the access token, the other updating
|
||||
// just the refresh token.) Verify that no updates are lost.
|
||||
for i := 0; i < 100; i++ {
|
||||
access := fmt.Sprintf("access-%d", i)
|
||||
s1.OauthToken.AccessToken = access
|
||||
rs1[0] = mkRecord(&s1)
|
||||
|
||||
refresh := fmt.Sprintf("refresh-%d", i)
|
||||
s2.OauthToken.RefreshToken = refresh
|
||||
rs2[0] = mkRecord(&s2)
|
||||
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
_, _, _ = backend.Patch(ctx, rs1, fmAccessToken)
|
||||
wg.Done()
|
||||
}()
|
||||
go func() {
|
||||
_, _, _ = backend.Patch(ctx, rs2, fmRefreshToken)
|
||||
wg.Done()
|
||||
}()
|
||||
wg.Wait()
|
||||
|
||||
r, err := backend.Get(ctx, "type.googleapis.com/session.Session", "concurrent")
|
||||
require.NoError(t, err)
|
||||
data, err := r.Data.UnmarshalNew()
|
||||
require.NoError(t, err)
|
||||
s := data.(*session.Session)
|
||||
require.Equal(t, access, s.OauthToken.AccessToken)
|
||||
require.Equal(t, refresh, s.OauthToken.RefreshToken)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// truncateTimestamps truncates Timestamp messages to 1 µs precision.
|
||||
func truncateTimestamps(ts ...*timestamppb.Timestamp) {
|
||||
for _, t := range ts {
|
||||
t.Nanos = (t.Nanos / 1000) * 1000
|
||||
}
|
||||
}
|
||||
|
||||
func gomaxprocs() int {
|
||||
env := os.Getenv("GOMAXPROCS")
|
||||
if n, err := strconv.Atoi(env); err == nil {
|
||||
return n
|
||||
}
|
||||
return runtime.NumCPU()
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue