mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 02:16:28 +02:00
internal/config: Switch to using struct scoped viper instance (#332)
* Switch to using struct scoped viper instance * Rename NewXXXOptions * Handle unchecked errors from viper.BindEnv
This commit is contained in:
parent
5df0ff500c
commit
251ab0d527
6 changed files with 98 additions and 52 deletions
|
@ -7,7 +7,7 @@ import (
|
|||
)
|
||||
|
||||
func newTestOptions(t *testing.T) *config.Options {
|
||||
opts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
|
||||
opts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@ func Test_newAuthenticateService(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
|
||||
testOpts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -83,7 +83,7 @@ func Test_newAuthorizeService(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
testOpts, err := config.NewOptions("https://some.example", "https://some.example")
|
||||
testOpts, err := config.NewMinimalOptions("https://some.example", "https://some.example")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -128,7 +128,7 @@ func Test_newProxyeService(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
mux := httputil.NewRouter()
|
||||
testOpts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
|
||||
testOpts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
- The healthcheck endpoints (`/ping`) now returns the http status `405` StatusMethodNotAllowed for non-`GET` requests. [GH-319](https://github.com/pomerium/pomerium/issues/319)
|
||||
- Authenticate service no longer uses gRPC.
|
||||
- The global request logger now captures the full array of proxies from `X-Forwarded-For`, in addition to just the client IP.
|
||||
- Options code refactored to eliminate global Viper state. [GH-332](https://github.com/pomerium/pomerium/pull/332/files)
|
||||
|
||||
### Removed
|
||||
|
||||
|
|
|
@ -25,7 +25,8 @@ import (
|
|||
// DisableHeaderKey is the key used to check whether to disable setting header
|
||||
const DisableHeaderKey = "disable"
|
||||
|
||||
// Options are the global environmental flags used to set up pomerium's services.
|
||||
// Options are the global environmental flags used to set up pomerium's services. Use NewXXXOptions() methods
|
||||
// for a safely initialized data structure.
|
||||
type Options struct {
|
||||
// Debug outputs human-readable logs to Stdout.
|
||||
Debug bool `mapstructure:"pomerium_debug"`
|
||||
|
@ -141,6 +142,9 @@ type Options struct {
|
|||
// GRPC Service Settings
|
||||
GRPCClientTimeout time.Duration `mapstructure:"grpc_client_timeout"`
|
||||
GRPCClientDNSRoundRobin bool `mapstructure:"grpc_client_dns_roundrobin"`
|
||||
|
||||
// Scoped viper instance
|
||||
viper *viper.Viper
|
||||
}
|
||||
|
||||
var defaultOptions = Options{
|
||||
|
@ -171,41 +175,63 @@ var defaultOptions = Options{
|
|||
GRPCClientDNSRoundRobin: true,
|
||||
}
|
||||
|
||||
// NewOptions returns a minimal options configuration built from default options.
|
||||
// NewOptions creates a new Options struct with only viper initialized
|
||||
func NewOptions() *Options {
|
||||
o := Options{}
|
||||
o.viper = viper.New()
|
||||
return &o
|
||||
}
|
||||
|
||||
// NewDefaultOptions returns an Options struct with defaults set and viper initialized
|
||||
func NewDefaultOptions() *Options {
|
||||
o := defaultOptions
|
||||
o.viper = viper.New()
|
||||
return &o
|
||||
}
|
||||
|
||||
// NewMinimalOptions returns a minimal options configuration built from default options.
|
||||
// Any modifications to the structure should be followed up by a subsequent
|
||||
// call to validate.
|
||||
func NewOptions(authenticateURL, authorizeURL string) (*Options, error) {
|
||||
o := defaultOptions
|
||||
func NewMinimalOptions(authenticateURL, authorizeURL string) (*Options, error) {
|
||||
o := NewDefaultOptions()
|
||||
o.AuthenticateURLString = authenticateURL
|
||||
o.AuthorizeURLString = authorizeURL
|
||||
if err := o.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("internal/config: validation error %s", err)
|
||||
}
|
||||
return &o, nil
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// OptionsFromViper builds the main binary's configuration
|
||||
// options by parsing environmental variables and config file
|
||||
func OptionsFromViper(configFile string) (*Options, error) {
|
||||
// start a copy of the default options
|
||||
o := defaultOptions
|
||||
o := NewDefaultOptions()
|
||||
// New viper instance to save into Options later
|
||||
v := viper.New()
|
||||
// Load up config
|
||||
o.bindEnvs()
|
||||
err := bindEnvs(o, v)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to bind options to env vars: %w", err)
|
||||
}
|
||||
|
||||
if configFile != "" {
|
||||
viper.SetConfigFile(configFile)
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
v.SetConfigFile(configFile)
|
||||
if err := v.ReadInConfig(); err != nil {
|
||||
return nil, fmt.Errorf("internal/config: failed to read config: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := viper.Unmarshal(&o); err != nil {
|
||||
if err := v.Unmarshal(&o); err != nil {
|
||||
return nil, fmt.Errorf("internal/config: failed to unmarshal config: %s", err)
|
||||
}
|
||||
|
||||
o.viper = v
|
||||
|
||||
if err := o.Validate(); err != nil {
|
||||
return nil, fmt.Errorf("internal/config: validation error %s", err)
|
||||
}
|
||||
return &o, nil
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// Validate ensures the Options fields are properly formed, present, and hydrated.
|
||||
|
@ -270,7 +296,7 @@ func (o *Options) parsePolicy() error {
|
|||
if err := yaml.Unmarshal(policyBytes, &policies); err != nil {
|
||||
return fmt.Errorf("could not unmarshal policy yaml: %s", err)
|
||||
}
|
||||
} else if err := viper.UnmarshalKey("policy", &policies); err != nil {
|
||||
} else if err := o.viper.UnmarshalKey("policy", &policies); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(policies) != 0 {
|
||||
|
@ -291,7 +317,7 @@ func (o *Options) parseHeaders() error {
|
|||
var headers map[string]string
|
||||
if o.HeadersEnv != "" {
|
||||
// Handle JSON by default via viper
|
||||
if headers = viper.GetStringMapString("HeadersEnv"); len(headers) == 0 {
|
||||
if headers = o.viper.GetStringMapString("HeadersEnv"); len(headers) == 0 {
|
||||
// Try to parse "Key1:Value1,Key2:Value2" syntax
|
||||
headerSlice := strings.Split(o.HeadersEnv, ",")
|
||||
for n := range headerSlice {
|
||||
|
@ -307,29 +333,41 @@ func (o *Options) parseHeaders() error {
|
|||
|
||||
}
|
||||
o.Headers = headers
|
||||
} else if viper.IsSet("headers") {
|
||||
if err := viper.UnmarshalKey("headers", &headers); err != nil {
|
||||
return fmt.Errorf("header %s failed to parse: %s", viper.Get("headers"), err)
|
||||
} else if o.viper.IsSet("headers") {
|
||||
if err := o.viper.UnmarshalKey("headers", &headers); err != nil {
|
||||
return fmt.Errorf("header %s failed to parse: %s", o.viper.Get("headers"), err)
|
||||
}
|
||||
o.Headers = headers
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// bindEnvs makes sure viper binds to each env var based on the mapstructure tag
|
||||
func (o *Options) bindEnvs() {
|
||||
// bindEnvs binds a viper instance to each env var of an Options struct based on the mapstructure tag
|
||||
func bindEnvs(o *Options, v *viper.Viper) error {
|
||||
tagName := `mapstructure`
|
||||
t := reflect.TypeOf(*o)
|
||||
|
||||
for i := 0; i < t.NumField(); i++ {
|
||||
field := t.Field(i)
|
||||
envName := field.Tag.Get(tagName)
|
||||
viper.BindEnv(envName)
|
||||
err := v.BindEnv(envName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bind field '%s' to env var '%s': %w", field.Name, envName, err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Statically bind fields
|
||||
viper.BindEnv("PolicyEnv", "POLICY")
|
||||
viper.BindEnv("HeadersEnv", "HEADERS")
|
||||
err := v.BindEnv("PolicyEnv", "POLICY")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bind field 'PolicyEnv' to env var 'POLICY': %w", err)
|
||||
}
|
||||
err = v.BindEnv("HeadersEnv", "HEADERS")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to bind field 'HeadersEnv' to env var 'HEADERS': %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// OptionsUpdater updates local state based on an Options struct
|
||||
|
|
|
@ -13,12 +13,16 @@ import (
|
|||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
var cmpOptIgnoreUnexported = cmpopts.IgnoreUnexported(Options{})
|
||||
|
||||
func Test_validate(t *testing.T) {
|
||||
t.Parallel()
|
||||
testOptions := func() Options {
|
||||
o := defaultOptions
|
||||
o := NewDefaultOptions()
|
||||
|
||||
o.SharedKey = "test"
|
||||
o.Services = "all"
|
||||
return o
|
||||
return *o
|
||||
}
|
||||
good := testOptions()
|
||||
badServices := testOptions()
|
||||
|
@ -55,7 +59,8 @@ func Test_validate(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_bindEnvs(t *testing.T) {
|
||||
o := &Options{}
|
||||
o := NewOptions()
|
||||
v := viper.New()
|
||||
os.Clearenv()
|
||||
defer os.Unsetenv("POMERIUM_DEBUG")
|
||||
defer os.Unsetenv("POLICY")
|
||||
|
@ -63,11 +68,15 @@ func Test_bindEnvs(t *testing.T) {
|
|||
os.Setenv("POMERIUM_DEBUG", "true")
|
||||
os.Setenv("POLICY", "mypolicy")
|
||||
os.Setenv("HEADERS", `{"X-Custom-1":"foo", "X-Custom-2":"bar"}`)
|
||||
o.bindEnvs()
|
||||
err := viper.Unmarshal(o)
|
||||
err := bindEnvs(o, v)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to bind options to env vars: %s", err)
|
||||
}
|
||||
err = v.Unmarshal(o)
|
||||
if err != nil {
|
||||
t.Errorf("Could not unmarshal %#v: %s", o, err)
|
||||
}
|
||||
o.viper = v
|
||||
if !o.Debug {
|
||||
t.Errorf("Failed to load POMERIUM_DEBUG from environment")
|
||||
}
|
||||
|
@ -83,6 +92,7 @@ func Test_bindEnvs(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_parseHeaders(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
want map[string]string
|
||||
|
@ -100,9 +110,9 @@ func Test_parseHeaders(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := defaultOptions
|
||||
viper.Set("headers", tt.viperHeaders)
|
||||
viper.Set("HeadersEnv", tt.envHeaders)
|
||||
o := NewDefaultOptions()
|
||||
o.viper.Set("headers", tt.viperHeaders)
|
||||
o.viper.Set("HeadersEnv", tt.envHeaders)
|
||||
o.HeadersEnv = tt.envHeaders
|
||||
|
||||
err := o.parseHeaders()
|
||||
|
@ -114,14 +124,12 @@ func Test_parseHeaders(t *testing.T) {
|
|||
if !tt.wantErr && !cmp.Equal(tt.want, o.Headers) {
|
||||
t.Errorf("Did get expected headers: %s", cmp.Diff(tt.want, o.Headers))
|
||||
}
|
||||
viper.Reset()
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func Test_OptionsFromViper(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
testPolicy := Policy{
|
||||
To: "https://httpbin.org",
|
||||
|
@ -135,7 +143,7 @@ func Test_OptionsFromViper(t *testing.T) {
|
|||
}
|
||||
|
||||
goodConfigBytes := []byte(`{"authorize_service_url":"https://authorize.corp.example","authenticate_service_url":"https://authenticate.corp.example","shared_secret":"Setec Astronomy","service":"all","policy":[{"from":"https://pomerium.io","to":"https://httpbin.org"}]}`)
|
||||
goodOptions := defaultOptions
|
||||
goodOptions := *(NewDefaultOptions())
|
||||
goodOptions.SharedKey = "Setec Astronomy"
|
||||
goodOptions.Services = "all"
|
||||
goodOptions.Policies = testPolicies
|
||||
|
@ -168,9 +176,9 @@ func Test_OptionsFromViper(t *testing.T) {
|
|||
{"bad json", badConfigBytes, nil, true},
|
||||
{"bad unmarshal", badUnmarshalConfigBytes, nil, true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
viper.Reset()
|
||||
os.Clearenv()
|
||||
os.Setenv("COOKIE_NAME", "oatmeal")
|
||||
defer os.Unsetenv("COOKIE_NAME")
|
||||
|
@ -187,7 +195,7 @@ func Test_OptionsFromViper(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
if diff := cmp.Diff(got, tt.want, cmpOptIgnoreUnexported); diff != "" {
|
||||
t.Errorf("OptionsFromViper() = \n%s\n, \ngot\n%+v\n, want \n%+v", diff, got, tt.want)
|
||||
}
|
||||
|
||||
|
@ -203,7 +211,6 @@ func Test_OptionsFromViper(t *testing.T) {
|
|||
|
||||
func Test_parsePolicyEnv(t *testing.T) {
|
||||
t.Parallel()
|
||||
viper.Reset()
|
||||
|
||||
source := "https://pomerium.io"
|
||||
sourceURL, _ := url.ParseRequestURI(source)
|
||||
|
@ -223,7 +230,7 @@ func Test_parsePolicyEnv(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
o := new(Options)
|
||||
o := NewOptions()
|
||||
|
||||
o.PolicyEnv = base64.StdEncoding.EncodeToString(tt.policyBytes)
|
||||
err := o.parsePolicy()
|
||||
|
@ -238,7 +245,7 @@ func Test_parsePolicyEnv(t *testing.T) {
|
|||
}
|
||||
|
||||
// Catch bad base64
|
||||
o := new(Options)
|
||||
o := NewOptions()
|
||||
o.PolicyEnv = "foo"
|
||||
err := o.parsePolicy()
|
||||
if err == nil {
|
||||
|
@ -247,7 +254,7 @@ func Test_parsePolicyEnv(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_parsePolicyFile(t *testing.T) {
|
||||
viper.Reset()
|
||||
t.Parallel()
|
||||
source := "https://pomerium.io"
|
||||
sourceURL, _ := url.ParseRequestURI(source)
|
||||
dest := "https://httpbin.org"
|
||||
|
@ -269,9 +276,9 @@ func Test_parsePolicyFile(t *testing.T) {
|
|||
defer tempFile.Close()
|
||||
defer os.Remove(tempFile.Name())
|
||||
tempFile.Write(tt.policyBytes)
|
||||
o := new(Options)
|
||||
viper.SetConfigFile(tempFile.Name())
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
o := NewOptions()
|
||||
o.viper.SetConfigFile(tempFile.Name())
|
||||
if err := o.viper.ReadInConfig(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
err := o.parsePolicy()
|
||||
|
@ -290,7 +297,7 @@ func Test_parsePolicyFile(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_Checksum(t *testing.T) {
|
||||
o := defaultOptions
|
||||
o := NewDefaultOptions()
|
||||
|
||||
oldChecksum := o.Checksum()
|
||||
o.SharedKey = "changemeplease"
|
||||
|
@ -310,7 +317,7 @@ func Test_Checksum(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNewOptions(t *testing.T) {
|
||||
viper.Reset()
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
name string
|
||||
authenticateURL string
|
||||
|
@ -326,7 +333,7 @@ func TestNewOptions(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := NewOptions(tt.authenticateURL, tt.authorizeURL)
|
||||
_, err := NewMinimalOptions(tt.authenticateURL, tt.authorizeURL)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("NewOptions() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -336,9 +343,11 @@ func TestNewOptions(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestOptionsFromViper(t *testing.T) {
|
||||
t.Parallel()
|
||||
opts := []cmp.Option{
|
||||
cmpopts.IgnoreFields(Options{}, "DefaultUpstreamTimeout", "CookieRefresh", "CookieExpire", "Services", "Addr", "RefreshCooldown", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "ReadHeaderTimeout", "IdleTimeout", "GRPCClientTimeout", "GRPCClientDNSRoundRobin"),
|
||||
cmpopts.IgnoreFields(Policy{}, "Source", "Destination"),
|
||||
cmpOptIgnoreUnexported,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
|
@ -394,8 +403,6 @@ func TestOptionsFromViper(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_parseOptions(t *testing.T) {
|
||||
viper.Reset()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
envKey string
|
||||
|
@ -450,7 +457,7 @@ func Test_HandleConfigUpdate(t *testing.T) {
|
|||
os.Setenv("SHARED_SECRET", "foo")
|
||||
defer os.Unsetenv("SHARED_SECRET")
|
||||
|
||||
blankOpts, err := NewOptions("https://authenticate.example", "https://authorize.example")
|
||||
blankOpts, err := NewMinimalOptions("https://authenticate.example", "https://authorize.example")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -11,7 +11,7 @@ import (
|
|||
)
|
||||
|
||||
func newTestOptions(t *testing.T) *config.Options {
|
||||
opts, err := config.NewOptions("https://authenticate.example", "https://authorize.example")
|
||||
opts, err := config.NewMinimalOptions("https://authenticate.example", "https://authorize.example")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue