pomerium/internal/databroker/config_source.go

251 lines
6 KiB
Go

package databroker
import (
"context"
"errors"
"sync"
"time"
"github.com/cenkalti/backoff/v4"
"github.com/golang/protobuf/ptypes"
"github.com/mitchellh/hashstructure"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/grpc"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/grpc/databroker"
)
var (
configTypeURL string
)
func init() {
any, _ := ptypes.MarshalAny(new(configpb.Config))
configTypeURL = any.GetTypeUrl()
}
// ConfigSource provides a new Config source that decorates an underlying config with
// configuration derived from the data broker.
type ConfigSource struct {
mu sync.RWMutex
computedConfig *config.Config
underlyingConfig *config.Config
dbConfigs map[string]*configpb.Config
updaterHash uint64
cancel func()
serverVersion string
recordVersion string
config.ChangeDispatcher
}
// NewConfigSource creates a new ConfigSource.
func NewConfigSource(underlying config.Source, listeners ...config.ChangeListener) *ConfigSource {
src := &ConfigSource{
dbConfigs: map[string]*configpb.Config{},
}
for _, li := range listeners {
src.OnConfigChange(li)
}
underlying.OnConfigChange(func(cfg *config.Config) {
src.mu.Lock()
src.underlyingConfig = cfg.Clone()
src.mu.Unlock()
src.rebuild(false)
})
src.underlyingConfig = underlying.GetConfig()
src.rebuild(true)
return src
}
// GetConfig gets the current config.
func (src *ConfigSource) GetConfig() *config.Config {
src.mu.RLock()
defer src.mu.RUnlock()
return src.computedConfig
}
func (src *ConfigSource) rebuild(firstTime bool) {
_, span := trace.StartSpan(context.Background(), "databroker.config_source.rebuild")
defer span.End()
src.mu.Lock()
defer src.mu.Unlock()
cfg := src.underlyingConfig.Clone()
// start the updater
src.runUpdater(cfg)
seen := map[uint64]struct{}{}
for _, policy := range cfg.Options.Policies {
seen[policy.RouteID()] = struct{}{}
}
var additionalPolicies []config.Policy
// add all the config policies to the list
for _, cfgpb := range src.dbConfigs {
cfg.Options.ApplySettings(cfgpb.Settings)
err := cfg.Options.Validate()
if err != nil {
log.Warn().Err(err).Msg("databroker: invalid config detected, ignoring")
return
}
for _, routepb := range cfgpb.GetRoutes() {
policy, err := config.NewPolicyFromProto(routepb)
if err != nil {
log.Warn().Err(err).Msg("databroker: error converting protobuf into policy")
continue
}
err = policy.Validate()
if err != nil {
log.Warn().Err(err).
Str("policy", policy.String()).
Msg("databroker: invalid policy, ignoring")
continue
}
routeID := policy.RouteID()
if _, ok := seen[routeID]; ok {
log.Warn().Err(err).
Str("policy", policy.String()).
Msg("databroker: duplicate policy detected, ignoring")
continue
}
seen[routeID] = struct{}{}
additionalPolicies = append(additionalPolicies, *policy)
}
}
// add the additional policies here since calling `Validate` will reset them.
cfg.Options.Policies = append(cfg.Options.Policies, additionalPolicies...)
src.computedConfig = cfg
if !firstTime {
src.Trigger(cfg)
}
}
func (src *ConfigSource) runUpdater(cfg *config.Config) {
connectionOptions := &grpc.Options{
Addr: cfg.Options.DataBrokerURL,
OverrideCertificateName: cfg.Options.OverrideCertificateName,
CA: cfg.Options.CA,
CAFile: cfg.Options.CAFile,
RequestTimeout: cfg.Options.GRPCClientTimeout,
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GRPCInsecure,
ServiceName: cfg.Options.Services,
}
h, err := hashstructure.Hash(connectionOptions, nil)
if err != nil {
log.Fatal().Err(err).Send()
}
// nothing changed, so don't restart the updater
if src.updaterHash == h {
return
}
src.updaterHash = h
if src.cancel != nil {
src.cancel()
src.cancel = nil
}
cc, err := grpc.NewGRPCClientConn(connectionOptions)
if err != nil {
log.Error().Err(err).Msg("databroker: failed to create gRPC connection to data broker")
return
}
client := databroker.NewDataBrokerServiceClient(cc)
ctx := context.Background()
ctx, src.cancel = context.WithCancel(ctx)
go tryForever(ctx, func(onSuccess func()) error {
src.mu.Lock()
serverVersion, recordVersion := src.serverVersion, src.recordVersion
src.mu.Unlock()
stream, err := client.Sync(ctx, &databroker.SyncRequest{
Type: configTypeURL,
ServerVersion: serverVersion,
RecordVersion: recordVersion,
})
if err != nil {
return err
}
for {
res, err := stream.Recv()
if err != nil {
return err
}
onSuccess()
if len(res.GetRecords()) > 0 {
src.onSync(res.GetRecords())
for _, record := range res.GetRecords() {
recordVersion = record.GetVersion()
}
}
src.mu.Lock()
src.serverVersion, src.recordVersion = res.GetServerVersion(), recordVersion
src.mu.Unlock()
}
})
}
func (src *ConfigSource) onSync(records []*databroker.Record) {
src.mu.Lock()
for _, record := range records {
if record.GetDeletedAt() != nil {
delete(src.dbConfigs, record.GetId())
continue
}
var cfgpb configpb.Config
err := ptypes.UnmarshalAny(record.GetData(), &cfgpb)
if err != nil {
log.Warn().Err(err).Msg("databroker: error decoding config")
delete(src.dbConfigs, record.GetId())
continue
}
src.dbConfigs[record.GetId()] = &cfgpb
}
src.mu.Unlock()
src.rebuild(false)
}
func tryForever(ctx context.Context, callback func(onSuccess func()) error) {
bo := backoff.NewExponentialBackOff()
bo.MaxElapsedTime = 0
for {
err := callback(bo.Reset)
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return
} else if err != nil {
log.Warn().Err(err).Msg("sync error")
}
select {
case <-ctx.Done():
return
case <-time.After(bo.NextBackOff()):
}
}
}