diff --git a/cache/cache.go b/cache/cache.go index 3bdb96c49..9140fe1f4 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -15,6 +15,7 @@ import ( "github.com/pomerium/pomerium/internal/directory" "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity/manager" + "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" @@ -35,18 +36,7 @@ type Cache struct { } // New creates a new cache service. -func New(opts config.Options) (*Cache, error) { - if err := validate(opts); err != nil { - return nil, fmt.Errorf("cache: bad option: %w", err) - } - - authenticator, err := identity.NewAuthenticator(opts.GetOauthOptions()) - if err != nil { - return nil, fmt.Errorf("cache: failed to create authenticator: %w", err) - } - - directoryProvider := directory.GetProvider(&opts) - +func New(cfg *config.Config) (*Cache, error) { localListener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, err @@ -56,7 +46,7 @@ func New(opts config.Options) (*Cache, error) { // if we no longer register with that grpc Server localGRPCServer := grpc.NewServer() - clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.Services) + clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services) clientDialOptions := clientStatsHandler.DialOptions(grpc.WithInsecure()) localGRPCConnection, err := grpc.DialContext( @@ -68,30 +58,33 @@ func New(opts config.Options) (*Cache, error) { return nil, err } - dataBrokerServer, err := NewDataBrokerServer(localGRPCServer, opts) - if err != nil { - return nil, err - } - dataBrokerClient := databroker.NewDataBrokerServiceClient(localGRPCConnection) - - manager := manager.New( - authenticator, - directoryProvider, - dataBrokerClient, - manager.WithGroupRefreshInterval(opts.RefreshDirectoryInterval), - manager.WithGroupRefreshTimeout(opts.RefreshDirectoryTimeout), - ) - - return &Cache{ - dataBrokerServer: dataBrokerServer, - manager: manager, + dataBrokerServer := NewDataBrokerServer(localGRPCServer, cfg) + c := &Cache{ + dataBrokerServer: dataBrokerServer, localListener: localListener, localGRPCServer: localGRPCServer, localGRPCConnection: localGRPCConnection, - deprecatedCacheClusterDomain: opts.GetDataBrokerURL().Hostname(), - dataBrokerStorageType: opts.DataBrokerStorageType, - }, nil + deprecatedCacheClusterDomain: cfg.Options.GetDataBrokerURL().Hostname(), + dataBrokerStorageType: cfg.Options.DataBrokerStorageType, + } + + err = c.update(cfg) + if err != nil { + return nil, err + } + + return c, nil +} + +// OnConfigChange is called whenever configuration is changed. +func (c *Cache) OnConfigChange(cfg *config.Config) { + err := c.update(cfg) + if err != nil { + log.Error().Err(err).Msg("cache: error updating configuration") + } + + c.dataBrokerServer.OnConfigChange(cfg) } // Register registers all the gRPC services with the given server. @@ -121,9 +114,45 @@ func (c *Cache) Run(ctx context.Context) error { return t.Wait() } +func (c *Cache) update(cfg *config.Config) error { + if err := validate(cfg.Options); err != nil { + return fmt.Errorf("cache: bad option: %w", err) + } + + authenticator, err := identity.NewAuthenticator(cfg.Options.GetOauthOptions()) + if err != nil { + return fmt.Errorf("cache: failed to create authenticator: %w", err) + } + + directoryProvider := directory.GetProvider(directory.Options{ + ServiceAccount: cfg.Options.ServiceAccount, + Provider: cfg.Options.Provider, + ProviderURL: cfg.Options.ProviderURL, + QPS: cfg.Options.QPS, + }) + + dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection) + + options := []manager.Option{ + manager.WithAuthenticator(authenticator), + manager.WithDirectoryProvider(directoryProvider), + manager.WithDataBrokerClient(dataBrokerClient), + manager.WithGroupRefreshInterval(cfg.Options.RefreshDirectoryInterval), + manager.WithGroupRefreshTimeout(cfg.Options.RefreshDirectoryTimeout), + } + + if c.manager == nil { + c.manager = manager.New(options...) + } else { + c.manager.UpdateConfig(options...) + } + + return nil +} + // validate checks that proper configuration settings are set to create // a cache instance -func validate(o config.Options) error { +func validate(o *config.Options) error { if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil { return fmt.Errorf("invalid 'SHARED_SECRET': %w", err) } diff --git a/cache/cache_test.go b/cache/cache_test.go index dccefdd69..cabad127b 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -30,7 +30,7 @@ func TestNew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.opts.Provider = "google" - _, err := New(tt.opts) + _, err := New(&config.Config{Options: &tt.opts}) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/cache/databroker.go b/cache/databroker.go index 19961ea0d..0e5397e8d 100644 --- a/cache/databroker.go +++ b/cache/databroker.go @@ -1,55 +1,38 @@ package cache import ( - "crypto/tls" - "crypto/x509" - "encoding/base64" - "fmt" - "io/ioutil" - "google.golang.org/grpc" "github.com/pomerium/pomerium/config" - internal_databroker "github.com/pomerium/pomerium/internal/databroker" - "github.com/pomerium/pomerium/internal/log" - "github.com/pomerium/pomerium/pkg/cryptutil" - "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/internal/databroker" + databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" ) // A DataBrokerServer implements the data broker service interface. type DataBrokerServer struct { - databroker.DataBrokerServiceServer + *databroker.Server } // NewDataBrokerServer creates a new databroker service server. -func NewDataBrokerServer(grpcServer *grpc.Server, opts config.Options) (*DataBrokerServer, error) { - key, err := base64.StdEncoding.DecodeString(opts.SharedKey) - if err != nil || len(key) != cryptutil.DefaultKeySize { - return nil, fmt.Errorf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize) - } - - caCertPool := x509.NewCertPool() - if caCert, err := ioutil.ReadFile(opts.DataBrokerStorageCAFile); err == nil { - caCertPool.AppendCertsFromPEM(caCert) - } else { - log.Warn().Err(err).Msg("failed to read databroker CA file") - } - tlsConfig := &tls.Config{ - RootCAs: caCertPool, - // nolint: gosec - InsecureSkipVerify: opts.DataBrokerStorageCertSkipVerify, - } - if opts.DataBrokerCertificate != nil { - tlsConfig.Certificates = []tls.Certificate{*opts.DataBrokerCertificate} - } - - internalSrv := internal_databroker.New( - internal_databroker.WithSecret(key), - internal_databroker.WithStorageType(opts.DataBrokerStorageType), - internal_databroker.WithStorageConnectionString(opts.DataBrokerStorageConnectionString), - internal_databroker.WithStorageTLSConfig(tlsConfig), - ) - srv := &DataBrokerServer{DataBrokerServiceServer: internalSrv} - databroker.RegisterDataBrokerServiceServer(grpcServer, srv) - return srv, nil +func NewDataBrokerServer(grpcServer *grpc.Server, cfg *config.Config) *DataBrokerServer { + srv := &DataBrokerServer{} + srv.Server = databroker.New(srv.getOptions(cfg)...) + databrokerpb.RegisterDataBrokerServiceServer(grpcServer, srv) + return srv +} + +// OnConfigChange updates the underlying databroker server whenever configuration is changed. +func (srv *DataBrokerServer) OnConfigChange(cfg *config.Config) { + srv.UpdateConfig(srv.getOptions(cfg)...) +} + +func (srv *DataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption { + return []databroker.ServerOption{ + databroker.WithSharedKey(cfg.Options.SharedKey), + databroker.WithStorageType(cfg.Options.DataBrokerStorageType), + databroker.WithStorageConnectionString(cfg.Options.DataBrokerStorageConnectionString), + databroker.WithStorageCAFile(cfg.Options.DataBrokerStorageCAFile), + databroker.WithStorageCertificate(cfg.Options.DataBrokerCertificate), + databroker.WithStorageCertSkipVerify(cfg.Options.DataBrokerStorageCertSkipVerify), + } } diff --git a/cache/databroker_test.go b/cache/databroker_test.go index 6fef9ad23..d7ad260d1 100644 --- a/cache/databroker_test.go +++ b/cache/databroker_test.go @@ -27,7 +27,7 @@ func init() { lis = bufconn.Listen(bufSize) s := grpc.NewServer() internalSrv := internal_databroker.New() - srv := &DataBrokerServer{DataBrokerServiceServer: internalSrv} + srv := &DataBrokerServer{Server: internalSrv} databroker.RegisterDataBrokerServiceServer(s, srv) go func() { diff --git a/cache/memberlist_test.go b/cache/memberlist_test.go index 7f8791f05..2ff12b5b9 100644 --- a/cache/memberlist_test.go +++ b/cache/memberlist_test.go @@ -15,10 +15,12 @@ import ( ) func TestCache_runMemberList(t *testing.T) { - c, err := New(config.Options{ - SharedKey: cryptutil.NewBase64Key(), - DataBrokerURL: &url.URL{Scheme: "http", Host: "member1"}, - Provider: "google", + c, err := New(&config.Config{ + Options: &config.Options{ + SharedKey: cryptutil.NewBase64Key(), + DataBrokerURL: &url.URL{Scheme: "http", Host: "member1"}, + Provider: "google", + }, }) require.NoError(t, err) diff --git a/internal/cmd/pomerium/pomerium.go b/internal/cmd/pomerium/pomerium.go index 0d7f949b2..77472bc2f 100644 --- a/internal/cmd/pomerium/pomerium.go +++ b/internal/cmd/pomerium/pomerium.go @@ -88,7 +88,7 @@ func Run(ctx context.Context, configFile string) error { } var cacheServer *cache.Cache if config.IsCache(cfg.Options.Services) { - cacheServer, err = setupCache(cfg.Options, controlPlane) + cacheServer, err = setupCache(src, cfg, controlPlane) if err != nil { return err } @@ -162,13 +162,15 @@ func setupAuthorize(src config.Source, cfg *config.Config, controlPlane *control return svc, nil } -func setupCache(opt *config.Options, controlPlane *controlplane.Server) (*cache.Cache, error) { - svc, err := cache.New(*opt) +func setupCache(src config.Source, cfg *config.Config, controlPlane *controlplane.Server) (*cache.Cache, error) { + svc, err := cache.New(cfg) if err != nil { return nil, fmt.Errorf("error creating config service: %w", err) } svc.Register(controlPlane.GRPCServer) log.Info().Msg("enabled cache service") + src.OnConfigChange(svc.OnConfigChange) + svc.OnConfigChange(cfg) return svc, nil } diff --git a/internal/databroker/config.go b/internal/databroker/config.go index 9e911b4ee..cc1e471f2 100644 --- a/internal/databroker/config.go +++ b/internal/databroker/config.go @@ -2,7 +2,11 @@ package databroker import ( "crypto/tls" + "encoding/base64" "time" + + "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/pkg/cryptutil" ) var ( @@ -21,7 +25,9 @@ type serverConfig struct { secret []byte storageType string storageConnectionString string - storageTLSConfig *tls.Config + storageCAFile string + storageCertSkipVerify bool + storageCertificate *tls.Certificate } func newServerConfig(options ...ServerOption) *serverConfig { @@ -54,10 +60,15 @@ func WithDeletePermanentlyAfter(dur time.Duration) ServerOption { } } -// WithSecret sets the secret in the config. -func WithSecret(secret []byte) ServerOption { +// WithSharedKey sets the secret in the config. +func WithSharedKey(sharedKey string) ServerOption { return func(cfg *serverConfig) { - cfg.secret = secret + key, err := base64.StdEncoding.DecodeString(sharedKey) + if err != nil || len(key) != cryptutil.DefaultKeySize { + log.Error().Err(err).Msgf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize) + return + } + cfg.secret = key } } @@ -75,9 +86,23 @@ func WithStorageConnectionString(connStr string) ServerOption { } } -// WithStorageTLSConfig sets the tls config for connection to storage. -func WithStorageTLSConfig(tlsConfig *tls.Config) ServerOption { +// WithStorageCAFile sets the CA file in the config. +func WithStorageCAFile(filePath string) ServerOption { return func(cfg *serverConfig) { - cfg.storageTLSConfig = tlsConfig + cfg.storageCAFile = filePath + } +} + +// WithStorageCertSkipVerify sets the storageCertSkipVerify in the config. +func WithStorageCertSkipVerify(storageCertSkipVerify bool) ServerOption { + return func(cfg *serverConfig) { + cfg.storageCertSkipVerify = storageCertSkipVerify + } +} + +// WithStorageCertificate sets the storageCertificate in the config. +func WithStorageCertificate(certificate *tls.Certificate) ServerOption { + return func(cfg *serverConfig) { + cfg.storageCertificate = certificate } } diff --git a/internal/databroker/server.go b/internal/databroker/server.go index 8e41d653e..7d33de21d 100644 --- a/internal/databroker/server.go +++ b/internal/databroker/server.go @@ -3,7 +3,10 @@ package databroker import ( "context" + "crypto/tls" + "crypto/x509" "fmt" + "io/ioutil" "reflect" "sort" "sync" @@ -11,6 +14,7 @@ import ( "github.com/golang/protobuf/ptypes" "github.com/golang/protobuf/ptypes/empty" + "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/rs/zerolog" "google.golang.org/grpc/codes" @@ -33,35 +37,39 @@ const ( syncBatchSize = 100 ) +// newUUID returns a new UUID. This make it easy to stub out in tests. +var newUUID = uuid.New + // Server implements the databroker service using an in memory database. type Server struct { - version string - cfg *serverConfig - log zerolog.Logger + cfg *serverConfig + log zerolog.Logger mu sync.RWMutex + version string byType map[string]storage.Backend onTypechange *signal.Signal } // New creates a new server. func New(options ...ServerOption) *Server { - cfg := newServerConfig(options...) srv := &Server{ - version: uuid.New().String(), - cfg: cfg, - log: log.With().Str("service", "databroker").Logger(), + log: log.With().Str("service", "databroker").Logger(), byType: make(map[string]storage.Backend), onTypechange: signal.New(), } - srv.initVersion() + srv.UpdateConfig(options...) go func() { - ticker := time.NewTicker(cfg.deletePermanentlyAfter / 2) + ticker := time.NewTicker(time.Minute) defer ticker.Stop() for range ticker.C { + srv.mu.RLock() + tm := time.Now().Add(-srv.cfg.deletePermanentlyAfter) + srv.mu.RUnlock() + var recordTypes []string srv.mu.RLock() for recordType := range srv.byType { @@ -70,11 +78,11 @@ func New(options ...ServerOption) *Server { srv.mu.RUnlock() for _, recordType := range recordTypes { - db, err := srv.getDB(recordType) + db, _, err := srv.getDB(recordType, true) if err != nil { continue } - db.ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter)) + db.ClearDeleted(context.Background(), tm) } } }() @@ -82,7 +90,7 @@ func New(options ...ServerOption) *Server { } func (srv *Server) initVersion() { - dbServerVersion, err := srv.getDB(recordTypeServerVersion) + dbServerVersion, _, err := srv.getDB(recordTypeServerVersion, false) if err != nil { log.Error().Err(err).Msg("failed to init server version") return @@ -98,12 +106,36 @@ func (srv *Server) initVersion() { return } + srv.version = newUUID().String() data, _ := ptypes.MarshalAny(&databroker.ServerVersion{Version: srv.version}) if err := dbServerVersion.Put(context.Background(), serverVersionKey, data); err != nil { srv.log.Warn().Err(err).Msg("failed to save server version.") } } +// UpdateConfig updates the server with the new options. +func (srv *Server) UpdateConfig(options ...ServerOption) { + srv.mu.Lock() + defer srv.mu.Unlock() + + cfg := newServerConfig(options...) + if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) { + log.Debug().Msg("databroker: no changes detected, re-using existing DBs") + return + } + srv.cfg = cfg + + for t, db := range srv.byType { + err := db.Close() + if err != nil { + log.Warn().Err(err).Msg("databroker: error closing backend") + } + delete(srv.byType, t) + } + + srv.initVersion() +} + // Delete deletes a record from the in-memory list. func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*empty.Empty, error) { _, span := trace.StartSpan(ctx, "databroker.grpc.Delete") @@ -113,7 +145,7 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (* Str("id", req.GetId()). Msg("delete") - db, err := srv.getDB(req.GetType()) + db, _, err := srv.getDB(req.GetType(), true) if err != nil { return nil, err } @@ -134,7 +166,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr Str("id", req.GetId()). Msg("get") - db, err := srv.getDB(req.GetType()) + db, _, err := srv.getDB(req.GetType(), true) if err != nil { return nil, err } @@ -156,7 +188,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (* Str("type", req.GetType()). Msg("get all") - db, err := srv.getDB(req.GetType()) + db, version, err := srv.getDB(req.GetType(), true) if err != nil { return nil, err } @@ -167,7 +199,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (* } if len(all) == 0 { - return &databroker.GetAllResponse{ServerVersion: srv.version}, nil + return &databroker.GetAllResponse{ServerVersion: version}, nil } var recordVersion string @@ -182,7 +214,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (* } return &databroker.GetAllResponse{ - ServerVersion: srv.version, + ServerVersion: version, RecordVersion: recordVersion, Records: records, }, nil @@ -197,7 +229,7 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr Str("id", req.GetId()). Msg("set") - db, err := srv.getDB(req.GetType()) + db, version, err := srv.getDB(req.GetType(), true) if err != nil { return nil, err } @@ -210,11 +242,13 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr } return &databroker.SetResponse{ Record: record, - ServerVersion: srv.version, + ServerVersion: version, }, nil } -func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage.Backend, stream databroker.DataBrokerService_SyncServer) error { +func (srv *Server) doSync(ctx context.Context, + serverVersion string, recordVersion *string, + db storage.Backend, stream databroker.DataBrokerService_SyncServer) error { updated, err := db.List(ctx, *recordVersion) if err != nil { return err @@ -232,7 +266,7 @@ func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage j = len(updated) } if err := stream.Send(&databroker.SyncResponse{ - ServerVersion: srv.version, + ServerVersion: serverVersion, Records: updated[i:j], }); err != nil { return err @@ -251,34 +285,34 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke Str("record_version", req.GetRecordVersion()). Msg("sync") + db, serverVersion, err := srv.getDB(req.GetType(), true) + if err != nil { + return err + } + recordVersion := req.GetRecordVersion() // reset record version if the server versions don't match - if req.GetServerVersion() != srv.version { + if req.GetServerVersion() != serverVersion { recordVersion = "" // send the new server version to the client err := stream.Send(&databroker.SyncResponse{ - ServerVersion: srv.version, + ServerVersion: serverVersion, }) if err != nil { return err } } - db, err := srv.getDB(req.GetType()) - if err != nil { - return err - } - ctx := stream.Context() ch := db.Watch(ctx) // Do first sync, so we won't missed anything. - if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil { + if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil { return err } for range ch { - if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil { + if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil { return err } } @@ -335,29 +369,56 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer } } -func (srv *Server) getDB(recordType string) (storage.Backend, error) { +func (srv *Server) getDB(recordType string, lock bool) (db storage.Backend, version string, err error) { // double-checked locking: // first try the read lock, then re-try with the write lock, and finally create a new db if nil - srv.mu.RLock() - db := srv.byType[recordType] - srv.mu.RUnlock() + if lock { + srv.mu.RLock() + } + db = srv.byType[recordType] + version = srv.version + if lock { + srv.mu.RUnlock() + } if db == nil { - srv.mu.Lock() + if lock { + srv.mu.Lock() + } db = srv.byType[recordType] + version = srv.version var err error if db == nil { db, err = srv.newDB(recordType) srv.byType[recordType] = db } - srv.mu.Unlock() + if lock { + srv.mu.Unlock() + } if err != nil { - return nil, err + return nil, "", err } } - return db, nil + return db, version, nil } func (srv *Server) newDB(recordType string) (db storage.Backend, err error) { + caCertPool := x509.NewCertPool() + if srv.cfg.storageCAFile != "" { + if caCert, err := ioutil.ReadFile(srv.cfg.storageCAFile); err == nil { + caCertPool.AppendCertsFromPEM(caCert) + } else { + log.Warn().Err(err).Msg("failed to read databroker CA file") + } + } + tlsConfig := &tls.Config{ + RootCAs: caCertPool, + // nolint: gosec + InsecureSkipVerify: srv.cfg.storageCertSkipVerify, + } + if srv.cfg.storageCertificate != nil { + tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate} + } + switch srv.cfg.storageType { case config.StorageInMemoryName: return inmemory.NewDB(recordType, srv.cfg.btreeDegree), nil @@ -366,7 +427,7 @@ func (srv *Server) newDB(recordType string) (db storage.Backend, err error) { srv.cfg.storageConnectionString, recordType, int64(srv.cfg.deletePermanentlyAfter.Seconds()), - redis.WithTLSConfig(srv.cfg.storageTLSConfig), + redis.WithTLSConfig(tlsConfig), ) if err != nil { return nil, fmt.Errorf("failed to create new redis storage: %w", err) diff --git a/internal/databroker/server_test.go b/internal/databroker/server_test.go index a03a34bd0..447fd1216 100644 --- a/internal/databroker/server_test.go +++ b/internal/databroker/server_test.go @@ -33,36 +33,46 @@ func newServer(cfg *serverConfig) *Server { func TestServer_initVersion(t *testing.T) { cfg := newServerConfig() t.Run("nil db", func(t *testing.T) { + srvVersion := uuid.New() + oldNewUUID := newUUID + newUUID = func() uuid.UUID { + return srvVersion + } + defer func() { newUUID = oldNewUUID }() + srv := newServer(cfg) - srvVersion := uuid.New().String() - srv.version = srvVersion srv.byType[recordTypeServerVersion] = nil srv.initVersion() - assert.Equal(t, srvVersion, srv.version) + assert.Equal(t, srvVersion.String(), srv.version) }) t.Run("new server with random version", func(t *testing.T) { + srvVersion := uuid.New() + oldNewUUID := newUUID + newUUID = func() uuid.UUID { + return srvVersion + } + defer func() { newUUID = oldNewUUID }() + srv := newServer(cfg) ctx := context.Background() - db, err := srv.getDB(recordTypeServerVersion) + db, _, err := srv.getDB(recordTypeServerVersion, false) require.NoError(t, err) r, err := db.Get(ctx, serverVersionKey) assert.Error(t, err) assert.Nil(t, r) - srvVersion := uuid.New().String() - srv.version = srvVersion srv.initVersion() - assert.Equal(t, srvVersion, srv.version) + assert.Equal(t, srvVersion.String(), srv.version) r, err = db.Get(ctx, serverVersionKey) require.NoError(t, err) assert.NotNil(t, r) var sv databroker.ServerVersion assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv)) - assert.Equal(t, srvVersion, sv.Version) + assert.Equal(t, srvVersion.String(), sv.Version) }) t.Run("init version twice should get the same version", func(t *testing.T) { srv := newServer(cfg) ctx := context.Background() - db, err := srv.getDB(recordTypeServerVersion) + db, _, err := srv.getDB(recordTypeServerVersion, false) require.NoError(t, err) r, err := db.Get(ctx, serverVersionKey) assert.Error(t, err) diff --git a/internal/directory/provider.go b/internal/directory/provider.go index f1b21c0d0..258e4134b 100644 --- a/internal/directory/provider.go +++ b/internal/directory/provider.go @@ -4,8 +4,10 @@ package directory import ( "context" "net/url" + "sync" + + "github.com/google/go-cmp/cmp" - "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/directory/azure" "github.com/pomerium/pomerium/internal/directory/github" "github.com/pomerium/pomerium/internal/directory/gitlab" @@ -27,8 +29,34 @@ type Provider interface { UserGroups(ctx context.Context) ([]*Group, []*User, error) } +// Options are the options specific to the provider. +type Options struct { + ServiceAccount string + Provider string + ProviderURL string + QPS float64 +} + +var globalProvider = struct { + sync.Mutex + provider Provider + options Options +}{} + // GetProvider gets the provider for the given options. -func GetProvider(options *config.Options) Provider { +func GetProvider(options Options) (provider Provider) { + globalProvider.Lock() + defer globalProvider.Unlock() + + if globalProvider.provider != nil && cmp.Equal(globalProvider.options, options) { + log.Debug().Str("provider", options.Provider).Msg("directory: no change detected, reusing existing directory provider") + return globalProvider.provider + } + defer func() { + globalProvider.provider = provider + globalProvider.options = options + }() + switch options.Provider { case azure.Name: serviceAccount, err := azure.ParseServiceAccount(options.ServiceAccount) diff --git a/internal/identity/manager/config.go b/internal/identity/manager/config.go index 2e292ae1b..8b991f7ad 100644 --- a/internal/identity/manager/config.go +++ b/internal/identity/manager/config.go @@ -1,6 +1,12 @@ package manager -import "time" +import ( + "sync/atomic" + "time" + + "github.com/pomerium/pomerium/internal/directory" + "github.com/pomerium/pomerium/pkg/grpc/databroker" +) var ( defaultGroupRefreshInterval = 10 * time.Minute @@ -10,6 +16,9 @@ var ( ) type config struct { + authenticator Authenticator + directory directory.Provider + dataBrokerClient databroker.DataBrokerServiceClient groupRefreshInterval time.Duration groupRefreshTimeout time.Duration sessionRefreshGracePeriod time.Duration @@ -31,6 +40,27 @@ func newConfig(options ...Option) *config { // An Option customizes the configuration used for the identity manager. type Option func(*config) +// WithAuthenticator sets the authenticator in the config. +func WithAuthenticator(authenticator Authenticator) Option { + return func(cfg *config) { + cfg.authenticator = authenticator + } +} + +// WithDirectoryProvider sets the directory provider in the config. +func WithDirectoryProvider(directoryProvider directory.Provider) Option { + return func(cfg *config) { + cfg.directory = directoryProvider + } +} + +// WithDataBrokerClient sets the databroker client in the config. +func WithDataBrokerClient(dataBrokerClient databroker.DataBrokerServiceClient) Option { + return func(cfg *config) { + cfg.dataBrokerClient = dataBrokerClient + } +} + // WithGroupRefreshInterval sets the group refresh interval used by the manager. func WithGroupRefreshInterval(interval time.Duration) Option { return func(cfg *config) { @@ -58,3 +88,21 @@ func WithSessionRefreshCoolOffDuration(dur time.Duration) Option { cfg.sessionRefreshCoolOffDuration = dur } } + +type atomicConfig struct { + value atomic.Value +} + +func newAtomicConfig(cfg *config) *atomicConfig { + ac := new(atomicConfig) + ac.Store(cfg) + return ac +} + +func (ac *atomicConfig) Load() *config { + return ac.value.Load().(*config) +} + +func (ac *atomicConfig) Store(cfg *config) { + ac.value.Store(cfg) +} diff --git a/internal/identity/manager/manager.go b/internal/identity/manager/manager.go index a95eb63c1..f1d03d2a9 100644 --- a/internal/identity/manager/manager.go +++ b/internal/identity/manager/manager.go @@ -42,11 +42,8 @@ type ( // A Manager refreshes identity information using session and user data. type Manager struct { - cfg *config - authenticator Authenticator - directory directory.Provider - dataBrokerClient databroker.DataBrokerServiceClient - log zerolog.Logger + cfg *atomicConfig + log zerolog.Logger sessions sessionCollection sessionScheduler *scheduler.Scheduler @@ -67,17 +64,11 @@ type Manager struct { // New creates a new identity manager. func New( - authenticator Authenticator, - directoryProvider directory.Provider, - dataBrokerClient databroker.DataBrokerServiceClient, options ...Option, ) *Manager { mgr := &Manager{ - cfg: newConfig(options...), - authenticator: authenticator, - directory: directoryProvider, - dataBrokerClient: dataBrokerClient, - log: log.With().Str("service", "identity_manager").Logger(), + cfg: newAtomicConfig(newConfig()), + log: log.With().Str("service", "identity_manager").Logger(), sessions: sessionCollection{ BTree: btree.New(8), @@ -88,9 +79,15 @@ func New( }, userScheduler: scheduler.New(), } + mgr.UpdateConfig(options...) return mgr } +// UpdateConfig updates the manager with the new options. +func (mgr *Manager) UpdateConfig(options ...Option) { + mgr.cfg.Store(newConfig(options...)) +} + // Run runs the manager. This method blocks until an error occurs or the given context is canceled. func (mgr *Manager) Run(ctx context.Context) error { err := mgr.initDirectoryGroups(ctx) @@ -169,7 +166,7 @@ func (mgr *Manager) refreshLoop( // refresh groups if mgr.directoryNextRefresh.Before(now) { mgr.refreshDirectoryUserGroups(ctx) - mgr.directoryNextRefresh = now.Add(mgr.cfg.groupRefreshInterval) + mgr.directoryNextRefresh = now.Add(mgr.cfg.Load().groupRefreshInterval) if mgr.directoryNextRefresh.Before(nextTime) { nextTime = mgr.directoryNextRefresh } @@ -211,10 +208,10 @@ func (mgr *Manager) refreshLoop( func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) { mgr.log.Info().Msg("refreshing directory users") - ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.groupRefreshTimeout) + ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.Load().groupRefreshTimeout) defer clearTimeout() - directoryGroups, directoryUsers, err := mgr.directory.UserGroups(ctx) + directoryGroups, directoryUsers, err := mgr.cfg.Load().directory.UserGroups(ctx) if err != nil { mgr.log.Warn().Err(err).Msg("failed to refresh directory users and groups") return @@ -238,7 +235,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director mgr.log.Warn().Err(err).Msg("failed to marshal directory group") return } - _, err = mgr.dataBrokerClient.Set(ctx, &databroker.SetRequest{ + _, err = mgr.cfg.Load().dataBrokerClient.Set(ctx, &databroker.SetRequest{ Type: any.GetTypeUrl(), Id: newDG.GetId(), Data: any, @@ -258,7 +255,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director mgr.log.Warn().Err(err).Msg("failed to marshal directory group") return } - _, err = mgr.dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{ + _, err = mgr.cfg.Load().dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{ Type: any.GetTypeUrl(), Id: curDG.GetId(), }) @@ -284,7 +281,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory. mgr.log.Warn().Err(err).Msg("failed to marshal directory user") return } - _, err = mgr.dataBrokerClient.Set(ctx, &databroker.SetRequest{ + _, err = mgr.cfg.Load().dataBrokerClient.Set(ctx, &databroker.SetRequest{ Type: any.GetTypeUrl(), Id: newDU.GetId(), Data: any, @@ -304,7 +301,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory. mgr.log.Warn().Err(err).Msg("failed to marshal directory user") return } - _, err = mgr.dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{ + _, err = mgr.cfg.Load().dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{ Type: any.GetTypeUrl(), Id: curDU.GetId(), }) @@ -349,7 +346,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string return } - newToken, err := mgr.authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s) + newToken, err := mgr.cfg.Load().authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s) if isTemporaryError(err) { mgr.log.Error().Err(err). Str("user_id", s.GetUserId()). @@ -366,7 +363,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string } s.OauthToken = ToOAuthToken(newToken) - res, err := session.Set(ctx, mgr.dataBrokerClient, s.Session) + res, err := session.Set(ctx, mgr.cfg.Load().dataBrokerClient, s.Session) if err != nil { mgr.log.Error().Err(err). Str("user_id", s.GetUserId()). @@ -401,7 +398,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) { continue } - err := mgr.authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u) + err := mgr.cfg.Load().authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u) if isTemporaryError(err) { mgr.log.Error().Err(err). Str("user_id", s.GetUserId()). @@ -417,7 +414,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) { continue } - record, err := user.Set(ctx, mgr.dataBrokerClient, u.User) + record, err := user.Set(ctx, mgr.cfg.Load().dataBrokerClient, u.User) if err != nil { mgr.log.Error().Err(err). Str("user_id", s.GetUserId()). @@ -438,7 +435,7 @@ func (mgr *Manager) syncSessions(ctx context.Context, ch chan<- sessionMessage) return err } - client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{ + client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{ Type: any.GetTypeUrl(), }) if err != nil { @@ -474,7 +471,7 @@ func (mgr *Manager) syncUsers(ctx context.Context, ch chan<- userMessage) error return err } - client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{ + client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{ Type: any.GetTypeUrl(), }) if err != nil { @@ -510,7 +507,7 @@ func (mgr *Manager) initDirectoryUsers(ctx context.Context) error { return err } - res, err := mgr.dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{ + res, err := mgr.cfg.Load().dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{ Type: any.GetTypeUrl(), }) if err != nil { @@ -541,7 +538,7 @@ func (mgr *Manager) syncDirectoryUsers(ctx context.Context, ch chan<- *directory return err } - client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{ + client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{ Type: any.GetTypeUrl(), ServerVersion: mgr.directoryUsersServerVersion, RecordVersion: mgr.directoryUsersRecordVersion, @@ -579,7 +576,7 @@ func (mgr *Manager) initDirectoryGroups(ctx context.Context) error { return err } - res, err := mgr.dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{ + res, err := mgr.cfg.Load().dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{ Type: any.GetTypeUrl(), }) if err != nil { @@ -610,7 +607,7 @@ func (mgr *Manager) syncDirectoryGroups(ctx context.Context, ch chan<- *director return err } - client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{ + client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{ Type: any.GetTypeUrl(), ServerVersion: mgr.directoryGroupsServerVersion, RecordVersion: mgr.directoryGroupsRecordVersion, @@ -652,8 +649,8 @@ func (mgr *Manager) onUpdateSession(ctx context.Context, msg sessionMessage) { // update session s, _ := mgr.sessions.Get(msg.session.GetUserId(), msg.session.GetId()) s.lastRefresh = time.Now() - s.gracePeriod = mgr.cfg.sessionRefreshGracePeriod - s.coolOffDuration = mgr.cfg.sessionRefreshCoolOffDuration + s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod + s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration s.Session = msg.session mgr.sessions.ReplaceOrInsert(s) mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(msg.session.GetUserId(), msg.session.GetId())) @@ -676,7 +673,7 @@ func (mgr *Manager) onUpdateUser(_ context.Context, msg userMessage) { // only reset the refresh time if this is an existing user u.lastRefresh = time.Now() } - u.refreshInterval = mgr.cfg.groupRefreshInterval + u.refreshInterval = mgr.cfg.Load().groupRefreshInterval u.User = msg.user mgr.users.ReplaceOrInsert(u) mgr.userScheduler.Add(u.NextRefresh(), u.GetId()) @@ -697,7 +694,7 @@ func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session) }, } - _, err := user.Set(ctx, mgr.dataBrokerClient, u.User) + _, err := user.Set(ctx, mgr.cfg.Load().dataBrokerClient, u.User) if err != nil { mgr.log.Error().Err(err). Str("user_id", pbSession.GetUserId()). @@ -707,7 +704,7 @@ func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session) } func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Session) { - err := session.Delete(ctx, mgr.dataBrokerClient, pbSession.GetId()) + err := session.Delete(ctx, mgr.cfg.Load().dataBrokerClient, pbSession.GetId()) if err != nil { mgr.log.Error().Err(err). Str("session_id", pbSession.GetId()). diff --git a/pkg/storage/inmemory/inmemory.go b/pkg/storage/inmemory/inmemory.go index 779cac4d7..3779ed499 100644 --- a/pkg/storage/inmemory/inmemory.go +++ b/pkg/storage/inmemory/inmemory.go @@ -52,6 +52,9 @@ type DB struct { byVersion *btree.BTree deletedIDs []string onchange *signal.Signal + + closeOnce sync.Once + closed chan struct{} } // NewDB creates a new in-memory database for the given record type. @@ -62,6 +65,7 @@ func NewDB(recordType string, btreeDegree int) *DB { byID: btree.New(btreeDegree), byVersion: btree.New(btreeDegree), onchange: s, + closed: make(chan struct{}), } } @@ -84,6 +88,14 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) { db.deletedIDs = remaining } +// Close closes the database. Any watchers will be closed. +func (db *DB) Close() error { + db.closeOnce.Do(func() { + close(db.closed) + }) + return nil +} + // Delete marks a record as deleted. func (db *DB) Delete(_ context.Context, id string) error { defer db.onchange.Broadcast() @@ -140,7 +152,10 @@ func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error { func (db *DB) Watch(ctx context.Context) <-chan struct{} { ch := db.onchange.Bind() go func() { - <-ctx.Done() + select { + case <-db.closed: + case <-ctx.Done(): + } close(ch) db.onchange.Unbind(ch) }() diff --git a/pkg/storage/redis/redis.go b/pkg/storage/redis/redis.go index 02075fd41..dd61ec61a 100644 --- a/pkg/storage/redis/redis.go +++ b/pkg/storage/redis/redis.go @@ -40,6 +40,9 @@ type DB struct { deletedSet string tlsConfig *tls.Config notifyChMu sync.Mutex + + closeOnce sync.Once + closed chan struct{} } // New returns new DB instance. @@ -50,6 +53,7 @@ func New(rawURL, recordType string, deletePermanentAfter int64, opts ...Option) versionSet: recordType + "_version_set", deletedSet: recordType + "_deleted_set", lastVersionKey: recordType + "_last_version", + closed: make(chan struct{}), } for _, o := range opts { @@ -79,6 +83,14 @@ func New(rawURL, recordType string, deletePermanentAfter int64, opts ...Option) return db, nil } +// Close closes the redis db connection. +func (db *DB) Close() error { + db.closeOnce.Do(func() { + close(db.closed) + }) + return nil +} + // Put sets new record for given id with input data. func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) (err error) { c := db.pool.Get() @@ -259,7 +271,10 @@ func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}) { } for { select { + case <-db.closed: + return case <-ctx.Done(): + return default: } switch v := psc.Receive().(type) { @@ -271,12 +286,17 @@ func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}) { } select { + case <-db.closed: + return case <-ctx.Done(): log.Warn().Err(ctx.Err()).Msg("context done, stop receive from redis channel") return default: db.notifyChMu.Lock() select { + case <-db.closed: + db.notifyChMu.Unlock() + return case <-ctx.Done(): db.notifyChMu.Unlock() log.Warn().Err(ctx.Err()).Msg("context done while holding notify lock, stop receive from redis channel") @@ -313,6 +333,7 @@ func (db *DB) watch(ctx context.Context, ch chan struct{}) { db.doNotifyLoop(ctx, ch) }() select { + case <-db.closed: case <-ctx.Done(): case <-done: } diff --git a/pkg/storage/storage.go b/pkg/storage/storage.go index c7e11f13c..16e7c0891 100644 --- a/pkg/storage/storage.go +++ b/pkg/storage/storage.go @@ -12,6 +12,9 @@ import ( // Backend is the interface required for a storage backend. type Backend interface { + // Close closes the backend. + Close() error + // Put is used to insert or update a record. Put(ctx context.Context, id string, data *anypb.Any) error diff --git a/pkg/storage/storage_test.go b/pkg/storage/storage_test.go index d246dabe2..6ab9b5fca 100644 --- a/pkg/storage/storage_test.go +++ b/pkg/storage/storage_test.go @@ -4,8 +4,9 @@ import ( "context" "time" - "github.com/pomerium/pomerium/pkg/grpc/databroker" "google.golang.org/protobuf/types/known/anypb" + + "github.com/pomerium/pomerium/pkg/grpc/databroker" ) type mockBackend struct { @@ -18,6 +19,10 @@ type mockBackend struct { watch func(ctx context.Context) <-chan struct{} } +func (m *mockBackend) Close() error { + return nil +} + func (m *mockBackend) Put(ctx context.Context, id string, data *anypb.Any) error { return m.put(ctx, id, data) }