address code review

This commit is contained in:
Denis Mishin 2024-04-02 17:12:19 -04:00
parent be9c14fc04
commit 0429a6ceb2
6 changed files with 16 additions and 151 deletions

View file

@ -40,5 +40,4 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
decodePPLPolicyHookFunc(),
decodeSANMatcherHookFunc(),
decodeStringToMapHookFunc(),
decodeRuntimeFlagsHookFunc(),
))

View file

@ -549,46 +549,6 @@ var (
ErrRuntimeFlagInvalidMapValue = fmt.Errorf("decoding runtime flags: unknown flag value (expected bool)")
)
// only for RuntimeFlags decoding
func decodeRuntimeFlagsHookFunc() mapstructure.DecodeHookFunc {
return mapstructure.DecodeHookFuncType(func(from, to reflect.Type, data any) (any, error) {
if to != reflect.TypeOf(RuntimeFlags{}) {
return data, nil
}
if from.Kind() != reflect.Map {
return nil, fmt.Errorf("%w: got %s", ErrRuntimeFlagsInvalidValue, from.Kind())
}
src, ok := data.(map[string]interface{})
if !ok {
return nil, fmt.Errorf("decoding runtime flags: expected map[string]interface{}, got %T", from.String())
}
dst := make(RuntimeFlags)
// copy default flags first
// this is to ensure we don't lose any default flags and only override the ones that are provided
for k, v := range defaultRuntimeFlags {
dst[k] = v
}
for k, v := range src {
key := RuntimeFlag(k)
if _, ok := dst[key]; !ok {
return nil, fmt.Errorf("%s: %w", k, ErrRuntimeFlagUnknown)
}
b, ok := v.(bool)
if !ok {
return nil, fmt.Errorf("%w: got %T", ErrRuntimeFlagInvalidMapValue, v)
}
dst[key] = b
}
return dst, nil
})
}
// serializable converts mapstructure nested map into map[string]interface{} that is serializable to JSON
func serializable(in interface{}) (interface{}, error) {
switch typed := in.(type) {

View file

@ -3,7 +3,6 @@ package config
import (
"encoding/base64"
"encoding/json"
"strings"
"testing"
"github.com/mitchellh/mapstructure"
@ -184,102 +183,3 @@ func TestDecodePPLPolicyHookFunc(t *testing.T) {
},
}, withPolicy.Policy)
}
func TestDecodeRuntimeFlagsHookFunc(t *testing.T) {
t.Parallel()
t.Run("valid", func(t *testing.T) {
t.Parallel()
defaults := DefaultRuntimeFlags()
withMap := struct {
SomethingElse map[string]bool `mapstructure:"something_else"`
RuntimeFlags RuntimeFlags `mapstructure:"runtime_flags"`
}{
RuntimeFlags: defaults,
}
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: decodeRuntimeFlagsHookFunc(),
Result: &withMap,
})
require.NoError(t, err)
expect := DefaultRuntimeFlags()
expect[GRPCDatabrokerKeepalive] = !expect[GRPCDatabrokerKeepalive]
err = decoder.Decode(map[string]interface{}{
"something_else": map[string]bool{
"hello": true,
},
"runtime_flags": map[string]interface{}{
string(GRPCDatabrokerKeepalive): expect[GRPCDatabrokerKeepalive],
},
})
assert.NoError(t, err)
assert.Equal(t, expect, withMap.RuntimeFlags)
})
t.Run("dont override if unset", func(t *testing.T) {
t.Parallel()
defaults := DefaultRuntimeFlags()
withMap := struct {
RuntimeFlags RuntimeFlags `mapstructure:"runtime_flags"`
}{
RuntimeFlags: defaults,
}
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: decodeRuntimeFlagsHookFunc(),
Result: &withMap,
})
require.NoError(t, err)
err = decoder.Decode(map[string]interface{}{})
assert.NoError(t, err)
assert.Equal(t, defaults, withMap.RuntimeFlags)
})
// mapstructure does not correctly wrap errors, so we will have to just search for the error text
// https://github.com/mitchellh/mapstructure/blob/ab69d8d93410fce4361f4912bb1ff88110a81311/error.go#L35-L38
assertErrorIs := func(t *testing.T, err error, target error) {
t.Helper()
if assert.Error(t, err) {
strings.Contains(err.Error(), target.Error())
}
}
t.Run("invalid input", func(t *testing.T) {
t.Parallel()
var withMap struct {
RuntimeFlags RuntimeFlags `mapstructure:"runtime_flags"`
}
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: decodeRuntimeFlagsHookFunc(),
Result: &withMap,
})
require.NoError(t, err)
err = decoder.Decode(map[string]interface{}{
"runtime_flags": "hello world",
})
assertErrorIs(t, err, ErrRuntimeFlagsInvalidValue)
err = decoder.Decode(map[string]interface{}{
"runtime_flags": map[string]interface{}{
"no_such_flag": true,
},
})
assertErrorIs(t, err, ErrRuntimeFlagUnknown)
err = decoder.Decode(map[string]interface{}{
"runtime_flags": map[string]interface{}{
string(GRPCDatabrokerKeepalive): "hello world",
},
})
assertErrorIs(t, err, ErrRuntimeFlagInvalidMapValue)
})
}

