diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 850fd3ec9..f1d9a06c9 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -20,7 +20,11 @@ import ( // ValidateOptions checks that configuration are complete and valid. // Returns on first error found. func ValidateOptions(o *config.Options) error { - if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil { + sharedKey, err := o.GetSharedKey() + if err != nil { + return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %w", err) + } + if _, err := cryptutil.NewAEADCipher(sharedKey); err != nil { return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %w", err) } if _, err := cryptutil.NewAEADCipherFromBase64(o.CookieSecret); err != nil { diff --git a/authenticate/authenticate_test.go b/authenticate/authenticate_test.go index 4bc418813..918882ca0 100644 --- a/authenticate/authenticate_test.go +++ b/authenticate/authenticate_test.go @@ -39,6 +39,7 @@ func TestOptions_Validate(t *testing.T) { shortCookieLength := newTestOptions(t) shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" badSharedKey := newTestOptions(t) + badSharedKey.Services = "authenticate" badSharedKey.SharedKey = "" badAuthenticateURL := newTestOptions(t) badAuthenticateURL.AuthenticateURLString = "BAD_URL" diff --git a/authenticate/state.go b/authenticate/state.go index 3a5572de1..6686f2146 100644 --- a/authenticate/state.go +++ b/authenticate/state.go @@ -78,7 +78,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err state.redirectURL.Path = cfg.Options.AuthenticateCallbackPath // shared cipher to encrypt data before passing data between services - state.sharedKey, err = base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + state.sharedKey, err = cfg.Options.GetSharedKey() if err != nil { return nil, err } @@ -140,7 +140,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err state.jwk.Keys = append(state.jwk.Keys, *jwk) } - sharedKey, err := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + sharedKey, err := cfg.Options.GetSharedKey() if err != nil { return nil, err } @@ -157,7 +157,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err CAFile: cfg.Options.CAFile, RequestTimeout: cfg.Options.GRPCClientTimeout, ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, - WithInsecure: cfg.Options.GRPCInsecure, + WithInsecure: cfg.Options.GetGRPCInsecure(), InstallationID: cfg.Options.InstallationID, ServiceName: cfg.Options.Services, SignedJWTKey: sharedKey, diff --git a/authorize/authorize.go b/authorize/authorize.go index b7b49e81b..188ab6a5c 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -61,7 +61,11 @@ func (a *Authorize) WaitForInitialSync(ctx context.Context) error { } func validateOptions(o *config.Options) error { - if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil { + sharedKey, err := o.GetSharedKey() + if err != nil { + return fmt.Errorf("authorize: bad 'SHARED_SECRET': %w", err) + } + if _, err := cryptutil.NewAEADCipher(sharedKey); err != nil { return fmt.Errorf("authorize: bad 'SHARED_SECRET': %w", err) } if _, err := o.GetAuthenticateURL(); err != nil { diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index c36031f18..03ce65243 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -48,7 +48,9 @@ func New(options *config.Options, store *Store) (*Evaluator, error) { } store.UpdateIssuer(authenticateURL.Host) - store.UpdateGoogleCloudServerlessAuthenticationServiceAccount(options.GoogleCloudServerlessAuthenticationServiceAccount) + store.UpdateGoogleCloudServerlessAuthenticationServiceAccount( + options.GetGoogleCloudServerlessAuthenticationServiceAccount(), + ) store.UpdateJWTClaimHeaders(options.JWTClaimsHeaders) store.UpdateRoutePolicies(options.GetAllPolicies()) store.UpdateSigningKey(jwk) diff --git a/authorize/state.go b/authorize/state.go index 3aebf6253..2c5985e28 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -1,7 +1,6 @@ package authorize import ( - "encoding/base64" "fmt" "sync/atomic" @@ -36,7 +35,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err) } - state.sharedKey, err = base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + state.sharedKey, err = cfg.Options.GetSharedKey() if err != nil { return nil, err } @@ -46,7 +45,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a return nil, err } - sharedKey, err := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + sharedKey, err := cfg.Options.GetSharedKey() if err != nil { return nil, err } @@ -63,7 +62,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a CAFile: cfg.Options.CAFile, RequestTimeout: cfg.Options.GRPCClientTimeout, ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, - WithInsecure: cfg.Options.GRPCInsecure, + WithInsecure: cfg.Options.GetGRPCInsecure(), InstallationID: cfg.Options.InstallationID, ServiceName: cfg.Options.Services, SignedJWTKey: sharedKey, diff --git a/config/envoyconfig/listeners.go b/config/envoyconfig/listeners.go index 82a220124..a702b62f5 100644 --- a/config/envoyconfig/listeners.go +++ b/config/envoyconfig/listeners.go @@ -307,7 +307,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter( Domains: []string{domain}, } - if options.Addr == options.GRPCAddr { + if options.Addr == options.GetGRPCAddr() { // if this is a gRPC service domain and we're supposed to handle that, add those routes if (config.IsAuthorize(options.Services) && hostsMatchDomain(authorizeURLs, domain)) || (config.IsDataBroker(options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) { @@ -337,7 +337,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter( // if we're the proxy or authenticate service, add our global headers if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) { - vh.ResponseHeadersToAdd = toEnvoyHeaders(options.SetResponseHeaders) + vh.ResponseHeadersToAdd = toEnvoyHeaders(options.GetSetResponseHeaders()) } if len(vh.Routes) > 0 { @@ -531,10 +531,10 @@ func (b *Builder) buildGRPCListener(cfg *config.Config) (*envoy_config_listener_ return nil, err } - if cfg.Options.GRPCInsecure { + if cfg.Options.GetGRPCInsecure() { return &envoy_config_listener_v3.Listener{ Name: "grpc-ingress", - Address: buildAddress(cfg.Options.GRPCAddr, 80), + Address: buildAddress(cfg.Options.GetGRPCAddr(), 80), FilterChains: []*envoy_config_listener_v3.FilterChain{{ Filters: []*envoy_config_listener_v3.Filter{ filter, @@ -572,7 +572,7 @@ func (b *Builder) buildGRPCListener(cfg *config.Config) (*envoy_config_listener_ tlsInspectorCfg := marshalAny(new(emptypb.Empty)) li := &envoy_config_listener_v3.Listener{ Name: "grpc-ingress", - Address: buildAddress(cfg.Options.GRPCAddr, 443), + Address: buildAddress(cfg.Options.GetGRPCAddr(), 443), ListenerFilters: []*envoy_config_listener_v3.ListenerFilter{{ Name: "envoy.filters.listener.tls_inspector", ConfigType: &envoy_config_listener_v3.ListenerFilter_TypedConfig{ @@ -713,14 +713,14 @@ func getAllRouteableDomains(options *config.Options, addr string) ([]string, err lookup[h] = struct{}{} } } - if config.IsAuthorize(options.Services) && addr == options.GRPCAddr { + if config.IsAuthorize(options.Services) && addr == options.GetGRPCAddr() { for _, u := range authorizeURLs { for _, h := range urlutil.GetDomainsForURL(*u) { lookup[h] = struct{}{} } } } - if config.IsDataBroker(options.Services) && addr == options.GRPCAddr { + if config.IsDataBroker(options.Services) && addr == options.GetGRPCAddr() { for _, u := range dataBrokerURLs { for _, h := range urlutil.GetDomainsForURL(*u) { lookup[h] = struct{}{} diff --git a/config/options.go b/config/options.go index f191bd70e..38a9279f2 100644 --- a/config/options.go +++ b/config/options.go @@ -50,6 +50,9 @@ const DefaultAlternativeAddr = ":5443" // EnvoyAdminURL indicates where the envoy control plane is listening var EnvoyAdminURL = &url.URL{Host: "127.0.0.1:9901", Scheme: "http"} +// The randomSharedKey is used if no shared key is supplied in all-in-one mode. +var randomSharedKey = cryptutil.NewBase64Key() + // Options are the global environmental flags used to set up pomerium's services. // Use NewXXXOptions() methods for a safely initialized data structure. type Options struct { @@ -513,26 +516,6 @@ func (o *Options) Validate() error { return fmt.Errorf("config: %s is an invalid service type", o.Services) } - if IsAll(o.Services) { - // mutual auth between services on the same host can be generated at runtime - if o.SharedKey == "" && o.DataBrokerStorageType == StorageInMemoryName { - o.SharedKey = cryptutil.NewBase64Key() - } - // in all in one mode we are running just over the local socket - o.GRPCInsecure = true - // to avoid port collision when running on localhost - if o.GRPCAddr == defaultOptions.GRPCAddr { - o.GRPCAddr = DefaultAlternativeAddr - } - // and we can set the corresponding client - if o.AuthorizeURLString == "" && len(o.AuthorizeURLStrings) == 0 { - o.AuthorizeURLString = "http://127.0.0.1" + DefaultAlternativeAddr - } - if o.DataBrokerURLString == "" && len(o.DataBrokerURLStrings) == 0 { - o.DataBrokerURLString = "http://127.0.0.1" + DefaultAlternativeAddr - } - } - switch o.DataBrokerStorageType { case StorageInMemoryName: case StorageRedisName: @@ -543,12 +526,9 @@ func (o *Options) Validate() error { return errors.New("config: unknown databroker storage backend type") } - if o.SharedKey == "" { - return errors.New("config: shared-key cannot be empty") - } - - if o.SharedKey != strings.TrimSpace(o.SharedKey) { - return errors.New("config: shared-key contains whitespace") + _, err := o.GetSharedKey() + if err != nil { + return fmt.Errorf("config: invalid shared-key: %w", err) } if o.AuthenticateURLString != "" { @@ -597,10 +577,6 @@ func (o *Options) Validate() error { return fmt.Errorf("config: failed to parse headers: %w", err) } - if _, disable := o.SetResponseHeaders[DisableHeaderKey]; disable { - o.SetResponseHeaders = make(map[string]string) - } - hasCert := false if o.Cert != "" || o.Key != "" { @@ -667,13 +643,6 @@ func (o *Options) Validate() error { } } - // if we are using google provider, default to using ServiceAccount for - // GoogleCloudServerlessAuthenticationServiceAccount - if o.Provider == "google" && o.GoogleCloudServerlessAuthenticationServiceAccount == "" { - o.GoogleCloudServerlessAuthenticationServiceAccount = o.ServiceAccount - log.Info(ctx).Msg("defaulting to idp_service_account for google_cloud_serverless_authentication_service_account") - } - // strip quotes from redirect address (#811) o.HTTPRedirectAddr = strings.Trim(o.HTTPRedirectAddr, `"'`) @@ -690,10 +659,6 @@ func (o *Options) Validate() error { default: } - if o.QPS < 1.0 { - o.QPS = 1.0 - } - if err := ValidateDNSLookupFamily(o.DNSLookupFamily); err != nil { return fmt.Errorf("config: %w", err) } @@ -744,11 +709,25 @@ func (o *Options) GetAuthenticateURL() (*url.URL, error) { // GetAuthorizeURLs returns the AuthorizeURLs in the options or 127.0.0.1:5443. func (o *Options) GetAuthorizeURLs() ([]*url.URL, error) { + if IsAll(o.Services) && o.AuthorizeURLString == "" && len(o.AuthorizeURLStrings) == 0 { + u, err := urlutil.ParseAndValidateURL("http://127.0.0.1" + DefaultAlternativeAddr) + if err != nil { + return nil, err + } + return []*url.URL{u}, nil + } return o.getURLs(append([]string{o.AuthorizeURLString}, o.AuthorizeURLStrings...)...) } // GetDataBrokerURLs returns the DataBrokerURLs in the options or 127.0.0.1:5443. func (o *Options) GetDataBrokerURLs() ([]*url.URL, error) { + if IsAll(o.Services) && o.DataBrokerURLString == "" && len(o.DataBrokerURLStrings) == 0 { + u, err := urlutil.ParseAndValidateURL("http://127.0.0.1" + DefaultAlternativeAddr) + if err != nil { + return nil, err + } + return []*url.URL{u}, nil + } return o.getURLs(append([]string{o.DataBrokerURLString}, o.DataBrokerURLStrings...)...) } @@ -782,6 +761,23 @@ func (o *Options) GetForwardAuthURL() (*url.URL, error) { return urlutil.ParseAndValidateURL(rawurl) } +// GetGRPCAddr gets the gRPC address. +func (o *Options) GetGRPCAddr() string { + // to avoid port collision when running on localhost + if IsAll(o.Services) && o.GRPCAddr == defaultOptions.GRPCAddr { + return DefaultAlternativeAddr + } + return o.GRPCAddr +} + +// GetGRPCInsecure gets whether or not gRPC is insecure. +func (o *Options) GetGRPCInsecure() bool { + if IsAll(o.Services) { + return true + } + return o.GRPCInsecure +} + // GetSignOutRedirectURL gets the SignOutRedirectURL. func (o *Options) GetSignOutRedirectURL() (*url.URL, error) { rawurl := o.SignOutRedirectURLString @@ -904,6 +900,46 @@ func (o *Options) GetCertificates() ([]tls.Certificate, error) { return certs, nil } +// GetSharedKey gets the decoded shared key. +func (o *Options) GetSharedKey() ([]byte, error) { + sharedKey := o.SharedKey + // mutual auth between services on the same host can be generated at runtime + if IsAll(o.Services) && o.SharedKey == "" && o.DataBrokerStorageType == StorageInMemoryName { + sharedKey = randomSharedKey + } + if sharedKey == "" { + return nil, errors.New("empty shared-key") + } + if strings.TrimSpace(sharedKey) != sharedKey { + return nil, errors.New("shared-key contains whitespace") + } + return base64.StdEncoding.DecodeString(sharedKey) +} + +// GetGoogleCloudServerlessAuthenticationServiceAccount gets the GoogleCloudServerlessAuthenticationServiceAccount. +func (o *Options) GetGoogleCloudServerlessAuthenticationServiceAccount() string { + if o.GoogleCloudServerlessAuthenticationServiceAccount == "" && o.Provider == "google" { + return o.ServiceAccount + } + return o.GoogleCloudServerlessAuthenticationServiceAccount +} + +// GetSetResponseHeaders gets the SetResponseHeaders. +func (o *Options) GetSetResponseHeaders() map[string]string { + if _, ok := o.SetResponseHeaders[DisableHeaderKey]; ok { + return map[string]string{} + } + return o.SetResponseHeaders +} + +// GetQPS gets the QPS. +func (o *Options) GetQPS() float64 { + if o.QPS < 1 { + return 1 + } + return o.QPS +} + // Checksum returns the checksum of the current options struct func (o *Options) Checksum() uint64 { return hashutil.MustHash(o) diff --git a/config/options_test.go b/config/options_test.go index 9d3658a44..059fbbdfd 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -338,7 +338,7 @@ func TestOptionsFromViper(t *testing.T) { InsecureServer: true, GRPCServerMaxConnectionAge: 5 * time.Minute, GRPCServerMaxConnectionAgeGrace: 5 * time.Minute, - SetResponseHeaders: map[string]string{}, + SetResponseHeaders: map[string]string{"disable": "true"}, RefreshDirectoryTimeout: 1 * time.Minute, RefreshDirectoryInterval: 10 * time.Minute, QPS: 1.0, diff --git a/databroker/cache.go b/databroker/cache.go index 80852f3b8..71262943b 100644 --- a/databroker/cache.go +++ b/databroker/cache.go @@ -5,7 +5,6 @@ package databroker import ( "context" - "encoding/base64" "fmt" "net" "sync" @@ -49,7 +48,7 @@ func New(cfg *config.Config) (*DataBroker, error) { return nil, err } - sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + sharedKey, _ := cfg.Options.GetSharedKey() ui, si := grpcutil.AttachMetadataInterceptors( metadata.Pairs(grpcutil.MetadataKeyPomeriumVersion, version.FullVersion()), @@ -155,7 +154,7 @@ func (c *DataBroker) update(cfg *config.Config) error { ServiceAccount: cfg.Options.ServiceAccount, Provider: cfg.Options.Provider, ProviderURL: cfg.Options.ProviderURL, - QPS: cfg.Options.QPS, + QPS: cfg.Options.GetQPS(), ClientID: cfg.Options.ClientID, ClientSecret: cfg.Options.ClientSecret, }) @@ -185,7 +184,11 @@ func (c *DataBroker) update(cfg *config.Config) error { // validate checks that proper configuration settings are set to create // a databroker instance func validate(o *config.Options) error { - if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil { + sharedKey, err := o.GetSharedKey() + if err != nil { + return fmt.Errorf("invalid 'SHARED_SECRET': %w", err) + } + if _, err := cryptutil.NewAEADCipher(sharedKey); err != nil { return fmt.Errorf("invalid 'SHARED_SECRET': %w", err) } return nil diff --git a/databroker/databroker.go b/databroker/databroker.go index caf4a98e7..263250d6a 100644 --- a/databroker/databroker.go +++ b/databroker/databroker.go @@ -3,7 +3,6 @@ package databroker import ( "context" - "encoding/base64" "sync/atomic" "github.com/pomerium/pomerium/config" @@ -35,7 +34,7 @@ func (srv *dataBrokerServer) OnConfigChange(ctx context.Context, cfg *config.Con func (srv *dataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption { cert, _ := cfg.Options.GetDataBrokerCertificate() return []databroker.ServerOption{ - databroker.WithSharedKey(cfg.Options.SharedKey), + databroker.WithGetSharedKey(cfg.Options.GetSharedKey), databroker.WithStorageType(cfg.Options.DataBrokerStorageType), databroker.WithStorageConnectionString(cfg.Options.DataBrokerStorageConnectionString), databroker.WithStorageCAFile(cfg.Options.DataBrokerStorageCAFile), @@ -45,7 +44,7 @@ func (srv *dataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerO } func (srv *dataBrokerServer) setKey(cfg *config.Config) { - bs, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + bs, _ := cfg.Options.GetSharedKey() if bs == nil { bs = make([]byte, 0) } diff --git a/internal/autocert/manager.go b/internal/autocert/manager.go index 434f013cb..3695bade2 100644 --- a/internal/autocert/manager.go +++ b/internal/autocert/manager.go @@ -23,6 +23,9 @@ import ( var ( errObtainCertFailed = errors.New("obtain cert failed") errRenewCertFailed = errors.New("renew cert failed") + + // RenewCert is not thread-safe + renewCertLock sync.Mutex ) // Manager manages TLS certificates. @@ -188,7 +191,9 @@ func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.C func (mgr *Manager) renewCert(domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) { expired := time.Now().After(cert.Leaf.NotAfter) log.Info(context.TODO()).Str("domain", domain).Msg("renewing certificate") + renewCertLock.Lock() err := cm.RenewCert(context.Background(), domain, false) + renewCertLock.Unlock() if err != nil { if expired { return certmagic.Certificate{}, errRenewCertFailed diff --git a/internal/databroker/config.go b/internal/databroker/config.go index f4018e664..71fb98723 100644 --- a/internal/databroker/config.go +++ b/internal/databroker/config.go @@ -3,7 +3,6 @@ package databroker import ( "context" "crypto/tls" - "encoding/base64" "time" "github.com/pomerium/pomerium/internal/log" @@ -61,15 +60,15 @@ func WithGetAllPageSize(pageSize int) ServerOption { } } -// WithSharedKey sets the secret in the config. -func WithSharedKey(sharedKey string) ServerOption { +// WithGetSharedKey sets the secret in the config. +func WithGetSharedKey(getSharedKey func() ([]byte, error)) ServerOption { return func(cfg *serverConfig) { - key, err := base64.StdEncoding.DecodeString(sharedKey) - if err != nil || len(key) != cryptutil.DefaultKeySize { + sharedKey, err := getSharedKey() + if err != nil { log.Error(context.TODO()).Err(err).Msgf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize) return } - cfg.secret = key + cfg.secret = sharedKey } } diff --git a/internal/databroker/config_source.go b/internal/databroker/config_source.go index 3160c00d8..17437acba 100644 --- a/internal/databroker/config_source.go +++ b/internal/databroker/config_source.go @@ -2,7 +2,6 @@ package databroker import ( "context" - "encoding/base64" "sync" "github.com/pomerium/pomerium/config" @@ -164,7 +163,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) { return } - sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + sharedKey, _ := cfg.Options.GetSharedKey() connectionOptions := &grpc.Options{ Addrs: urls, OverrideCertificateName: cfg.Options.OverrideCertificateName, @@ -172,7 +171,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) { CAFile: cfg.Options.CAFile, RequestTimeout: cfg.Options.GRPCClientTimeout, ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, - WithInsecure: cfg.Options.GRPCInsecure, + WithInsecure: cfg.Options.GetGRPCInsecure(), ServiceName: cfg.Options.Services, SignedJWTKey: sharedKey, } diff --git a/internal/httputil/reproxy/reproxy.go b/internal/httputil/reproxy/reproxy.go index 0178b6ade..4fc31ac8e 100644 --- a/internal/httputil/reproxy/reproxy.go +++ b/internal/httputil/reproxy/reproxy.go @@ -125,7 +125,7 @@ func (h *Handler) Update(ctx context.Context, cfg *config.Config) { h.mu.Lock() defer h.mu.Unlock() - h.key, _ = base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + h.key, _ = cfg.Options.GetSharedKey() h.options = cfg.Options h.policies = make(map[uint64]*config.Policy) for i, p := range cfg.Options.Policies { diff --git a/internal/registry/reporter.go b/internal/registry/reporter.go index b78210efb..4f09895a4 100644 --- a/internal/registry/reporter.go +++ b/internal/registry/reporter.go @@ -33,7 +33,7 @@ func (r *Reporter) OnConfigChange(ctx context.Context, cfg *config.Config) { log.Warn(ctx).Err(err).Msg("metrics announce to service registry is disabled") } - sharedKey, err := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + sharedKey, err := cfg.Options.GetSharedKey() if err != nil { log.Error(ctx).Err(err).Msg("decoding shared key") return @@ -52,7 +52,7 @@ func (r *Reporter) OnConfigChange(ctx context.Context, cfg *config.Config) { CAFile: cfg.Options.CAFile, RequestTimeout: cfg.Options.GRPCClientTimeout, ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, - WithInsecure: cfg.Options.GRPCInsecure, + WithInsecure: cfg.Options.GetGRPCInsecure(), InstallationID: cfg.Options.InstallationID, ServiceName: cfg.Options.Services, SignedJWTKey: sharedKey, diff --git a/proxy/proxy.go b/proxy/proxy.go index bf6081334..8c3879bff 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -31,7 +31,12 @@ const ( // ValidateOptions checks that proper configuration settings are set to create // a proper Proxy instance func ValidateOptions(o *config.Options) error { - if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil { + sharedKey, err := o.GetSharedKey() + if err != nil { + return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %w", err) + } + + if _, err := cryptutil.NewAEADCipher(sharedKey); err != nil { return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %w", err) } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 07c0874f4..de4e3596c 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -48,6 +48,7 @@ func TestOptions_Validate(t *testing.T) { shortCookieLength := testOptions(t) shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" badSharedKey := testOptions(t) + badSharedKey.Services = "proxy" badSharedKey.SharedKey = "" sharedKeyBadBas64 := testOptions(t) sharedKeyBadBas64.SharedKey = "%(*@389" diff --git a/proxy/state.go b/proxy/state.go index acf6580dd..54cc082ea 100644 --- a/proxy/state.go +++ b/proxy/state.go @@ -44,12 +44,12 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) { } state := new(proxyState) - state.sharedKey, err = base64.StdEncoding.DecodeString(cfg.Options.SharedKey) + state.sharedKey, err = cfg.Options.GetSharedKey() if err != nil { return nil, err } - state.sharedCipher, err = cryptutil.NewAEADCipherFromBase64(cfg.Options.SharedKey) + state.sharedCipher, err = cryptutil.NewAEADCipher(state.sharedKey) if err != nil { return nil, err }