mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 02:46:30 +02:00
This also replaces instances where we manually write "return ctx.Err()" with "return context.Cause(ctx)" which is functionally identical, but will also correctly propagate cause errors if present.
324 lines
8.7 KiB
Go
324 lines
8.7 KiB
Go
package databroker
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"maps"
|
|
"slices"
|
|
"sync"
|
|
"time"
|
|
|
|
"golang.org/x/sync/errgroup"
|
|
|
|
"github.com/pomerium/pomerium/config"
|
|
"github.com/pomerium/pomerium/internal/errgrouputil"
|
|
"github.com/pomerium/pomerium/internal/hashutil"
|
|
"github.com/pomerium/pomerium/internal/log"
|
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
"github.com/pomerium/pomerium/pkg/grpc"
|
|
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
|
"github.com/pomerium/pomerium/pkg/grpc/databroker"
|
|
"github.com/pomerium/pomerium/pkg/grpcutil"
|
|
"github.com/pomerium/pomerium/pkg/health"
|
|
)
|
|
|
|
// ConfigSource provides a new Config source that decorates an underlying config with
|
|
// configuration derived from the data broker.
|
|
type ConfigSource struct {
|
|
mu sync.RWMutex
|
|
outboundGRPCConnection *grpc.CachedOutboundGRPClientConn
|
|
computedConfig *config.Config
|
|
underlyingConfig *config.Config
|
|
dbConfigs map[string]dbConfig
|
|
updaterHash uint64
|
|
cancel func()
|
|
enableValidation bool
|
|
|
|
config.ChangeDispatcher
|
|
}
|
|
|
|
type dbConfig struct {
|
|
*configpb.Config
|
|
version uint64
|
|
}
|
|
|
|
// EnableConfigValidation is a type that can be used to enable config validation.
|
|
type EnableConfigValidation bool
|
|
|
|
// NewConfigSource creates a new ConfigSource.
|
|
func NewConfigSource(
|
|
ctx context.Context,
|
|
underlying config.Source,
|
|
enableValidation EnableConfigValidation,
|
|
listeners ...config.ChangeListener,
|
|
) *ConfigSource {
|
|
src := &ConfigSource{
|
|
enableValidation: bool(enableValidation),
|
|
dbConfigs: map[string]dbConfig{},
|
|
outboundGRPCConnection: new(grpc.CachedOutboundGRPClientConn),
|
|
}
|
|
for _, li := range listeners {
|
|
src.OnConfigChange(ctx, li)
|
|
}
|
|
underlying.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) {
|
|
src.mu.Lock()
|
|
src.underlyingConfig = cfg.Clone()
|
|
src.mu.Unlock()
|
|
|
|
src.rebuild(ctx, firstTime(false))
|
|
})
|
|
src.underlyingConfig = underlying.GetConfig()
|
|
src.rebuild(ctx, firstTime(true))
|
|
return src
|
|
}
|
|
|
|
// GetConfig gets the current config.
|
|
func (src *ConfigSource) GetConfig() *config.Config {
|
|
src.mu.RLock()
|
|
defer src.mu.RUnlock()
|
|
|
|
return src.computedConfig
|
|
}
|
|
|
|
type firstTime bool
|
|
|
|
func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
|
|
_, span := trace.StartSpan(ctx, "databroker.config_source.rebuild")
|
|
defer span.End()
|
|
|
|
now := time.Now()
|
|
src.mu.Lock()
|
|
defer src.mu.Unlock()
|
|
log.Ctx(ctx).Debug().Str("lock-wait", time.Since(now).String()).Msg("databroker: rebuilding configuration")
|
|
|
|
cfg := src.underlyingConfig.Clone()
|
|
|
|
// start the updater
|
|
src.runUpdater(ctx, cfg)
|
|
|
|
now = time.Now()
|
|
err := src.buildNewConfigLocked(ctx, cfg)
|
|
if err != nil {
|
|
health.ReportError(health.BuildDatabrokerConfig, err)
|
|
log.Ctx(ctx).Error().Err(err).Msg("databroker: failed to build new config")
|
|
return
|
|
}
|
|
health.ReportOK(health.BuildDatabrokerConfig)
|
|
log.Ctx(ctx).Debug().Str("elapsed", time.Since(now).String()).Msg("databroker: built new config")
|
|
|
|
src.computedConfig = cfg
|
|
if !firstTime {
|
|
src.Trigger(ctx, cfg)
|
|
}
|
|
|
|
metrics.SetConfigInfo(ctx, cfg.Options.Services, "databroker", cfg.Checksum(), true)
|
|
}
|
|
|
|
func (src *ConfigSource) buildNewConfigLocked(ctx context.Context, cfg *config.Config) error {
|
|
eg, ctx := errgroup.WithContext(ctx)
|
|
eg.Go(func() error {
|
|
src.applySettingsLocked(ctx, cfg)
|
|
err := cfg.Options.Validate()
|
|
if err != nil {
|
|
return fmt.Errorf("validating settings: %w", err)
|
|
}
|
|
return nil
|
|
})
|
|
|
|
var policyBuilders []errgrouputil.BuilderFunc[config.Policy]
|
|
for _, cfgpb := range src.dbConfigs {
|
|
for _, routepb := range cfgpb.GetRoutes() {
|
|
routepb := routepb
|
|
policyBuilders = append(policyBuilders, func(ctx context.Context) (*config.Policy, error) {
|
|
p, err := src.buildPolicyFromProto(ctx, routepb)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error building route id=%s: %w", routepb.GetId(), err)
|
|
}
|
|
return p, nil
|
|
})
|
|
}
|
|
}
|
|
|
|
var policies []*config.Policy
|
|
eg.Go(func() error {
|
|
var errs []error
|
|
policies, errs = errgrouputil.Build(ctx, policyBuilders...)
|
|
if len(errs) > 0 {
|
|
for _, err := range errs {
|
|
log.Ctx(ctx).Error().Msg(err.Error())
|
|
}
|
|
return fmt.Errorf("error building policies")
|
|
}
|
|
return nil
|
|
})
|
|
|
|
err := eg.Wait()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
src.addPolicies(ctx, cfg, policies)
|
|
return nil
|
|
}
|
|
|
|
func (src *ConfigSource) applySettingsLocked(ctx context.Context, cfg *config.Config) {
|
|
ids := slices.Sorted(maps.Keys(src.dbConfigs))
|
|
|
|
var certsIndex *cryptutil.CertificatesIndex
|
|
if src.enableValidation {
|
|
certsIndex = cryptutil.NewCertificatesIndex()
|
|
for _, cert := range cfg.Options.GetX509Certificates() {
|
|
certsIndex.Add(cert)
|
|
}
|
|
}
|
|
|
|
for i := 0; i < len(ids) && ctx.Err() == nil; i++ {
|
|
cfgpb := src.dbConfigs[ids[i]]
|
|
cfg.Options.ApplySettings(ctx, certsIndex, cfgpb.Settings)
|
|
}
|
|
}
|
|
|
|
func (src *ConfigSource) buildPolicyFromProto(_ context.Context, routepb *configpb.Route) (*config.Policy, error) {
|
|
policy, err := config.NewPolicyFromProto(routepb)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error building policy from protobuf: %w", err)
|
|
}
|
|
|
|
if !src.enableValidation {
|
|
return policy, nil
|
|
}
|
|
|
|
err = policy.Validate()
|
|
if err != nil {
|
|
return nil, fmt.Errorf("error validating policy: %w", err)
|
|
}
|
|
|
|
return policy, nil
|
|
}
|
|
|
|
func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, policies []*config.Policy) {
|
|
seen := make(map[uint64]struct{}, len(policies)+cfg.Options.NumPolicies())
|
|
for policy := range cfg.Options.GetAllPolicies() {
|
|
id, err := policy.RouteID()
|
|
if err != nil {
|
|
log.Ctx(ctx).Err(err).Str("policy", policy.String()).Msg("databroker: error getting route id")
|
|
continue
|
|
}
|
|
seen[id] = struct{}{}
|
|
}
|
|
|
|
additionalPolicies := make([]config.Policy, 0, len(policies))
|
|
for _, policy := range policies {
|
|
if policy == nil {
|
|
continue
|
|
}
|
|
|
|
id, err := policy.RouteID()
|
|
if err != nil {
|
|
log.Ctx(ctx).Err(err).Str("policy", policy.String()).Msg("databroker: error getting route id")
|
|
continue
|
|
}
|
|
if _, ok := seen[id]; ok {
|
|
log.Ctx(ctx).Debug().Str("policy", policy.String()).Msg("databroker: policy already exists")
|
|
continue
|
|
}
|
|
additionalPolicies = append(additionalPolicies, *policy)
|
|
seen[id] = struct{}{}
|
|
}
|
|
|
|
config.SortPolicies(additionalPolicies)
|
|
|
|
// add the additional policies here since calling `Validate` will reset them.
|
|
cfg.Options.AdditionalPolicies = append(cfg.Options.AdditionalPolicies, additionalPolicies...)
|
|
}
|
|
|
|
func (src *ConfigSource) runUpdater(ctx context.Context, cfg *config.Config) {
|
|
sharedKey, _ := cfg.Options.GetSharedKey()
|
|
connectionOptions := &grpc.OutboundOptions{
|
|
OutboundPort: cfg.OutboundPort,
|
|
InstallationID: cfg.Options.InstallationID,
|
|
ServiceName: cfg.Options.Services,
|
|
SignedJWTKey: sharedKey,
|
|
}
|
|
h, err := hashutil.Hash(connectionOptions)
|
|
if err != nil {
|
|
log.Fatal().Err(err).Send()
|
|
}
|
|
// nothing changed, so don't restart the updater
|
|
if src.updaterHash == h {
|
|
return
|
|
}
|
|
src.updaterHash = h
|
|
|
|
if src.cancel != nil {
|
|
src.cancel()
|
|
src.cancel = nil
|
|
}
|
|
|
|
ctx, src.cancel = context.WithCancel(ctx)
|
|
|
|
cc, err := src.outboundGRPCConnection.Get(ctx, connectionOptions)
|
|
if err != nil {
|
|
log.Ctx(ctx).Error().Err(err).Msg("databroker: failed to create gRPC connection to data broker")
|
|
return
|
|
}
|
|
|
|
client := databroker.NewDataBrokerServiceClient(cc)
|
|
|
|
syncer := databroker.NewSyncer(ctx, "databroker", &syncerHandler{
|
|
client: client,
|
|
src: src,
|
|
}, databroker.WithTypeURL(grpcutil.GetTypeURL(new(configpb.Config))),
|
|
databroker.WithFastForward())
|
|
go func() {
|
|
log.Ctx(ctx).Debug().
|
|
Str("outbound_port", cfg.OutboundPort).
|
|
Msg("config: starting databroker config source syncer")
|
|
_ = grpc.WaitForReady(ctx, cc, time.Second*10)
|
|
_ = syncer.Run(ctx)
|
|
}()
|
|
}
|
|
|
|
type syncerHandler struct {
|
|
src *ConfigSource
|
|
client databroker.DataBrokerServiceClient
|
|
}
|
|
|
|
func (s *syncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
|
|
return s.client
|
|
}
|
|
|
|
func (s *syncerHandler) ClearRecords(_ context.Context) {
|
|
s.src.mu.Lock()
|
|
s.src.dbConfigs = map[string]dbConfig{}
|
|
s.src.mu.Unlock()
|
|
}
|
|
|
|
func (s *syncerHandler) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) {
|
|
if len(records) == 0 {
|
|
return
|
|
}
|
|
|
|
s.src.mu.Lock()
|
|
for _, record := range records {
|
|
if record.GetDeletedAt() != nil {
|
|
delete(s.src.dbConfigs, record.GetId())
|
|
continue
|
|
}
|
|
|
|
var cfgpb configpb.Config
|
|
err := record.GetData().UnmarshalTo(&cfgpb)
|
|
if err != nil {
|
|
log.Ctx(ctx).Error().Err(err).Msg("databroker: error decoding config")
|
|
delete(s.src.dbConfigs, record.GetId())
|
|
continue
|
|
}
|
|
|
|
s.src.dbConfigs[record.GetId()] = dbConfig{&cfgpb, record.Version}
|
|
}
|
|
s.src.mu.Unlock()
|
|
|
|
s.src.rebuild(ctx, firstTime(false))
|
|
}
|