View file

@ -79,7 +79,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
authorizeCluster.OutlierDetection = grpcOutlierDetection()
}
databrokerKeepalive := Keepalive(cfg.Options.IsRuntimeFlagSet(config.GRPCDatabrokerKeepalive))
databrokerKeepalive := Keepalive(cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagGRPCDatabrokerKeepalive))
databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2, databrokerKeepalive)
if err != nil {
return nil, err

View file

@ -389,7 +389,7 @@ func optionsFromViper(configFile string) (*Options, error) {
if err := v.Unmarshal(o, ViperPolicyHooks, func(c *mapstructure.DecoderConfig) { c.Metadata = &metadata }); err != nil {
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
}
if err := checkConfigKeysErrors(configFile, metadata.Unused); err != nil {
if err := checkConfigKeysErrors(configFile, o, metadata.Unused); err != nil {
return nil, err
}
@ -402,7 +402,7 @@ func optionsFromViper(configFile string) (*Options, error) {
return o, nil
}
func checkConfigKeysErrors(configFile string, unused []string) error {
func checkConfigKeysErrors(configFile string, o *Options, unused []string) error {
checks := CheckUnknownConfigFields(unused)
ctx := context.Background()
errInvalidConfigKeys := errors.New("some configuration options are no longer supported, please check logs for details")
@ -423,6 +423,14 @@ func checkConfigKeysErrors(configFile string, unused []string) error {
}
evt.Msg(string(check.FieldCheckMsg))
}
// check for unknown runtime flags
for flag := range o.RuntimeFlags {
if _, ok := defaultRuntimeFlags[flag]; !ok {
log.Warn(ctx).Str("config_file", configFile).Str("flag", string(flag)).Msg("unknown runtime flag")
}
}
return err
}

View file

@ -1,7 +1,9 @@
package config
// GRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service
var GRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false)
import "golang.org/x/exp/maps"
// RuntimeFlagGRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service
var RuntimeFlagGRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false)
// RuntimeFlag is a runtime flag that can flip on/off certain features
type RuntimeFlag string
@ -18,9 +20,5 @@ func runtimeFlag(txt string, def bool) RuntimeFlag {
var defaultRuntimeFlags = map[RuntimeFlag]bool{}
func DefaultRuntimeFlags() RuntimeFlags {
out := make(RuntimeFlags, len(defaultRuntimeFlags))
for k, v := range defaultRuntimeFlags {
out[k] = v
}
return out
return maps.Clone(defaultRuntimeFlags)
}