From 251ab0d527363ae67790e2f29925925b7d2f5194 Mon Sep 17 00:00:00 2001 From: Travis Groth Date: Tue, 1 Oct 2019 18:16:36 -0400 Subject: [PATCH] internal/config: Switch to using struct scoped viper instance (#332) * Switch to using struct scoped viper instance * Rename NewXXXOptions * Handle unchecked errors from viper.BindEnv --- authenticate/authenticate_test.go | 2 +- cmd/pomerium/main_test.go | 6 +-- docs/docs/CHANGELOG.md | 1 + internal/config/options.go | 80 +++++++++++++++++++++++-------- internal/config/options_test.go | 59 +++++++++++++---------- proxy/proxy_test.go | 2 +- 6 files changed, 98 insertions(+), 52 deletions(-) diff --git a/authenticate/authenticate_test.go b/authenticate/authenticate_test.go index 4567fe3ff..be181b6c1 100644 --- a/authenticate/authenticate_test.go +++ b/authenticate/authenticate_test.go @@ -7,7 +7,7 @@ import ( ) func newTestOptions(t *testing.T) *config.Options { - opts, err := config.NewOptions("https://authenticate.example", "https://authorize.example") + opts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example") if err != nil { t.Fatal(err) } diff --git a/cmd/pomerium/main_test.go b/cmd/pomerium/main_test.go index 52ed82511..86b7a6513 100644 --- a/cmd/pomerium/main_test.go +++ b/cmd/pomerium/main_test.go @@ -38,7 +38,7 @@ func Test_newAuthenticateService(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example") + testOpts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example") if err != nil { t.Fatal(err) } @@ -83,7 +83,7 @@ func Test_newAuthorizeService(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - testOpts, err := config.NewOptions("https://some.example", "https://some.example") + testOpts, err := config.NewMinimalOptions("https://some.example", "https://some.example") if err != nil { t.Fatal(err) } @@ -128,7 +128,7 @@ func Test_newProxyeService(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mux := httputil.NewRouter() - testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example") + testOpts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example") if err != nil { t.Fatal(err) } diff --git a/docs/docs/CHANGELOG.md b/docs/docs/CHANGELOG.md index baec9fcb3..82b31abfa 100644 --- a/docs/docs/CHANGELOG.md +++ b/docs/docs/CHANGELOG.md @@ -23,6 +23,7 @@ - The healthcheck endpoints (`/ping`) now returns the http status `405` StatusMethodNotAllowed for non-`GET` requests. [GH-319](https://github.com/pomerium/pomerium/issues/319) - Authenticate service no longer uses gRPC. - The global request logger now captures the full array of proxies from `X-Forwarded-For`, in addition to just the client IP. +- Options code refactored to eliminate global Viper state. [GH-332](https://github.com/pomerium/pomerium/pull/332/files) ### Removed diff --git a/internal/config/options.go b/internal/config/options.go index e2d8e0624..4cda399a7 100644 --- a/internal/config/options.go +++ b/internal/config/options.go @@ -25,7 +25,8 @@ import ( // DisableHeaderKey is the key used to check whether to disable setting header const DisableHeaderKey = "disable" -// Options are the global environmental flags used to set up pomerium's services. +// Options are the global environmental flags used to set up pomerium's services. Use NewXXXOptions() methods +// for a safely initialized data structure. type Options struct { // Debug outputs human-readable logs to Stdout. Debug bool `mapstructure:"pomerium_debug"` @@ -141,6 +142,9 @@ type Options struct { // GRPC Service Settings GRPCClientTimeout time.Duration `mapstructure:"grpc_client_timeout"` GRPCClientDNSRoundRobin bool `mapstructure:"grpc_client_dns_roundrobin"` + + // Scoped viper instance + viper *viper.Viper } var defaultOptions = Options{ @@ -171,41 +175,63 @@ var defaultOptions = Options{ GRPCClientDNSRoundRobin: true, } -// NewOptions returns a minimal options configuration built from default options. +// NewOptions creates a new Options struct with only viper initialized +func NewOptions() *Options { + o := Options{} + o.viper = viper.New() + return &o +} + +// NewDefaultOptions returns an Options struct with defaults set and viper initialized +func NewDefaultOptions() *Options { + o := defaultOptions + o.viper = viper.New() + return &o +} + +// NewMinimalOptions returns a minimal options configuration built from default options. // Any modifications to the structure should be followed up by a subsequent // call to validate. -func NewOptions(authenticateURL, authorizeURL string) (*Options, error) { - o := defaultOptions +func NewMinimalOptions(authenticateURL, authorizeURL string) (*Options, error) { + o := NewDefaultOptions() o.AuthenticateURLString = authenticateURL o.AuthorizeURLString = authorizeURL if err := o.Validate(); err != nil { return nil, fmt.Errorf("internal/config: validation error %s", err) } - return &o, nil + return o, nil } // OptionsFromViper builds the main binary's configuration // options by parsing environmental variables and config file func OptionsFromViper(configFile string) (*Options, error) { // start a copy of the default options - o := defaultOptions + o := NewDefaultOptions() + // New viper instance to save into Options later + v := viper.New() // Load up config - o.bindEnvs() + err := bindEnvs(o, v) + if err != nil { + return nil, fmt.Errorf("failed to bind options to env vars: %w", err) + } + if configFile != "" { - viper.SetConfigFile(configFile) - if err := viper.ReadInConfig(); err != nil { + v.SetConfigFile(configFile) + if err := v.ReadInConfig(); err != nil { return nil, fmt.Errorf("internal/config: failed to read config: %s", err) } } - if err := viper.Unmarshal(&o); err != nil { + if err := v.Unmarshal(&o); err != nil { return nil, fmt.Errorf("internal/config: failed to unmarshal config: %s", err) } + o.viper = v + if err := o.Validate(); err != nil { return nil, fmt.Errorf("internal/config: validation error %s", err) } - return &o, nil + return o, nil } // Validate ensures the Options fields are properly formed, present, and hydrated. @@ -270,7 +296,7 @@ func (o *Options) parsePolicy() error { if err := yaml.Unmarshal(policyBytes, &policies); err != nil { return fmt.Errorf("could not unmarshal policy yaml: %s", err) } - } else if err := viper.UnmarshalKey("policy", &policies); err != nil { + } else if err := o.viper.UnmarshalKey("policy", &policies); err != nil { return err } if len(policies) != 0 { @@ -291,7 +317,7 @@ func (o *Options) parseHeaders() error { var headers map[string]string if o.HeadersEnv != "" { // Handle JSON by default via viper - if headers = viper.GetStringMapString("HeadersEnv"); len(headers) == 0 { + if headers = o.viper.GetStringMapString("HeadersEnv"); len(headers) == 0 { // Try to parse "Key1:Value1,Key2:Value2" syntax headerSlice := strings.Split(o.HeadersEnv, ",") for n := range headerSlice { @@ -307,29 +333,41 @@ func (o *Options) parseHeaders() error { } o.Headers = headers - } else if viper.IsSet("headers") { - if err := viper.UnmarshalKey("headers", &headers); err != nil { - return fmt.Errorf("header %s failed to parse: %s", viper.Get("headers"), err) + } else if o.viper.IsSet("headers") { + if err := o.viper.UnmarshalKey("headers", &headers); err != nil { + return fmt.Errorf("header %s failed to parse: %s", o.viper.Get("headers"), err) } o.Headers = headers } return nil } -// bindEnvs makes sure viper binds to each env var based on the mapstructure tag -func (o *Options) bindEnvs() { +// bindEnvs binds a viper instance to each env var of an Options struct based on the mapstructure tag +func bindEnvs(o *Options, v *viper.Viper) error { tagName := `mapstructure` t := reflect.TypeOf(*o) for i := 0; i < t.NumField(); i++ { field := t.Field(i) envName := field.Tag.Get(tagName) - viper.BindEnv(envName) + err := v.BindEnv(envName) + if err != nil { + return fmt.Errorf("failed to bind field '%s' to env var '%s': %w", field.Name, envName, err) + } + } // Statically bind fields - viper.BindEnv("PolicyEnv", "POLICY") - viper.BindEnv("HeadersEnv", "HEADERS") + err := v.BindEnv("PolicyEnv", "POLICY") + if err != nil { + return fmt.Errorf("failed to bind field 'PolicyEnv' to env var 'POLICY': %w", err) + } + err = v.BindEnv("HeadersEnv", "HEADERS") + if err != nil { + return fmt.Errorf("failed to bind field 'HeadersEnv' to env var 'HEADERS': %w", err) + } + + return nil } // OptionsUpdater updates local state based on an Options struct diff --git a/internal/config/options_test.go b/internal/config/options_test.go index d8eddc98e..ad57f8755 100644 --- a/internal/config/options_test.go +++ b/internal/config/options_test.go @@ -13,12 +13,16 @@ import ( "github.com/spf13/viper" ) +var cmpOptIgnoreUnexported = cmpopts.IgnoreUnexported(Options{}) + func Test_validate(t *testing.T) { + t.Parallel() testOptions := func() Options { - o := defaultOptions + o := NewDefaultOptions() + o.SharedKey = "test" o.Services = "all" - return o + return *o } good := testOptions() badServices := testOptions() @@ -55,7 +59,8 @@ func Test_validate(t *testing.T) { } func Test_bindEnvs(t *testing.T) { - o := &Options{} + o := NewOptions() + v := viper.New() os.Clearenv() defer os.Unsetenv("POMERIUM_DEBUG") defer os.Unsetenv("POLICY") @@ -63,11 +68,15 @@ func Test_bindEnvs(t *testing.T) { os.Setenv("POMERIUM_DEBUG", "true") os.Setenv("POLICY", "mypolicy") os.Setenv("HEADERS", `{"X-Custom-1":"foo", "X-Custom-2":"bar"}`) - o.bindEnvs() - err := viper.Unmarshal(o) + err := bindEnvs(o, v) + if err != nil { + t.Fatalf("failed to bind options to env vars: %s", err) + } + err = v.Unmarshal(o) if err != nil { t.Errorf("Could not unmarshal %#v: %s", o, err) } + o.viper = v if !o.Debug { t.Errorf("Failed to load POMERIUM_DEBUG from environment") } @@ -83,6 +92,7 @@ func Test_bindEnvs(t *testing.T) { } func Test_parseHeaders(t *testing.T) { + t.Parallel() tests := []struct { name string want map[string]string @@ -100,9 +110,9 @@ func Test_parseHeaders(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - o := defaultOptions - viper.Set("headers", tt.viperHeaders) - viper.Set("HeadersEnv", tt.envHeaders) + o := NewDefaultOptions() + o.viper.Set("headers", tt.viperHeaders) + o.viper.Set("HeadersEnv", tt.envHeaders) o.HeadersEnv = tt.envHeaders err := o.parseHeaders() @@ -114,14 +124,12 @@ func Test_parseHeaders(t *testing.T) { if !tt.wantErr && !cmp.Equal(tt.want, o.Headers) { t.Errorf("Did get expected headers: %s", cmp.Diff(tt.want, o.Headers)) } - viper.Reset() }) } } func Test_OptionsFromViper(t *testing.T) { - viper.Reset() testPolicy := Policy{ To: "https://httpbin.org", @@ -135,7 +143,7 @@ func Test_OptionsFromViper(t *testing.T) { } goodConfigBytes := []byte(`{"authorize_service_url":"https://authorize.corp.example","authenticate_service_url":"https://authenticate.corp.example","shared_secret":"Setec Astronomy","service":"all","policy":[{"from":"https://pomerium.io","to":"https://httpbin.org"}]}`) - goodOptions := defaultOptions + goodOptions := *(NewDefaultOptions()) goodOptions.SharedKey = "Setec Astronomy" goodOptions.Services = "all" goodOptions.Policies = testPolicies @@ -168,9 +176,9 @@ func Test_OptionsFromViper(t *testing.T) { {"bad json", badConfigBytes, nil, true}, {"bad unmarshal", badUnmarshalConfigBytes, nil, true}, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - viper.Reset() os.Clearenv() os.Setenv("COOKIE_NAME", "oatmeal") defer os.Unsetenv("COOKIE_NAME") @@ -187,7 +195,7 @@ func Test_OptionsFromViper(t *testing.T) { t.Fatal(err) } } - if diff := cmp.Diff(got, tt.want); diff != "" { + if diff := cmp.Diff(got, tt.want, cmpOptIgnoreUnexported); diff != "" { t.Errorf("OptionsFromViper() = \n%s\n, \ngot\n%+v\n, want \n%+v", diff, got, tt.want) } @@ -203,7 +211,6 @@ func Test_OptionsFromViper(t *testing.T) { func Test_parsePolicyEnv(t *testing.T) { t.Parallel() - viper.Reset() source := "https://pomerium.io" sourceURL, _ := url.ParseRequestURI(source) @@ -223,7 +230,7 @@ func Test_parsePolicyEnv(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - o := new(Options) + o := NewOptions() o.PolicyEnv = base64.StdEncoding.EncodeToString(tt.policyBytes) err := o.parsePolicy() @@ -238,7 +245,7 @@ func Test_parsePolicyEnv(t *testing.T) { } // Catch bad base64 - o := new(Options) + o := NewOptions() o.PolicyEnv = "foo" err := o.parsePolicy() if err == nil { @@ -247,7 +254,7 @@ func Test_parsePolicyEnv(t *testing.T) { } func Test_parsePolicyFile(t *testing.T) { - viper.Reset() + t.Parallel() source := "https://pomerium.io" sourceURL, _ := url.ParseRequestURI(source) dest := "https://httpbin.org" @@ -269,9 +276,9 @@ func Test_parsePolicyFile(t *testing.T) { defer tempFile.Close() defer os.Remove(tempFile.Name()) tempFile.Write(tt.policyBytes) - o := new(Options) - viper.SetConfigFile(tempFile.Name()) - if err := viper.ReadInConfig(); err != nil { + o := NewOptions() + o.viper.SetConfigFile(tempFile.Name()) + if err := o.viper.ReadInConfig(); err != nil { t.Fatal(err) } err := o.parsePolicy() @@ -290,7 +297,7 @@ func Test_parsePolicyFile(t *testing.T) { } func Test_Checksum(t *testing.T) { - o := defaultOptions + o := NewDefaultOptions() oldChecksum := o.Checksum() o.SharedKey = "changemeplease" @@ -310,7 +317,7 @@ func Test_Checksum(t *testing.T) { } func TestNewOptions(t *testing.T) { - viper.Reset() + t.Parallel() tests := []struct { name string authenticateURL string @@ -326,7 +333,7 @@ func TestNewOptions(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := NewOptions(tt.authenticateURL, tt.authorizeURL) + _, err := NewMinimalOptions(tt.authenticateURL, tt.authorizeURL) if (err != nil) != tt.wantErr { t.Errorf("NewOptions() error = %v, wantErr %v", err, tt.wantErr) return @@ -336,9 +343,11 @@ func TestNewOptions(t *testing.T) { } func TestOptionsFromViper(t *testing.T) { + t.Parallel() opts := []cmp.Option{ cmpopts.IgnoreFields(Options{}, "DefaultUpstreamTimeout", "CookieRefresh", "CookieExpire", "Services", "Addr", "RefreshCooldown", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "ReadHeaderTimeout", "IdleTimeout", "GRPCClientTimeout", "GRPCClientDNSRoundRobin"), cmpopts.IgnoreFields(Policy{}, "Source", "Destination"), + cmpOptIgnoreUnexported, } tests := []struct { @@ -394,8 +403,6 @@ func TestOptionsFromViper(t *testing.T) { } func Test_parseOptions(t *testing.T) { - viper.Reset() - tests := []struct { name string envKey string @@ -450,7 +457,7 @@ func Test_HandleConfigUpdate(t *testing.T) { os.Setenv("SHARED_SECRET", "foo") defer os.Unsetenv("SHARED_SECRET") - blankOpts, err := NewOptions("https://authenticate.example", "https://authorize.example") + blankOpts, err := NewMinimalOptions("https://authenticate.example", "https://authorize.example") if err != nil { t.Fatal(err) } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 100a4e118..609638cf6 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -11,7 +11,7 @@ import ( ) func newTestOptions(t *testing.T) *config.Options { - opts, err := config.NewOptions("https://authenticate.example", "https://authorize.example") + opts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example") if err != nil { t.Fatal(err) }