mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 10:56:28 +02:00
cache: support databroker option changes (#1294)
This commit is contained in:
parent
31205c0c29
commit
a1378c81f8
16 changed files with 408 additions and 179 deletions
97
cache/cache.go
vendored
97
cache/cache.go
vendored
|
@ -15,6 +15,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/directory"
|
"github.com/pomerium/pomerium/internal/directory"
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
"github.com/pomerium/pomerium/internal/identity/manager"
|
"github.com/pomerium/pomerium/internal/identity/manager"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry"
|
"github.com/pomerium/pomerium/internal/telemetry"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
@ -35,18 +36,7 @@ type Cache struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new cache service.
|
// New creates a new cache service.
|
||||||
func New(opts config.Options) (*Cache, error) {
|
func New(cfg *config.Config) (*Cache, error) {
|
||||||
if err := validate(opts); err != nil {
|
|
||||||
return nil, fmt.Errorf("cache: bad option: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
authenticator, err := identity.NewAuthenticator(opts.GetOauthOptions())
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("cache: failed to create authenticator: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
directoryProvider := directory.GetProvider(&opts)
|
|
||||||
|
|
||||||
localListener, err := net.Listen("tcp", "127.0.0.1:0")
|
localListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -56,7 +46,7 @@ func New(opts config.Options) (*Cache, error) {
|
||||||
// if we no longer register with that grpc Server
|
// if we no longer register with that grpc Server
|
||||||
localGRPCServer := grpc.NewServer()
|
localGRPCServer := grpc.NewServer()
|
||||||
|
|
||||||
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.Services)
|
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services)
|
||||||
clientDialOptions := clientStatsHandler.DialOptions(grpc.WithInsecure())
|
clientDialOptions := clientStatsHandler.DialOptions(grpc.WithInsecure())
|
||||||
|
|
||||||
localGRPCConnection, err := grpc.DialContext(
|
localGRPCConnection, err := grpc.DialContext(
|
||||||
|
@ -68,30 +58,33 @@ func New(opts config.Options) (*Cache, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dataBrokerServer, err := NewDataBrokerServer(localGRPCServer, opts)
|
dataBrokerServer := NewDataBrokerServer(localGRPCServer, cfg)
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
dataBrokerClient := databroker.NewDataBrokerServiceClient(localGRPCConnection)
|
|
||||||
|
|
||||||
manager := manager.New(
|
|
||||||
authenticator,
|
|
||||||
directoryProvider,
|
|
||||||
dataBrokerClient,
|
|
||||||
manager.WithGroupRefreshInterval(opts.RefreshDirectoryInterval),
|
|
||||||
manager.WithGroupRefreshTimeout(opts.RefreshDirectoryTimeout),
|
|
||||||
)
|
|
||||||
|
|
||||||
return &Cache{
|
|
||||||
dataBrokerServer: dataBrokerServer,
|
|
||||||
manager: manager,
|
|
||||||
|
|
||||||
|
c := &Cache{
|
||||||
|
dataBrokerServer: dataBrokerServer,
|
||||||
localListener: localListener,
|
localListener: localListener,
|
||||||
localGRPCServer: localGRPCServer,
|
localGRPCServer: localGRPCServer,
|
||||||
localGRPCConnection: localGRPCConnection,
|
localGRPCConnection: localGRPCConnection,
|
||||||
deprecatedCacheClusterDomain: opts.GetDataBrokerURL().Hostname(),
|
deprecatedCacheClusterDomain: cfg.Options.GetDataBrokerURL().Hostname(),
|
||||||
dataBrokerStorageType: opts.DataBrokerStorageType,
|
dataBrokerStorageType: cfg.Options.DataBrokerStorageType,
|
||||||
}, nil
|
}
|
||||||
|
|
||||||
|
err = c.update(cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnConfigChange is called whenever configuration is changed.
|
||||||
|
func (c *Cache) OnConfigChange(cfg *config.Config) {
|
||||||
|
err := c.update(cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("cache: error updating configuration")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.dataBrokerServer.OnConfigChange(cfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register registers all the gRPC services with the given server.
|
// Register registers all the gRPC services with the given server.
|
||||||
|
@ -121,9 +114,45 @@ func (c *Cache) Run(ctx context.Context) error {
|
||||||
return t.Wait()
|
return t.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Cache) update(cfg *config.Config) error {
|
||||||
|
if err := validate(cfg.Options); err != nil {
|
||||||
|
return fmt.Errorf("cache: bad option: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
authenticator, err := identity.NewAuthenticator(cfg.Options.GetOauthOptions())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cache: failed to create authenticator: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
directoryProvider := directory.GetProvider(directory.Options{
|
||||||
|
ServiceAccount: cfg.Options.ServiceAccount,
|
||||||
|
Provider: cfg.Options.Provider,
|
||||||
|
ProviderURL: cfg.Options.ProviderURL,
|
||||||
|
QPS: cfg.Options.QPS,
|
||||||
|
})
|
||||||
|
|
||||||
|
dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection)
|
||||||
|
|
||||||
|
options := []manager.Option{
|
||||||
|
manager.WithAuthenticator(authenticator),
|
||||||
|
manager.WithDirectoryProvider(directoryProvider),
|
||||||
|
manager.WithDataBrokerClient(dataBrokerClient),
|
||||||
|
manager.WithGroupRefreshInterval(cfg.Options.RefreshDirectoryInterval),
|
||||||
|
manager.WithGroupRefreshTimeout(cfg.Options.RefreshDirectoryTimeout),
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.manager == nil {
|
||||||
|
c.manager = manager.New(options...)
|
||||||
|
} else {
|
||||||
|
c.manager.UpdateConfig(options...)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// validate checks that proper configuration settings are set to create
|
// validate checks that proper configuration settings are set to create
|
||||||
// a cache instance
|
// a cache instance
|
||||||
func validate(o config.Options) error {
|
func validate(o *config.Options) error {
|
||||||
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
|
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
|
||||||
return fmt.Errorf("invalid 'SHARED_SECRET': %w", err)
|
return fmt.Errorf("invalid 'SHARED_SECRET': %w", err)
|
||||||
}
|
}
|
||||||
|
|
2
cache/cache_test.go
vendored
2
cache/cache_test.go
vendored
|
@ -30,7 +30,7 @@ func TestNew(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tt.opts.Provider = "google"
|
tt.opts.Provider = "google"
|
||||||
_, err := New(tt.opts)
|
_, err := New(&config.Config{Options: &tt.opts})
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
|
65
cache/databroker.go
vendored
65
cache/databroker.go
vendored
|
@ -1,55 +1,38 @@
|
||||||
package cache
|
package cache
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
|
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
internal_databroker "github.com/pomerium/pomerium/internal/databroker"
|
"github.com/pomerium/pomerium/internal/databroker"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// A DataBrokerServer implements the data broker service interface.
|
// A DataBrokerServer implements the data broker service interface.
|
||||||
type DataBrokerServer struct {
|
type DataBrokerServer struct {
|
||||||
databroker.DataBrokerServiceServer
|
*databroker.Server
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDataBrokerServer creates a new databroker service server.
|
// NewDataBrokerServer creates a new databroker service server.
|
||||||
func NewDataBrokerServer(grpcServer *grpc.Server, opts config.Options) (*DataBrokerServer, error) {
|
func NewDataBrokerServer(grpcServer *grpc.Server, cfg *config.Config) *DataBrokerServer {
|
||||||
key, err := base64.StdEncoding.DecodeString(opts.SharedKey)
|
srv := &DataBrokerServer{}
|
||||||
if err != nil || len(key) != cryptutil.DefaultKeySize {
|
srv.Server = databroker.New(srv.getOptions(cfg)...)
|
||||||
return nil, fmt.Errorf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
|
databrokerpb.RegisterDataBrokerServiceServer(grpcServer, srv)
|
||||||
}
|
return srv
|
||||||
|
}
|
||||||
caCertPool := x509.NewCertPool()
|
|
||||||
if caCert, err := ioutil.ReadFile(opts.DataBrokerStorageCAFile); err == nil {
|
// OnConfigChange updates the underlying databroker server whenever configuration is changed.
|
||||||
caCertPool.AppendCertsFromPEM(caCert)
|
func (srv *DataBrokerServer) OnConfigChange(cfg *config.Config) {
|
||||||
} else {
|
srv.UpdateConfig(srv.getOptions(cfg)...)
|
||||||
log.Warn().Err(err).Msg("failed to read databroker CA file")
|
}
|
||||||
}
|
|
||||||
tlsConfig := &tls.Config{
|
func (srv *DataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption {
|
||||||
RootCAs: caCertPool,
|
return []databroker.ServerOption{
|
||||||
// nolint: gosec
|
databroker.WithSharedKey(cfg.Options.SharedKey),
|
||||||
InsecureSkipVerify: opts.DataBrokerStorageCertSkipVerify,
|
databroker.WithStorageType(cfg.Options.DataBrokerStorageType),
|
||||||
}
|
databroker.WithStorageConnectionString(cfg.Options.DataBrokerStorageConnectionString),
|
||||||
if opts.DataBrokerCertificate != nil {
|
databroker.WithStorageCAFile(cfg.Options.DataBrokerStorageCAFile),
|
||||||
tlsConfig.Certificates = []tls.Certificate{*opts.DataBrokerCertificate}
|
databroker.WithStorageCertificate(cfg.Options.DataBrokerCertificate),
|
||||||
}
|
databroker.WithStorageCertSkipVerify(cfg.Options.DataBrokerStorageCertSkipVerify),
|
||||||
|
}
|
||||||
internalSrv := internal_databroker.New(
|
|
||||||
internal_databroker.WithSecret(key),
|
|
||||||
internal_databroker.WithStorageType(opts.DataBrokerStorageType),
|
|
||||||
internal_databroker.WithStorageConnectionString(opts.DataBrokerStorageConnectionString),
|
|
||||||
internal_databroker.WithStorageTLSConfig(tlsConfig),
|
|
||||||
)
|
|
||||||
srv := &DataBrokerServer{DataBrokerServiceServer: internalSrv}
|
|
||||||
databroker.RegisterDataBrokerServiceServer(grpcServer, srv)
|
|
||||||
return srv, nil
|
|
||||||
}
|
}
|
||||||
|
|
2
cache/databroker_test.go
vendored
2
cache/databroker_test.go
vendored
|
@ -27,7 +27,7 @@ func init() {
|
||||||
lis = bufconn.Listen(bufSize)
|
lis = bufconn.Listen(bufSize)
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
internalSrv := internal_databroker.New()
|
internalSrv := internal_databroker.New()
|
||||||
srv := &DataBrokerServer{DataBrokerServiceServer: internalSrv}
|
srv := &DataBrokerServer{Server: internalSrv}
|
||||||
databroker.RegisterDataBrokerServiceServer(s, srv)
|
databroker.RegisterDataBrokerServiceServer(s, srv)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
|
10
cache/memberlist_test.go
vendored
10
cache/memberlist_test.go
vendored
|
@ -15,10 +15,12 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCache_runMemberList(t *testing.T) {
|
func TestCache_runMemberList(t *testing.T) {
|
||||||
c, err := New(config.Options{
|
c, err := New(&config.Config{
|
||||||
SharedKey: cryptutil.NewBase64Key(),
|
Options: &config.Options{
|
||||||
DataBrokerURL: &url.URL{Scheme: "http", Host: "member1"},
|
SharedKey: cryptutil.NewBase64Key(),
|
||||||
Provider: "google",
|
DataBrokerURL: &url.URL{Scheme: "http", Host: "member1"},
|
||||||
|
Provider: "google",
|
||||||
|
},
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
|
|
@ -88,7 +88,7 @@ func Run(ctx context.Context, configFile string) error {
|
||||||
}
|
}
|
||||||
var cacheServer *cache.Cache
|
var cacheServer *cache.Cache
|
||||||
if config.IsCache(cfg.Options.Services) {
|
if config.IsCache(cfg.Options.Services) {
|
||||||
cacheServer, err = setupCache(cfg.Options, controlPlane)
|
cacheServer, err = setupCache(src, cfg, controlPlane)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -162,13 +162,15 @@ func setupAuthorize(src config.Source, cfg *config.Config, controlPlane *control
|
||||||
return svc, nil
|
return svc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupCache(opt *config.Options, controlPlane *controlplane.Server) (*cache.Cache, error) {
|
func setupCache(src config.Source, cfg *config.Config, controlPlane *controlplane.Server) (*cache.Cache, error) {
|
||||||
svc, err := cache.New(*opt)
|
svc, err := cache.New(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating config service: %w", err)
|
return nil, fmt.Errorf("error creating config service: %w", err)
|
||||||
}
|
}
|
||||||
svc.Register(controlPlane.GRPCServer)
|
svc.Register(controlPlane.GRPCServer)
|
||||||
log.Info().Msg("enabled cache service")
|
log.Info().Msg("enabled cache service")
|
||||||
|
src.OnConfigChange(svc.OnConfigChange)
|
||||||
|
svc.OnConfigChange(cfg)
|
||||||
return svc, nil
|
return svc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,11 @@ package databroker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/base64"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
@ -21,7 +25,9 @@ type serverConfig struct {
|
||||||
secret []byte
|
secret []byte
|
||||||
storageType string
|
storageType string
|
||||||
storageConnectionString string
|
storageConnectionString string
|
||||||
storageTLSConfig *tls.Config
|
storageCAFile string
|
||||||
|
storageCertSkipVerify bool
|
||||||
|
storageCertificate *tls.Certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
func newServerConfig(options ...ServerOption) *serverConfig {
|
func newServerConfig(options ...ServerOption) *serverConfig {
|
||||||
|
@ -54,10 +60,15 @@ func WithDeletePermanentlyAfter(dur time.Duration) ServerOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithSecret sets the secret in the config.
|
// WithSharedKey sets the secret in the config.
|
||||||
func WithSecret(secret []byte) ServerOption {
|
func WithSharedKey(sharedKey string) ServerOption {
|
||||||
return func(cfg *serverConfig) {
|
return func(cfg *serverConfig) {
|
||||||
cfg.secret = secret
|
key, err := base64.StdEncoding.DecodeString(sharedKey)
|
||||||
|
if err != nil || len(key) != cryptutil.DefaultKeySize {
|
||||||
|
log.Error().Err(err).Msgf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cfg.secret = key
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,9 +86,23 @@ func WithStorageConnectionString(connStr string) ServerOption {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithStorageTLSConfig sets the tls config for connection to storage.
|
// WithStorageCAFile sets the CA file in the config.
|
||||||
func WithStorageTLSConfig(tlsConfig *tls.Config) ServerOption {
|
func WithStorageCAFile(filePath string) ServerOption {
|
||||||
return func(cfg *serverConfig) {
|
return func(cfg *serverConfig) {
|
||||||
cfg.storageTLSConfig = tlsConfig
|
cfg.storageCAFile = filePath
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithStorageCertSkipVerify sets the storageCertSkipVerify in the config.
|
||||||
|
func WithStorageCertSkipVerify(storageCertSkipVerify bool) ServerOption {
|
||||||
|
return func(cfg *serverConfig) {
|
||||||
|
cfg.storageCertSkipVerify = storageCertSkipVerify
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithStorageCertificate sets the storageCertificate in the config.
|
||||||
|
func WithStorageCertificate(certificate *tls.Certificate) ServerOption {
|
||||||
|
return func(cfg *serverConfig) {
|
||||||
|
cfg.storageCertificate = certificate
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,7 +3,10 @@ package databroker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"io/ioutil"
|
||||||
"reflect"
|
"reflect"
|
||||||
"sort"
|
"sort"
|
||||||
"sync"
|
"sync"
|
||||||
|
@ -11,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/golang/protobuf/ptypes"
|
"github.com/golang/protobuf/ptypes"
|
||||||
"github.com/golang/protobuf/ptypes/empty"
|
"github.com/golang/protobuf/ptypes/empty"
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
|
@ -33,35 +37,39 @@ const (
|
||||||
syncBatchSize = 100
|
syncBatchSize = 100
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// newUUID returns a new UUID. This make it easy to stub out in tests.
|
||||||
|
var newUUID = uuid.New
|
||||||
|
|
||||||
// Server implements the databroker service using an in memory database.
|
// Server implements the databroker service using an in memory database.
|
||||||
type Server struct {
|
type Server struct {
|
||||||
version string
|
cfg *serverConfig
|
||||||
cfg *serverConfig
|
log zerolog.Logger
|
||||||
log zerolog.Logger
|
|
||||||
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
version string
|
||||||
byType map[string]storage.Backend
|
byType map[string]storage.Backend
|
||||||
onTypechange *signal.Signal
|
onTypechange *signal.Signal
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new server.
|
// New creates a new server.
|
||||||
func New(options ...ServerOption) *Server {
|
func New(options ...ServerOption) *Server {
|
||||||
cfg := newServerConfig(options...)
|
|
||||||
srv := &Server{
|
srv := &Server{
|
||||||
version: uuid.New().String(),
|
log: log.With().Str("service", "databroker").Logger(),
|
||||||
cfg: cfg,
|
|
||||||
log: log.With().Str("service", "databroker").Logger(),
|
|
||||||
|
|
||||||
byType: make(map[string]storage.Backend),
|
byType: make(map[string]storage.Backend),
|
||||||
onTypechange: signal.New(),
|
onTypechange: signal.New(),
|
||||||
}
|
}
|
||||||
srv.initVersion()
|
srv.UpdateConfig(options...)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
ticker := time.NewTicker(cfg.deletePermanentlyAfter / 2)
|
ticker := time.NewTicker(time.Minute)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
for range ticker.C {
|
for range ticker.C {
|
||||||
|
srv.mu.RLock()
|
||||||
|
tm := time.Now().Add(-srv.cfg.deletePermanentlyAfter)
|
||||||
|
srv.mu.RUnlock()
|
||||||
|
|
||||||
var recordTypes []string
|
var recordTypes []string
|
||||||
srv.mu.RLock()
|
srv.mu.RLock()
|
||||||
for recordType := range srv.byType {
|
for recordType := range srv.byType {
|
||||||
|
@ -70,11 +78,11 @@ func New(options ...ServerOption) *Server {
|
||||||
srv.mu.RUnlock()
|
srv.mu.RUnlock()
|
||||||
|
|
||||||
for _, recordType := range recordTypes {
|
for _, recordType := range recordTypes {
|
||||||
db, err := srv.getDB(recordType)
|
db, _, err := srv.getDB(recordType, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
db.ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
|
db.ClearDeleted(context.Background(), tm)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -82,7 +90,7 @@ func New(options ...ServerOption) *Server {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) initVersion() {
|
func (srv *Server) initVersion() {
|
||||||
dbServerVersion, err := srv.getDB(recordTypeServerVersion)
|
dbServerVersion, _, err := srv.getDB(recordTypeServerVersion, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("failed to init server version")
|
log.Error().Err(err).Msg("failed to init server version")
|
||||||
return
|
return
|
||||||
|
@ -98,12 +106,36 @@ func (srv *Server) initVersion() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
srv.version = newUUID().String()
|
||||||
data, _ := ptypes.MarshalAny(&databroker.ServerVersion{Version: srv.version})
|
data, _ := ptypes.MarshalAny(&databroker.ServerVersion{Version: srv.version})
|
||||||
if err := dbServerVersion.Put(context.Background(), serverVersionKey, data); err != nil {
|
if err := dbServerVersion.Put(context.Background(), serverVersionKey, data); err != nil {
|
||||||
srv.log.Warn().Err(err).Msg("failed to save server version.")
|
srv.log.Warn().Err(err).Msg("failed to save server version.")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateConfig updates the server with the new options.
|
||||||
|
func (srv *Server) UpdateConfig(options ...ServerOption) {
|
||||||
|
srv.mu.Lock()
|
||||||
|
defer srv.mu.Unlock()
|
||||||
|
|
||||||
|
cfg := newServerConfig(options...)
|
||||||
|
if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) {
|
||||||
|
log.Debug().Msg("databroker: no changes detected, re-using existing DBs")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
srv.cfg = cfg
|
||||||
|
|
||||||
|
for t, db := range srv.byType {
|
||||||
|
err := db.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Warn().Err(err).Msg("databroker: error closing backend")
|
||||||
|
}
|
||||||
|
delete(srv.byType, t)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.initVersion()
|
||||||
|
}
|
||||||
|
|
||||||
// Delete deletes a record from the in-memory list.
|
// Delete deletes a record from the in-memory list.
|
||||||
func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*empty.Empty, error) {
|
func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*empty.Empty, error) {
|
||||||
_, span := trace.StartSpan(ctx, "databroker.grpc.Delete")
|
_, span := trace.StartSpan(ctx, "databroker.grpc.Delete")
|
||||||
|
@ -113,7 +145,7 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*
|
||||||
Str("id", req.GetId()).
|
Str("id", req.GetId()).
|
||||||
Msg("delete")
|
Msg("delete")
|
||||||
|
|
||||||
db, err := srv.getDB(req.GetType())
|
db, _, err := srv.getDB(req.GetType(), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -134,7 +166,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
||||||
Str("id", req.GetId()).
|
Str("id", req.GetId()).
|
||||||
Msg("get")
|
Msg("get")
|
||||||
|
|
||||||
db, err := srv.getDB(req.GetType())
|
db, _, err := srv.getDB(req.GetType(), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -156,7 +188,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
||||||
Str("type", req.GetType()).
|
Str("type", req.GetType()).
|
||||||
Msg("get all")
|
Msg("get all")
|
||||||
|
|
||||||
db, err := srv.getDB(req.GetType())
|
db, version, err := srv.getDB(req.GetType(), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -167,7 +199,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(all) == 0 {
|
if len(all) == 0 {
|
||||||
return &databroker.GetAllResponse{ServerVersion: srv.version}, nil
|
return &databroker.GetAllResponse{ServerVersion: version}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var recordVersion string
|
var recordVersion string
|
||||||
|
@ -182,7 +214,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
||||||
}
|
}
|
||||||
|
|
||||||
return &databroker.GetAllResponse{
|
return &databroker.GetAllResponse{
|
||||||
ServerVersion: srv.version,
|
ServerVersion: version,
|
||||||
RecordVersion: recordVersion,
|
RecordVersion: recordVersion,
|
||||||
Records: records,
|
Records: records,
|
||||||
}, nil
|
}, nil
|
||||||
|
@ -197,7 +229,7 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
||||||
Str("id", req.GetId()).
|
Str("id", req.GetId()).
|
||||||
Msg("set")
|
Msg("set")
|
||||||
|
|
||||||
db, err := srv.getDB(req.GetType())
|
db, version, err := srv.getDB(req.GetType(), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -210,11 +242,13 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
||||||
}
|
}
|
||||||
return &databroker.SetResponse{
|
return &databroker.SetResponse{
|
||||||
Record: record,
|
Record: record,
|
||||||
ServerVersion: srv.version,
|
ServerVersion: version,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage.Backend, stream databroker.DataBrokerService_SyncServer) error {
|
func (srv *Server) doSync(ctx context.Context,
|
||||||
|
serverVersion string, recordVersion *string,
|
||||||
|
db storage.Backend, stream databroker.DataBrokerService_SyncServer) error {
|
||||||
updated, err := db.List(ctx, *recordVersion)
|
updated, err := db.List(ctx, *recordVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -232,7 +266,7 @@ func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage
|
||||||
j = len(updated)
|
j = len(updated)
|
||||||
}
|
}
|
||||||
if err := stream.Send(&databroker.SyncResponse{
|
if err := stream.Send(&databroker.SyncResponse{
|
||||||
ServerVersion: srv.version,
|
ServerVersion: serverVersion,
|
||||||
Records: updated[i:j],
|
Records: updated[i:j],
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -251,34 +285,34 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
||||||
Str("record_version", req.GetRecordVersion()).
|
Str("record_version", req.GetRecordVersion()).
|
||||||
Msg("sync")
|
Msg("sync")
|
||||||
|
|
||||||
|
db, serverVersion, err := srv.getDB(req.GetType(), true)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
recordVersion := req.GetRecordVersion()
|
recordVersion := req.GetRecordVersion()
|
||||||
// reset record version if the server versions don't match
|
// reset record version if the server versions don't match
|
||||||
if req.GetServerVersion() != srv.version {
|
if req.GetServerVersion() != serverVersion {
|
||||||
recordVersion = ""
|
recordVersion = ""
|
||||||
// send the new server version to the client
|
// send the new server version to the client
|
||||||
err := stream.Send(&databroker.SyncResponse{
|
err := stream.Send(&databroker.SyncResponse{
|
||||||
ServerVersion: srv.version,
|
ServerVersion: serverVersion,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
db, err := srv.getDB(req.GetType())
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := stream.Context()
|
ctx := stream.Context()
|
||||||
ch := db.Watch(ctx)
|
ch := db.Watch(ctx)
|
||||||
|
|
||||||
// Do first sync, so we won't missed anything.
|
// Do first sync, so we won't missed anything.
|
||||||
if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil {
|
if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
for range ch {
|
for range ch {
|
||||||
if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil {
|
if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -335,29 +369,56 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) getDB(recordType string) (storage.Backend, error) {
|
func (srv *Server) getDB(recordType string, lock bool) (db storage.Backend, version string, err error) {
|
||||||
// double-checked locking:
|
// double-checked locking:
|
||||||
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
|
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
|
||||||
srv.mu.RLock()
|
if lock {
|
||||||
db := srv.byType[recordType]
|
srv.mu.RLock()
|
||||||
srv.mu.RUnlock()
|
}
|
||||||
|
db = srv.byType[recordType]
|
||||||
|
version = srv.version
|
||||||
|
if lock {
|
||||||
|
srv.mu.RUnlock()
|
||||||
|
}
|
||||||
if db == nil {
|
if db == nil {
|
||||||
srv.mu.Lock()
|
if lock {
|
||||||
|
srv.mu.Lock()
|
||||||
|
}
|
||||||
db = srv.byType[recordType]
|
db = srv.byType[recordType]
|
||||||
|
version = srv.version
|
||||||
var err error
|
var err error
|
||||||
if db == nil {
|
if db == nil {
|
||||||
db, err = srv.newDB(recordType)
|
db, err = srv.newDB(recordType)
|
||||||
srv.byType[recordType] = db
|
srv.byType[recordType] = db
|
||||||
}
|
}
|
||||||
srv.mu.Unlock()
|
if lock {
|
||||||
|
srv.mu.Unlock()
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return db, nil
|
return db, version, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
|
func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
|
||||||
|
caCertPool := x509.NewCertPool()
|
||||||
|
if srv.cfg.storageCAFile != "" {
|
||||||
|
if caCert, err := ioutil.ReadFile(srv.cfg.storageCAFile); err == nil {
|
||||||
|
caCertPool.AppendCertsFromPEM(caCert)
|
||||||
|
} else {
|
||||||
|
log.Warn().Err(err).Msg("failed to read databroker CA file")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tlsConfig := &tls.Config{
|
||||||
|
RootCAs: caCertPool,
|
||||||
|
// nolint: gosec
|
||||||
|
InsecureSkipVerify: srv.cfg.storageCertSkipVerify,
|
||||||
|
}
|
||||||
|
if srv.cfg.storageCertificate != nil {
|
||||||
|
tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate}
|
||||||
|
}
|
||||||
|
|
||||||
switch srv.cfg.storageType {
|
switch srv.cfg.storageType {
|
||||||
case config.StorageInMemoryName:
|
case config.StorageInMemoryName:
|
||||||
return inmemory.NewDB(recordType, srv.cfg.btreeDegree), nil
|
return inmemory.NewDB(recordType, srv.cfg.btreeDegree), nil
|
||||||
|
@ -366,7 +427,7 @@ func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
|
||||||
srv.cfg.storageConnectionString,
|
srv.cfg.storageConnectionString,
|
||||||
recordType,
|
recordType,
|
||||||
int64(srv.cfg.deletePermanentlyAfter.Seconds()),
|
int64(srv.cfg.deletePermanentlyAfter.Seconds()),
|
||||||
redis.WithTLSConfig(srv.cfg.storageTLSConfig),
|
redis.WithTLSConfig(tlsConfig),
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
|
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
|
||||||
|
|
|
@ -33,36 +33,46 @@ func newServer(cfg *serverConfig) *Server {
|
||||||
func TestServer_initVersion(t *testing.T) {
|
func TestServer_initVersion(t *testing.T) {
|
||||||
cfg := newServerConfig()
|
cfg := newServerConfig()
|
||||||
t.Run("nil db", func(t *testing.T) {
|
t.Run("nil db", func(t *testing.T) {
|
||||||
|
srvVersion := uuid.New()
|
||||||
|
oldNewUUID := newUUID
|
||||||
|
newUUID = func() uuid.UUID {
|
||||||
|
return srvVersion
|
||||||
|
}
|
||||||
|
defer func() { newUUID = oldNewUUID }()
|
||||||
|
|
||||||
srv := newServer(cfg)
|
srv := newServer(cfg)
|
||||||
srvVersion := uuid.New().String()
|
|
||||||
srv.version = srvVersion
|
|
||||||
srv.byType[recordTypeServerVersion] = nil
|
srv.byType[recordTypeServerVersion] = nil
|
||||||
srv.initVersion()
|
srv.initVersion()
|
||||||
assert.Equal(t, srvVersion, srv.version)
|
assert.Equal(t, srvVersion.String(), srv.version)
|
||||||
})
|
})
|
||||||
t.Run("new server with random version", func(t *testing.T) {
|
t.Run("new server with random version", func(t *testing.T) {
|
||||||
|
srvVersion := uuid.New()
|
||||||
|
oldNewUUID := newUUID
|
||||||
|
newUUID = func() uuid.UUID {
|
||||||
|
return srvVersion
|
||||||
|
}
|
||||||
|
defer func() { newUUID = oldNewUUID }()
|
||||||
|
|
||||||
srv := newServer(cfg)
|
srv := newServer(cfg)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
db, err := srv.getDB(recordTypeServerVersion)
|
db, _, err := srv.getDB(recordTypeServerVersion, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
r, err := db.Get(ctx, serverVersionKey)
|
r, err := db.Get(ctx, serverVersionKey)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
assert.Nil(t, r)
|
assert.Nil(t, r)
|
||||||
srvVersion := uuid.New().String()
|
|
||||||
srv.version = srvVersion
|
|
||||||
srv.initVersion()
|
srv.initVersion()
|
||||||
assert.Equal(t, srvVersion, srv.version)
|
assert.Equal(t, srvVersion.String(), srv.version)
|
||||||
r, err = db.Get(ctx, serverVersionKey)
|
r, err = db.Get(ctx, serverVersionKey)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.NotNil(t, r)
|
assert.NotNil(t, r)
|
||||||
var sv databroker.ServerVersion
|
var sv databroker.ServerVersion
|
||||||
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
|
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
|
||||||
assert.Equal(t, srvVersion, sv.Version)
|
assert.Equal(t, srvVersion.String(), sv.Version)
|
||||||
})
|
})
|
||||||
t.Run("init version twice should get the same version", func(t *testing.T) {
|
t.Run("init version twice should get the same version", func(t *testing.T) {
|
||||||
srv := newServer(cfg)
|
srv := newServer(cfg)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
db, err := srv.getDB(recordTypeServerVersion)
|
db, _, err := srv.getDB(recordTypeServerVersion, false)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
r, err := db.Get(ctx, serverVersionKey)
|
r, err := db.Get(ctx, serverVersionKey)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
|
|
|
@ -4,8 +4,10 @@ package directory
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
|
||||||
"github.com/pomerium/pomerium/internal/directory/azure"
|
"github.com/pomerium/pomerium/internal/directory/azure"
|
||||||
"github.com/pomerium/pomerium/internal/directory/github"
|
"github.com/pomerium/pomerium/internal/directory/github"
|
||||||
"github.com/pomerium/pomerium/internal/directory/gitlab"
|
"github.com/pomerium/pomerium/internal/directory/gitlab"
|
||||||
|
@ -27,8 +29,34 @@ type Provider interface {
|
||||||
UserGroups(ctx context.Context) ([]*Group, []*User, error)
|
UserGroups(ctx context.Context) ([]*Group, []*User, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Options are the options specific to the provider.
|
||||||
|
type Options struct {
|
||||||
|
ServiceAccount string
|
||||||
|
Provider string
|
||||||
|
ProviderURL string
|
||||||
|
QPS float64
|
||||||
|
}
|
||||||
|
|
||||||
|
var globalProvider = struct {
|
||||||
|
sync.Mutex
|
||||||
|
provider Provider
|
||||||
|
options Options
|
||||||
|
}{}
|
||||||
|
|
||||||
// GetProvider gets the provider for the given options.
|
// GetProvider gets the provider for the given options.
|
||||||
func GetProvider(options *config.Options) Provider {
|
func GetProvider(options Options) (provider Provider) {
|
||||||
|
globalProvider.Lock()
|
||||||
|
defer globalProvider.Unlock()
|
||||||
|
|
||||||
|
if globalProvider.provider != nil && cmp.Equal(globalProvider.options, options) {
|
||||||
|
log.Debug().Str("provider", options.Provider).Msg("directory: no change detected, reusing existing directory provider")
|
||||||
|
return globalProvider.provider
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
globalProvider.provider = provider
|
||||||
|
globalProvider.options = options
|
||||||
|
}()
|
||||||
|
|
||||||
switch options.Provider {
|
switch options.Provider {
|
||||||
case azure.Name:
|
case azure.Name:
|
||||||
serviceAccount, err := azure.ParseServiceAccount(options.ServiceAccount)
|
serviceAccount, err := azure.ParseServiceAccount(options.ServiceAccount)
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
package manager
|
package manager
|
||||||
|
|
||||||
import "time"
|
import (
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/directory"
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
defaultGroupRefreshInterval = 10 * time.Minute
|
defaultGroupRefreshInterval = 10 * time.Minute
|
||||||
|
@ -10,6 +16,9 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
type config struct {
|
type config struct {
|
||||||
|
authenticator Authenticator
|
||||||
|
directory directory.Provider
|
||||||
|
dataBrokerClient databroker.DataBrokerServiceClient
|
||||||
groupRefreshInterval time.Duration
|
groupRefreshInterval time.Duration
|
||||||
groupRefreshTimeout time.Duration
|
groupRefreshTimeout time.Duration
|
||||||
sessionRefreshGracePeriod time.Duration
|
sessionRefreshGracePeriod time.Duration
|
||||||
|
@ -31,6 +40,27 @@ func newConfig(options ...Option) *config {
|
||||||
// An Option customizes the configuration used for the identity manager.
|
// An Option customizes the configuration used for the identity manager.
|
||||||
type Option func(*config)
|
type Option func(*config)
|
||||||
|
|
||||||
|
// WithAuthenticator sets the authenticator in the config.
|
||||||
|
func WithAuthenticator(authenticator Authenticator) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.authenticator = authenticator
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDirectoryProvider sets the directory provider in the config.
|
||||||
|
func WithDirectoryProvider(directoryProvider directory.Provider) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.directory = directoryProvider
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithDataBrokerClient sets the databroker client in the config.
|
||||||
|
func WithDataBrokerClient(dataBrokerClient databroker.DataBrokerServiceClient) Option {
|
||||||
|
return func(cfg *config) {
|
||||||
|
cfg.dataBrokerClient = dataBrokerClient
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// WithGroupRefreshInterval sets the group refresh interval used by the manager.
|
// WithGroupRefreshInterval sets the group refresh interval used by the manager.
|
||||||
func WithGroupRefreshInterval(interval time.Duration) Option {
|
func WithGroupRefreshInterval(interval time.Duration) Option {
|
||||||
return func(cfg *config) {
|
return func(cfg *config) {
|
||||||
|
@ -58,3 +88,21 @@ func WithSessionRefreshCoolOffDuration(dur time.Duration) Option {
|
||||||
cfg.sessionRefreshCoolOffDuration = dur
|
cfg.sessionRefreshCoolOffDuration = dur
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type atomicConfig struct {
|
||||||
|
value atomic.Value
|
||||||
|
}
|
||||||
|
|
||||||
|
func newAtomicConfig(cfg *config) *atomicConfig {
|
||||||
|
ac := new(atomicConfig)
|
||||||
|
ac.Store(cfg)
|
||||||
|
return ac
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ac *atomicConfig) Load() *config {
|
||||||
|
return ac.value.Load().(*config)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ac *atomicConfig) Store(cfg *config) {
|
||||||
|
ac.value.Store(cfg)
|
||||||
|
}
|
||||||
|
|
|
@ -42,11 +42,8 @@ type (
|
||||||
|
|
||||||
// A Manager refreshes identity information using session and user data.
|
// A Manager refreshes identity information using session and user data.
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
cfg *config
|
cfg *atomicConfig
|
||||||
authenticator Authenticator
|
log zerolog.Logger
|
||||||
directory directory.Provider
|
|
||||||
dataBrokerClient databroker.DataBrokerServiceClient
|
|
||||||
log zerolog.Logger
|
|
||||||
|
|
||||||
sessions sessionCollection
|
sessions sessionCollection
|
||||||
sessionScheduler *scheduler.Scheduler
|
sessionScheduler *scheduler.Scheduler
|
||||||
|
@ -67,17 +64,11 @@ type Manager struct {
|
||||||
|
|
||||||
// New creates a new identity manager.
|
// New creates a new identity manager.
|
||||||
func New(
|
func New(
|
||||||
authenticator Authenticator,
|
|
||||||
directoryProvider directory.Provider,
|
|
||||||
dataBrokerClient databroker.DataBrokerServiceClient,
|
|
||||||
options ...Option,
|
options ...Option,
|
||||||
) *Manager {
|
) *Manager {
|
||||||
mgr := &Manager{
|
mgr := &Manager{
|
||||||
cfg: newConfig(options...),
|
cfg: newAtomicConfig(newConfig()),
|
||||||
authenticator: authenticator,
|
log: log.With().Str("service", "identity_manager").Logger(),
|
||||||
directory: directoryProvider,
|
|
||||||
dataBrokerClient: dataBrokerClient,
|
|
||||||
log: log.With().Str("service", "identity_manager").Logger(),
|
|
||||||
|
|
||||||
sessions: sessionCollection{
|
sessions: sessionCollection{
|
||||||
BTree: btree.New(8),
|
BTree: btree.New(8),
|
||||||
|
@ -88,9 +79,15 @@ func New(
|
||||||
},
|
},
|
||||||
userScheduler: scheduler.New(),
|
userScheduler: scheduler.New(),
|
||||||
}
|
}
|
||||||
|
mgr.UpdateConfig(options...)
|
||||||
return mgr
|
return mgr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateConfig updates the manager with the new options.
|
||||||
|
func (mgr *Manager) UpdateConfig(options ...Option) {
|
||||||
|
mgr.cfg.Store(newConfig(options...))
|
||||||
|
}
|
||||||
|
|
||||||
// Run runs the manager. This method blocks until an error occurs or the given context is canceled.
|
// Run runs the manager. This method blocks until an error occurs or the given context is canceled.
|
||||||
func (mgr *Manager) Run(ctx context.Context) error {
|
func (mgr *Manager) Run(ctx context.Context) error {
|
||||||
err := mgr.initDirectoryGroups(ctx)
|
err := mgr.initDirectoryGroups(ctx)
|
||||||
|
@ -169,7 +166,7 @@ func (mgr *Manager) refreshLoop(
|
||||||
// refresh groups
|
// refresh groups
|
||||||
if mgr.directoryNextRefresh.Before(now) {
|
if mgr.directoryNextRefresh.Before(now) {
|
||||||
mgr.refreshDirectoryUserGroups(ctx)
|
mgr.refreshDirectoryUserGroups(ctx)
|
||||||
mgr.directoryNextRefresh = now.Add(mgr.cfg.groupRefreshInterval)
|
mgr.directoryNextRefresh = now.Add(mgr.cfg.Load().groupRefreshInterval)
|
||||||
if mgr.directoryNextRefresh.Before(nextTime) {
|
if mgr.directoryNextRefresh.Before(nextTime) {
|
||||||
nextTime = mgr.directoryNextRefresh
|
nextTime = mgr.directoryNextRefresh
|
||||||
}
|
}
|
||||||
|
@ -211,10 +208,10 @@ func (mgr *Manager) refreshLoop(
|
||||||
func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) {
|
func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) {
|
||||||
mgr.log.Info().Msg("refreshing directory users")
|
mgr.log.Info().Msg("refreshing directory users")
|
||||||
|
|
||||||
ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.groupRefreshTimeout)
|
ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.Load().groupRefreshTimeout)
|
||||||
defer clearTimeout()
|
defer clearTimeout()
|
||||||
|
|
||||||
directoryGroups, directoryUsers, err := mgr.directory.UserGroups(ctx)
|
directoryGroups, directoryUsers, err := mgr.cfg.Load().directory.UserGroups(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mgr.log.Warn().Err(err).Msg("failed to refresh directory users and groups")
|
mgr.log.Warn().Err(err).Msg("failed to refresh directory users and groups")
|
||||||
return
|
return
|
||||||
|
@ -238,7 +235,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
|
||||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
|
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = mgr.dataBrokerClient.Set(ctx, &databroker.SetRequest{
|
_, err = mgr.cfg.Load().dataBrokerClient.Set(ctx, &databroker.SetRequest{
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
Id: newDG.GetId(),
|
Id: newDG.GetId(),
|
||||||
Data: any,
|
Data: any,
|
||||||
|
@ -258,7 +255,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
|
||||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
|
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = mgr.dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
|
_, err = mgr.cfg.Load().dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
Id: curDG.GetId(),
|
Id: curDG.GetId(),
|
||||||
})
|
})
|
||||||
|
@ -284,7 +281,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
|
||||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
|
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = mgr.dataBrokerClient.Set(ctx, &databroker.SetRequest{
|
_, err = mgr.cfg.Load().dataBrokerClient.Set(ctx, &databroker.SetRequest{
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
Id: newDU.GetId(),
|
Id: newDU.GetId(),
|
||||||
Data: any,
|
Data: any,
|
||||||
|
@ -304,7 +301,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
|
||||||
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
|
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
_, err = mgr.dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
|
_, err = mgr.cfg.Load().dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
Id: curDU.GetId(),
|
Id: curDU.GetId(),
|
||||||
})
|
})
|
||||||
|
@ -349,7 +346,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
newToken, err := mgr.authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s)
|
newToken, err := mgr.cfg.Load().authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s)
|
||||||
if isTemporaryError(err) {
|
if isTemporaryError(err) {
|
||||||
mgr.log.Error().Err(err).
|
mgr.log.Error().Err(err).
|
||||||
Str("user_id", s.GetUserId()).
|
Str("user_id", s.GetUserId()).
|
||||||
|
@ -366,7 +363,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
||||||
}
|
}
|
||||||
s.OauthToken = ToOAuthToken(newToken)
|
s.OauthToken = ToOAuthToken(newToken)
|
||||||
|
|
||||||
res, err := session.Set(ctx, mgr.dataBrokerClient, s.Session)
|
res, err := session.Set(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mgr.log.Error().Err(err).
|
mgr.log.Error().Err(err).
|
||||||
Str("user_id", s.GetUserId()).
|
Str("user_id", s.GetUserId()).
|
||||||
|
@ -401,7 +398,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := mgr.authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u)
|
err := mgr.cfg.Load().authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u)
|
||||||
if isTemporaryError(err) {
|
if isTemporaryError(err) {
|
||||||
mgr.log.Error().Err(err).
|
mgr.log.Error().Err(err).
|
||||||
Str("user_id", s.GetUserId()).
|
Str("user_id", s.GetUserId()).
|
||||||
|
@ -417,7 +414,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
record, err := user.Set(ctx, mgr.dataBrokerClient, u.User)
|
record, err := user.Set(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mgr.log.Error().Err(err).
|
mgr.log.Error().Err(err).
|
||||||
Str("user_id", s.GetUserId()).
|
Str("user_id", s.GetUserId()).
|
||||||
|
@ -438,7 +435,7 @@ func (mgr *Manager) syncSessions(ctx context.Context, ch chan<- sessionMessage)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -474,7 +471,7 @@ func (mgr *Manager) syncUsers(ctx context.Context, ch chan<- userMessage) error
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -510,7 +507,7 @@ func (mgr *Manager) initDirectoryUsers(ctx context.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := mgr.dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
|
res, err := mgr.cfg.Load().dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -541,7 +538,7 @@ func (mgr *Manager) syncDirectoryUsers(ctx context.Context, ch chan<- *directory
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
ServerVersion: mgr.directoryUsersServerVersion,
|
ServerVersion: mgr.directoryUsersServerVersion,
|
||||||
RecordVersion: mgr.directoryUsersRecordVersion,
|
RecordVersion: mgr.directoryUsersRecordVersion,
|
||||||
|
@ -579,7 +576,7 @@ func (mgr *Manager) initDirectoryGroups(ctx context.Context) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := mgr.dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
|
res, err := mgr.cfg.Load().dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -610,7 +607,7 @@ func (mgr *Manager) syncDirectoryGroups(ctx context.Context, ch chan<- *director
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||||
Type: any.GetTypeUrl(),
|
Type: any.GetTypeUrl(),
|
||||||
ServerVersion: mgr.directoryGroupsServerVersion,
|
ServerVersion: mgr.directoryGroupsServerVersion,
|
||||||
RecordVersion: mgr.directoryGroupsRecordVersion,
|
RecordVersion: mgr.directoryGroupsRecordVersion,
|
||||||
|
@ -652,8 +649,8 @@ func (mgr *Manager) onUpdateSession(ctx context.Context, msg sessionMessage) {
|
||||||
// update session
|
// update session
|
||||||
s, _ := mgr.sessions.Get(msg.session.GetUserId(), msg.session.GetId())
|
s, _ := mgr.sessions.Get(msg.session.GetUserId(), msg.session.GetId())
|
||||||
s.lastRefresh = time.Now()
|
s.lastRefresh = time.Now()
|
||||||
s.gracePeriod = mgr.cfg.sessionRefreshGracePeriod
|
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
|
||||||
s.coolOffDuration = mgr.cfg.sessionRefreshCoolOffDuration
|
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
|
||||||
s.Session = msg.session
|
s.Session = msg.session
|
||||||
mgr.sessions.ReplaceOrInsert(s)
|
mgr.sessions.ReplaceOrInsert(s)
|
||||||
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(msg.session.GetUserId(), msg.session.GetId()))
|
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(msg.session.GetUserId(), msg.session.GetId()))
|
||||||
|
@ -676,7 +673,7 @@ func (mgr *Manager) onUpdateUser(_ context.Context, msg userMessage) {
|
||||||
// only reset the refresh time if this is an existing user
|
// only reset the refresh time if this is an existing user
|
||||||
u.lastRefresh = time.Now()
|
u.lastRefresh = time.Now()
|
||||||
}
|
}
|
||||||
u.refreshInterval = mgr.cfg.groupRefreshInterval
|
u.refreshInterval = mgr.cfg.Load().groupRefreshInterval
|
||||||
u.User = msg.user
|
u.User = msg.user
|
||||||
mgr.users.ReplaceOrInsert(u)
|
mgr.users.ReplaceOrInsert(u)
|
||||||
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
|
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
|
||||||
|
@ -697,7 +694,7 @@ func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session)
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err := user.Set(ctx, mgr.dataBrokerClient, u.User)
|
_, err := user.Set(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mgr.log.Error().Err(err).
|
mgr.log.Error().Err(err).
|
||||||
Str("user_id", pbSession.GetUserId()).
|
Str("user_id", pbSession.GetUserId()).
|
||||||
|
@ -707,7 +704,7 @@ func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Session) {
|
func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Session) {
|
||||||
err := session.Delete(ctx, mgr.dataBrokerClient, pbSession.GetId())
|
err := session.Delete(ctx, mgr.cfg.Load().dataBrokerClient, pbSession.GetId())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
mgr.log.Error().Err(err).
|
mgr.log.Error().Err(err).
|
||||||
Str("session_id", pbSession.GetId()).
|
Str("session_id", pbSession.GetId()).
|
||||||
|
|
|
@ -52,6 +52,9 @@ type DB struct {
|
||||||
byVersion *btree.BTree
|
byVersion *btree.BTree
|
||||||
deletedIDs []string
|
deletedIDs []string
|
||||||
onchange *signal.Signal
|
onchange *signal.Signal
|
||||||
|
|
||||||
|
closeOnce sync.Once
|
||||||
|
closed chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDB creates a new in-memory database for the given record type.
|
// NewDB creates a new in-memory database for the given record type.
|
||||||
|
@ -62,6 +65,7 @@ func NewDB(recordType string, btreeDegree int) *DB {
|
||||||
byID: btree.New(btreeDegree),
|
byID: btree.New(btreeDegree),
|
||||||
byVersion: btree.New(btreeDegree),
|
byVersion: btree.New(btreeDegree),
|
||||||
onchange: s,
|
onchange: s,
|
||||||
|
closed: make(chan struct{}),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -84,6 +88,14 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
|
||||||
db.deletedIDs = remaining
|
db.deletedIDs = remaining
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the database. Any watchers will be closed.
|
||||||
|
func (db *DB) Close() error {
|
||||||
|
db.closeOnce.Do(func() {
|
||||||
|
close(db.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Delete marks a record as deleted.
|
// Delete marks a record as deleted.
|
||||||
func (db *DB) Delete(_ context.Context, id string) error {
|
func (db *DB) Delete(_ context.Context, id string) error {
|
||||||
defer db.onchange.Broadcast()
|
defer db.onchange.Broadcast()
|
||||||
|
@ -140,7 +152,10 @@ func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
|
||||||
func (db *DB) Watch(ctx context.Context) <-chan struct{} {
|
func (db *DB) Watch(ctx context.Context) <-chan struct{} {
|
||||||
ch := db.onchange.Bind()
|
ch := db.onchange.Bind()
|
||||||
go func() {
|
go func() {
|
||||||
<-ctx.Done()
|
select {
|
||||||
|
case <-db.closed:
|
||||||
|
case <-ctx.Done():
|
||||||
|
}
|
||||||
close(ch)
|
close(ch)
|
||||||
db.onchange.Unbind(ch)
|
db.onchange.Unbind(ch)
|
||||||
}()
|
}()
|
||||||
|
|
|
@ -40,6 +40,9 @@ type DB struct {
|
||||||
deletedSet string
|
deletedSet string
|
||||||
tlsConfig *tls.Config
|
tlsConfig *tls.Config
|
||||||
notifyChMu sync.Mutex
|
notifyChMu sync.Mutex
|
||||||
|
|
||||||
|
closeOnce sync.Once
|
||||||
|
closed chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
// New returns new DB instance.
|
// New returns new DB instance.
|
||||||
|
@ -50,6 +53,7 @@ func New(rawURL, recordType string, deletePermanentAfter int64, opts ...Option)
|
||||||
versionSet: recordType + "_version_set",
|
versionSet: recordType + "_version_set",
|
||||||
deletedSet: recordType + "_deleted_set",
|
deletedSet: recordType + "_deleted_set",
|
||||||
lastVersionKey: recordType + "_last_version",
|
lastVersionKey: recordType + "_last_version",
|
||||||
|
closed: make(chan struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, o := range opts {
|
for _, o := range opts {
|
||||||
|
@ -79,6 +83,14 @@ func New(rawURL, recordType string, deletePermanentAfter int64, opts ...Option)
|
||||||
return db, nil
|
return db, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the redis db connection.
|
||||||
|
func (db *DB) Close() error {
|
||||||
|
db.closeOnce.Do(func() {
|
||||||
|
close(db.closed)
|
||||||
|
})
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Put sets new record for given id with input data.
|
// Put sets new record for given id with input data.
|
||||||
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) (err error) {
|
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) (err error) {
|
||||||
c := db.pool.Get()
|
c := db.pool.Get()
|
||||||
|
@ -259,7 +271,10 @@ func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}) {
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
|
case <-db.closed:
|
||||||
|
return
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
switch v := psc.Receive().(type) {
|
switch v := psc.Receive().(type) {
|
||||||
|
@ -271,12 +286,17 @@ func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}) {
|
||||||
}
|
}
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
case <-db.closed:
|
||||||
|
return
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
log.Warn().Err(ctx.Err()).Msg("context done, stop receive from redis channel")
|
log.Warn().Err(ctx.Err()).Msg("context done, stop receive from redis channel")
|
||||||
return
|
return
|
||||||
default:
|
default:
|
||||||
db.notifyChMu.Lock()
|
db.notifyChMu.Lock()
|
||||||
select {
|
select {
|
||||||
|
case <-db.closed:
|
||||||
|
db.notifyChMu.Unlock()
|
||||||
|
return
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
db.notifyChMu.Unlock()
|
db.notifyChMu.Unlock()
|
||||||
log.Warn().Err(ctx.Err()).Msg("context done while holding notify lock, stop receive from redis channel")
|
log.Warn().Err(ctx.Err()).Msg("context done while holding notify lock, stop receive from redis channel")
|
||||||
|
@ -313,6 +333,7 @@ func (db *DB) watch(ctx context.Context, ch chan struct{}) {
|
||||||
db.doNotifyLoop(ctx, ch)
|
db.doNotifyLoop(ctx, ch)
|
||||||
}()
|
}()
|
||||||
select {
|
select {
|
||||||
|
case <-db.closed:
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
case <-done:
|
case <-done:
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,9 @@ import (
|
||||||
|
|
||||||
// Backend is the interface required for a storage backend.
|
// Backend is the interface required for a storage backend.
|
||||||
type Backend interface {
|
type Backend interface {
|
||||||
|
// Close closes the backend.
|
||||||
|
Close() error
|
||||||
|
|
||||||
// Put is used to insert or update a record.
|
// Put is used to insert or update a record.
|
||||||
Put(ctx context.Context, id string, data *anypb.Any) error
|
Put(ctx context.Context, id string, data *anypb.Any) error
|
||||||
|
|
||||||
|
|
|
@ -4,8 +4,9 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
||||||
"google.golang.org/protobuf/types/known/anypb"
|
"google.golang.org/protobuf/types/known/anypb"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||||
)
|
)
|
||||||
|
|
||||||
type mockBackend struct {
|
type mockBackend struct {
|
||||||
|
@ -18,6 +19,10 @@ type mockBackend struct {
|
||||||
watch func(ctx context.Context) <-chan struct{}
|
watch func(ctx context.Context) <-chan struct{}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *mockBackend) Close() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *mockBackend) Put(ctx context.Context, id string, data *anypb.Any) error {
|
func (m *mockBackend) Put(ctx context.Context, id string, data *anypb.Any) error {
|
||||||
return m.put(ctx, id, data)
|
return m.put(ctx, id, data)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue