diff --git a/config/options.go b/config/options.go index 48d23d76a..6e425e44b 100644 --- a/config/options.go +++ b/config/options.go @@ -371,7 +371,7 @@ func optionsFromViper(configFile string) (*Options, error) { o := NewDefaultOptions() v := o.viper // Load up config - err := bindEnvs(o, v) + err := bindEnvs(v) if err != nil { return nil, fmt.Errorf("failed to bind options to env vars: %w", err) } @@ -506,20 +506,11 @@ func (o *Options) parseHeaders(_ context.Context) error { return nil } -// bindEnvs binds a viper instance to each env var of an Options struct based -// on the mapstructure tag -func bindEnvs(o *Options, v *viper.Viper) error { - tagName := `mapstructure` - t := reflect.TypeOf(*o) - - for i := 0; i < t.NumField(); i++ { - field := t.Field(i) - envName := field.Tag.Get(tagName) - err := v.BindEnv(envName) - if err != nil { - return fmt.Errorf("failed to bind field '%s' to env var '%s': %w", field.Name, envName, err) - } - +// bindEnvs adds a Viper environment variable binding for each field in the +// Options struct (including nested structs), based on the mapstructure tag. +func bindEnvs(v *viper.Viper) error { + if _, err := bindEnvsRecursive(reflect.TypeOf(Options{}), v, "", ""); err != nil { + return err } // Statically bind fields @@ -531,20 +522,56 @@ func bindEnvs(o *Options, v *viper.Viper) error { if err != nil { return fmt.Errorf("failed to bind field 'HeadersEnv' to env var 'HEADERS': %w", err) } - // autocert options - ao := reflect.TypeOf(o.AutocertOptions) - for i := 0; i < ao.NumField(); i++ { - field := ao.Field(i) - envName := field.Tag.Get(tagName) - err := v.BindEnv(envName) - if err != nil { - return fmt.Errorf("failed to bind field '%s' to env var '%s': %w", field.Name, envName, err) - } - } return nil } +// bindEnvsRecursive binds all fields of the provided struct type that have a +// "mapstructure" tag to corresponding environment variables, recursively. If a +// nested struct contains no fields with a "mapstructure" tag, a binding will +// be added for the struct itself (e.g. null.Bool). +func bindEnvsRecursive(t reflect.Type, v *viper.Viper, keyPrefix, envPrefix string) (bool, error) { + anyFieldHasMapstructureTag := false + for i := 0; i < t.NumField(); i++ { + field := t.Field(i) + tag, hasTag := field.Tag.Lookup("mapstructure") + if !hasTag || tag == "-" { + continue + } + + anyFieldHasMapstructureTag = true + + key, _, _ := strings.Cut(tag, ",") + keyPath := keyPrefix + key + envName := envPrefix + strings.ToUpper(key) + + if field.Type.Kind() == reflect.Struct { + newKeyPrefix := keyPath + newEnvPrefix := envName + if key != "" { + newKeyPrefix += "." + newEnvPrefix += "_" + } + nestedMapstructure, err := bindEnvsRecursive(field.Type, v, newKeyPrefix, newEnvPrefix) + if err != nil { + return false, err + } else if nestedMapstructure { + // If we've bound any nested fields from this struct, do not + // also bind this struct itself. + continue + } + } + + if key != "" { + if err := v.BindEnv(keyPath, envName); err != nil { + return false, fmt.Errorf("failed to bind field '%s' to env var '%s': %w", + field.Name, envName, err) + } + } + } + return anyFieldHasMapstructureTag, nil +} + // Validate ensures the Options fields are valid, and hydrated. func (o *Options) Validate() error { ctx := context.TODO() diff --git a/config/options_test.go b/config/options_test.go index 4748ab046..b0380b613 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -11,6 +11,8 @@ import ( "net/url" "os" "path/filepath" + "reflect" + "strings" "sync" "testing" "time" @@ -94,7 +96,7 @@ func Test_bindEnvs(t *testing.T) { t.Setenv("POMERIUM_DEBUG", "true") t.Setenv("POLICY", "LSBmcm9tOiBodHRwczovL2h0dHBiaW4ubG9jYWxob3N0LnBvbWVyaXVtLmlvCiAgdG86IAogICAgLSBodHRwOi8vbG9jYWxob3N0OjgwODEsMQo=") t.Setenv("HEADERS", `{"X-Custom-1":"foo", "X-Custom-2":"bar"}`) - err := bindEnvs(o, v) + err := bindEnvs(v) if err != nil { t.Fatalf("failed to bind options to env vars: %s", err) } @@ -117,6 +119,83 @@ func Test_bindEnvs(t *testing.T) { } } +type Foo struct { + FieldOne Bar `mapstructure:"field_one"` + FieldTwo string `mapstructure:"field_two"` +} +type Bar struct { + Baz int `mapstructure:"baz"` + Quux string `mapstructure:"quux"` +} + +func Test_bindEnvsRecursive(t *testing.T) { + v := viper.New() + _, err := bindEnvsRecursive(reflect.TypeOf(Foo{}), v, "", "") + require.NoError(t, err) + + t.Setenv("FIELD_ONE_BAZ", "123") + t.Setenv("FIELD_ONE_QUUX", "hello") + t.Setenv("FIELD_TWO", "world") + + var foo Foo + v.Unmarshal(&foo) + assert.Equal(t, Foo{ + FieldOne: Bar{ + Baz: 123, + Quux: "hello", + }, + FieldTwo: "world", + }, foo) +} + +func Test_bindEnvsRecursive_Override(t *testing.T) { + v := viper.New() + v.SetConfigType("yaml") + v.ReadConfig(strings.NewReader(` +field_one: + baz: 10 + quux: abc +field_two: hello +`)) + + // Baseline: values populated from config file. + var foo1 Foo + v.Unmarshal(&foo1) + assert.Equal(t, Foo{ + FieldOne: Bar{ + Baz: 10, + Quux: "abc", + }, + FieldTwo: "hello", + }, foo1) + + _, err := bindEnvsRecursive(reflect.TypeOf(Foo{}), v, "", "") + require.NoError(t, err) + + // Environment variables should selectively override config file keys. + t.Setenv("FIELD_ONE_QUUX", "def") + var foo2 Foo + v.Unmarshal(&foo2) + assert.Equal(t, Foo{ + FieldOne: Bar{ + Baz: 10, + Quux: "def", + }, + FieldTwo: "hello", + }, foo2) + + t.Setenv("FIELD_TWO", "world") + var foo3 Foo + v.Unmarshal(&foo3) + assert.Equal(t, Foo{ + FieldOne: Bar{ + Baz: 10, + Quux: "def", + }, + FieldTwo: "world", + }, foo3) +} + func Test_parseHeaders(t *testing.T) { // t.Parallel() tests := []struct {