diff --git a/config/constants.go b/config/constants.go index 93fdd62db..362ab380e 100644 --- a/config/constants.go +++ b/config/constants.go @@ -38,4 +38,5 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc( decodeJWTClaimHeadersHookFunc(), decodeCodecTypeHookFunc(), decodePPLPolicyHookFunc(), + decodeSANMatcherHookFunc(), )) diff --git a/config/custom.go b/config/custom.go index 812a8f423..1cd8a6902 100644 --- a/config/custom.go +++ b/config/custom.go @@ -508,6 +508,26 @@ func parseJSONPB(src map[string]interface{}, dst proto.Message, opts protojson.U return opts.Unmarshal(data, dst) } +// decodeSANMatcherHookFunc returns a decode hook for the SANMatcher type. +func decodeSANMatcherHookFunc() mapstructure.DecodeHookFunc { + return func(f, t reflect.Type, data interface{}) (interface{}, error) { + if t != reflect.TypeOf(SANMatcher{}) { + return data, nil + } + + b, err := json.Marshal(data) + if err != nil { + return nil, err + } + + var m SANMatcher + if err := json.Unmarshal(b, &m); err != nil { + return nil, err + } + return m, nil + } +} + // serializable converts mapstructure nested map into map[string]interface{} that is serializable to JSON func serializable(in interface{}) (interface{}, error) { switch typed := in.(type) { diff --git a/config/options_test.go b/config/options_test.go index 3e88e3e85..aa5ee9357 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -341,6 +341,27 @@ func Test_parsePolicyFile(t *testing.T) { } } +func Test_decodeSANMatcher(t *testing.T) { + // Verify that config file parsing will decode the SANMatcher type. + const yaml = ` +downstream_mtls: + match_subject_alt_names: + - dns: 'example-1\..*' + - dns: '.*\.example-2' +` + cfg := filepath.Join(t.TempDir(), "config.yaml") + err := os.WriteFile(cfg, []byte(yaml), 0644) + require.NoError(t, err) + + o, err := optionsFromViper(cfg) + require.NoError(t, err) + + assert.Equal(t, []SANMatcher{ + {Type: SANTypeDNS, Pattern: `example-1\..*`}, + {Type: SANTypeDNS, Pattern: `.*\.example-2`}, + }, o.DownstreamMTLS.MatchSubjectAltNames) +} + func Test_Checksum(t *testing.T) { o := NewDefaultOptions()