mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-01 16:01:26 +02:00
address code review
This commit is contained in:
parent
be9c14fc04
commit
0429a6ceb2
6 changed files with 16 additions and 151 deletions
|
@ -40,5 +40,4 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
|
|||
decodePPLPolicyHookFunc(),
|
||||
decodeSANMatcherHookFunc(),
|
||||
decodeStringToMapHookFunc(),
|
||||
decodeRuntimeFlagsHookFunc(),
|
||||
))
|
||||
|
|
|
@ -549,46 +549,6 @@ var (
|
|||
ErrRuntimeFlagInvalidMapValue = fmt.Errorf("decoding runtime flags: unknown flag value (expected bool)")
|
||||
)
|
||||
|
||||
// only for RuntimeFlags decoding
|
||||
func decodeRuntimeFlagsHookFunc() mapstructure.DecodeHookFunc {
|
||||
return mapstructure.DecodeHookFuncType(func(from, to reflect.Type, data any) (any, error) {
|
||||
if to != reflect.TypeOf(RuntimeFlags{}) {
|
||||
return data, nil
|
||||
}
|
||||
if from.Kind() != reflect.Map {
|
||||
return nil, fmt.Errorf("%w: got %s", ErrRuntimeFlagsInvalidValue, from.Kind())
|
||||
}
|
||||
|
||||
src, ok := data.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("decoding runtime flags: expected map[string]interface{}, got %T", from.String())
|
||||
}
|
||||
|
||||
dst := make(RuntimeFlags)
|
||||
// copy default flags first
|
||||
// this is to ensure we don't lose any default flags and only override the ones that are provided
|
||||
for k, v := range defaultRuntimeFlags {
|
||||
dst[k] = v
|
||||
}
|
||||
|
||||
for k, v := range src {
|
||||
key := RuntimeFlag(k)
|
||||
if _, ok := dst[key]; !ok {
|
||||
return nil, fmt.Errorf("%s: %w", k, ErrRuntimeFlagUnknown)
|
||||
}
|
||||
|
||||
b, ok := v.(bool)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("%w: got %T", ErrRuntimeFlagInvalidMapValue, v)
|
||||
}
|
||||
|
||||
dst[key] = b
|
||||
}
|
||||
|
||||
return dst, nil
|
||||
})
|
||||
}
|
||||
|
||||
// serializable converts mapstructure nested map into map[string]interface{} that is serializable to JSON
|
||||
func serializable(in interface{}) (interface{}, error) {
|
||||
switch typed := in.(type) {
|
||||
|
|
|
@ -3,7 +3,6 @@ package config
|
|||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
@ -184,102 +183,3 @@ func TestDecodePPLPolicyHookFunc(t *testing.T) {
|
|||
},
|
||||
}, withPolicy.Policy)
|
||||
}
|
||||
|
||||
func TestDecodeRuntimeFlagsHookFunc(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
defaults := DefaultRuntimeFlags()
|
||||
withMap := struct {
|
||||
SomethingElse map[string]bool `mapstructure:"something_else"`
|
||||
RuntimeFlags RuntimeFlags `mapstructure:"runtime_flags"`
|
||||
}{
|
||||
RuntimeFlags: defaults,
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: decodeRuntimeFlagsHookFunc(),
|
||||
Result: &withMap,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
expect := DefaultRuntimeFlags()
|
||||
expect[GRPCDatabrokerKeepalive] = !expect[GRPCDatabrokerKeepalive]
|
||||
|
||||
err = decoder.Decode(map[string]interface{}{
|
||||
"something_else": map[string]bool{
|
||||
"hello": true,
|
||||
},
|
||||
"runtime_flags": map[string]interface{}{
|
||||
string(GRPCDatabrokerKeepalive): expect[GRPCDatabrokerKeepalive],
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, expect, withMap.RuntimeFlags)
|
||||
})
|
||||
|
||||
t.Run("dont override if unset", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
defaults := DefaultRuntimeFlags()
|
||||
withMap := struct {
|
||||
RuntimeFlags RuntimeFlags `mapstructure:"runtime_flags"`
|
||||
}{
|
||||
RuntimeFlags: defaults,
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: decodeRuntimeFlagsHookFunc(),
|
||||
Result: &withMap,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = decoder.Decode(map[string]interface{}{})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, defaults, withMap.RuntimeFlags)
|
||||
})
|
||||
|
||||
// mapstructure does not correctly wrap errors, so we will have to just search for the error text
|
||||
// https://github.com/mitchellh/mapstructure/blob/ab69d8d93410fce4361f4912bb1ff88110a81311/error.go#L35-L38
|
||||
assertErrorIs := func(t *testing.T, err error, target error) {
|
||||
t.Helper()
|
||||
if assert.Error(t, err) {
|
||||
strings.Contains(err.Error(), target.Error())
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("invalid input", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
var withMap struct {
|
||||
RuntimeFlags RuntimeFlags `mapstructure:"runtime_flags"`
|
||||
}
|
||||
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: decodeRuntimeFlagsHookFunc(),
|
||||
Result: &withMap,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = decoder.Decode(map[string]interface{}{
|
||||
"runtime_flags": "hello world",
|
||||
})
|
||||
assertErrorIs(t, err, ErrRuntimeFlagsInvalidValue)
|
||||
|
||||
err = decoder.Decode(map[string]interface{}{
|
||||
"runtime_flags": map[string]interface{}{
|
||||
"no_such_flag": true,
|
||||
},
|
||||
})
|
||||
assertErrorIs(t, err, ErrRuntimeFlagUnknown)
|
||||
|
||||
err = decoder.Decode(map[string]interface{}{
|
||||
"runtime_flags": map[string]interface{}{
|
||||
string(GRPCDatabrokerKeepalive): "hello world",
|
||||
},
|
||||
})
|
||||
assertErrorIs(t, err, ErrRuntimeFlagInvalidMapValue)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -79,7 +79,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
|||
authorizeCluster.OutlierDetection = grpcOutlierDetection()
|
||||
}
|
||||
|
||||
databrokerKeepalive := Keepalive(cfg.Options.IsRuntimeFlagSet(config.GRPCDatabrokerKeepalive))
|
||||
databrokerKeepalive := Keepalive(cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagGRPCDatabrokerKeepalive))
|
||||
databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2, databrokerKeepalive)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -389,7 +389,7 @@ func optionsFromViper(configFile string) (*Options, error) {
|
|||
if err := v.Unmarshal(o, ViperPolicyHooks, func(c *mapstructure.DecoderConfig) { c.Metadata = &metadata }); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
if err := checkConfigKeysErrors(configFile, metadata.Unused); err != nil {
|
||||
if err := checkConfigKeysErrors(configFile, o, metadata.Unused); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
@ -402,7 +402,7 @@ func optionsFromViper(configFile string) (*Options, error) {
|
|||
return o, nil
|
||||
}
|
||||
|
||||
func checkConfigKeysErrors(configFile string, unused []string) error {
|
||||
func checkConfigKeysErrors(configFile string, o *Options, unused []string) error {
|
||||
checks := CheckUnknownConfigFields(unused)
|
||||
ctx := context.Background()
|
||||
errInvalidConfigKeys := errors.New("some configuration options are no longer supported, please check logs for details")
|
||||
|
@ -423,6 +423,14 @@ func checkConfigKeysErrors(configFile string, unused []string) error {
|
|||
}
|
||||
evt.Msg(string(check.FieldCheckMsg))
|
||||
}
|
||||
|
||||
// check for unknown runtime flags
|
||||
for flag := range o.RuntimeFlags {
|
||||
if _, ok := defaultRuntimeFlags[flag]; !ok {
|
||||
log.Warn(ctx).Str("config_file", configFile).Str("flag", string(flag)).Msg("unknown runtime flag")
|
||||
}
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
package config
|
||||
|
||||
// GRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service
|
||||
var GRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false)
|
||||
import "golang.org/x/exp/maps"
|
||||
|
||||
// RuntimeFlagGRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service
|
||||
var RuntimeFlagGRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false)
|
||||
|
||||
// RuntimeFlag is a runtime flag that can flip on/off certain features
|
||||
type RuntimeFlag string
|
||||
|
@ -18,9 +20,5 @@ func runtimeFlag(txt string, def bool) RuntimeFlag {
|
|||
var defaultRuntimeFlags = map[RuntimeFlag]bool{}
|
||||
|
||||
func DefaultRuntimeFlags() RuntimeFlags {
|
||||
out := make(RuntimeFlags, len(defaultRuntimeFlags))
|
||||
for k, v := range defaultRuntimeFlags {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
return maps.Clone(defaultRuntimeFlags)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue