mirror of
https://github.com/pomerium/pomerium.git
synced 2025-08-06 10:21:05 +02:00
address code review
This commit is contained in:
parent
be9c14fc04
commit
0429a6ceb2
6 changed files with 16 additions and 151 deletions
|
@ -40,5 +40,4 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
|
||||||
decodePPLPolicyHookFunc(),
|
decodePPLPolicyHookFunc(),
|
||||||
decodeSANMatcherHookFunc(),
|
decodeSANMatcherHookFunc(),
|
||||||
decodeStringToMapHookFunc(),
|
decodeStringToMapHookFunc(),
|
||||||
decodeRuntimeFlagsHookFunc(),
|
|
||||||
))
|
))
|
||||||
|
|
|
@ -549,46 +549,6 @@ var (
|
||||||
ErrRuntimeFlagInvalidMapValue = fmt.Errorf("decoding runtime flags: unknown flag value (expected bool)")
|
ErrRuntimeFlagInvalidMapValue = fmt.Errorf("decoding runtime flags: unknown flag value (expected bool)")
|
||||||
)
|
)
|
||||||
|
|
||||||
// only for RuntimeFlags decoding
|
|
||||||
func decodeRuntimeFlagsHookFunc() mapstructure.DecodeHookFunc {
|
|
||||||
return mapstructure.DecodeHookFuncType(func(from, to reflect.Type, data any) (any, error) {
|
|
||||||
if to != reflect.TypeOf(RuntimeFlags{}) {
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
if from.Kind() != reflect.Map {
|
|
||||||
return nil, fmt.Errorf("%w: got %s", ErrRuntimeFlagsInvalidValue, from.Kind())
|
|
||||||
}
|
|
||||||
|
|
||||||
src, ok := data.(map[string]interface{})
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("decoding runtime flags: expected map[string]interface{}, got %T", from.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
dst := make(RuntimeFlags)
|
|
||||||
// copy default flags first
|
|
||||||
// this is to ensure we don't lose any default flags and only override the ones that are provided
|
|
||||||
for k, v := range defaultRuntimeFlags {
|
|
||||||
dst[k] = v
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range src {
|
|
||||||
key := RuntimeFlag(k)
|
|
||||||
if _, ok := dst[key]; !ok {
|
|
||||||
return nil, fmt.Errorf("%s: %w", k, ErrRuntimeFlagUnknown)
|
|
||||||
}
|
|
||||||
|
|
||||||
b, ok := v.(bool)
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("%w: got %T", ErrRuntimeFlagInvalidMapValue, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
dst[key] = b
|
|
||||||
}
|
|
||||||
|
|
||||||
return dst, nil
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// serializable converts mapstructure nested map into map[string]interface{} that is serializable to JSON
|
// serializable converts mapstructure nested map into map[string]interface{} that is serializable to JSON
|
||||||
func serializable(in interface{}) (interface{}, error) {
|
func serializable(in interface{}) (interface{}, error) {
|
||||||
switch typed := in.(type) {
|
switch typed := in.(type) {
|
||||||
|
|
|
@ -3,7 +3,6 @@ package config
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/mitchellh/mapstructure"
|
"github.com/mitchellh/mapstructure"
|
||||||
|
@ -184,102 +183,3 @@ func TestDecodePPLPolicyHookFunc(t *testing.T) {
|
||||||
},
|
},
|
||||||
}, withPolicy.Policy)
|
}, withPolicy.Policy)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDecodeRuntimeFlagsHookFunc(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
t.Run("valid", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
defaults := DefaultRuntimeFlags()
|
|
||||||
withMap := struct {
|
|
||||||
SomethingElse map[string]bool `mapstructure:"something_else"`
|
|
||||||
RuntimeFlags RuntimeFlags `mapstructure:"runtime_flags"`
|
|
||||||
}{
|
|
||||||
RuntimeFlags: defaults,
|
|
||||||
}
|
|
||||||
|
|
||||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
|
||||||
DecodeHook: decodeRuntimeFlagsHookFunc(),
|
|
||||||
Result: &withMap,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
expect := DefaultRuntimeFlags()
|
|
||||||
expect[GRPCDatabrokerKeepalive] = !expect[GRPCDatabrokerKeepalive]
|
|
||||||
|
|
||||||
err = decoder.Decode(map[string]interface{}{
|
|
||||||
"something_else": map[string]bool{
|
|
||||||
"hello": true,
|
|
||||||
},
|
|
||||||
"runtime_flags": map[string]interface{}{
|
|
||||||
string(GRPCDatabrokerKeepalive): expect[GRPCDatabrokerKeepalive],
|
|
||||||
},
|
|
||||||
})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, expect, withMap.RuntimeFlags)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("dont override if unset", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
defaults := DefaultRuntimeFlags()
|
|
||||||
withMap := struct {
|
|
||||||
RuntimeFlags RuntimeFlags `mapstructure:"runtime_flags"`
|
|
||||||
}{
|
|
||||||
RuntimeFlags: defaults,
|
|
||||||
}
|
|
||||||
|
|
||||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
|
||||||
DecodeHook: decodeRuntimeFlagsHookFunc(),
|
|
||||||
Result: &withMap,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = decoder.Decode(map[string]interface{}{})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, defaults, withMap.RuntimeFlags)
|
|
||||||
})
|
|
||||||
|
|
||||||
// mapstructure does not correctly wrap errors, so we will have to just search for the error text
|
|
||||||
// https://github.com/mitchellh/mapstructure/blob/ab69d8d93410fce4361f4912bb1ff88110a81311/error.go#L35-L38
|
|
||||||
assertErrorIs := func(t *testing.T, err error, target error) {
|
|
||||||
t.Helper()
|
|
||||||
if assert.Error(t, err) {
|
|
||||||
strings.Contains(err.Error(), target.Error())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Run("invalid input", func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
var withMap struct {
|
|
||||||
RuntimeFlags RuntimeFlags `mapstructure:"runtime_flags"`
|
|
||||||
}
|
|
||||||
|
|
||||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
|
||||||
DecodeHook: decodeRuntimeFlagsHookFunc(),
|
|
||||||
Result: &withMap,
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = decoder.Decode(map[string]interface{}{
|
|
||||||
"runtime_flags": "hello world",
|
|
||||||
})
|
|
||||||
assertErrorIs(t, err, ErrRuntimeFlagsInvalidValue)
|
|
||||||
|
|
||||||
err = decoder.Decode(map[string]interface{}{
|
|
||||||
"runtime_flags": map[string]interface{}{
|
|
||||||
"no_such_flag": true,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
assertErrorIs(t, err, ErrRuntimeFlagUnknown)
|
|
||||||
|
|
||||||
err = decoder.Decode(map[string]interface{}{
|
|
||||||
"runtime_flags": map[string]interface{}{
|
|
||||||
string(GRPCDatabrokerKeepalive): "hello world",
|
|
||||||
},
|
|
||||||
})
|
|
||||||
assertErrorIs(t, err, ErrRuntimeFlagInvalidMapValue)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
|
@ -79,7 +79,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
||||||
authorizeCluster.OutlierDetection = grpcOutlierDetection()
|
authorizeCluster.OutlierDetection = grpcOutlierDetection()
|
||||||
}
|
}
|
||||||
|
|
||||||
databrokerKeepalive := Keepalive(cfg.Options.IsRuntimeFlagSet(config.GRPCDatabrokerKeepalive))
|
databrokerKeepalive := Keepalive(cfg.Options.IsRuntimeFlagSet(config.RuntimeFlagGRPCDatabrokerKeepalive))
|
||||||
databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2, databrokerKeepalive)
|
databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2, databrokerKeepalive)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -389,7 +389,7 @@ func optionsFromViper(configFile string) (*Options, error) {
|
||||||
if err := v.Unmarshal(o, ViperPolicyHooks, func(c *mapstructure.DecoderConfig) { c.Metadata = &metadata }); err != nil {
|
if err := v.Unmarshal(o, ViperPolicyHooks, func(c *mapstructure.DecoderConfig) { c.Metadata = &metadata }); err != nil {
|
||||||
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
return nil, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||||
}
|
}
|
||||||
if err := checkConfigKeysErrors(configFile, metadata.Unused); err != nil {
|
if err := checkConfigKeysErrors(configFile, o, metadata.Unused); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -402,7 +402,7 @@ func optionsFromViper(configFile string) (*Options, error) {
|
||||||
return o, nil
|
return o, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func checkConfigKeysErrors(configFile string, unused []string) error {
|
func checkConfigKeysErrors(configFile string, o *Options, unused []string) error {
|
||||||
checks := CheckUnknownConfigFields(unused)
|
checks := CheckUnknownConfigFields(unused)
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
errInvalidConfigKeys := errors.New("some configuration options are no longer supported, please check logs for details")
|
errInvalidConfigKeys := errors.New("some configuration options are no longer supported, please check logs for details")
|
||||||
|
@ -423,6 +423,14 @@ func checkConfigKeysErrors(configFile string, unused []string) error {
|
||||||
}
|
}
|
||||||
evt.Msg(string(check.FieldCheckMsg))
|
evt.Msg(string(check.FieldCheckMsg))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// check for unknown runtime flags
|
||||||
|
for flag := range o.RuntimeFlags {
|
||||||
|
if _, ok := defaultRuntimeFlags[flag]; !ok {
|
||||||
|
log.Warn(ctx).Str("config_file", configFile).Str("flag", string(flag)).Msg("unknown runtime flag")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
// GRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service
|
import "golang.org/x/exp/maps"
|
||||||
var GRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false)
|
|
||||||
|
// RuntimeFlagGRPCDatabrokerKeepalive enables gRPC keepalive to the databroker service
|
||||||
|
var RuntimeFlagGRPCDatabrokerKeepalive = runtimeFlag("grpc_databroker_keepalive", false)
|
||||||
|
|
||||||
// RuntimeFlag is a runtime flag that can flip on/off certain features
|
// RuntimeFlag is a runtime flag that can flip on/off certain features
|
||||||
type RuntimeFlag string
|
type RuntimeFlag string
|
||||||
|
@ -18,9 +20,5 @@ func runtimeFlag(txt string, def bool) RuntimeFlag {
|
||||||
var defaultRuntimeFlags = map[RuntimeFlag]bool{}
|
var defaultRuntimeFlags = map[RuntimeFlag]bool{}
|
||||||
|
|
||||||
func DefaultRuntimeFlags() RuntimeFlags {
|
func DefaultRuntimeFlags() RuntimeFlags {
|
||||||
out := make(RuntimeFlags, len(defaultRuntimeFlags))
|
return maps.Clone(defaultRuntimeFlags)
|
||||||
for k, v := range defaultRuntimeFlags {
|
|
||||||
out[k] = v
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue