config: remove validate side effects (#2109)

* config: default shared key

* handle additional errors

* update grpc addr and grpc insecure

* update google cloud service authentication service account

* fix set response headers

* fix qps

* fix test
This commit is contained in:
Caleb Doxsey 2021-04-22 15:10:50 -06:00 committed by GitHub
parent 2806b67bee
commit b1d62bb541
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 138 additions and 81 deletions

View file

@ -20,7 +20,11 @@ import (
// ValidateOptions checks that configuration are complete and valid. // ValidateOptions checks that configuration are complete and valid.
// Returns on first error found. // Returns on first error found.
func ValidateOptions(o *config.Options) error { func ValidateOptions(o *config.Options) error {
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil { sharedKey, err := o.GetSharedKey()
if err != nil {
return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %w", err)
}
if _, err := cryptutil.NewAEADCipher(sharedKey); err != nil {
return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %w", err) return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %w", err)
} }
if _, err := cryptutil.NewAEADCipherFromBase64(o.CookieSecret); err != nil { if _, err := cryptutil.NewAEADCipherFromBase64(o.CookieSecret); err != nil {

View file

@ -39,6 +39,7 @@ func TestOptions_Validate(t *testing.T) {
shortCookieLength := newTestOptions(t) shortCookieLength := newTestOptions(t)
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
badSharedKey := newTestOptions(t) badSharedKey := newTestOptions(t)
badSharedKey.Services = "authenticate"
badSharedKey.SharedKey = "" badSharedKey.SharedKey = ""
badAuthenticateURL := newTestOptions(t) badAuthenticateURL := newTestOptions(t)
badAuthenticateURL.AuthenticateURLString = "BAD_URL" badAuthenticateURL.AuthenticateURLString = "BAD_URL"

View file

@ -78,7 +78,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
state.redirectURL.Path = cfg.Options.AuthenticateCallbackPath state.redirectURL.Path = cfg.Options.AuthenticateCallbackPath
// shared cipher to encrypt data before passing data between services // shared cipher to encrypt data before passing data between services
state.sharedKey, err = base64.StdEncoding.DecodeString(cfg.Options.SharedKey) state.sharedKey, err = cfg.Options.GetSharedKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -140,7 +140,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
state.jwk.Keys = append(state.jwk.Keys, *jwk) state.jwk.Keys = append(state.jwk.Keys, *jwk)
} }
sharedKey, err := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) sharedKey, err := cfg.Options.GetSharedKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -157,7 +157,7 @@ func newAuthenticateStateFromConfig(cfg *config.Config) (*authenticateState, err
CAFile: cfg.Options.CAFile, CAFile: cfg.Options.CAFile,
RequestTimeout: cfg.Options.GRPCClientTimeout, RequestTimeout: cfg.Options.GRPCClientTimeout,
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GRPCInsecure, WithInsecure: cfg.Options.GetGRPCInsecure(),
InstallationID: cfg.Options.InstallationID, InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey, SignedJWTKey: sharedKey,

View file

@ -61,7 +61,11 @@ func (a *Authorize) WaitForInitialSync(ctx context.Context) error {
} }
func validateOptions(o *config.Options) error { func validateOptions(o *config.Options) error {
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil { sharedKey, err := o.GetSharedKey()
if err != nil {
return fmt.Errorf("authorize: bad 'SHARED_SECRET': %w", err)
}
if _, err := cryptutil.NewAEADCipher(sharedKey); err != nil {
return fmt.Errorf("authorize: bad 'SHARED_SECRET': %w", err) return fmt.Errorf("authorize: bad 'SHARED_SECRET': %w", err)
} }
if _, err := o.GetAuthenticateURL(); err != nil { if _, err := o.GetAuthenticateURL(); err != nil {

View file

@ -48,7 +48,9 @@ func New(options *config.Options, store *Store) (*Evaluator, error) {
} }
store.UpdateIssuer(authenticateURL.Host) store.UpdateIssuer(authenticateURL.Host)
store.UpdateGoogleCloudServerlessAuthenticationServiceAccount(options.GoogleCloudServerlessAuthenticationServiceAccount) store.UpdateGoogleCloudServerlessAuthenticationServiceAccount(
options.GetGoogleCloudServerlessAuthenticationServiceAccount(),
)
store.UpdateJWTClaimHeaders(options.JWTClaimsHeaders) store.UpdateJWTClaimHeaders(options.JWTClaimsHeaders)
store.UpdateRoutePolicies(options.GetAllPolicies()) store.UpdateRoutePolicies(options.GetAllPolicies())
store.UpdateSigningKey(jwk) store.UpdateSigningKey(jwk)

View file

@ -1,7 +1,6 @@
package authorize package authorize
import ( import (
"encoding/base64"
"fmt" "fmt"
"sync/atomic" "sync/atomic"
@ -36,7 +35,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err) return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
} }
state.sharedKey, err = base64.StdEncoding.DecodeString(cfg.Options.SharedKey) state.sharedKey, err = cfg.Options.GetSharedKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -46,7 +45,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a
return nil, err return nil, err
} }
sharedKey, err := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) sharedKey, err := cfg.Options.GetSharedKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -63,7 +62,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *evaluator.Store) (*a
CAFile: cfg.Options.CAFile, CAFile: cfg.Options.CAFile,
RequestTimeout: cfg.Options.GRPCClientTimeout, RequestTimeout: cfg.Options.GRPCClientTimeout,
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GRPCInsecure, WithInsecure: cfg.Options.GetGRPCInsecure(),
InstallationID: cfg.Options.InstallationID, InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey, SignedJWTKey: sharedKey,

View file

@ -307,7 +307,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
Domains: []string{domain}, Domains: []string{domain},
} }
if options.Addr == options.GRPCAddr { if options.Addr == options.GetGRPCAddr() {
// if this is a gRPC service domain and we're supposed to handle that, add those routes // if this is a gRPC service domain and we're supposed to handle that, add those routes
if (config.IsAuthorize(options.Services) && hostsMatchDomain(authorizeURLs, domain)) || if (config.IsAuthorize(options.Services) && hostsMatchDomain(authorizeURLs, domain)) ||
(config.IsDataBroker(options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) { (config.IsDataBroker(options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) {
@ -337,7 +337,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
// if we're the proxy or authenticate service, add our global headers // if we're the proxy or authenticate service, add our global headers
if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) { if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) {
vh.ResponseHeadersToAdd = toEnvoyHeaders(options.SetResponseHeaders) vh.ResponseHeadersToAdd = toEnvoyHeaders(options.GetSetResponseHeaders())
} }
if len(vh.Routes) > 0 { if len(vh.Routes) > 0 {
@ -531,10 +531,10 @@ func (b *Builder) buildGRPCListener(cfg *config.Config) (*envoy_config_listener_
return nil, err return nil, err
} }
if cfg.Options.GRPCInsecure { if cfg.Options.GetGRPCInsecure() {
return &envoy_config_listener_v3.Listener{ return &envoy_config_listener_v3.Listener{
Name: "grpc-ingress", Name: "grpc-ingress",
Address: buildAddress(cfg.Options.GRPCAddr, 80), Address: buildAddress(cfg.Options.GetGRPCAddr(), 80),
FilterChains: []*envoy_config_listener_v3.FilterChain{{ FilterChains: []*envoy_config_listener_v3.FilterChain{{
Filters: []*envoy_config_listener_v3.Filter{ Filters: []*envoy_config_listener_v3.Filter{
filter, filter,
@ -572,7 +572,7 @@ func (b *Builder) buildGRPCListener(cfg *config.Config) (*envoy_config_listener_
tlsInspectorCfg := marshalAny(new(emptypb.Empty)) tlsInspectorCfg := marshalAny(new(emptypb.Empty))
li := &envoy_config_listener_v3.Listener{ li := &envoy_config_listener_v3.Listener{
Name: "grpc-ingress", Name: "grpc-ingress",
Address: buildAddress(cfg.Options.GRPCAddr, 443), Address: buildAddress(cfg.Options.GetGRPCAddr(), 443),
ListenerFilters: []*envoy_config_listener_v3.ListenerFilter{{ ListenerFilters: []*envoy_config_listener_v3.ListenerFilter{{
Name: "envoy.filters.listener.tls_inspector", Name: "envoy.filters.listener.tls_inspector",
ConfigType: &envoy_config_listener_v3.ListenerFilter_TypedConfig{ ConfigType: &envoy_config_listener_v3.ListenerFilter_TypedConfig{
@ -713,14 +713,14 @@ func getAllRouteableDomains(options *config.Options, addr string) ([]string, err
lookup[h] = struct{}{} lookup[h] = struct{}{}
} }
} }
if config.IsAuthorize(options.Services) && addr == options.GRPCAddr { if config.IsAuthorize(options.Services) && addr == options.GetGRPCAddr() {
for _, u := range authorizeURLs { for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) { for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{} lookup[h] = struct{}{}
} }
} }
} }
if config.IsDataBroker(options.Services) && addr == options.GRPCAddr { if config.IsDataBroker(options.Services) && addr == options.GetGRPCAddr() {
for _, u := range dataBrokerURLs { for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) { for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{} lookup[h] = struct{}{}

View file

@ -50,6 +50,9 @@ const DefaultAlternativeAddr = ":5443"
// EnvoyAdminURL indicates where the envoy control plane is listening // EnvoyAdminURL indicates where the envoy control plane is listening
var EnvoyAdminURL = &url.URL{Host: "127.0.0.1:9901", Scheme: "http"} var EnvoyAdminURL = &url.URL{Host: "127.0.0.1:9901", Scheme: "http"}
// The randomSharedKey is used if no shared key is supplied in all-in-one mode.
var randomSharedKey = cryptutil.NewBase64Key()
// Options are the global environmental flags used to set up pomerium's services. // Options are the global environmental flags used to set up pomerium's services.
// Use NewXXXOptions() methods for a safely initialized data structure. // Use NewXXXOptions() methods for a safely initialized data structure.
type Options struct { type Options struct {
@ -513,26 +516,6 @@ func (o *Options) Validate() error {
return fmt.Errorf("config: %s is an invalid service type", o.Services) return fmt.Errorf("config: %s is an invalid service type", o.Services)
} }
if IsAll(o.Services) {
// mutual auth between services on the same host can be generated at runtime
if o.SharedKey == "" && o.DataBrokerStorageType == StorageInMemoryName {
o.SharedKey = cryptutil.NewBase64Key()
}
// in all in one mode we are running just over the local socket
o.GRPCInsecure = true
// to avoid port collision when running on localhost
if o.GRPCAddr == defaultOptions.GRPCAddr {
o.GRPCAddr = DefaultAlternativeAddr
}
// and we can set the corresponding client
if o.AuthorizeURLString == "" && len(o.AuthorizeURLStrings) == 0 {
o.AuthorizeURLString = "http://127.0.0.1" + DefaultAlternativeAddr
}
if o.DataBrokerURLString == "" && len(o.DataBrokerURLStrings) == 0 {
o.DataBrokerURLString = "http://127.0.0.1" + DefaultAlternativeAddr
}
}
switch o.DataBrokerStorageType { switch o.DataBrokerStorageType {
case StorageInMemoryName: case StorageInMemoryName:
case StorageRedisName: case StorageRedisName:
@ -543,12 +526,9 @@ func (o *Options) Validate() error {
return errors.New("config: unknown databroker storage backend type") return errors.New("config: unknown databroker storage backend type")
} }
if o.SharedKey == "" { _, err := o.GetSharedKey()
return errors.New("config: shared-key cannot be empty") if err != nil {
} return fmt.Errorf("config: invalid shared-key: %w", err)
if o.SharedKey != strings.TrimSpace(o.SharedKey) {
return errors.New("config: shared-key contains whitespace")
} }
if o.AuthenticateURLString != "" { if o.AuthenticateURLString != "" {
@ -597,10 +577,6 @@ func (o *Options) Validate() error {
return fmt.Errorf("config: failed to parse headers: %w", err) return fmt.Errorf("config: failed to parse headers: %w", err)
} }
if _, disable := o.SetResponseHeaders[DisableHeaderKey]; disable {
o.SetResponseHeaders = make(map[string]string)
}
hasCert := false hasCert := false
if o.Cert != "" || o.Key != "" { if o.Cert != "" || o.Key != "" {
@ -667,13 +643,6 @@ func (o *Options) Validate() error {
} }
} }
// if we are using google provider, default to using ServiceAccount for
// GoogleCloudServerlessAuthenticationServiceAccount
if o.Provider == "google" && o.GoogleCloudServerlessAuthenticationServiceAccount == "" {
o.GoogleCloudServerlessAuthenticationServiceAccount = o.ServiceAccount
log.Info(ctx).Msg("defaulting to idp_service_account for google_cloud_serverless_authentication_service_account")
}
// strip quotes from redirect address (#811) // strip quotes from redirect address (#811)
o.HTTPRedirectAddr = strings.Trim(o.HTTPRedirectAddr, `"'`) o.HTTPRedirectAddr = strings.Trim(o.HTTPRedirectAddr, `"'`)
@ -690,10 +659,6 @@ func (o *Options) Validate() error {
default: default:
} }
if o.QPS < 1.0 {
o.QPS = 1.0
}
if err := ValidateDNSLookupFamily(o.DNSLookupFamily); err != nil { if err := ValidateDNSLookupFamily(o.DNSLookupFamily); err != nil {
return fmt.Errorf("config: %w", err) return fmt.Errorf("config: %w", err)
} }
@ -744,11 +709,25 @@ func (o *Options) GetAuthenticateURL() (*url.URL, error) {
// GetAuthorizeURLs returns the AuthorizeURLs in the options or 127.0.0.1:5443. // GetAuthorizeURLs returns the AuthorizeURLs in the options or 127.0.0.1:5443.
func (o *Options) GetAuthorizeURLs() ([]*url.URL, error) { func (o *Options) GetAuthorizeURLs() ([]*url.URL, error) {
if IsAll(o.Services) && o.AuthorizeURLString == "" && len(o.AuthorizeURLStrings) == 0 {
u, err := urlutil.ParseAndValidateURL("http://127.0.0.1" + DefaultAlternativeAddr)
if err != nil {
return nil, err
}
return []*url.URL{u}, nil
}
return o.getURLs(append([]string{o.AuthorizeURLString}, o.AuthorizeURLStrings...)...) return o.getURLs(append([]string{o.AuthorizeURLString}, o.AuthorizeURLStrings...)...)
} }
// GetDataBrokerURLs returns the DataBrokerURLs in the options or 127.0.0.1:5443. // GetDataBrokerURLs returns the DataBrokerURLs in the options or 127.0.0.1:5443.
func (o *Options) GetDataBrokerURLs() ([]*url.URL, error) { func (o *Options) GetDataBrokerURLs() ([]*url.URL, error) {
if IsAll(o.Services) && o.DataBrokerURLString == "" && len(o.DataBrokerURLStrings) == 0 {
u, err := urlutil.ParseAndValidateURL("http://127.0.0.1" + DefaultAlternativeAddr)
if err != nil {
return nil, err
}
return []*url.URL{u}, nil
}
return o.getURLs(append([]string{o.DataBrokerURLString}, o.DataBrokerURLStrings...)...) return o.getURLs(append([]string{o.DataBrokerURLString}, o.DataBrokerURLStrings...)...)
} }
@ -782,6 +761,23 @@ func (o *Options) GetForwardAuthURL() (*url.URL, error) {
return urlutil.ParseAndValidateURL(rawurl) return urlutil.ParseAndValidateURL(rawurl)
} }
// GetGRPCAddr gets the gRPC address.
func (o *Options) GetGRPCAddr() string {
// to avoid port collision when running on localhost
if IsAll(o.Services) && o.GRPCAddr == defaultOptions.GRPCAddr {
return DefaultAlternativeAddr
}
return o.GRPCAddr
}
// GetGRPCInsecure gets whether or not gRPC is insecure.
func (o *Options) GetGRPCInsecure() bool {
if IsAll(o.Services) {
return true
}
return o.GRPCInsecure
}
// GetSignOutRedirectURL gets the SignOutRedirectURL. // GetSignOutRedirectURL gets the SignOutRedirectURL.
func (o *Options) GetSignOutRedirectURL() (*url.URL, error) { func (o *Options) GetSignOutRedirectURL() (*url.URL, error) {
rawurl := o.SignOutRedirectURLString rawurl := o.SignOutRedirectURLString
@ -904,6 +900,46 @@ func (o *Options) GetCertificates() ([]tls.Certificate, error) {
return certs, nil return certs, nil
} }
// GetSharedKey gets the decoded shared key.
func (o *Options) GetSharedKey() ([]byte, error) {
sharedKey := o.SharedKey
// mutual auth between services on the same host can be generated at runtime
if IsAll(o.Services) && o.SharedKey == "" && o.DataBrokerStorageType == StorageInMemoryName {
sharedKey = randomSharedKey
}
if sharedKey == "" {
return nil, errors.New("empty shared-key")
}
if strings.TrimSpace(sharedKey) != sharedKey {
return nil, errors.New("shared-key contains whitespace")
}
return base64.StdEncoding.DecodeString(sharedKey)
}
// GetGoogleCloudServerlessAuthenticationServiceAccount gets the GoogleCloudServerlessAuthenticationServiceAccount.
func (o *Options) GetGoogleCloudServerlessAuthenticationServiceAccount() string {
if o.GoogleCloudServerlessAuthenticationServiceAccount == "" && o.Provider == "google" {
return o.ServiceAccount
}
return o.GoogleCloudServerlessAuthenticationServiceAccount
}
// GetSetResponseHeaders gets the SetResponseHeaders.
func (o *Options) GetSetResponseHeaders() map[string]string {
if _, ok := o.SetResponseHeaders[DisableHeaderKey]; ok {
return map[string]string{}
}
return o.SetResponseHeaders
}
// GetQPS gets the QPS.
func (o *Options) GetQPS() float64 {
if o.QPS < 1 {
return 1
}
return o.QPS
}
// Checksum returns the checksum of the current options struct // Checksum returns the checksum of the current options struct
func (o *Options) Checksum() uint64 { func (o *Options) Checksum() uint64 {
return hashutil.MustHash(o) return hashutil.MustHash(o)

View file

@ -338,7 +338,7 @@ func TestOptionsFromViper(t *testing.T) {
InsecureServer: true, InsecureServer: true,
GRPCServerMaxConnectionAge: 5 * time.Minute, GRPCServerMaxConnectionAge: 5 * time.Minute,
GRPCServerMaxConnectionAgeGrace: 5 * time.Minute, GRPCServerMaxConnectionAgeGrace: 5 * time.Minute,
SetResponseHeaders: map[string]string{}, SetResponseHeaders: map[string]string{"disable": "true"},
RefreshDirectoryTimeout: 1 * time.Minute, RefreshDirectoryTimeout: 1 * time.Minute,
RefreshDirectoryInterval: 10 * time.Minute, RefreshDirectoryInterval: 10 * time.Minute,
QPS: 1.0, QPS: 1.0,

View file

@ -5,7 +5,6 @@ package databroker
import ( import (
"context" "context"
"encoding/base64"
"fmt" "fmt"
"net" "net"
"sync" "sync"
@ -49,7 +48,7 @@ func New(cfg *config.Config) (*DataBroker, error) {
return nil, err return nil, err
} }
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) sharedKey, _ := cfg.Options.GetSharedKey()
ui, si := grpcutil.AttachMetadataInterceptors( ui, si := grpcutil.AttachMetadataInterceptors(
metadata.Pairs(grpcutil.MetadataKeyPomeriumVersion, version.FullVersion()), metadata.Pairs(grpcutil.MetadataKeyPomeriumVersion, version.FullVersion()),
@ -155,7 +154,7 @@ func (c *DataBroker) update(cfg *config.Config) error {
ServiceAccount: cfg.Options.ServiceAccount, ServiceAccount: cfg.Options.ServiceAccount,
Provider: cfg.Options.Provider, Provider: cfg.Options.Provider,
ProviderURL: cfg.Options.ProviderURL, ProviderURL: cfg.Options.ProviderURL,
QPS: cfg.Options.QPS, QPS: cfg.Options.GetQPS(),
ClientID: cfg.Options.ClientID, ClientID: cfg.Options.ClientID,
ClientSecret: cfg.Options.ClientSecret, ClientSecret: cfg.Options.ClientSecret,
}) })
@ -185,7 +184,11 @@ func (c *DataBroker) update(cfg *config.Config) error {
// validate checks that proper configuration settings are set to create // validate checks that proper configuration settings are set to create
// a databroker instance // a databroker instance
func validate(o *config.Options) error { func validate(o *config.Options) error {
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil { sharedKey, err := o.GetSharedKey()
if err != nil {
return fmt.Errorf("invalid 'SHARED_SECRET': %w", err)
}
if _, err := cryptutil.NewAEADCipher(sharedKey); err != nil {
return fmt.Errorf("invalid 'SHARED_SECRET': %w", err) return fmt.Errorf("invalid 'SHARED_SECRET': %w", err)
} }
return nil return nil

View file

@ -3,7 +3,6 @@ package databroker
import ( import (
"context" "context"
"encoding/base64"
"sync/atomic" "sync/atomic"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
@ -35,7 +34,7 @@ func (srv *dataBrokerServer) OnConfigChange(ctx context.Context, cfg *config.Con
func (srv *dataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption { func (srv *dataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerOption {
cert, _ := cfg.Options.GetDataBrokerCertificate() cert, _ := cfg.Options.GetDataBrokerCertificate()
return []databroker.ServerOption{ return []databroker.ServerOption{
databroker.WithSharedKey(cfg.Options.SharedKey), databroker.WithGetSharedKey(cfg.Options.GetSharedKey),
databroker.WithStorageType(cfg.Options.DataBrokerStorageType), databroker.WithStorageType(cfg.Options.DataBrokerStorageType),
databroker.WithStorageConnectionString(cfg.Options.DataBrokerStorageConnectionString), databroker.WithStorageConnectionString(cfg.Options.DataBrokerStorageConnectionString),
databroker.WithStorageCAFile(cfg.Options.DataBrokerStorageCAFile), databroker.WithStorageCAFile(cfg.Options.DataBrokerStorageCAFile),
@ -45,7 +44,7 @@ func (srv *dataBrokerServer) getOptions(cfg *config.Config) []databroker.ServerO
} }
func (srv *dataBrokerServer) setKey(cfg *config.Config) { func (srv *dataBrokerServer) setKey(cfg *config.Config) {
bs, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) bs, _ := cfg.Options.GetSharedKey()
if bs == nil { if bs == nil {
bs = make([]byte, 0) bs = make([]byte, 0)
} }

View file

@ -23,6 +23,9 @@ import (
var ( var (
errObtainCertFailed = errors.New("obtain cert failed") errObtainCertFailed = errors.New("obtain cert failed")
errRenewCertFailed = errors.New("renew cert failed") errRenewCertFailed = errors.New("renew cert failed")
// RenewCert is not thread-safe
renewCertLock sync.Mutex
) )
// Manager manages TLS certificates. // Manager manages TLS certificates.
@ -188,7 +191,9 @@ func (mgr *Manager) obtainCert(domain string, cm *certmagic.Config) (certmagic.C
func (mgr *Manager) renewCert(domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) { func (mgr *Manager) renewCert(domain string, cert certmagic.Certificate, cm *certmagic.Config) (certmagic.Certificate, error) {
expired := time.Now().After(cert.Leaf.NotAfter) expired := time.Now().After(cert.Leaf.NotAfter)
log.Info(context.TODO()).Str("domain", domain).Msg("renewing certificate") log.Info(context.TODO()).Str("domain", domain).Msg("renewing certificate")
renewCertLock.Lock()
err := cm.RenewCert(context.Background(), domain, false) err := cm.RenewCert(context.Background(), domain, false)
renewCertLock.Unlock()
if err != nil { if err != nil {
if expired { if expired {
return certmagic.Certificate{}, errRenewCertFailed return certmagic.Certificate{}, errRenewCertFailed

View file

@ -3,7 +3,6 @@ package databroker
import ( import (
"context" "context"
"crypto/tls" "crypto/tls"
"encoding/base64"
"time" "time"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
@ -61,15 +60,15 @@ func WithGetAllPageSize(pageSize int) ServerOption {
} }
} }
// WithSharedKey sets the secret in the config. // WithGetSharedKey sets the secret in the config.
func WithSharedKey(sharedKey string) ServerOption { func WithGetSharedKey(getSharedKey func() ([]byte, error)) ServerOption {
return func(cfg *serverConfig) { return func(cfg *serverConfig) {
key, err := base64.StdEncoding.DecodeString(sharedKey) sharedKey, err := getSharedKey()
if err != nil || len(key) != cryptutil.DefaultKeySize { if err != nil {
log.Error(context.TODO()).Err(err).Msgf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize) log.Error(context.TODO()).Err(err).Msgf("shared key is required and must be %d bytes long", cryptutil.DefaultKeySize)
return return
} }
cfg.secret = key cfg.secret = sharedKey
} }
} }

View file

@ -2,7 +2,6 @@ package databroker
import ( import (
"context" "context"
"encoding/base64"
"sync" "sync"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
@ -164,7 +163,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
return return
} }
sharedKey, _ := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) sharedKey, _ := cfg.Options.GetSharedKey()
connectionOptions := &grpc.Options{ connectionOptions := &grpc.Options{
Addrs: urls, Addrs: urls,
OverrideCertificateName: cfg.Options.OverrideCertificateName, OverrideCertificateName: cfg.Options.OverrideCertificateName,
@ -172,7 +171,7 @@ func (src *ConfigSource) runUpdater(cfg *config.Config) {
CAFile: cfg.Options.CAFile, CAFile: cfg.Options.CAFile,
RequestTimeout: cfg.Options.GRPCClientTimeout, RequestTimeout: cfg.Options.GRPCClientTimeout,
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GRPCInsecure, WithInsecure: cfg.Options.GetGRPCInsecure(),
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey, SignedJWTKey: sharedKey,
} }

View file

@ -125,7 +125,7 @@ func (h *Handler) Update(ctx context.Context, cfg *config.Config) {
h.mu.Lock() h.mu.Lock()
defer h.mu.Unlock() defer h.mu.Unlock()
h.key, _ = base64.StdEncoding.DecodeString(cfg.Options.SharedKey) h.key, _ = cfg.Options.GetSharedKey()
h.options = cfg.Options h.options = cfg.Options
h.policies = make(map[uint64]*config.Policy) h.policies = make(map[uint64]*config.Policy)
for i, p := range cfg.Options.Policies { for i, p := range cfg.Options.Policies {

View file

@ -33,7 +33,7 @@ func (r *Reporter) OnConfigChange(ctx context.Context, cfg *config.Config) {
log.Warn(ctx).Err(err).Msg("metrics announce to service registry is disabled") log.Warn(ctx).Err(err).Msg("metrics announce to service registry is disabled")
} }
sharedKey, err := base64.StdEncoding.DecodeString(cfg.Options.SharedKey) sharedKey, err := cfg.Options.GetSharedKey()
if err != nil { if err != nil {
log.Error(ctx).Err(err).Msg("decoding shared key") log.Error(ctx).Err(err).Msg("decoding shared key")
return return
@ -52,7 +52,7 @@ func (r *Reporter) OnConfigChange(ctx context.Context, cfg *config.Config) {
CAFile: cfg.Options.CAFile, CAFile: cfg.Options.CAFile,
RequestTimeout: cfg.Options.GRPCClientTimeout, RequestTimeout: cfg.Options.GRPCClientTimeout,
ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin, ClientDNSRoundRobin: cfg.Options.GRPCClientDNSRoundRobin,
WithInsecure: cfg.Options.GRPCInsecure, WithInsecure: cfg.Options.GetGRPCInsecure(),
InstallationID: cfg.Options.InstallationID, InstallationID: cfg.Options.InstallationID,
ServiceName: cfg.Options.Services, ServiceName: cfg.Options.Services,
SignedJWTKey: sharedKey, SignedJWTKey: sharedKey,

View file

@ -31,7 +31,12 @@ const (
// ValidateOptions checks that proper configuration settings are set to create // ValidateOptions checks that proper configuration settings are set to create
// a proper Proxy instance // a proper Proxy instance
func ValidateOptions(o *config.Options) error { func ValidateOptions(o *config.Options) error {
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil { sharedKey, err := o.GetSharedKey()
if err != nil {
return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %w", err)
}
if _, err := cryptutil.NewAEADCipher(sharedKey); err != nil {
return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %w", err) return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %w", err)
} }

View file

@ -48,6 +48,7 @@ func TestOptions_Validate(t *testing.T) {
shortCookieLength := testOptions(t) shortCookieLength := testOptions(t)
shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg==" shortCookieLength.CookieSecret = "gN3xnvfsAwfCXxnJorGLKUG4l2wC8sS8nfLMhcStPg=="
badSharedKey := testOptions(t) badSharedKey := testOptions(t)
badSharedKey.Services = "proxy"
badSharedKey.SharedKey = "" badSharedKey.SharedKey = ""
sharedKeyBadBas64 := testOptions(t) sharedKeyBadBas64 := testOptions(t)
sharedKeyBadBas64.SharedKey = "%(*@389" sharedKeyBadBas64.SharedKey = "%(*@389"

View file

@ -44,12 +44,12 @@ func newProxyStateFromConfig(cfg *config.Config) (*proxyState, error) {
} }
state := new(proxyState) state := new(proxyState)
state.sharedKey, err = base64.StdEncoding.DecodeString(cfg.Options.SharedKey) state.sharedKey, err = cfg.Options.GetSharedKey()
if err != nil { if err != nil {
return nil, err return nil, err
} }
state.sharedCipher, err = cryptutil.NewAEADCipherFromBase64(cfg.Options.SharedKey) state.sharedCipher, err = cryptutil.NewAEADCipher(state.sharedKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }