mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
cache: support databroker option changes (#1294)
This commit is contained in:
parent
31205c0c29
commit
a1378c81f8
16 changed files with 408 additions and 179 deletions
97
cache/cache.go
vendored
97
cache/cache.go
vendored
|
@ -15,6 +15,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/directory"
|
||||
"github.com/pomerium/pomerium/internal/identity"
|
||||
"github.com/pomerium/pomerium/internal/identity/manager"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/telemetry"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
|
@ -35,18 +36,7 @@ type Cache struct {
|
|||
}
|
||||
|
||||
// New creates a new cache service.
|
||||
func New(opts config.Options) (*Cache, error) {
|
||||
if err := validate(opts); err != nil {
|
||||
return nil, fmt.Errorf("cache: bad option: %w", err)
|
||||
}
|
||||
|
||||
authenticator, err := identity.NewAuthenticator(opts.GetOauthOptions())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cache: failed to create authenticator: %w", err)
|
||||
}
|
||||
|
||||
directoryProvider := directory.GetProvider(&opts)
|
||||
|
||||
func New(cfg *config.Config) (*Cache, error) {
|
||||
localListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -56,7 +46,7 @@ func New(opts config.Options) (*Cache, error) {
|
|||
// if we no longer register with that grpc Server
|
||||
localGRPCServer := grpc.NewServer()
|
||||
|
||||
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(opts.Services)
|
||||
clientStatsHandler := telemetry.NewGRPCClientStatsHandler(cfg.Options.Services)
|
||||
clientDialOptions := clientStatsHandler.DialOptions(grpc.WithInsecure())
|
||||
|
||||
localGRPCConnection, err := grpc.DialContext(
|
||||
|
@ -68,30 +58,33 @@ func New(opts config.Options) (*Cache, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
dataBrokerServer, err := NewDataBrokerServer(localGRPCServer, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dataBrokerClient := databroker.NewDataBrokerServiceClient(localGRPCConnection)
|
||||
|
||||
manager := manager.New(
|
||||
authenticator,
|
||||
directoryProvider,
|
||||
dataBrokerClient,
|
||||
manager.WithGroupRefreshInterval(opts.RefreshDirectoryInterval),
|
||||
manager.WithGroupRefreshTimeout(opts.RefreshDirectoryTimeout),
|
||||
)
|
||||
|
||||
return &Cache{
|
||||
dataBrokerServer: dataBrokerServer,
|
||||
manager: manager,
|
||||
dataBrokerServer := NewDataBrokerServer(localGRPCServer, cfg)
|
||||
|
||||
c := &Cache{
|
||||
dataBrokerServer: dataBrokerServer,
|
||||
localListener: localListener,
|
||||
localGRPCServer: localGRPCServer,
|
||||
localGRPCConnection: localGRPCConnection,
|
||||
deprecatedCacheClusterDomain: opts.GetDataBrokerURL().Hostname(),
|
||||
dataBrokerStorageType: opts.DataBrokerStorageType,
|
||||
}, nil
|
||||
deprecatedCacheClusterDomain: cfg.Options.GetDataBrokerURL().Hostname(),
|
||||
dataBrokerStorageType: cfg.Options.DataBrokerStorageType,
|
||||
}
|
||||
|
||||
err = c.update(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// OnConfigChange is called whenever configuration is changed.
|
||||
func (c *Cache) OnConfigChange(cfg *config.Config) {
|
||||
err := c.update(cfg)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("cache: error updating configuration")
|
||||
}
|
||||
|
||||
c.dataBrokerServer.OnConfigChange(cfg)
|
||||
}
|
||||
|
||||
// Register registers all the gRPC services with the given server.
|
||||
|
@ -121,9 +114,45 @@ func (c *Cache) Run(ctx context.Context) error {
|
|||
return t.Wait()
|
||||
}
|
||||
|
||||
func (c *Cache) update(cfg *config.Config) error {
|
||||
if err := validate(cfg.Options); err != nil {
|
||||
return fmt.Errorf("cache: bad option: %w", err)
|
||||
}
|
||||
|
||||
authenticator, err := identity.NewAuthenticator(cfg.Options.GetOauthOptions())
|
||||
if err != nil {
|
||||
return fmt.Errorf("cache: failed to create authenticator: %w", err)
|
||||
}
|
||||
|
||||
directoryProvider := directory.GetProvider(directory.Options{
|
||||
ServiceAccount: cfg.Options.ServiceAccount,
|
||||
Provider: cfg.Options.Provider,
|
||||
ProviderURL: cfg.Options.ProviderURL,
|
||||
QPS: cfg.Options.QPS,
|
||||
})
|
||||
|
||||
dataBrokerClient := databroker.NewDataBrokerServiceClient(c.localGRPCConnection)
|
||||
|
||||
options := []manager.Option{
|
||||
manager.WithAuthenticator(authenticator),
|
||||
manager.WithDirectoryProvider(directoryProvider),
|
||||
manager.WithDataBrokerClient(dataBrokerClient),
|
||||
manager.WithGroupRefreshInterval(cfg.Options.RefreshDirectoryInterval),
|
||||
manager.WithGroupRefreshTimeout(cfg.Options.RefreshDirectoryTimeout),
|
||||
}
|
||||
|
||||
if c.manager == nil {
|
||||
c.manager = manager.New(options...)
|
||||
} else {
|
||||
c.manager.UpdateConfig(options...)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validate checks that proper configuration settings are set to create
|
||||
// a cache instance
|
||||
func validate(o config.Options) error {
|
||||
func validate(o *config.Options) error {
|
||||
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
|
||||
return fmt.Errorf("invalid 'SHARED_SECRET': %w", err)
|
||||
}
|
||||
|
|
2
cache/cache_test.go
vendored
2
cache/cache_test.go
vendored
|
@ -30,7 +30,7 @@ func TestNew(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.opts.Provider = "google"
|
||||
_, err := New(tt.opts)
|
||||
_, err := New(&config.Config{Options: &tt.opts})
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
|
65
cache/databroker.go
vendored
65
cache/databroker.go
vendored
|
@ -1,55 +1,38 @@
|
|||
package cache
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
internal_databroker "github.com/pomerium/pomerium/internal/databroker"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"github.com/pomerium/pomerium/internal/databroker"
|
||||
databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
// A DataBrokerServer implements the data broker service interface.
|
||||
type DataBrokerServer struct {
|
||||
databroker.DataBrokerServiceServer
|
||||
*databroker.Server
|
||||
}
|
||||
|
||||
// NewDataBrokerServer creates a new databroker service server.
|
||||
func NewDataBrokerServer(grpcServer *grpc.Server, opts config.Options) (*DataBrokerServer, error) {
|
||||
key, err := base64.StdEncoding.DecodeString(opts.SharedKey)
|
||||
if err != nil || len(key) != cryptutil.DefaultKeySize {
|
||||
return nil, fmt.Errorf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
|
||||
}
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
if caCert, err := ioutil.ReadFile(opts.DataBrokerStorageCAFile); err == nil {
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
} else {
|
||||
log.Warn().Err(err).Msg("failed to read databroker CA file")
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: caCertPool,
|
||||
// nolint: gosec
|
||||
InsecureSkipVerify: opts.DataBrokerStorageCertSkipVerify,
|
||||
}
|
||||
if opts.DataBrokerCertificate != nil {
|
||||
tlsConfig.Certificates = []tls.Certificate{*opts.DataBrokerCertificate}
|
||||
}
|
||||
|
||||
internalSrv := internal_databroker.New(
|
||||
internal_databroker.WithSecret(key),
|
||||
internal_databroker.WithStorageType(opts.DataBrokerStorageType),
|
||||
internal_databroker.WithStorageConnectionString(opts.DataBrokerStorageConnectionString),
|
||||
internal_databroker.WithStorageTLSConfig(tlsConfig),
|
||||
)
|
||||
srv := &DataBrokerServer{DataBrokerServiceServer: internalSrv}
|
||||
databroker.RegisterDataBrokerServiceServer(grpcServer, srv)
|
||||
return srv, nil
|
||||
func NewDataBrokerServer(grpcServer *grpc.Server, cfg *config.Config) *DataBrokerServer {
|
||||
srv := &DataBrokerServer{}
|
||||
srv.Server = databroker.New(srv.getOptions(cfg)...)
|
||||
databrokerpb.RegisterDataBrokerServiceServer(grpcServer, srv)
|
||||
return srv
|
||||
}
|
||||
|
||||
// OnConfigChange updates the underlying databroker server whenever configuration is changed.
|
||||
func (srv *DataBrokerServer) OnConfigChange(cfg *config.Config) {
|
||||
srv.UpdateConfig(srv.getOptions(cfg)...)
|
||||
}
|
||||
|
||||
func (srv *DataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption {
|
||||
return []databroker.ServerOption{
|
||||
databroker.WithSharedKey(cfg.Options.SharedKey),
|
||||
databroker.WithStorageType(cfg.Options.DataBrokerStorageType),
|
||||
databroker.WithStorageConnectionString(cfg.Options.DataBrokerStorageConnectionString),
|
||||
databroker.WithStorageCAFile(cfg.Options.DataBrokerStorageCAFile),
|
||||
databroker.WithStorageCertificate(cfg.Options.DataBrokerCertificate),
|
||||
databroker.WithStorageCertSkipVerify(cfg.Options.DataBrokerStorageCertSkipVerify),
|
||||
}
|
||||
}
|
||||
|
|
2
cache/databroker_test.go
vendored
2
cache/databroker_test.go
vendored
|
@ -27,7 +27,7 @@ func init() {
|
|||
lis = bufconn.Listen(bufSize)
|
||||
s := grpc.NewServer()
|
||||
internalSrv := internal_databroker.New()
|
||||
srv := &DataBrokerServer{DataBrokerServiceServer: internalSrv}
|
||||
srv := &DataBrokerServer{Server: internalSrv}
|
||||
databroker.RegisterDataBrokerServiceServer(s, srv)
|
||||
|
||||
go func() {
|
||||
|
|
10
cache/memberlist_test.go
vendored
10
cache/memberlist_test.go
vendored
|
@ -15,10 +15,12 @@ import (
|
|||
)
|
||||
|
||||
func TestCache_runMemberList(t *testing.T) {
|
||||
c, err := New(config.Options{
|
||||
SharedKey: cryptutil.NewBase64Key(),
|
||||
DataBrokerURL: &url.URL{Scheme: "http", Host: "member1"},
|
||||
Provider: "google",
|
||||
c, err := New(&config.Config{
|
||||
Options: &config.Options{
|
||||
SharedKey: cryptutil.NewBase64Key(),
|
||||
DataBrokerURL: &url.URL{Scheme: "http", Host: "member1"},
|
||||
Provider: "google",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
@ -88,7 +88,7 @@ func Run(ctx context.Context, configFile string) error {
|
|||
}
|
||||
var cacheServer *cache.Cache
|
||||
if config.IsCache(cfg.Options.Services) {
|
||||
cacheServer, err = setupCache(cfg.Options, controlPlane)
|
||||
cacheServer, err = setupCache(src, cfg, controlPlane)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -162,13 +162,15 @@ func setupAuthorize(src config.Source, cfg *config.Config, controlPlane *control
|
|||
return svc, nil
|
||||
}
|
||||
|
||||
func setupCache(opt *config.Options, controlPlane *controlplane.Server) (*cache.Cache, error) {
|
||||
svc, err := cache.New(*opt)
|
||||
func setupCache(src config.Source, cfg *config.Config, controlPlane *controlplane.Server) (*cache.Cache, error) {
|
||||
svc, err := cache.New(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error creating config service: %w", err)
|
||||
}
|
||||
svc.Register(controlPlane.GRPCServer)
|
||||
log.Info().Msg("enabled cache service")
|
||||
src.OnConfigChange(svc.OnConfigChange)
|
||||
svc.OnConfigChange(cfg)
|
||||
return svc, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -2,7 +2,11 @@ package databroker
|
|||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -21,7 +25,9 @@ type serverConfig struct {
|
|||
secret []byte
|
||||
storageType string
|
||||
storageConnectionString string
|
||||
storageTLSConfig *tls.Config
|
||||
storageCAFile string
|
||||
storageCertSkipVerify bool
|
||||
storageCertificate *tls.Certificate
|
||||
}
|
||||
|
||||
func newServerConfig(options ...ServerOption) *serverConfig {
|
||||
|
@ -54,10 +60,15 @@ func WithDeletePermanentlyAfter(dur time.Duration) ServerOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithSecret sets the secret in the config.
|
||||
func WithSecret(secret []byte) ServerOption {
|
||||
// WithSharedKey sets the secret in the config.
|
||||
func WithSharedKey(sharedKey string) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.secret = secret
|
||||
key, err := base64.StdEncoding.DecodeString(sharedKey)
|
||||
if err != nil || len(key) != cryptutil.DefaultKeySize {
|
||||
log.Error().Err(err).Msgf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
|
||||
return
|
||||
}
|
||||
cfg.secret = key
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,9 +86,23 @@ func WithStorageConnectionString(connStr string) ServerOption {
|
|||
}
|
||||
}
|
||||
|
||||
// WithStorageTLSConfig sets the tls config for connection to storage.
|
||||
func WithStorageTLSConfig(tlsConfig *tls.Config) ServerOption {
|
||||
// WithStorageCAFile sets the CA file in the config.
|
||||
func WithStorageCAFile(filePath string) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.storageTLSConfig = tlsConfig
|
||||
cfg.storageCAFile = filePath
|
||||
}
|
||||
}
|
||||
|
||||
// WithStorageCertSkipVerify sets the storageCertSkipVerify in the config.
|
||||
func WithStorageCertSkipVerify(storageCertSkipVerify bool) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.storageCertSkipVerify = storageCertSkipVerify
|
||||
}
|
||||
}
|
||||
|
||||
// WithStorageCertificate sets the storageCertificate in the config.
|
||||
func WithStorageCertificate(certificate *tls.Certificate) ServerOption {
|
||||
return func(cfg *serverConfig) {
|
||||
cfg.storageCertificate = certificate
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,7 +3,10 @@ package databroker
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"sort"
|
||||
"sync"
|
||||
|
@ -11,6 +14,7 @@ import (
|
|||
|
||||
"github.com/golang/protobuf/ptypes"
|
||||
"github.com/golang/protobuf/ptypes/empty"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/uuid"
|
||||
"github.com/rs/zerolog"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
@ -33,35 +37,39 @@ const (
|
|||
syncBatchSize = 100
|
||||
)
|
||||
|
||||
// newUUID returns a new UUID. This make it easy to stub out in tests.
|
||||
var newUUID = uuid.New
|
||||
|
||||
// Server implements the databroker service using an in memory database.
|
||||
type Server struct {
|
||||
version string
|
||||
cfg *serverConfig
|
||||
log zerolog.Logger
|
||||
cfg *serverConfig
|
||||
log zerolog.Logger
|
||||
|
||||
mu sync.RWMutex
|
||||
version string
|
||||
byType map[string]storage.Backend
|
||||
onTypechange *signal.Signal
|
||||
}
|
||||
|
||||
// New creates a new server.
|
||||
func New(options ...ServerOption) *Server {
|
||||
cfg := newServerConfig(options...)
|
||||
srv := &Server{
|
||||
version: uuid.New().String(),
|
||||
cfg: cfg,
|
||||
log: log.With().Str("service", "databroker").Logger(),
|
||||
log: log.With().Str("service", "databroker").Logger(),
|
||||
|
||||
byType: make(map[string]storage.Backend),
|
||||
onTypechange: signal.New(),
|
||||
}
|
||||
srv.initVersion()
|
||||
srv.UpdateConfig(options...)
|
||||
|
||||
go func() {
|
||||
ticker := time.NewTicker(cfg.deletePermanentlyAfter / 2)
|
||||
ticker := time.NewTicker(time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for range ticker.C {
|
||||
srv.mu.RLock()
|
||||
tm := time.Now().Add(-srv.cfg.deletePermanentlyAfter)
|
||||
srv.mu.RUnlock()
|
||||
|
||||
var recordTypes []string
|
||||
srv.mu.RLock()
|
||||
for recordType := range srv.byType {
|
||||
|
@ -70,11 +78,11 @@ func New(options ...ServerOption) *Server {
|
|||
srv.mu.RUnlock()
|
||||
|
||||
for _, recordType := range recordTypes {
|
||||
db, err := srv.getDB(recordType)
|
||||
db, _, err := srv.getDB(recordType, true)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
db.ClearDeleted(context.Background(), time.Now().Add(-cfg.deletePermanentlyAfter))
|
||||
db.ClearDeleted(context.Background(), tm)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
@ -82,7 +90,7 @@ func New(options ...ServerOption) *Server {
|
|||
}
|
||||
|
||||
func (srv *Server) initVersion() {
|
||||
dbServerVersion, err := srv.getDB(recordTypeServerVersion)
|
||||
dbServerVersion, _, err := srv.getDB(recordTypeServerVersion, false)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to init server version")
|
||||
return
|
||||
|
@ -98,12 +106,36 @@ func (srv *Server) initVersion() {
|
|||
return
|
||||
}
|
||||
|
||||
srv.version = newUUID().String()
|
||||
data, _ := ptypes.MarshalAny(&databroker.ServerVersion{Version: srv.version})
|
||||
if err := dbServerVersion.Put(context.Background(), serverVersionKey, data); err != nil {
|
||||
srv.log.Warn().Err(err).Msg("failed to save server version.")
|
||||
}
|
||||
}
|
||||
|
||||
// UpdateConfig updates the server with the new options.
|
||||
func (srv *Server) UpdateConfig(options ...ServerOption) {
|
||||
srv.mu.Lock()
|
||||
defer srv.mu.Unlock()
|
||||
|
||||
cfg := newServerConfig(options...)
|
||||
if cmp.Equal(cfg, srv.cfg, cmp.AllowUnexported(serverConfig{})) {
|
||||
log.Debug().Msg("databroker: no changes detected, re-using existing DBs")
|
||||
return
|
||||
}
|
||||
srv.cfg = cfg
|
||||
|
||||
for t, db := range srv.byType {
|
||||
err := db.Close()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("databroker: error closing backend")
|
||||
}
|
||||
delete(srv.byType, t)
|
||||
}
|
||||
|
||||
srv.initVersion()
|
||||
}
|
||||
|
||||
// Delete deletes a record from the in-memory list.
|
||||
func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*empty.Empty, error) {
|
||||
_, span := trace.StartSpan(ctx, "databroker.grpc.Delete")
|
||||
|
@ -113,7 +145,7 @@ func (srv *Server) Delete(ctx context.Context, req *databroker.DeleteRequest) (*
|
|||
Str("id", req.GetId()).
|
||||
Msg("delete")
|
||||
|
||||
db, err := srv.getDB(req.GetType())
|
||||
db, _, err := srv.getDB(req.GetType(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -134,7 +166,7 @@ func (srv *Server) Get(ctx context.Context, req *databroker.GetRequest) (*databr
|
|||
Str("id", req.GetId()).
|
||||
Msg("get")
|
||||
|
||||
db, err := srv.getDB(req.GetType())
|
||||
db, _, err := srv.getDB(req.GetType(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -156,7 +188,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
|||
Str("type", req.GetType()).
|
||||
Msg("get all")
|
||||
|
||||
db, err := srv.getDB(req.GetType())
|
||||
db, version, err := srv.getDB(req.GetType(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -167,7 +199,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
|||
}
|
||||
|
||||
if len(all) == 0 {
|
||||
return &databroker.GetAllResponse{ServerVersion: srv.version}, nil
|
||||
return &databroker.GetAllResponse{ServerVersion: version}, nil
|
||||
}
|
||||
|
||||
var recordVersion string
|
||||
|
@ -182,7 +214,7 @@ func (srv *Server) GetAll(ctx context.Context, req *databroker.GetAllRequest) (*
|
|||
}
|
||||
|
||||
return &databroker.GetAllResponse{
|
||||
ServerVersion: srv.version,
|
||||
ServerVersion: version,
|
||||
RecordVersion: recordVersion,
|
||||
Records: records,
|
||||
}, nil
|
||||
|
@ -197,7 +229,7 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
|||
Str("id", req.GetId()).
|
||||
Msg("set")
|
||||
|
||||
db, err := srv.getDB(req.GetType())
|
||||
db, version, err := srv.getDB(req.GetType(), true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -210,11 +242,13 @@ func (srv *Server) Set(ctx context.Context, req *databroker.SetRequest) (*databr
|
|||
}
|
||||
return &databroker.SetResponse{
|
||||
Record: record,
|
||||
ServerVersion: srv.version,
|
||||
ServerVersion: version,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage.Backend, stream databroker.DataBrokerService_SyncServer) error {
|
||||
func (srv *Server) doSync(ctx context.Context,
|
||||
serverVersion string, recordVersion *string,
|
||||
db storage.Backend, stream databroker.DataBrokerService_SyncServer) error {
|
||||
updated, err := db.List(ctx, *recordVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -232,7 +266,7 @@ func (srv *Server) doSync(ctx context.Context, recordVersion *string, db storage
|
|||
j = len(updated)
|
||||
}
|
||||
if err := stream.Send(&databroker.SyncResponse{
|
||||
ServerVersion: srv.version,
|
||||
ServerVersion: serverVersion,
|
||||
Records: updated[i:j],
|
||||
}); err != nil {
|
||||
return err
|
||||
|
@ -251,34 +285,34 @@ func (srv *Server) Sync(req *databroker.SyncRequest, stream databroker.DataBroke
|
|||
Str("record_version", req.GetRecordVersion()).
|
||||
Msg("sync")
|
||||
|
||||
db, serverVersion, err := srv.getDB(req.GetType(), true)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
recordVersion := req.GetRecordVersion()
|
||||
// reset record version if the server versions don't match
|
||||
if req.GetServerVersion() != srv.version {
|
||||
if req.GetServerVersion() != serverVersion {
|
||||
recordVersion = ""
|
||||
// send the new server version to the client
|
||||
err := stream.Send(&databroker.SyncResponse{
|
||||
ServerVersion: srv.version,
|
||||
ServerVersion: serverVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
db, err := srv.getDB(req.GetType())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx := stream.Context()
|
||||
ch := db.Watch(ctx)
|
||||
|
||||
// Do first sync, so we won't missed anything.
|
||||
if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil {
|
||||
if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for range ch {
|
||||
if err := srv.doSync(ctx, &recordVersion, db, stream); err != nil {
|
||||
if err := srv.doSync(ctx, serverVersion, &recordVersion, db, stream); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -335,29 +369,56 @@ func (srv *Server) SyncTypes(req *emptypb.Empty, stream databroker.DataBrokerSer
|
|||
}
|
||||
}
|
||||
|
||||
func (srv *Server) getDB(recordType string) (storage.Backend, error) {
|
||||
func (srv *Server) getDB(recordType string, lock bool) (db storage.Backend, version string, err error) {
|
||||
// double-checked locking:
|
||||
// first try the read lock, then re-try with the write lock, and finally create a new db if nil
|
||||
srv.mu.RLock()
|
||||
db := srv.byType[recordType]
|
||||
srv.mu.RUnlock()
|
||||
if lock {
|
||||
srv.mu.RLock()
|
||||
}
|
||||
db = srv.byType[recordType]
|
||||
version = srv.version
|
||||
if lock {
|
||||
srv.mu.RUnlock()
|
||||
}
|
||||
if db == nil {
|
||||
srv.mu.Lock()
|
||||
if lock {
|
||||
srv.mu.Lock()
|
||||
}
|
||||
db = srv.byType[recordType]
|
||||
version = srv.version
|
||||
var err error
|
||||
if db == nil {
|
||||
db, err = srv.newDB(recordType)
|
||||
srv.byType[recordType] = db
|
||||
}
|
||||
srv.mu.Unlock()
|
||||
if lock {
|
||||
srv.mu.Unlock()
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
}
|
||||
return db, nil
|
||||
return db, version, nil
|
||||
}
|
||||
|
||||
func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
|
||||
caCertPool := x509.NewCertPool()
|
||||
if srv.cfg.storageCAFile != "" {
|
||||
if caCert, err := ioutil.ReadFile(srv.cfg.storageCAFile); err == nil {
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
} else {
|
||||
log.Warn().Err(err).Msg("failed to read databroker CA file")
|
||||
}
|
||||
}
|
||||
tlsConfig := &tls.Config{
|
||||
RootCAs: caCertPool,
|
||||
// nolint: gosec
|
||||
InsecureSkipVerify: srv.cfg.storageCertSkipVerify,
|
||||
}
|
||||
if srv.cfg.storageCertificate != nil {
|
||||
tlsConfig.Certificates = []tls.Certificate{*srv.cfg.storageCertificate}
|
||||
}
|
||||
|
||||
switch srv.cfg.storageType {
|
||||
case config.StorageInMemoryName:
|
||||
return inmemory.NewDB(recordType, srv.cfg.btreeDegree), nil
|
||||
|
@ -366,7 +427,7 @@ func (srv *Server) newDB(recordType string) (db storage.Backend, err error) {
|
|||
srv.cfg.storageConnectionString,
|
||||
recordType,
|
||||
int64(srv.cfg.deletePermanentlyAfter.Seconds()),
|
||||
redis.WithTLSConfig(srv.cfg.storageTLSConfig),
|
||||
redis.WithTLSConfig(tlsConfig),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create new redis storage: %w", err)
|
||||
|
|
|
@ -33,36 +33,46 @@ func newServer(cfg *serverConfig) *Server {
|
|||
func TestServer_initVersion(t *testing.T) {
|
||||
cfg := newServerConfig()
|
||||
t.Run("nil db", func(t *testing.T) {
|
||||
srvVersion := uuid.New()
|
||||
oldNewUUID := newUUID
|
||||
newUUID = func() uuid.UUID {
|
||||
return srvVersion
|
||||
}
|
||||
defer func() { newUUID = oldNewUUID }()
|
||||
|
||||
srv := newServer(cfg)
|
||||
srvVersion := uuid.New().String()
|
||||
srv.version = srvVersion
|
||||
srv.byType[recordTypeServerVersion] = nil
|
||||
srv.initVersion()
|
||||
assert.Equal(t, srvVersion, srv.version)
|
||||
assert.Equal(t, srvVersion.String(), srv.version)
|
||||
})
|
||||
t.Run("new server with random version", func(t *testing.T) {
|
||||
srvVersion := uuid.New()
|
||||
oldNewUUID := newUUID
|
||||
newUUID = func() uuid.UUID {
|
||||
return srvVersion
|
||||
}
|
||||
defer func() { newUUID = oldNewUUID }()
|
||||
|
||||
srv := newServer(cfg)
|
||||
ctx := context.Background()
|
||||
db, err := srv.getDB(recordTypeServerVersion)
|
||||
db, _, err := srv.getDB(recordTypeServerVersion, false)
|
||||
require.NoError(t, err)
|
||||
r, err := db.Get(ctx, serverVersionKey)
|
||||
assert.Error(t, err)
|
||||
assert.Nil(t, r)
|
||||
srvVersion := uuid.New().String()
|
||||
srv.version = srvVersion
|
||||
srv.initVersion()
|
||||
assert.Equal(t, srvVersion, srv.version)
|
||||
assert.Equal(t, srvVersion.String(), srv.version)
|
||||
r, err = db.Get(ctx, serverVersionKey)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, r)
|
||||
var sv databroker.ServerVersion
|
||||
assert.NoError(t, ptypes.UnmarshalAny(r.GetData(), &sv))
|
||||
assert.Equal(t, srvVersion, sv.Version)
|
||||
assert.Equal(t, srvVersion.String(), sv.Version)
|
||||
})
|
||||
t.Run("init version twice should get the same version", func(t *testing.T) {
|
||||
srv := newServer(cfg)
|
||||
ctx := context.Background()
|
||||
db, err := srv.getDB(recordTypeServerVersion)
|
||||
db, _, err := srv.getDB(recordTypeServerVersion, false)
|
||||
require.NoError(t, err)
|
||||
r, err := db.Get(ctx, serverVersionKey)
|
||||
assert.Error(t, err)
|
||||
|
|
|
@ -4,8 +4,10 @@ package directory
|
|||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
"sync"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/directory/azure"
|
||||
"github.com/pomerium/pomerium/internal/directory/github"
|
||||
"github.com/pomerium/pomerium/internal/directory/gitlab"
|
||||
|
@ -27,8 +29,34 @@ type Provider interface {
|
|||
UserGroups(ctx context.Context) ([]*Group, []*User, error)
|
||||
}
|
||||
|
||||
// Options are the options specific to the provider.
|
||||
type Options struct {
|
||||
ServiceAccount string
|
||||
Provider string
|
||||
ProviderURL string
|
||||
QPS float64
|
||||
}
|
||||
|
||||
var globalProvider = struct {
|
||||
sync.Mutex
|
||||
provider Provider
|
||||
options Options
|
||||
}{}
|
||||
|
||||
// GetProvider gets the provider for the given options.
|
||||
func GetProvider(options *config.Options) Provider {
|
||||
func GetProvider(options Options) (provider Provider) {
|
||||
globalProvider.Lock()
|
||||
defer globalProvider.Unlock()
|
||||
|
||||
if globalProvider.provider != nil && cmp.Equal(globalProvider.options, options) {
|
||||
log.Debug().Str("provider", options.Provider).Msg("directory: no change detected, reusing existing directory provider")
|
||||
return globalProvider.provider
|
||||
}
|
||||
defer func() {
|
||||
globalProvider.provider = provider
|
||||
globalProvider.options = options
|
||||
}()
|
||||
|
||||
switch options.Provider {
|
||||
case azure.Name:
|
||||
serviceAccount, err := azure.ParseServiceAccount(options.ServiceAccount)
|
||||
|
|
|
@ -1,6 +1,12 @@
|
|||
package manager
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/directory"
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
var (
|
||||
defaultGroupRefreshInterval = 10 * time.Minute
|
||||
|
@ -10,6 +16,9 @@ var (
|
|||
)
|
||||
|
||||
type config struct {
|
||||
authenticator Authenticator
|
||||
directory directory.Provider
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
groupRefreshInterval time.Duration
|
||||
groupRefreshTimeout time.Duration
|
||||
sessionRefreshGracePeriod time.Duration
|
||||
|
@ -31,6 +40,27 @@ func newConfig(options ...Option) *config {
|
|||
// An Option customizes the configuration used for the identity manager.
|
||||
type Option func(*config)
|
||||
|
||||
// WithAuthenticator sets the authenticator in the config.
|
||||
func WithAuthenticator(authenticator Authenticator) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.authenticator = authenticator
|
||||
}
|
||||
}
|
||||
|
||||
// WithDirectoryProvider sets the directory provider in the config.
|
||||
func WithDirectoryProvider(directoryProvider directory.Provider) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.directory = directoryProvider
|
||||
}
|
||||
}
|
||||
|
||||
// WithDataBrokerClient sets the databroker client in the config.
|
||||
func WithDataBrokerClient(dataBrokerClient databroker.DataBrokerServiceClient) Option {
|
||||
return func(cfg *config) {
|
||||
cfg.dataBrokerClient = dataBrokerClient
|
||||
}
|
||||
}
|
||||
|
||||
// WithGroupRefreshInterval sets the group refresh interval used by the manager.
|
||||
func WithGroupRefreshInterval(interval time.Duration) Option {
|
||||
return func(cfg *config) {
|
||||
|
@ -58,3 +88,21 @@ func WithSessionRefreshCoolOffDuration(dur time.Duration) Option {
|
|||
cfg.sessionRefreshCoolOffDuration = dur
|
||||
}
|
||||
}
|
||||
|
||||
type atomicConfig struct {
|
||||
value atomic.Value
|
||||
}
|
||||
|
||||
func newAtomicConfig(cfg *config) *atomicConfig {
|
||||
ac := new(atomicConfig)
|
||||
ac.Store(cfg)
|
||||
return ac
|
||||
}
|
||||
|
||||
func (ac *atomicConfig) Load() *config {
|
||||
return ac.value.Load().(*config)
|
||||
}
|
||||
|
||||
func (ac *atomicConfig) Store(cfg *config) {
|
||||
ac.value.Store(cfg)
|
||||
}
|
||||
|
|
|
@ -42,11 +42,8 @@ type (
|
|||
|
||||
// A Manager refreshes identity information using session and user data.
|
||||
type Manager struct {
|
||||
cfg *config
|
||||
authenticator Authenticator
|
||||
directory directory.Provider
|
||||
dataBrokerClient databroker.DataBrokerServiceClient
|
||||
log zerolog.Logger
|
||||
cfg *atomicConfig
|
||||
log zerolog.Logger
|
||||
|
||||
sessions sessionCollection
|
||||
sessionScheduler *scheduler.Scheduler
|
||||
|
@ -67,17 +64,11 @@ type Manager struct {
|
|||
|
||||
// New creates a new identity manager.
|
||||
func New(
|
||||
authenticator Authenticator,
|
||||
directoryProvider directory.Provider,
|
||||
dataBrokerClient databroker.DataBrokerServiceClient,
|
||||
options ...Option,
|
||||
) *Manager {
|
||||
mgr := &Manager{
|
||||
cfg: newConfig(options...),
|
||||
authenticator: authenticator,
|
||||
directory: directoryProvider,
|
||||
dataBrokerClient: dataBrokerClient,
|
||||
log: log.With().Str("service", "identity_manager").Logger(),
|
||||
cfg: newAtomicConfig(newConfig()),
|
||||
log: log.With().Str("service", "identity_manager").Logger(),
|
||||
|
||||
sessions: sessionCollection{
|
||||
BTree: btree.New(8),
|
||||
|
@ -88,9 +79,15 @@ func New(
|
|||
},
|
||||
userScheduler: scheduler.New(),
|
||||
}
|
||||
mgr.UpdateConfig(options...)
|
||||
return mgr
|
||||
}
|
||||
|
||||
// UpdateConfig updates the manager with the new options.
|
||||
func (mgr *Manager) UpdateConfig(options ...Option) {
|
||||
mgr.cfg.Store(newConfig(options...))
|
||||
}
|
||||
|
||||
// Run runs the manager. This method blocks until an error occurs or the given context is canceled.
|
||||
func (mgr *Manager) Run(ctx context.Context) error {
|
||||
err := mgr.initDirectoryGroups(ctx)
|
||||
|
@ -169,7 +166,7 @@ func (mgr *Manager) refreshLoop(
|
|||
// refresh groups
|
||||
if mgr.directoryNextRefresh.Before(now) {
|
||||
mgr.refreshDirectoryUserGroups(ctx)
|
||||
mgr.directoryNextRefresh = now.Add(mgr.cfg.groupRefreshInterval)
|
||||
mgr.directoryNextRefresh = now.Add(mgr.cfg.Load().groupRefreshInterval)
|
||||
if mgr.directoryNextRefresh.Before(nextTime) {
|
||||
nextTime = mgr.directoryNextRefresh
|
||||
}
|
||||
|
@ -211,10 +208,10 @@ func (mgr *Manager) refreshLoop(
|
|||
func (mgr *Manager) refreshDirectoryUserGroups(ctx context.Context) {
|
||||
mgr.log.Info().Msg("refreshing directory users")
|
||||
|
||||
ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.groupRefreshTimeout)
|
||||
ctx, clearTimeout := context.WithTimeout(ctx, mgr.cfg.Load().groupRefreshTimeout)
|
||||
defer clearTimeout()
|
||||
|
||||
directoryGroups, directoryUsers, err := mgr.directory.UserGroups(ctx)
|
||||
directoryGroups, directoryUsers, err := mgr.cfg.Load().directory.UserGroups(ctx)
|
||||
if err != nil {
|
||||
mgr.log.Warn().Err(err).Msg("failed to refresh directory users and groups")
|
||||
return
|
||||
|
@ -238,7 +235,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
|
|||
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
|
||||
return
|
||||
}
|
||||
_, err = mgr.dataBrokerClient.Set(ctx, &databroker.SetRequest{
|
||||
_, err = mgr.cfg.Load().dataBrokerClient.Set(ctx, &databroker.SetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: newDG.GetId(),
|
||||
Data: any,
|
||||
|
@ -258,7 +255,7 @@ func (mgr *Manager) mergeGroups(ctx context.Context, directoryGroups []*director
|
|||
mgr.log.Warn().Err(err).Msg("failed to marshal directory group")
|
||||
return
|
||||
}
|
||||
_, err = mgr.dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
|
||||
_, err = mgr.cfg.Load().dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: curDG.GetId(),
|
||||
})
|
||||
|
@ -284,7 +281,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
|
|||
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
|
||||
return
|
||||
}
|
||||
_, err = mgr.dataBrokerClient.Set(ctx, &databroker.SetRequest{
|
||||
_, err = mgr.cfg.Load().dataBrokerClient.Set(ctx, &databroker.SetRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: newDU.GetId(),
|
||||
Data: any,
|
||||
|
@ -304,7 +301,7 @@ func (mgr *Manager) mergeUsers(ctx context.Context, directoryUsers []*directory.
|
|||
mgr.log.Warn().Err(err).Msg("failed to marshal directory user")
|
||||
return
|
||||
}
|
||||
_, err = mgr.dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
|
||||
_, err = mgr.cfg.Load().dataBrokerClient.Delete(ctx, &databroker.DeleteRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
Id: curDU.GetId(),
|
||||
})
|
||||
|
@ -349,7 +346,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
return
|
||||
}
|
||||
|
||||
newToken, err := mgr.authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s)
|
||||
newToken, err := mgr.cfg.Load().authenticator.Refresh(ctx, FromOAuthToken(s.OauthToken), &s)
|
||||
if isTemporaryError(err) {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
|
@ -366,7 +363,7 @@ func (mgr *Manager) refreshSession(ctx context.Context, userID, sessionID string
|
|||
}
|
||||
s.OauthToken = ToOAuthToken(newToken)
|
||||
|
||||
res, err := session.Set(ctx, mgr.dataBrokerClient, s.Session)
|
||||
res, err := session.Set(ctx, mgr.cfg.Load().dataBrokerClient, s.Session)
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
|
@ -401,7 +398,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
continue
|
||||
}
|
||||
|
||||
err := mgr.authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u)
|
||||
err := mgr.cfg.Load().authenticator.UpdateUserInfo(ctx, FromOAuthToken(s.OauthToken), &u)
|
||||
if isTemporaryError(err) {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
|
@ -417,7 +414,7 @@ func (mgr *Manager) refreshUser(ctx context.Context, userID string) {
|
|||
continue
|
||||
}
|
||||
|
||||
record, err := user.Set(ctx, mgr.dataBrokerClient, u.User)
|
||||
record, err := user.Set(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", s.GetUserId()).
|
||||
|
@ -438,7 +435,7 @@ func (mgr *Manager) syncSessions(ctx context.Context, ch chan<- sessionMessage)
|
|||
return err
|
||||
}
|
||||
|
||||
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -474,7 +471,7 @@ func (mgr *Manager) syncUsers(ctx context.Context, ch chan<- userMessage) error
|
|||
return err
|
||||
}
|
||||
|
||||
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -510,7 +507,7 @@ func (mgr *Manager) initDirectoryUsers(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
res, err := mgr.dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
|
||||
res, err := mgr.cfg.Load().dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -541,7 +538,7 @@ func (mgr *Manager) syncDirectoryUsers(ctx context.Context, ch chan<- *directory
|
|||
return err
|
||||
}
|
||||
|
||||
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
ServerVersion: mgr.directoryUsersServerVersion,
|
||||
RecordVersion: mgr.directoryUsersRecordVersion,
|
||||
|
@ -579,7 +576,7 @@ func (mgr *Manager) initDirectoryGroups(ctx context.Context) error {
|
|||
return err
|
||||
}
|
||||
|
||||
res, err := mgr.dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
|
||||
res, err := mgr.cfg.Load().dataBrokerClient.GetAll(ctx, &databroker.GetAllRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
})
|
||||
if err != nil {
|
||||
|
@ -610,7 +607,7 @@ func (mgr *Manager) syncDirectoryGroups(ctx context.Context, ch chan<- *director
|
|||
return err
|
||||
}
|
||||
|
||||
client, err := mgr.dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
client, err := mgr.cfg.Load().dataBrokerClient.Sync(ctx, &databroker.SyncRequest{
|
||||
Type: any.GetTypeUrl(),
|
||||
ServerVersion: mgr.directoryGroupsServerVersion,
|
||||
RecordVersion: mgr.directoryGroupsRecordVersion,
|
||||
|
@ -652,8 +649,8 @@ func (mgr *Manager) onUpdateSession(ctx context.Context, msg sessionMessage) {
|
|||
// update session
|
||||
s, _ := mgr.sessions.Get(msg.session.GetUserId(), msg.session.GetId())
|
||||
s.lastRefresh = time.Now()
|
||||
s.gracePeriod = mgr.cfg.sessionRefreshGracePeriod
|
||||
s.coolOffDuration = mgr.cfg.sessionRefreshCoolOffDuration
|
||||
s.gracePeriod = mgr.cfg.Load().sessionRefreshGracePeriod
|
||||
s.coolOffDuration = mgr.cfg.Load().sessionRefreshCoolOffDuration
|
||||
s.Session = msg.session
|
||||
mgr.sessions.ReplaceOrInsert(s)
|
||||
mgr.sessionScheduler.Add(s.NextRefresh(), toSessionSchedulerKey(msg.session.GetUserId(), msg.session.GetId()))
|
||||
|
@ -676,7 +673,7 @@ func (mgr *Manager) onUpdateUser(_ context.Context, msg userMessage) {
|
|||
// only reset the refresh time if this is an existing user
|
||||
u.lastRefresh = time.Now()
|
||||
}
|
||||
u.refreshInterval = mgr.cfg.groupRefreshInterval
|
||||
u.refreshInterval = mgr.cfg.Load().groupRefreshInterval
|
||||
u.User = msg.user
|
||||
mgr.users.ReplaceOrInsert(u)
|
||||
mgr.userScheduler.Add(u.NextRefresh(), u.GetId())
|
||||
|
@ -697,7 +694,7 @@ func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session)
|
|||
},
|
||||
}
|
||||
|
||||
_, err := user.Set(ctx, mgr.dataBrokerClient, u.User)
|
||||
_, err := user.Set(ctx, mgr.cfg.Load().dataBrokerClient, u.User)
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("user_id", pbSession.GetUserId()).
|
||||
|
@ -707,7 +704,7 @@ func (mgr *Manager) createUser(ctx context.Context, pbSession *session.Session)
|
|||
}
|
||||
|
||||
func (mgr *Manager) deleteSession(ctx context.Context, pbSession *session.Session) {
|
||||
err := session.Delete(ctx, mgr.dataBrokerClient, pbSession.GetId())
|
||||
err := session.Delete(ctx, mgr.cfg.Load().dataBrokerClient, pbSession.GetId())
|
||||
if err != nil {
|
||||
mgr.log.Error().Err(err).
|
||||
Str("session_id", pbSession.GetId()).
|
||||
|
|
|
@ -52,6 +52,9 @@ type DB struct {
|
|||
byVersion *btree.BTree
|
||||
deletedIDs []string
|
||||
onchange *signal.Signal
|
||||
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// NewDB creates a new in-memory database for the given record type.
|
||||
|
@ -62,6 +65,7 @@ func NewDB(recordType string, btreeDegree int) *DB {
|
|||
byID: btree.New(btreeDegree),
|
||||
byVersion: btree.New(btreeDegree),
|
||||
onchange: s,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -84,6 +88,14 @@ func (db *DB) ClearDeleted(_ context.Context, cutoff time.Time) {
|
|||
db.deletedIDs = remaining
|
||||
}
|
||||
|
||||
// Close closes the database. Any watchers will be closed.
|
||||
func (db *DB) Close() error {
|
||||
db.closeOnce.Do(func() {
|
||||
close(db.closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete marks a record as deleted.
|
||||
func (db *DB) Delete(_ context.Context, id string) error {
|
||||
defer db.onchange.Broadcast()
|
||||
|
@ -140,7 +152,10 @@ func (db *DB) Put(_ context.Context, id string, data *anypb.Any) error {
|
|||
func (db *DB) Watch(ctx context.Context) <-chan struct{} {
|
||||
ch := db.onchange.Bind()
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
select {
|
||||
case <-db.closed:
|
||||
case <-ctx.Done():
|
||||
}
|
||||
close(ch)
|
||||
db.onchange.Unbind(ch)
|
||||
}()
|
||||
|
|
|
@ -40,6 +40,9 @@ type DB struct {
|
|||
deletedSet string
|
||||
tlsConfig *tls.Config
|
||||
notifyChMu sync.Mutex
|
||||
|
||||
closeOnce sync.Once
|
||||
closed chan struct{}
|
||||
}
|
||||
|
||||
// New returns new DB instance.
|
||||
|
@ -50,6 +53,7 @@ func New(rawURL, recordType string, deletePermanentAfter int64, opts ...Option)
|
|||
versionSet: recordType + "_version_set",
|
||||
deletedSet: recordType + "_deleted_set",
|
||||
lastVersionKey: recordType + "_last_version",
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
|
||||
for _, o := range opts {
|
||||
|
@ -79,6 +83,14 @@ func New(rawURL, recordType string, deletePermanentAfter int64, opts ...Option)
|
|||
return db, nil
|
||||
}
|
||||
|
||||
// Close closes the redis db connection.
|
||||
func (db *DB) Close() error {
|
||||
db.closeOnce.Do(func() {
|
||||
close(db.closed)
|
||||
})
|
||||
return nil
|
||||
}
|
||||
|
||||
// Put sets new record for given id with input data.
|
||||
func (db *DB) Put(ctx context.Context, id string, data *anypb.Any) (err error) {
|
||||
c := db.pool.Get()
|
||||
|
@ -259,7 +271,10 @@ func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}) {
|
|||
}
|
||||
for {
|
||||
select {
|
||||
case <-db.closed:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
switch v := psc.Receive().(type) {
|
||||
|
@ -271,12 +286,17 @@ func (db *DB) doNotifyLoop(ctx context.Context, ch chan struct{}) {
|
|||
}
|
||||
|
||||
select {
|
||||
case <-db.closed:
|
||||
return
|
||||
case <-ctx.Done():
|
||||
log.Warn().Err(ctx.Err()).Msg("context done, stop receive from redis channel")
|
||||
return
|
||||
default:
|
||||
db.notifyChMu.Lock()
|
||||
select {
|
||||
case <-db.closed:
|
||||
db.notifyChMu.Unlock()
|
||||
return
|
||||
case <-ctx.Done():
|
||||
db.notifyChMu.Unlock()
|
||||
log.Warn().Err(ctx.Err()).Msg("context done while holding notify lock, stop receive from redis channel")
|
||||
|
@ -313,6 +333,7 @@ func (db *DB) watch(ctx context.Context, ch chan struct{}) {
|
|||
db.doNotifyLoop(ctx, ch)
|
||||
}()
|
||||
select {
|
||||
case <-db.closed:
|
||||
case <-ctx.Done():
|
||||
case <-done:
|
||||
}
|
||||
|
|
|
@ -12,6 +12,9 @@ import (
|
|||
|
||||
// Backend is the interface required for a storage backend.
|
||||
type Backend interface {
|
||||
// Close closes the backend.
|
||||
Close() error
|
||||
|
||||
// Put is used to insert or update a record.
|
||||
Put(ctx context.Context, id string, data *anypb.Any) error
|
||||
|
||||
|
|
|
@ -4,8 +4,9 @@ import (
|
|||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
"google.golang.org/protobuf/types/known/anypb"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
||||
)
|
||||
|
||||
type mockBackend struct {
|
||||
|
@ -18,6 +19,10 @@ type mockBackend struct {
|
|||
watch func(ctx context.Context) <-chan struct{}
|
||||
}
|
||||
|
||||
func (m *mockBackend) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *mockBackend) Put(ctx context.Context, id string, data *anypb.Any) error {
|
||||
return m.put(ctx, id, data)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue