cache: support databroker option changes (#1294)

This commit is contained in:
Caleb Doxsey 2020-08-18 07:27:20 -06:00 committed by GitHub
parent 31205c0c29
commit a1378c81f8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 408 additions and 179 deletions

95
cache/cache.go vendored
View file

@ -15,6 +15,7 @@ import (
"github.com/pomerium/pomerium/internal/directory"
"github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/identity/manager"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
@ -35,18 +36,7 @@ type Cache struct {
}
// New creates a new cache service.
func New(opts config.Options) (*Cache, error) {
if err := validate(opts); err != nil {
return nil, fmt.Errorf("cache: bad option: %w", err)
}
authenticator, err := identity.NewAuthenticator(opts.GetOauthOptions())
if err != nil {
return nil, fmt.Errorf("cache: failed to create authenticator: %w", err)
}
directoryProvider := directory.GetProvider(&opts)
func New(cfg *config.Config) (*Cache, error) {
localListener, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return nil, err
@ -56,7 +46,7 @@ func New(opts config.Options) (*Cache, error) {
// if we no longer register with that grpc Server
localGRPCServer := grpc.NewServer()
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.Services)
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services)
clientDialOptions := clientStatsHandler.DialOptions(grpc.WithInsecure())
localGRPCConnection, err := grpc.DialContext(
@ -68,30 +58,33 @@ func New(opts config.Options) (*Cache, error) {
return nil, err
}
dataBrokerServer, err := NewDataBrokerServer(localGRPCServer, opts)
if err != nil {
return nil, err
}
dataBrokerClient := databroker.NewDataBrokerServiceClient(localGRPCConnection)
dataBrokerServer := NewDataBrokerServer(localGRPCServer, cfg)
manager := manager.New(
authenticator,
directoryProvider,
dataBrokerClient,
manager.WithGroupRefreshInterval(opts.RefreshDirectoryInterval),
manager.WithGroupRefreshTimeout(opts.RefreshDirectoryTimeout),
)
return &Cache{
c := &Cache{
dataBrokerServer: dataBrokerServer,
manager: manager,
localListener: localListener,
localGRPCServer: localGRPCServer,
localGRPCConnection: localGRPCConnection,
deprecatedCacheClusterDomain: opts.GetDataBrokerURL().Hostname(),
dataBrokerStorageType: opts.DataBrokerStorageType,
}, nil
deprecatedCacheClusterDomain: cfg.Options.GetDataBrokerURL().Hostname(),
dataBrokerStorageType: cfg.Options.DataBrokerStorageType,
}
err = c.update(cfg)
if err != nil {
return nil, err
}
return c, nil
}
// OnConfigChange is called whenever configuration is changed.
func (c *Cache) OnConfigChange(cfg *config.Config) {
err := c.update(cfg)
if err != nil {
log.Error().Err(err).Msg("cache: error updating configuration")
}
c.dataBrokerServer.OnConfigChange(cfg)
}
// Register registers all the gRPC services with the given server.
@ -121,9 +114,45 @@ func (c *Cache) Run(ctx context.Context) error {
return t.Wait()
}
func (c *Cache) update(cfg *config.Config) error {
if err := validate(cfg.Options); err != nil {
return fmt.Errorf("cache: bad option: %w", err)
}
authenticator, err := identity.NewAuthenticator(cfg.Options.GetOauthOptions())
if err != nil {
return fmt.Errorf("cache: failed to create authenticator: %w", err)
}
directoryProvider := directory.GetProvider(directory.Options{
ServiceAccount: cfg.Options.ServiceAccount,
Provider: cfg.Options.Provider,
ProviderURL: cfg.Options.ProviderURL,
QPS: cfg.Options.QPS,
})
dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection)
options := []manager.Option{
manager.WithAuthenticator(authenticator),
manager.WithDirectoryProvider(directoryProvider),
manager.WithDataBrokerClient(dataBrokerClient),
manager.WithGroupRefreshInterval(cfg.Options.RefreshDirectoryInterval),
manager.WithGroupRefreshTimeout(cfg.Options.RefreshDirectoryTimeout),
}
if c.manager == nil {
c.manager = manager.New(options...)
} else {
c.manager.UpdateConfig(options...)
}
return nil
}
// validate checks that proper configuration settings are set to create
// a cache instance
func validate(o config.Options) error {
func validate(o *config.Options) error {
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
return fmt.Errorf("invalid 'SHARED_SECRET': %w", err)
}

2
cache/cache_test.go vendored
View file

@ -30,7 +30,7 @@ func TestNew(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.opts.Provider = "google"
_, err := New(tt.opts)
_, err := New(&config.Config{Options: &tt.opts})
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return

65
cache/databroker.go vendored
View file

@ -1,55 +1,38 @@
package cache
import (
"crypto/tls"
"crypto/x509"
"encoding/base64"
"fmt"
"io/ioutil"
"google.golang.org/grpc"
"github.com/pomerium/pomerium/config"
internal_databroker "github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/internal/databroker"
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
)
// A DataBrokerServer implements the data broker service interface.
type DataBrokerServer struct {
databroker.DataBrokerServiceServer
*databroker.Server
}
// NewDataBrokerServer creates a new databroker service server.
func NewDataBrokerServer(grpcServer *grpc.Server, opts config.Options) (*DataBrokerServer, error) {
key, err := base64.StdEncoding.DecodeString(opts.SharedKey)
if err != nil || len(key) != cryptutil.DefaultKeySize {
return nil, fmt.Errorf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
}
caCertPool := x509.NewCertPool()
if caCert, err := ioutil.ReadFile(opts.DataBrokerStorageCAFile); err == nil {
caCertPool.AppendCertsFromPEM(caCert)
} else {
log.Warn().Err(err).Msg("failed to read databroker CA file")
}
tlsConfig := &tls.Config{
RootCAs: caCertPool,
// nolint: gosec
InsecureSkipVerify: opts.DataBrokerStorageCertSkipVerify,
}
if opts.DataBrokerCertificate != nil {
tlsConfig.Certificates = []tls.Certificate{*opts.DataBrokerCertificate}
}
internalSrv := internal_databroker.New(
internal_databroker.WithSecret(key),
internal_databroker.WithStorageType(opts.DataBrokerStorageType),
internal_databroker.WithStorageConnectionString(opts.DataBrokerStorageConnectionString),
internal_databroker.WithStorageTLSConfig(tlsConfig),
)
srv := &DataBrokerServer{DataBrokerServiceServer: internalSrv}
databroker.RegisterDataBrokerServiceServer(grpcServer, srv)
return srv, nil
func NewDataBrokerServer(grpcServer *grpc.Server, cfg *config.Config) *DataBrokerServer {
srv := &DataBrokerServer{}
srv.Server = databroker.New(srv.getOptions(cfg)...)
databrokerpb.RegisterDataBrokerServiceServer(grpcServer, srv)
return srv
}
// OnConfigChange updates the underlying databroker server whenever configuration is changed.
func (srv *DataBrokerServer) OnConfigChange(cfg *config.Config) {
srv.UpdateConfig(srv.getOptions(cfg)...)
}
func (srv *DataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption {
return []databroker.ServerOption{
databroker.WithSharedKey(cfg.Options.SharedKey),
databroker.WithStorageType(cfg.Options.DataBrokerStorageType),
databroker.WithStorageConnectionString(cfg.Options.DataBrokerStorageConnectionString),
databroker.WithStorageCAFile(cfg.Options.DataBrokerStorageCAFile),
databroker.WithStorageCertificate(cfg.Options.DataBrokerCertificate),
databroker.WithStorageCertSkipVerify(cfg.Options.DataBrokerStorageCertSkipVerify),
}
}

View file

@ -27,7 +27,7 @@ func init() {
lis = bufconn.Listen(bufSize)
s := grpc.NewServer()
internalSrv := internal_databroker.New()
srv := &DataBrokerServer{DataBrokerServiceServer: internalSrv}
srv := &DataBrokerServer{Server: internalSrv}
databroker.RegisterDataBrokerServiceServer(s, srv)
go func() {

View file

@ -15,10 +15,12 @@ import (
)
func TestCache_runMemberList(t *testing.T) {
c, err := New(config.Options{
c, err := New(&config.Config{
Options: &config.Options{
SharedKey: cryptutil.NewBase64Key(),
DataBrokerURL: &url.URL{Scheme: "http", Host: "member1"},
Provider: "google",
},
})
require.NoError(t, err)

View file

@ -88,7 +88,7 @@ func Run(ctx context.Context, configFile string) error {
}
var cacheServer *cache.Cache
if config.IsCache(cfg.Options.Services) {
cacheServer, err = setupCache(cfg.Options, controlPlane)
cacheServer, err = setupCache(src, cfg, controlPlane)
if err != nil {
return err
}
@ -162,13 +162,15 @@ func setupAuthorize(src config.Source, cfg *config.Config, controlPlane *control
return svc, nil
}
func setupCache(opt *config.Options, controlPlane *controlplane.Server) (*cache.Cache, error) {
svc, err := cache.New(*opt)
func setupCache(src config.Source, cfg *config.Config, controlPlane *controlplane.Server) (*cache.Cache, error) {
svc, err := cache.New(cfg)
if err != nil {
return nil, fmt.Errorf("error creating config service: %w", err)
}
svc.Register(controlPlane.GRPCServer)
log.Info().Msg("enabled cache service")
src.OnConfigChange(svc.OnConfigChange)
svc.OnConfigChange(cfg)
return svc, nil
}

View file

@ -2,7 +2,11 @@ package databroker
import (
"crypto/tls"
"encoding/base64"
"time"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
var (
@ -21,7 +25,9 @@ type serverConfig struct {
secret []byte
storageType string
storageConnectionString string
storageTLSConfig *tls.Config
storageCAFile string
storageCertSkipVerify bool
storageCertificate *tls.Certificate
}
func newServerConfig(options ...ServerOption) *serverConfig {
@ -54,10 +60,15 @@ func WithDeletePermanentlyAfter(dur time.Duration) ServerOption {
}
}
// WithSecret sets the secret in the config.
func WithSecret(secret []byte) ServerOption {
// WithSharedKey sets the secret in the config.
func WithSharedKey(sharedKey string) ServerOption {
return func(cfg *serverConfig) {
cfg.secret = secret
key, err := base64.StdEncoding.DecodeString(sharedKey)
if err != nil || len(key) != cryptutil.DefaultKeySize {
log.Error().Err(err).Msgf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
return
}
cfg.secret = key
}
}
@ -75,9 +86,23 @@ func WithStorageConnectionString(connStr string) ServerOption {
}
}
// WithStorageTLSConfig sets the tls config for connection to storage.
func WithStorageTLSConfig(tlsConfig *tls.Config) ServerOption {
// WithStorageCAFile sets the CA file in the config.
func WithStorageCAFile(filePath string) ServerOption {
return func(cfg *serverConfig) {
cfg.storageTLSConfig = tlsConfig
cfg.storageCAFile = filePath
}
}
// WithStorageCertSkipVerify sets the storageCertSkipVerify in the config.
func WithStorageCertSkipVerify(storageCertSkipVerify bool) ServerOption {
return func(cfg *serverConfig) {
cfg.storageCertSkipVerify = storageCertSkipVerify
}
}
// WithStorageCertificate sets the storageCertificate in the config.
func WithStorageCertificate(certificate *tls.Certificate) ServerOption {
return func(cfg *serverConfig) {
cfg.storageCertificate = certificate
}
}

View file

@ -3,7 +3,10 @@ package databroker
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"io/ioutil"
"reflect"
"sort"
"sync"
@ -11,6 +14,7 @@ import (
"github.com/golang/protobuf/ptypes"
"github.com/golang/protobuf/ptypes/empty"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
"github.com/rs/zerolog"
"google.golang.org/grpc/codes"
@ -33,35 +37,39 @@ const (
syncBatchSize = 100
)
// newUUID returns a new UUID. This make it easy to stub out in tests.
var newUUID = uuid.New
// Server implements the databroker service using an in memory database.
type Server struct {
version string
cfg *serverConfig
log zerolog.Logger
mu sync.RWMutex
version string
byType map[string]storage.Backend
onTypechange *signal.Signal
}
// New creates a new server.
func New(options ...ServerOption) *Server {
cfg := newServerConfig(options...)
srv := &Server{
version: uuid.New().String(),
cfg: cfg,
log: log.With().Str("service", "databroker").Logger(),
byType: make(map[string]storage.Backend),
onTypechange: signal.New(),
}
srv.initVersion()
srv.UpdateConfig(options...)
go func() {
ticker := time.NewTicker(cfg.deletePermanentlyAfter / 2)
ticker := time.NewTicker(time.Minute)
defer ticker.Stop()
for range ticker.C {
srv.mu.RLock()
tm := time.Now().Add(-srv.cfg.deletePermanentlyAfter)
srv.mu.RUnlock()
var recordTypes []string
srv.mu.RLock()
for recordType := range srv.byType {
@ -70,11 +78,11 @@ func New(options ...ServerOption) *Server {
srv.mu.RUnlock()
for _, recordType := range recordTypes {
db, err := srv.getDB(recordType)
db, _, err := srv.getDB(recordType, true)
if err != nil {
continue
}
db.ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
db.ClearDeleted(context.Background(), tm)
}
}
}()
@ -82,7 +90,7 @@ func New(options ...ServerOption) *Server {
}
func (srv *Server) initVersion() {
dbServerVersion, err := srv.getDB(recordTypeServerVersion)
dbServerVersion, _, err := srv.getDB(recordTypeServerVersion, false)
if err != nil {
log.Error().Err(err).Msg("failed to init server version")
return
@ -98,12 +106,36 @@ func (srv *Server) initVersion() {
return
}
srv.version = newUUID().String()
data, _ := ptypes.MarshalAny(&databroker.ServerVersion{Version: srv.version})
if err := dbServerVersion.Put(context.Background(), serverVersionKey, data); err != nil {
srv.log.Warn().Err(err).Msg("failed to save server version.")
}
}
// UpdateConfig updates the server with the new options.
func (srv *Server) UpdateConfig(options ...ServerOption) {
srv.mu.Lock()
defer srv.mu.Unlock()
cfg := newServerConfig(options...)
if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) {
log.Debug().Msg("databroker: no changes detected, re-using existing DBs")
return
}
srv.cfg = cfg
for t, db := range srv.byType {
err := db.Close()
if err != nil {
log.Warn().Err(err).Msg("databroker: error closing backend")
}
delete(srv.byType, t)
}
srv.initVersion()
}
// Delete deletes a record from the in-memory list.
func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*empty.Empty, error) {
_, span := trace.StartSpan(ctx, "databroker.grpc.Delete")
@ -113,7 +145,7 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*
Str("id", req.GetId()).
Msg("delete")
db, err := srv.getDB(req.GetType())
db, _, err := srv.getDB(req.GetType(), true)
if err != nil {
return nil, err
}
@ -134,7 +166,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
Str("id", req.GetId()).
Msg("get")
db, err := srv.getDB(req.GetType())
db, _, err := srv.getDB(req.GetType(), true)
if err != nil {
return nil, err
}
@ -156,7 +188,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
Str("type", req.GetType()).
Msg("get all")
db, err := srv.getDB(req.GetType())
db, version, err := srv.getDB(req.GetType(), true)
if err != nil {
return nil, err
}
@ -167,7 +199,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
}
if len(all) == 0 {
return &databroker.GetAllResponse{ServerVersion: srv.version}, nil
return &databroker.GetAllResponse{ServerVersion: version}, nil
}
var recordVersion string
@ -182,7 +214,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
}
return &databroker.GetAllResponse{
ServerVersion: srv.version,
ServerVersion: version,
RecordVersion: recordVersion,
Records: records,
}, nil
@ -197,7 +229,7 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
Str("id", req.GetId()).
Msg("set")
db, err := srv.getDB(req.GetType())
db, version, err := srv.getDB(req.GetType(), true)
if err != nil {
return nil, err
}
@ -210,11 +242,13 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
}
return &databroker.SetResponse{
Record: record,
ServerVersion: srv.version,
ServerVersion: version,
}, nil
}
func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage.Backend, stream databroker.DataBrokerService_SyncServer) error {
func (srv *Server) doSync(ctx context.Context,
serverVersion string, recordVersion *string,
db storage.Backend, stream databroker.DataBrokerService_SyncServer) error {
updated, err := db.List(ctx, *recordVersion)
if err != nil {
return err
@ -232,7 +266,7 @@ func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage
j = len(updated)
}
if err := stream.Send(&databroker.SyncResponse{
ServerVersion: srv.version,
ServerVersion: serverVersion,
Records: updated[i:j],
}); err != nil {
return err
@ -251,34 +285,34 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
Str("record_version", req.GetRecordVersion()).
Msg("sync")
db, serverVersion, err := srv.getDB(req.GetType(), true)
if err != nil {
return err
}
recordVersion := req.GetRecordVersion()
// reset record version if the server versions don't match
if req.GetServerVersion() != srv.version {
if req.GetServerVersion() != serverVersion {
recordVersion = ""
// send the new server version to the client
err := stream.Send(&databroker.SyncResponse{
ServerVersion: srv.version,
ServerVersion: serverVersion,
})
if err != nil {
return err
}
}
db, err := srv.getDB(req.GetType())
if err != nil {
return err
}
ctx := stream.Context()
ch := db.Watch(ctx)
// Do first sync, so we won't missed anything.
if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil {
if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil {
return err
}
for range ch {
if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil {
if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil {
return err
}
}
@ -335,29 +369,56 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer
}
}
func (srv *Server) getDB(recordType string) (storage.Backend, error) {
func (srv *Server) getDB(recordType string, lock bool) (db storage.Backend, version string, err error) {
// double-checked locking:
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
if lock {
srv.mu.RLock()
db := srv.byType[recordType]
srv.mu.RUnlock()
if db == nil {
srv.mu.Lock()
}
db = srv.byType[recordType]
version = srv.version
if lock {
srv.mu.RUnlock()
}
if db == nil {
if lock {
srv.mu.Lock()
}
db = srv.byType[recordType]
version = srv.version
var err error
if db == nil {
db, err = srv.newDB(recordType)
srv.byType[recordType] = db
}
if lock {
srv.mu.Unlock()
}
if err != nil {
return nil, err
return nil, "", err
}
}
return db, nil
return db, version, nil
}
func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
caCertPool := x509.NewCertPool()
if srv.cfg.storageCAFile != "" {
if caCert, err := ioutil.ReadFile(srv.cfg.storageCAFile); err == nil {
caCertPool.AppendCertsFromPEM(caCert)
} else {
log.Warn().Err(err).Msg("failed to read databroker CA file")
}
}
tlsConfig := &tls.Config{
RootCAs: caCertPool,
// nolint: gosec
InsecureSkipVerify: srv.cfg.storageCertSkipVerify,
}
if srv.cfg.storageCertificate != nil {
tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate}
}
switch srv.cfg.storageType {
case config.StorageInMemoryName:
return inmemory.NewDB(recordType, srv.cfg.btreeDegree), nil
@ -366,7 +427,7 @@ func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
srv.cfg.storageConnectionString,
recordType,
int64(srv.cfg.deletePermanentlyAfter.Seconds()),
redis.WithTLSConfig(srv.cfg.storageTLSConfig),
redis.WithTLSConfig(tlsConfig),
)
if err != nil {
return nil, fmt.Errorf("failed to create new redis storage: %w", err)

View file

@ -33,36 +33,46 @@ func newServer(cfg *serverConfig) *Server {
func TestServer_initVersion(t *testing.T) {
cfg := newServerConfig()
t.Run("nil db", func(t *testing.T) {
srvVersion := uuid.New()
oldNewUUID := newUUID
newUUID = func() uuid.UUID {
return srvVersion
}
defer func() { newUUID = oldNewUUID }()
srv := newServer(cfg)
srvVersion := uuid.New().String()
srv.version = srvVersion
srv.byType[recordTypeServerVersion] = nil
srv.initVersion()
assert.Equal(t, srvVersion, srv.version)
assert.Equal(t, srvVersion.String(), srv.version)
})
t.Run("new server with random version", func(t *testing.T) {
srvVersion := uuid.New()
oldNewUUID := newUUID
newUUID = func() uuid.UUID {
return srvVersion
}
defer func() { newUUID = oldNewUUID }()
srv := newServer(cfg)
ctx := context.Background()
db, err := srv.getDB(recordTypeServerVersion)
db, _, err := srv.getDB(recordTypeServerVersion, false)
require.NoError(t, err)
r, err := db.Get(ctx, serverVersionKey)
assert.Error(t, err)
assert.Nil(t, r)
srvVersion := uuid.New().String()
srv.version = srvVersion
srv.initVersion()
assert.Equal(t, srvVersion, srv.version)
assert.Equal(t, srvVersion.String(), srv.version)
r, err = db.Get(ctx, serverVersionKey)
require.NoError(t, err)
assert.NotNil(t, r)
var sv databroker.ServerVersion
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
assert.Equal(t, srvVersion, sv.Version)
assert.Equal(t, srvVersion.String(), sv.Version)
})
t.Run("init version twice should get the same version", func(t *testing.T) {
srv := newServer(cfg)
ctx := context.Background()
db, err := srv.getDB(recordTypeServerVersion)
db, _, err := srv.getDB(recordTypeServerVersion, false)
require.NoError(t, err)
r, err := db.Get(ctx, serverVersionKey)
assert.Error(t, err)

View file

@ -4,8 +4,10 @@ package directory
import (
"context"
"net/url"
"sync"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/directory/azure"
"github.com/pomerium/pomerium/internal/directory/github"
"github.com/pomerium/pomerium/internal/directory/gitlab"
@ -27,8 +29,34 @@ type Provider interface {
UserGroups(ctx context.Context) ([]*Group, []*User, error)
}
// Options are the options specific to the provider.
type Options struct {
ServiceAccount string
Provider string
ProviderURL string
QPS float64
}
var globalProvider = struct {
sync.Mutex
provider Provider
options Options
}{}
// GetProvider gets the provider for the given options.
func GetProvider(options *config.Options) Provider {
func GetProvider(options Options) (provider Provider) {
globalProvider.Lock()
defer globalProvider.Unlock()
if globalProvider.provider != nil && cmp.Equal(globalProvider.options, options) {
log.Debug().Str("provider", options.Provider).Msg("directory: no change detected, reusing existing directory provider")
return globalProvider.provider
}
defer func() {
globalProvider.provider = provider
globalProvider.options = options
}()
switch options.Provider {
case azure.Name:
serviceAccount, err := azure.ParseServiceAccount(options.ServiceAccount)

View file

@ -1,6 +1,12 @@
package manager
import "time"
import (
"sync/atomic"
"time"
"github.com/pomerium/pomerium/internal/directory"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
var (
defaultGroupRefreshInterval = 10 * time.Minute
@ -10,6 +16,9 @@ var (
)
type config struct {
authenticator Authenticator
directory directory.Provider
dataBrokerClient databroker.DataBrokerServiceClient
groupRefreshInterval time.Duration
groupRefreshTimeout time.Duration
sessionRefreshGracePeriod time.Duration
@ -31,6 +40,27 @@ func newConfig(options ...Option) *config {
// An Option customizes the configuration used for the identity manager.
type Option func(*config)
// WithAuthenticator sets the authenticator in the config.
func WithAuthenticator(authenticator Authenticator) Option {
return func(cfg *config) {
cfg.authenticator = authenticator
}
}
// WithDirectoryProvider sets the directory provider in the config.
func WithDirectoryProvider(directoryProvider directory.Provider) Option {
return func(cfg *config) {
cfg.directory = directoryProvider
}
}
// WithDataBrokerClient sets the databroker client in the config.
func WithDataBrokerClient(dataBrokerClient databroker.DataBrokerServiceClient) Option {
return func(cfg *config) {
cfg.dataBrokerClient = dataBrokerClient
}
}
// WithGroupRefreshInterval sets the group refresh interval used by the manager.
func WithGroupRefreshInterval(interval time.Duration) Option {
return func(cfg *config) {
@ -58,3 +88,21 @@ func WithSessionRefreshCoolOffDuration(dur time.Duration) Option {
cfg.sessionRefreshCoolOffDuration = dur
}
}
type atomicConfig struct {
value atomic.Value
}
func newAtomicConfig(cfg *config) *atomicConfig {
ac := new(atomicConfig)
ac.Store(cfg)
return ac
}
func (ac *atomicConfig) Load() *config {
return ac.value.Load().(*config)
}
func (ac *atomicConfig) Store(cfg *config) {
ac.value.Store(cfg)
}

View file

@ -42,10 +42,7 @@ type (
// A Manager refreshes identity information using session and user data.
type Manager struct {
cfg *config
authenticator Authenticator
directory directory.Provider
dataBrokerClient databroker.DataBrokerServiceClient
cfg *atomicConfig
log zerolog.Logger
sessions sessionCollection
@ -67,16 +64,10 @@ type Manager struct {
// New creates a new identity manager.
func New(
authenticator Authenticator,
directoryProvider directory.Provider,
dataBrokerClient databroker.DataBrokerServiceClient,
options ...Option,
) *Manager {
mgr := &Manager{
cfg: newConfig(options...),
authenticator: authenticator,
directory: directoryProvider,
dataBrokerClient: dataBrokerClient,
cfg: newAtomicConfig(newConfig()),
log: log.With().Str("service", "identity_manager").Logger(),
sessions: sessionCollection{
@ -88,9 +79,15 @@ func New(
},
userScheduler: scheduler.New(),
}
mgr.UpdateConfig(options...)
return mgr
}
// UpdateConfig updates the manager with the new options.
func (mgr *Manager) UpdateConfig(options ...Option) {
mgr.cfg.Store(newConfig(options...))
}
// Run runs the manager. This method blocks until an error occurs or the given context is canceled.
func (mgr *Manager) Run(ctx context.Context) error {
err := mgr.initDirectoryGroups(ctx)
@ -169,7 +166,7 @@ func (mgr *Manager) refreshLoop(
// refresh groups
if mgr.directoryNextRefresh.Before(now) {
mgr.refreshDirectoryUserGroups(ctx)
mgr.directoryNextRefresh = now.Add(mgr.cfg.groupRefreshInterval)
mgr.directoryNextRefresh = now.Add(mgr.cfg.Load().groupRefreshInterval)
if mgr.directoryNextRefresh.Before(nextTime) {
nextTime = mgr.directoryNextRefresh
}
@ -211,10 +208,10 @@ func (mgr *Manager) refreshLoop(
func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) {
mgr.log.Info().Msg("refreshing directory users")
ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.groupRefreshTimeout)
ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.Load().groupRefreshTimeout)
defer clearTimeout()
directoryGroups, directoryUsers, err := mgr.directory.UserGroups(ctx)
directoryGroups, directoryUsers, err := mgr.cfg.Load().directory.UserGroups(ctx)
if err != nil {
mgr.log.Warn().Err(err).Msg("failed to refresh directory users and groups")
return
@ -238,7 +235,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
return
}
_, err = mgr.dataBrokerClient.Set(ctx, &databroker.SetRequest{
_, err = mgr.cfg.Load().dataBrokerClient.Set(ctx, &databroker.SetRequest{
Type: any.GetTypeUrl(),
Id: newDG.GetId(),
Data: any,
@ -258,7 +255,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
return
}
_, err = mgr.dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
_, err = mgr.cfg.Load().dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
Type: any.GetTypeUrl(),
Id: curDG.GetId(),
})
@ -284,7 +281,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
return
}
_, err = mgr.dataBrokerClient.Set(ctx, &databroker.SetRequest{
_, err = mgr.cfg.Load().dataBrokerClient.Set(ctx, &databroker.SetRequest{
Type: any.GetTypeUrl(),
Id: newDU.GetId(),
Data: any,
@ -304,7 +301,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
return
}
_, err = mgr.dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
_, err = mgr.cfg.Load().dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
Type: any.GetTypeUrl(),
Id: curDU.GetId(),
})
@ -349,7 +346,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
return
}
newToken, err := mgr.authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s)
newToken, err := mgr.cfg.Load().authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s)
if isTemporaryError(err) {
mgr.log.Error().Err(err).
Str("user_id", s.GetUserId()).
@ -366,7 +363,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
}
s.OauthToken = ToOAuthToken(newToken)
res, err := session.Set(ctx, mgr.dataBrokerClient, s.Session)
res, err := session.Set(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
if err != nil {
mgr.log.Error().Err(err).
Str("user_id", s.GetUserId()).
@ -401,7 +398,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
continue
}
err := mgr.authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u)
err := mgr.cfg.Load().authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u)
if isTemporaryError(err) {
mgr.log.Error().Err(err).
Str("user_id", s.GetUserId()).
@ -417,7 +414,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
continue
}
record, err := user.Set(ctx, mgr.dataBrokerClient, u.User)
record, err := user.Set(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
if err != nil {
mgr.log.Error().Err(err).
Str("user_id", s.GetUserId()).
@ -438,7 +435,7 @@ func (mgr *Manager) syncSessions(ctx context.Context, ch chan<- sessionMessage)
return err
}
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
Type: any.GetTypeUrl(),
})
if err != nil {
@ -474,7 +471,7 @@ func (mgr *Manager) syncUsers(ctx context.Context, ch chan<- userMessage) error
return err
}
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
Type: any.GetTypeUrl(),
})
if err != nil {
@ -510,7 +507,7 @@ func (mgr *Manager) initDirectoryUsers(ctx context.Context) error {
return err
}
res, err := mgr.dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
res, err := mgr.cfg.Load().dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
Type: any.GetTypeUrl(),
})
if err != nil {
@ -541,7 +538,7 @@ func (mgr *Manager) syncDirectoryUsers(ctx context.Context, ch chan<- *directory
return err
}
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
Type: any.GetTypeUrl(),
ServerVersion: mgr.directoryUsersServerVersion,
RecordVersion: mgr.directoryUsersRecordVersion,
@ -579,7 +576,7 @@ func (mgr *Manager) initDirectoryGroups(ctx context.Context) error {
return err
}
res, err := mgr.dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
res, err := mgr.cfg.Load().dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
Type: any.GetTypeUrl(),
})
if err != nil {
@ -610,7 +607,7 @@ func (mgr *Manager) syncDirectoryGroups(ctx context.Context, ch chan<- *director
return err
}
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
Type: any.GetTypeUrl(),
ServerVersion: mgr.directoryGroupsServerVersion,
RecordVersion: mgr.directoryGroupsRecordVersion,
@ -652,8 +649,8 @@ func (mgr *Manager) onUpdateSession(ctx context.Context, msg sessionMessage) {
// update session
s, _ := mgr.sessions.Get(msg.session.GetUserId(), msg.session.GetId())
s.lastRefresh = time.Now()
s.gracePeriod = mgr.cfg.sessionRefreshGracePeriod
s.coolOffDuration = mgr.cfg.sessionRefreshCoolOffDuration
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
s.Session = msg.session
mgr.sessions.ReplaceOrInsert(s)
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(msg.session.GetUserId(), msg.session.GetId()))
@ -676,7 +673,7 @@ func (mgr *Manager) onUpdateUser(_ context.Context, msg userMessage) {
// only reset the refresh time if this is an existing user
u.lastRefresh = time.Now()
}
u.refreshInterval = mgr.cfg.groupRefreshInterval
u.refreshInterval = mgr.cfg.Load().groupRefreshInterval
u.User = msg.user
mgr.users.ReplaceOrInsert(u)
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
@ -697,7 +694,7 @@ func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session)
},
}
_, err := user.Set(ctx, mgr.dataBrokerClient, u.User)
_, err := user.Set(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
if err != nil {
mgr.log.Error().Err(err).
Str("user_id", pbSession.GetUserId()).
@ -707,7 +704,7 @@ func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session)
}
func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Session) {
err := session.Delete(ctx, mgr.dataBrokerClient, pbSession.GetId())
err := session.Delete(ctx, mgr.cfg.Load().dataBrokerClient, pbSession.GetId())
if err != nil {
mgr.log.Error().Err(err).
Str("session_id", pbSession.GetId()).

View file

@ -52,6 +52,9 @@ type DB struct {
byVersion *btree.BTree
deletedIDs []string
onchange *signal.Signal
closeOnce sync.Once
closed chan struct{}
}
// NewDB creates a new in-memory database for the given record type.
@ -62,6 +65,7 @@ func NewDB(recordType string, btreeDegree int) *DB {
byID: btree.New(btreeDegree),
byVersion: btree.New(btreeDegree),
onchange: s,
closed: make(chan struct{}),
}
}
@ -84,6 +88,14 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
db.deletedIDs = remaining
}
// Close closes the database. Any watchers will be closed.
func (db *DB) Close() error {
db.closeOnce.Do(func() {
close(db.closed)
})
return nil
}
// Delete marks a record as deleted.
func (db *DB) Delete(_ context.Context, id string) error {
defer db.onchange.Broadcast()
@ -140,7 +152,10 @@ func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
func (db *DB) Watch(ctx context.Context) <-chan struct{} {
ch := db.onchange.Bind()
go func() {
<-ctx.Done()
select {
case <-db.closed:
case <-ctx.Done():
}
close(ch)
db.onchange.Unbind(ch)
}()

View file

@ -40,6 +40,9 @@ type DB struct {
deletedSet string
tlsConfig *tls.Config
notifyChMu sync.Mutex
closeOnce sync.Once
closed chan struct{}
}
// New returns new DB instance.
@ -50,6 +53,7 @@ func New(rawURL, recordType string, deletePermanentAfter int64, opts ...Option)
versionSet: recordType + "_version_set",
deletedSet: recordType + "_deleted_set",
lastVersionKey: recordType + "_last_version",
closed: make(chan struct{}),
}
for _, o := range opts {
@ -79,6 +83,14 @@ func New(rawURL, recordType string, deletePermanentAfter int64, opts ...Option)
return db, nil
}
// Close closes the redis db connection.
func (db *DB) Close() error {
db.closeOnce.Do(func() {
close(db.closed)
})
return nil
}
// Put sets new record for given id with input data.
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) (err error) {
c := db.pool.Get()
@ -259,7 +271,10 @@ func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}) {
}
for {
select {
case <-db.closed:
return
case <-ctx.Done():
return
default:
}
switch v := psc.Receive().(type) {
@ -271,12 +286,17 @@ func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}) {
}
select {
case <-db.closed:
return
case <-ctx.Done():
log.Warn().Err(ctx.Err()).Msg("context done, stop receive from redis channel")
return
default:
db.notifyChMu.Lock()
select {
case <-db.closed:
db.notifyChMu.Unlock()
return
case <-ctx.Done():
db.notifyChMu.Unlock()
log.Warn().Err(ctx.Err()).Msg("context done while holding notify lock, stop receive from redis channel")
@ -313,6 +333,7 @@ func (db *DB) watch(ctx context.Context, ch chan struct{}) {
db.doNotifyLoop(ctx, ch)
}()
select {
case <-db.closed:
case <-ctx.Done():
case <-done:
}

View file

@ -12,6 +12,9 @@ import (
// Backend is the interface required for a storage backend.
type Backend interface {
// Close closes the backend.
Close() error
// Put is used to insert or update a record.
Put(ctx context.Context, id string, data *anypb.Any) error

View file

@ -4,8 +4,9 @@ import (
"context"
"time"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"google.golang.org/protobuf/types/known/anypb"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
type mockBackend struct {
@ -18,6 +19,10 @@ type mockBackend struct {
watch func(ctx context.Context) <-chan struct{}
}
func (m *mockBackend) Close() error {
return nil
}
func (m *mockBackend) Put(ctx context.Context, id string, data *anypb.Any) error {
return m.put(ctx, id, data)
}