internal/config: Switch to using struct scoped viper instance (#332)

* Switch to using struct scoped viper instance

* Rename NewXXXOptions

* Handle unchecked errors from viper.BindEnv
This commit is contained in:
Travis Groth 2019-10-01 18:16:36 -04:00 committed by GitHub
parent 5df0ff500c
commit 251ab0d527
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 98 additions and 52 deletions

View file

@ -7,7 +7,7 @@ import (
)
func newTestOptions(t *testing.T) *config.Options {
opts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
opts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
t.Fatal(err)
}

View file

@ -38,7 +38,7 @@ func Test_newAuthenticateService(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
testOpts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
t.Fatal(err)
}
@ -83,7 +83,7 @@ func Test_newAuthorizeService(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
testOpts, err := config.NewOptions("https://some.example", "https://some.example")
testOpts, err := config.NewMinimalOptions("https://some.example", "https://some.example")
if err != nil {
t.Fatal(err)
}
@ -128,7 +128,7 @@ func Test_newProxyeService(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mux := httputil.NewRouter()
testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
testOpts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
t.Fatal(err)
}

View file

@ -23,6 +23,7 @@
- The healthcheck endpoints (`/ping`) now returns the http status `405` StatusMethodNotAllowed for non-`GET` requests. [GH-319](https://github.com/pomerium/pomerium/issues/319)
- Authenticate service no longer uses gRPC.
- The global request logger now captures the full array of proxies from `X-Forwarded-For`, in addition to just the client IP.
- Options code refactored to eliminate global Viper state. [GH-332](https://github.com/pomerium/pomerium/pull/332/files)
### Removed

View file

@ -25,7 +25,8 @@ import (
// DisableHeaderKey is the key used to check whether to disable setting header
const DisableHeaderKey = "disable"
// Options are the global environmental flags used to set up pomerium's services.
// Options are the global environmental flags used to set up pomerium's services. Use NewXXXOptions() methods
// for a safely initialized data structure.
type Options struct {
// Debug outputs human-readable logs to Stdout.
Debug bool `mapstructure:"pomerium_debug"`
@ -141,6 +142,9 @@ type Options struct {
// GRPC Service Settings
GRPCClientTimeout time.Duration `mapstructure:"grpc_client_timeout"`
GRPCClientDNSRoundRobin bool `mapstructure:"grpc_client_dns_roundrobin"`
// Scoped viper instance
viper *viper.Viper
}
var defaultOptions = Options{
@ -171,41 +175,63 @@ var defaultOptions = Options{
GRPCClientDNSRoundRobin: true,
}
// NewOptions returns a minimal options configuration built from default options.
// NewOptions creates a new Options struct with only viper initialized
func NewOptions() *Options {
o := Options{}
o.viper = viper.New()
return &o
}
// NewDefaultOptions returns an Options struct with defaults set and viper initialized
func NewDefaultOptions() *Options {
o := defaultOptions
o.viper = viper.New()
return &o
}
// NewMinimalOptions returns a minimal options configuration built from default options.
// Any modifications to the structure should be followed up by a subsequent
// call to validate.
func NewOptions(authenticateURL, authorizeURL string) (*Options, error) {
o := defaultOptions
func NewMinimalOptions(authenticateURL, authorizeURL string) (*Options, error) {
o := NewDefaultOptions()
o.AuthenticateURLString = authenticateURL
o.AuthorizeURLString = authorizeURL
if err := o.Validate(); err != nil {
return nil, fmt.Errorf("internal/config: validation error %s", err)
}
return &o, nil
return o, nil
}
// OptionsFromViper builds the main binary's configuration
// options by parsing environmental variables and config file
func OptionsFromViper(configFile string) (*Options, error) {
// start a copy of the default options
o := defaultOptions
o := NewDefaultOptions()
// New viper instance to save into Options later
v := viper.New()
// Load up config
o.bindEnvs()
err := bindEnvs(o, v)
if err != nil {
return nil, fmt.Errorf("failed to bind options to env vars: %w", err)
}
if configFile != "" {
viper.SetConfigFile(configFile)
if err := viper.ReadInConfig(); err != nil {
v.SetConfigFile(configFile)
if err := v.ReadInConfig(); err != nil {
return nil, fmt.Errorf("internal/config: failed to read config: %s", err)
}
}
if err := viper.Unmarshal(&o); err != nil {
if err := v.Unmarshal(&o); err != nil {
return nil, fmt.Errorf("internal/config: failed to unmarshal config: %s", err)
}
o.viper = v
if err := o.Validate(); err != nil {
return nil, fmt.Errorf("internal/config: validation error %s", err)
}
return &o, nil
return o, nil
}
// Validate ensures the Options fields are properly formed, present, and hydrated.
@ -270,7 +296,7 @@ func (o *Options) parsePolicy() error {
if err := yaml.Unmarshal(policyBytes, &policies); err != nil {
return fmt.Errorf("could not unmarshal policy yaml: %s", err)
}
} else if err := viper.UnmarshalKey("policy", &policies); err != nil {
} else if err := o.viper.UnmarshalKey("policy", &policies); err != nil {
return err
}
if len(policies) != 0 {
@ -291,7 +317,7 @@ func (o *Options) parseHeaders() error {
var headers map[string]string
if o.HeadersEnv != "" {
// Handle JSON by default via viper
if headers = viper.GetStringMapString("HeadersEnv"); len(headers) == 0 {
if headers = o.viper.GetStringMapString("HeadersEnv"); len(headers) == 0 {
// Try to parse "Key1:Value1,Key2:Value2" syntax
headerSlice := strings.Split(o.HeadersEnv, ",")
for n := range headerSlice {
@ -307,29 +333,41 @@ func (o *Options) parseHeaders() error {
}
o.Headers = headers
} else if viper.IsSet("headers") {
if err := viper.UnmarshalKey("headers", &headers); err != nil {
return fmt.Errorf("header %s failed to parse: %s", viper.Get("headers"), err)
} else if o.viper.IsSet("headers") {
if err := o.viper.UnmarshalKey("headers", &headers); err != nil {
return fmt.Errorf("header %s failed to parse: %s", o.viper.Get("headers"), err)
}
o.Headers = headers
}
return nil
}
// bindEnvs makes sure viper binds to each env var based on the mapstructure tag
func (o *Options) bindEnvs() {
// bindEnvs binds a viper instance to each env var of an Options struct based on the mapstructure tag
func bindEnvs(o *Options, v *viper.Viper) error {
tagName := `mapstructure`
t := reflect.TypeOf(*o)
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
envName := field.Tag.Get(tagName)
viper.BindEnv(envName)
err := v.BindEnv(envName)
if err != nil {
return fmt.Errorf("failed to bind field '%s' to env var '%s': %w", field.Name, envName, err)
}
}
// Statically bind fields
viper.BindEnv("PolicyEnv", "POLICY")
viper.BindEnv("HeadersEnv", "HEADERS")
err := v.BindEnv("PolicyEnv", "POLICY")
if err != nil {
return fmt.Errorf("failed to bind field 'PolicyEnv' to env var 'POLICY': %w", err)
}
err = v.BindEnv("HeadersEnv", "HEADERS")
if err != nil {
return fmt.Errorf("failed to bind field 'HeadersEnv' to env var 'HEADERS': %w", err)
}
return nil
}
// OptionsUpdater updates local state based on an Options struct

View file

@ -13,12 +13,16 @@ import (
"github.com/spf13/viper"
)
var cmpOptIgnoreUnexported = cmpopts.IgnoreUnexported(Options{})
func Test_validate(t *testing.T) {
t.Parallel()
testOptions := func() Options {
o := defaultOptions
o := NewDefaultOptions()
o.SharedKey = "test"
o.Services = "all"
return o
return *o
}
good := testOptions()
badServices := testOptions()
@ -55,7 +59,8 @@ func Test_validate(t *testing.T) {
}
func Test_bindEnvs(t *testing.T) {
o := &Options{}
o := NewOptions()
v := viper.New()
os.Clearenv()
defer os.Unsetenv("POMERIUM_DEBUG")
defer os.Unsetenv("POLICY")
@ -63,11 +68,15 @@ func Test_bindEnvs(t *testing.T) {
os.Setenv("POMERIUM_DEBUG", "true")
os.Setenv("POLICY", "mypolicy")
os.Setenv("HEADERS", `{"X-Custom-1":"foo", "X-Custom-2":"bar"}`)
o.bindEnvs()
err := viper.Unmarshal(o)
err := bindEnvs(o, v)
if err != nil {
t.Fatalf("failed to bind options to env vars: %s", err)
}
err = v.Unmarshal(o)
if err != nil {
t.Errorf("Could not unmarshal %#v: %s", o, err)
}
o.viper = v
if !o.Debug {
t.Errorf("Failed to load POMERIUM_DEBUG from environment")
}
@ -83,6 +92,7 @@ func Test_bindEnvs(t *testing.T) {
}
func Test_parseHeaders(t *testing.T) {
t.Parallel()
tests := []struct {
name string
want map[string]string
@ -100,9 +110,9 @@ func Test_parseHeaders(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := defaultOptions
viper.Set("headers", tt.viperHeaders)
viper.Set("HeadersEnv", tt.envHeaders)
o := NewDefaultOptions()
o.viper.Set("headers", tt.viperHeaders)
o.viper.Set("HeadersEnv", tt.envHeaders)
o.HeadersEnv = tt.envHeaders
err := o.parseHeaders()
@ -114,14 +124,12 @@ func Test_parseHeaders(t *testing.T) {
if !tt.wantErr && !cmp.Equal(tt.want, o.Headers) {
t.Errorf("Did get expected headers: %s", cmp.Diff(tt.want, o.Headers))
}
viper.Reset()
})
}
}
func Test_OptionsFromViper(t *testing.T) {
viper.Reset()
testPolicy := Policy{
To: "https://httpbin.org",
@ -135,7 +143,7 @@ func Test_OptionsFromViper(t *testing.T) {
}
goodConfigBytes := []byte(`{"authorize_service_url":"https://authorize.corp.example","authenticate_service_url":"https://authenticate.corp.example","shared_secret":"Setec Astronomy","service":"all","policy":[{"from":"https://pomerium.io","to":"https://httpbin.org"}]}`)
goodOptions := defaultOptions
goodOptions := *(NewDefaultOptions())
goodOptions.SharedKey = "Setec Astronomy"
goodOptions.Services = "all"
goodOptions.Policies = testPolicies
@ -168,9 +176,9 @@ func Test_OptionsFromViper(t *testing.T) {
{"bad json", badConfigBytes, nil, true},
{"bad unmarshal", badUnmarshalConfigBytes, nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
viper.Reset()
os.Clearenv()
os.Setenv("COOKIE_NAME", "oatmeal")
defer os.Unsetenv("COOKIE_NAME")
@ -187,7 +195,7 @@ func Test_OptionsFromViper(t *testing.T) {
t.Fatal(err)
}
}
if diff := cmp.Diff(got, tt.want); diff != "" {
if diff := cmp.Diff(got, tt.want, cmpOptIgnoreUnexported); diff != "" {
t.Errorf("OptionsFromViper() = \n%s\n, \ngot\n%+v\n, want \n%+v", diff, got, tt.want)
}
@ -203,7 +211,6 @@ func Test_OptionsFromViper(t *testing.T) {
func Test_parsePolicyEnv(t *testing.T) {
t.Parallel()
viper.Reset()
source := "https://pomerium.io"
sourceURL, _ := url.ParseRequestURI(source)
@ -223,7 +230,7 @@ func Test_parsePolicyEnv(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := new(Options)
o := NewOptions()
o.PolicyEnv = base64.StdEncoding.EncodeToString(tt.policyBytes)
err := o.parsePolicy()
@ -238,7 +245,7 @@ func Test_parsePolicyEnv(t *testing.T) {
}
// Catch bad base64
o := new(Options)
o := NewOptions()
o.PolicyEnv = "foo"
err := o.parsePolicy()
if err == nil {
@ -247,7 +254,7 @@ func Test_parsePolicyEnv(t *testing.T) {
}
func Test_parsePolicyFile(t *testing.T) {
viper.Reset()
t.Parallel()
source := "https://pomerium.io"
sourceURL, _ := url.ParseRequestURI(source)
dest := "https://httpbin.org"
@ -269,9 +276,9 @@ func Test_parsePolicyFile(t *testing.T) {
defer tempFile.Close()
defer os.Remove(tempFile.Name())
tempFile.Write(tt.policyBytes)
o := new(Options)
viper.SetConfigFile(tempFile.Name())
if err := viper.ReadInConfig(); err != nil {
o := NewOptions()
o.viper.SetConfigFile(tempFile.Name())
if err := o.viper.ReadInConfig(); err != nil {
t.Fatal(err)
}
err := o.parsePolicy()
@ -290,7 +297,7 @@ func Test_parsePolicyFile(t *testing.T) {
}
func Test_Checksum(t *testing.T) {
o := defaultOptions
o := NewDefaultOptions()
oldChecksum := o.Checksum()
o.SharedKey = "changemeplease"
@ -310,7 +317,7 @@ func Test_Checksum(t *testing.T) {
}
func TestNewOptions(t *testing.T) {
viper.Reset()
t.Parallel()
tests := []struct {
name string
authenticateURL string
@ -326,7 +333,7 @@ func TestNewOptions(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
_, err := NewOptions(tt.authenticateURL, tt.authorizeURL)
_, err := NewMinimalOptions(tt.authenticateURL, tt.authorizeURL)
if (err != nil) != tt.wantErr {
t.Errorf("NewOptions() error = %v, wantErr %v", err, tt.wantErr)
return
@ -336,9 +343,11 @@ func TestNewOptions(t *testing.T) {
}
func TestOptionsFromViper(t *testing.T) {
t.Parallel()
opts := []cmp.Option{
cmpopts.IgnoreFields(Options{}, "DefaultUpstreamTimeout", "CookieRefresh", "CookieExpire", "Services", "Addr", "RefreshCooldown", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "ReadHeaderTimeout", "IdleTimeout", "GRPCClientTimeout", "GRPCClientDNSRoundRobin"),
cmpopts.IgnoreFields(Policy{}, "Source", "Destination"),
cmpOptIgnoreUnexported,
}
tests := []struct {
@ -394,8 +403,6 @@ func TestOptionsFromViper(t *testing.T) {
}
func Test_parseOptions(t *testing.T) {
viper.Reset()
tests := []struct {
name string
envKey string
@ -450,7 +457,7 @@ func Test_HandleConfigUpdate(t *testing.T) {
os.Setenv("SHARED_SECRET", "foo")
defer os.Unsetenv("SHARED_SECRET")
blankOpts, err := NewOptions("https://authenticate.example", "https://authorize.example")
blankOpts, err := NewMinimalOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
t.Fatal(err)
}

View file

@ -11,7 +11,7 @@ import (
)
func newTestOptions(t *testing.T) *config.Options {
opts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
opts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example")
if err != nil {
t.Fatal(err)
}