1
0
Fork 0
mirror of https://github.com/pomerium/pomerium.git synced 2025-07-09 21:08:11 +02:00
pomerium/internal/databroker/config_source.go
Joe Kralicky fe31799eb5
Fix many instances of contexts and loggers not being propagated ()
This also replaces instances where we manually write "return ctx.Err()"
with "return context.Cause(ctx)" which is functionally identical, but
will also correctly propagate cause errors if present.
2024-10-25 14:50:56 -04:00

324 lines
8.7 KiB
Go

package databroker
import (
"context"
"fmt"
"maps"
"slices"
"sync"
"time"
"golang.org/x/sync/errgroup"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/errgrouputil"
"github.com/pomerium/pomerium/internal/hashutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
"github.com/pomerium/pomerium/pkg/grpcutil"
"github.com/pomerium/pomerium/pkg/health"
)
// ConfigSource provides a new Config source that decorates an underlying config with
// configuration derived from the data broker.
type ConfigSource struct {
mu sync.RWMutex
outboundGRPCConnection *grpc.CachedOutboundGRPClientConn
computedConfig *config.Config
underlyingConfig *config.Config
dbConfigs map[string]dbConfig
updaterHash uint64
cancel func()
enableValidation bool
config.ChangeDispatcher
}
type dbConfig struct {
*configpb.Config
version uint64
}
// EnableConfigValidation is a type that can be used to enable config validation.
type EnableConfigValidation bool
// NewConfigSource creates a new ConfigSource.
func NewConfigSource(
ctx context.Context,
underlying config.Source,
enableValidation EnableConfigValidation,
listeners ...config.ChangeListener,
) *ConfigSource {
src := &ConfigSource{
enableValidation: bool(enableValidation),
dbConfigs: map[string]dbConfig{},
outboundGRPCConnection: new(grpc.CachedOutboundGRPClientConn),
}
for _, li := range listeners {
src.OnConfigChange(ctx, li)
}
underlying.OnConfigChange(ctx, func(ctx context.Context, cfg *config.Config) {
src.mu.Lock()
src.underlyingConfig = cfg.Clone()
src.mu.Unlock()
src.rebuild(ctx, firstTime(false))
})
src.underlyingConfig = underlying.GetConfig()
src.rebuild(ctx, firstTime(true))
return src
}
// GetConfig gets the current config.
func (src *ConfigSource) GetConfig() *config.Config {
src.mu.RLock()
defer src.mu.RUnlock()
return src.computedConfig
}
type firstTime bool
func (src *ConfigSource) rebuild(ctx context.Context, firstTime firstTime) {
_, span := trace.StartSpan(ctx, "databroker.config_source.rebuild")
defer span.End()
now := time.Now()
src.mu.Lock()
defer src.mu.Unlock()
log.Ctx(ctx).Debug().Str("lock-wait", time.Since(now).String()).Msg("databroker: rebuilding configuration")
cfg := src.underlyingConfig.Clone()
// start the updater
src.runUpdater(ctx, cfg)
now = time.Now()
err := src.buildNewConfigLocked(ctx, cfg)
if err != nil {
health.ReportError(health.BuildDatabrokerConfig, err)
log.Ctx(ctx).Error().Err(err).Msg("databroker: failed to build new config")
return
}
health.ReportOK(health.BuildDatabrokerConfig)
log.Ctx(ctx).Debug().Str("elapsed", time.Since(now).String()).Msg("databroker: built new config")
src.computedConfig = cfg
if !firstTime {
src.Trigger(ctx, cfg)
}
metrics.SetConfigInfo(ctx, cfg.Options.Services, "databroker", cfg.Checksum(), true)
}
func (src *ConfigSource) buildNewConfigLocked(ctx context.Context, cfg *config.Config) error {
eg, ctx := errgroup.WithContext(ctx)
eg.Go(func() error {
src.applySettingsLocked(ctx, cfg)
err := cfg.Options.Validate()
if err != nil {
return fmt.Errorf("validating settings: %w", err)
}
return nil
})
var policyBuilders []errgrouputil.BuilderFunc[config.Policy]
for _, cfgpb := range src.dbConfigs {
for _, routepb := range cfgpb.GetRoutes() {
routepb := routepb
policyBuilders = append(policyBuilders, func(ctx context.Context) (*config.Policy, error) {
p, err := src.buildPolicyFromProto(ctx, routepb)
if err != nil {
return nil, fmt.Errorf("error building route id=%s: %w", routepb.GetId(), err)
}
return p, nil
})
}
}
var policies []*config.Policy
eg.Go(func() error {
var errs []error
policies, errs = errgrouputil.Build(ctx, policyBuilders...)
if len(errs) > 0 {
for _, err := range errs {
log.Ctx(ctx).Error().Msg(err.Error())
}
return fmt.Errorf("error building policies")
}
return nil
})
err := eg.Wait()
if err != nil {
return err
}
src.addPolicies(ctx, cfg, policies)
return nil
}
func (src *ConfigSource) applySettingsLocked(ctx context.Context, cfg *config.Config) {
ids := slices.Sorted(maps.Keys(src.dbConfigs))
var certsIndex *cryptutil.CertificatesIndex
if src.enableValidation {
certsIndex = cryptutil.NewCertificatesIndex()
for _, cert := range cfg.Options.GetX509Certificates() {
certsIndex.Add(cert)
}
}
for i := 0; i < len(ids) && ctx.Err() == nil; i++ {
cfgpb := src.dbConfigs[ids[i]]
cfg.Options.ApplySettings(ctx, certsIndex, cfgpb.Settings)
}
}
func (src *ConfigSource) buildPolicyFromProto(_ context.Context, routepb *configpb.Route) (*config.Policy, error) {
policy, err := config.NewPolicyFromProto(routepb)
if err != nil {
return nil, fmt.Errorf("error building policy from protobuf: %w", err)
}
if !src.enableValidation {
return policy, nil
}
err = policy.Validate()
if err != nil {
return nil, fmt.Errorf("error validating policy: %w", err)
}
return policy, nil
}
func (src *ConfigSource) addPolicies(ctx context.Context, cfg *config.Config, policies []*config.Policy) {
seen := make(map[uint64]struct{}, len(policies)+cfg.Options.NumPolicies())
for policy := range cfg.Options.GetAllPolicies() {
id, err := policy.RouteID()
if err != nil {
log.Ctx(ctx).Err(err).Str("policy", policy.String()).Msg("databroker: error getting route id")
continue
}
seen[id] = struct{}{}
}
additionalPolicies := make([]config.Policy, 0, len(policies))
for _, policy := range policies {
if policy == nil {
continue
}
id, err := policy.RouteID()
if err != nil {
log.Ctx(ctx).Err(err).Str("policy", policy.String()).Msg("databroker: error getting route id")
continue
}
if _, ok := seen[id]; ok {
log.Ctx(ctx).Debug().Str("policy", policy.String()).Msg("databroker: policy already exists")
continue
}
additionalPolicies = append(additionalPolicies, *policy)
seen[id] = struct{}{}
}
config.SortPolicies(additionalPolicies)
// add the additional policies here since calling `Validate` will reset them.
cfg.Options.AdditionalPolicies = append(cfg.Options.AdditionalPolicies, additionalPolicies...)
}
func (src *ConfigSource) runUpdater(ctx context.Context, cfg *config.Config) {
sharedKey, _ := cfg.Options.GetSharedKey()
connectionOptions := &grpc.OutboundOptions{
OutboundPort: cfg.OutboundPort,
InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey,
}
h, err := hashutil.Hash(connectionOptions)
if err != nil {
log.Fatal().Err(err).Send()
}
// nothing changed, so don't restart the updater
if src.updaterHash == h {
return
}
src.updaterHash = h
if src.cancel != nil {
src.cancel()
src.cancel = nil
}
ctx, src.cancel = context.WithCancel(ctx)
cc, err := src.outboundGRPCConnection.Get(ctx, connectionOptions)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("databroker: failed to create gRPC connection to data broker")
return
}
client := databroker.NewDataBrokerServiceClient(cc)
syncer := databroker.NewSyncer(ctx, "databroker", &syncerHandler{
client: client,
src: src,
}, databroker.WithTypeURL(grpcutil.GetTypeURL(new(configpb.Config))),
databroker.WithFastForward())
go func() {
log.Ctx(ctx).Debug().
Str("outbound_port", cfg.OutboundPort).
Msg("config: starting databroker config source syncer")
_ = grpc.WaitForReady(ctx, cc, time.Second*10)
_ = syncer.Run(ctx)
}()
}
type syncerHandler struct {
src *ConfigSource
client databroker.DataBrokerServiceClient
}
func (s *syncerHandler) GetDataBrokerServiceClient() databroker.DataBrokerServiceClient {
return s.client
}
func (s *syncerHandler) ClearRecords(_ context.Context) {
s.src.mu.Lock()
s.src.dbConfigs = map[string]dbConfig{}
s.src.mu.Unlock()
}
func (s *syncerHandler) UpdateRecords(ctx context.Context, _ uint64, records []*databroker.Record) {
if len(records) == 0 {
return
}
s.src.mu.Lock()
for _, record := range records {
if record.GetDeletedAt() != nil {
delete(s.src.dbConfigs, record.GetId())
continue
}
var cfgpb configpb.Config
err := record.GetData().UnmarshalTo(&cfgpb)
if err != nil {
log.Ctx(ctx).Error().Err(err).Msg("databroker: error decoding config")
delete(s.src.dbConfigs, record.GetId())
continue
}
s.src.dbConfigs[record.GetId()] = dbConfig{&cfgpb, record.Version}
}
s.src.mu.Unlock()
s.src.rebuild(ctx, firstTime(false))
}