diff --git a/config/constants.go b/config/constants.go index fb2793c52..17196cb3b 100644 --- a/config/constants.go +++ b/config/constants.go @@ -40,5 +40,4 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc( decodePPLPolicyHookFunc(), decodeSANMatcherHookFunc(), decodeStringToMapHookFunc(), - decodeRuntimeFlagsHookFunc(), )) diff --git a/config/custom.go b/config/custom.go index 4cefc6038..15124e34a 100644 --- a/config/custom.go +++ b/config/custom.go @@ -549,46 +549,6 @@ var ( ErrRuntimeFlagInvalidMapValue = fmt.Errorf("decoding runtime flags: unknown flag value (expected bool)") ) -// only for RuntimeFlags decoding -func decodeRuntimeFlagsHookFunc() mapstructure.DecodeHookFunc { - return mapstructure.DecodeHookFuncType(func(from, to reflect.Type, data any) (any, error) { - if to != reflect.TypeOf(RuntimeFlags{}) { - return data, nil - } - if from.Kind() != reflect.Map { - return nil, fmt.Errorf("%w: got %s", ErrRuntimeFlagsInvalidValue, from.Kind()) - } - - src, ok := data.(map[string]interface{}) - if !ok { - return nil, fmt.Errorf("decoding runtime flags: expected map[string]interface{}, got %T", from.String()) - } - - dst := make(RuntimeFlags) - // copy default flags first - // this is to ensure we don't lose any default flags and only override the ones that are provided - for k, v := range defaultRuntimeFlags { - dst[k] = v - } - - for k, v := range src { - key := RuntimeFlag(k) - if _, ok := dst[key]; !ok { - return nil, fmt.Errorf("%s: %w", k, ErrRuntimeFlagUnknown) - } - - b, ok := v.(bool) - if !ok { - return nil, fmt.Errorf("%w: got %T", ErrRuntimeFlagInvalidMapValue, v) - } - - dst[key] = b - } - - return dst, nil - }) -} - // serializable converts mapstructure nested map into map[string]interface{} that is serializable to JSON func serializable(in interface{}) (interface{}, error) { switch typed := in.(type) { diff --git a/config/custom_test.go b/config/custom_test.go index 74d0d8a05..e556c0d65 100644 --- a/config/custom_test.go +++ b/config/custom_test.go @@ -3,7 +3,6 @@ package config import ( "encoding/base64" "encoding/json" - "strings" "testing" "github.com/mitchellh/mapstructure" @@ -184,102 +183,3 @@ func TestDecodePPLPolicyHookFunc(t *testing.T) { }, }, withPolicy.Policy) } - -func TestDecodeRuntimeFlagsHookFunc(t *testing.T) { - t.Parallel() - - t.Run("valid", func(t *testing.T) { - t.Parallel() - - defaults := DefaultRuntimeFlags() - withMap := struct { - SomethingElse map[string]bool `mapstructure:"something_else"` - RuntimeFlags RuntimeFlags `mapstructure:"runtime_flags"` - }{ - RuntimeFlags: defaults, - } - - decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - DecodeHook: decodeRuntimeFlagsHookFunc(), - Result: &withMap, - }) - require.NoError(t, err) - - expect := DefaultRuntimeFlags() - expect[GRPCDatabrokerKeepalive] = !expect[GRPCDatabrokerKeepalive] - - err = decoder.Decode(map[string]interface{}{ - "something_else": map[string]bool{ - "hello": true, - }, - "runtime_flags": map[string]interface{}{ - string(GRPCDatabrokerKeepalive): expect[GRPCDatabrokerKeepalive], - }, - }) - assert.NoError(t, err) - assert.Equal(t, expect, withMap.RuntimeFlags) - }) - - t.Run("dont override if unset", func(t *testing.T) { - t.Parallel() - - defaults := DefaultRuntimeFlags() - withMap := struct { - RuntimeFlags RuntimeFlags `mapstructure:"runtime_flags"` - }{ - RuntimeFlags: defaults, - } - - decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - DecodeHook: decodeRuntimeFlagsHookFunc(), - Result: &withMap, - }) - require.NoError(t, err) - - err = decoder.Decode(map[string]interface{}{}) - assert.NoError(t, err) - assert.Equal(t, defaults, withMap.RuntimeFlags) - }) - - // mapstructure does not correctly wrap errors, so we will have to just search for the error text - // https://github.com/mitchellh/mapstructure/blob/ab69d8d93410fce4361f4912bb1ff88110a81311/error.go#L35-L38 - assertErrorIs := func(t *testing.T, err error, target error) { - t.Helper() - if assert.Error(t, err) { - strings.Contains(err.Error(), target.Error()) - } - } - - t.Run("invalid input", func(t *testing.T) { - t.Parallel() - - var withMap struct { - RuntimeFlags RuntimeFlags `mapstructure:"runtime_flags"` - } - - decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ - DecodeHook: decodeRuntimeFlagsHookFunc(), - Result: &withMap, - }) - require.NoError(t, err) - - err = decoder.Decode(map[string]interface{}{ - "runtime_flags": "hello world", - }) - assertErrorIs(t, err, ErrRuntimeFlagsInvalidValue) - - err = decoder.Decode(map[string]interface{}{ - "runtime_flags": map[string]interface{}{ - "no_such_flag": true, - }, - }) - assertErrorIs(t, err, ErrRuntimeFlagUnknown) - - err = decoder.Decode(map[string]interface{}{ - "runtime_flags": map[string]interface{}{ - string(GRPCDatabrokerKeepalive): "hello world", - }, - }) - assertErrorIs(t, err, ErrRuntimeFlagInvalidMapValue) - }) -} diff --git a/config/envoyconfig/clusters.go b/config/envoyconfig/clusters.go index 799760985..d6b0fb02d 100644 --- a/config/envoyconfig/clusters.go +++ b/config/envoyconfig/clusters.go @@ -79,7 +79,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env authorizeCluster.OutlierDetection = grpcOutlierDetection() } - databrokerKeepalive := Keepalive(cfg.Options.IsRuntimeFlagSet(config.GRPCDatabrokerKeepalive)) + databrokerKeepalive := Keepalive(cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagGRPCDatabrokerKeepalive)) databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2, databrokerKeepalive) if err != nil { return nil, err diff --git a/config/options.go b/config/options.go index c2afaf1dd..beddece94 100644 --- a/config/options.go +++ b/config/options.go @@ -389,7 +389,7 @@ func optionsFromViper(configFile string) (*Options, error) { if err := v.Unmarshal(o, ViperPolicyHooks, func(c *mapstructure.DecoderConfig) { c.Metadata = &metadata }); err != nil { return nil, fmt.Errorf("failed to unmarshal config: %w", err) } - if err := checkConfigKeysErrors(configFile, metadata.Unused); err != nil { + if err := checkConfigKeysErrors(configFile, o, metadata.Unused); err != nil { return nil, err } @@ -402,7 +402,7 @@ func optionsFromViper(configFile string) (*Options, error) { return o, nil } -func checkConfigKeysErrors(configFile string, unused []string) error { +func checkConfigKeysErrors(configFile string, o *Options, unused []string) error { checks := CheckUnknownConfigFields(unused) ctx := context.Background() errInvalidConfigKeys := errors.New("some configuration options are no longer supported, please check logs for details") @@ -423,6 +423,14 @@ func checkConfigKeysErrors(configFile string, unused []string) error { } evt.Msg(string(check.FieldCheckMsg)) } + + // check for unknown runtime flags + for flag := range o.RuntimeFlags { + if _, ok := defaultRuntimeFlags[flag]; !ok { + log.Warn(ctx).Str("config_file", configFile).Str("flag", string(flag)).Msg("unknown runtime flag") + } + } + return err } diff --git a/config/runtime_flags.go b/config/runtime_flags.go index 9b1b8f302..fef2ee66b 100644 --- a/config/runtime_flags.go +++ b/config/runtime_flags.go @@ -1,7 +1,9 @@ package config -// GRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service -var GRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false) +import "golang.org/x/exp/maps" + +// RuntimeFlagGRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service +var RuntimeFlagGRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false) // RuntimeFlag is a runtime flag that can flip on/off certain features type RuntimeFlag string @@ -18,9 +20,5 @@ func runtimeFlag(txt string, def bool) RuntimeFlag { var defaultRuntimeFlags = map[RuntimeFlag]bool{} func DefaultRuntimeFlags() RuntimeFlags { - out := make(RuntimeFlags, len(defaultRuntimeFlags)) - for k, v := range defaultRuntimeFlags { - out[k] = v - } - return out + return maps.Clone(defaultRuntimeFlags) }