package controller import ( "context" "encoding/base64" "errors" "fmt" "net" "net/url" "sync" "google.golang.org/grpc" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/retry" "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpcutil" ) // ErrBootstrapConfigurationChanged is returned when the bootstrap configuration has changed and the function needs to be restarted. var ErrBootstrapConfigurationChanged = errors.New("bootstrap configuration changed") type DatabrokerRestartRunner struct { lock sync.RWMutex cancel chan struct{} conn *grpc.ClientConn client databroker.DataBrokerServiceClient initError error } // NewDatabrokerRestartRunner is a helper to run a function that needs to be restarted when the underlying databroker configuration changes. func NewDatabrokerRestartRunner( ctx context.Context, src config.Source, ) *DatabrokerRestartRunner { p := new(DatabrokerRestartRunner) p.initLocked(ctx, src.GetConfig()) src.OnConfigChange(ctx, p.onConfigChange) return p } func (p *DatabrokerRestartRunner) Run( ctx context.Context, fn func(context.Context, databroker.DataBrokerServiceClient) error, ) error { return retry.WithBackoff(ctx, "databroker-restart", func(ctx context.Context) error { return p.runUntilDatabrokerChanges(ctx, fn) }) } // Close releases the resources used by the databroker provider. func (p *DatabrokerRestartRunner) Close() { p.lock.Lock() defer p.lock.Unlock() p.closeLocked() } func (p *DatabrokerRestartRunner) GetDatabrokerClient() (databroker.DataBrokerServiceClient, error) { client, _, err := p.getDatabrokerClient() return client, err } // GetDatabrokerClient returns the databroker client and a channel that will be closed when the client is no longer valid. func (p *DatabrokerRestartRunner) getDatabrokerClient() (databroker.DataBrokerServiceClient, <-chan struct{}, error) { p.lock.RLock() defer p.lock.RUnlock() if p.initError != nil { return nil, nil, p.initError } return p.client, p.cancel, nil } func (p *DatabrokerRestartRunner) onConfigChange(ctx context.Context, cfg *config.Config) { p.lock.Lock() defer p.lock.Unlock() p.closeLocked() p.initLocked(ctx, cfg) } func (p *DatabrokerRestartRunner) initLocked(ctx context.Context, cfg *config.Config) { conn, err := newDataBrokerConnection(ctx, cfg) if err != nil { p.initError = fmt.Errorf("databroker connection: %w", err) return } p.conn = conn p.client = databroker.NewDataBrokerServiceClient(conn) p.cancel = make(chan struct{}) p.initError = nil } func (p *DatabrokerRestartRunner) closeLocked() { if p.conn != nil { p.conn.Close() p.conn = nil } if p.cancel != nil { close(p.cancel) p.cancel = nil } p.initError = errors.New("databroker connection closed") } func (p *DatabrokerRestartRunner) runUntilDatabrokerChanges( ctx context.Context, fn func(context.Context, databroker.DataBrokerServiceClient) error, ) error { client, cancelCh, err := p.getDatabrokerClient() if err != nil { return fmt.Errorf("get databroker client: %w", err) } ctx, cancel := context.WithCancelCause(ctx) defer cancel(context.Canceled) go func() { select { case <-ctx.Done(): case <-cancelCh: cancel(ErrBootstrapConfigurationChanged) } }() return fn(ctx, client) } func newDataBrokerConnection(ctx context.Context, cfg *config.Config) (*grpc.ClientConn, error) { sharedSecret, err := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) if err != nil { return nil, fmt.Errorf("decode shared_secret: %w", err) } if len(sharedSecret) != 32 { return nil, fmt.Errorf("shared_secret: expected 32 bytes, got %d", len(sharedSecret)) } return grpcutil.NewGRPCClientConn(ctx, &grpcutil.Options{ Address: &url.URL{ Scheme: "http", Host: net.JoinHostPort("localhost", cfg.GRPCPort), }, ServiceName: "databroker", SignedJWTKey: sharedSecret, }) }