diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 3be0b0c6a..5e798b28a 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -36,7 +36,7 @@ import ( // ValidateOptions checks that configuration are complete and valid. // Returns on first error found. -func ValidateOptions(o config.Options) error { +func ValidateOptions(o *config.Options) error { if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil { return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %w", err) } @@ -118,7 +118,7 @@ type Authenticate struct { } // New validates and creates a new authenticate service from a set of Options. -func New(opts config.Options) (*Authenticate, error) { +func New(opts *config.Options) (*Authenticate, error) { if err := ValidateOptions(opts); err != nil { return nil, err } @@ -238,15 +238,13 @@ func (a *Authenticate) setAdminUsers(opts *config.Options) { } } -// UpdateOptions implements the OptionsUpdater interface and updates internal +// OnConfigChange implements the OptionsUpdater interface and updates internal // structures based on config.Options -func (a *Authenticate) UpdateOptions(opts config.Options) error { +func (a *Authenticate) OnConfigChange(cfg *config.Config) { if a == nil { - return nil + return } - log.Info().Str("checksum", fmt.Sprintf("%x", opts.Checksum())).Msg("authenticate: updating options") - a.setAdminUsers(&opts) - - return nil + log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("authenticate: updating options") + a.setAdminUsers(cfg.Options) } diff --git a/authenticate/authenticate_test.go b/authenticate/authenticate_test.go index 56c2bb4c3..0af8de710 100644 --- a/authenticate/authenticate_test.go +++ b/authenticate/authenticate_test.go @@ -71,7 +71,7 @@ func TestOptions_Validate(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := ValidateOptions(*tt.o); (err != nil) != tt.wantErr { + if err := ValidateOptions(tt.o); (err != nil) != tt.wantErr { t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -128,7 +128,7 @@ func TestNew(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - _, err := New(*tt.opts) + _, err := New(tt.opts) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return @@ -155,8 +155,8 @@ func TestIsAdmin(t *testing.T) { t.Parallel() opts := newTestOptions(t) opts.Administrators = tc.admins - a, err := New(*opts) - assert.NoError(t, a.UpdateOptions(*opts)) + a, err := New(opts) + a.OnConfigChange(&config.Config{Options: opts}) require.NoError(t, err) assert.True(t, a.isAdmin(tc.user) == tc.isAdmin) }) diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index 220d6d09c..b37ef5098 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -511,7 +511,7 @@ func TestWellKnownEndpoint(t *testing.T) { func TestJwksEndpoint(t *testing.T) { o := newTestOptions(t) o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=" - auth, err := New(*o) + auth, err := New(o) if err != nil { t.Fatal(err) } diff --git a/authorize/authorize.go b/authorize/authorize.go index 37ae18903..14203c141 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -27,11 +27,11 @@ type atomicOptions struct { value atomic.Value } -func (a *atomicOptions) Load() config.Options { - return a.value.Load().(config.Options) +func (a *atomicOptions) Load() *config.Options { + return a.value.Load().(*config.Options) } -func (a *atomicOptions) Store(options config.Options) { +func (a *atomicOptions) Store(options *config.Options) { a.value.Store(options) } @@ -63,7 +63,7 @@ type Authorize struct { } // New validates and creates a new Authorize service from a set of config options. -func New(opts config.Options) (*Authorize, error) { +func New(opts *config.Options) (*Authorize, error) { if err := validateOptions(opts); err != nil { return nil, fmt.Errorf("authorize: bad options: %w", err) } @@ -98,16 +98,11 @@ func New(opts config.Options) (*Authorize, error) { return nil, err } a.currentEncoder.Store(encoder) - - a.currentOptions.Store(config.Options{}) - err = a.UpdateOptions(opts) - if err != nil { - return nil, err - } + a.currentOptions.Store(new(config.Options)) return &a, nil } -func validateOptions(o config.Options) error { +func validateOptions(o *config.Options) error { if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil { return fmt.Errorf("bad shared_secret: %w", err) } @@ -128,19 +123,19 @@ func newPolicyEvaluator(opts *config.Options) (*evaluator.Evaluator, error) { return evaluator.New(opts) } -// UpdateOptions implements the OptionsUpdater interface and updates internal +// OnConfigChange implements the OptionsUpdater interface and updates internal // structures based on config.Options -func (a *Authorize) UpdateOptions(opts config.Options) error { +func (a *Authorize) OnConfigChange(cfg *config.Config) { if a == nil { - return nil + return } - log.Info().Str("checksum", fmt.Sprintf("%x", opts.Checksum())).Msg("authorize: updating options") - a.currentOptions.Store(opts) + log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("authorize: updating options") + a.currentOptions.Store(cfg.Options) var err error - if a.pe, err = newPolicyEvaluator(&opts); err != nil { - return err + if a.pe, err = newPolicyEvaluator(cfg.Options); err != nil { + log.Error().Err(err).Msg("authorize: failed to update policy with options") + return } - return nil } diff --git a/authorize/authorize_test.go b/authorize/authorize_test.go index 244220db6..c77ebef38 100644 --- a/authorize/authorize_test.go +++ b/authorize/authorize_test.go @@ -24,13 +24,13 @@ func TestNew(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - o := config.Options{ + o := &config.Options{ AuthenticateURL: mustParseURL("https://authN.example.com"), DataBrokerURL: mustParseURL("https://cache.example.com"), SharedKey: tt.SharedKey, Policies: tt.Policies} if tt.name == "empty options" { - o = config.Options{} + o = &config.Options{} } _, err := New(o) if (err != nil) != tt.wantErr { diff --git a/authorize/grpc_test.go b/authorize/grpc_test.go index 7f91a5075..5d46cdd05 100644 --- a/authorize/grpc_test.go +++ b/authorize/grpc_test.go @@ -240,7 +240,7 @@ func Test_handleForwardAuth(t *testing.T) { if tc.forwardAuthURL != "" { fau = mustParseURL(tc.forwardAuthURL) } - a.currentOptions.Store(config.Options{ForwardAuthURL: fau}) + a.currentOptions.Store(&config.Options{ForwardAuthURL: fau}) assert.Equal(t, tc.isForwardAuth, a.handleForwardAuth(tc.checkReq)) if tc.attrCtxHTTPReq != nil { assert.Equal(t, tc.attrCtxHTTPReq, tc.checkReq.Attributes.Request.Http) diff --git a/authorize/session.go b/authorize/session.go index 67d663320..8aa1e2eb5 100644 --- a/authorize/session.go +++ b/authorize/session.go @@ -18,7 +18,7 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" ) -func loadRawSession(req *http.Request, options config.Options, encoder encoding.MarshalUnmarshaler) ([]byte, error) { +func loadRawSession(req *http.Request, options *config.Options, encoder encoding.MarshalUnmarshaler) ([]byte, error) { var loaders []sessions.SessionLoader cookieStore, err := getCookieStore(options, encoder) if err != nil { @@ -51,7 +51,7 @@ func loadSession(encoder encoding.MarshalUnmarshaler, rawJWT []byte) (*sessions. return &s, nil } -func getCookieStore(options config.Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) { +func getCookieStore(options *config.Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) { cookieOptions := &cookie.Options{ Name: options.CookieName, Domain: options.CookieDomain, @@ -85,7 +85,7 @@ func getJWTSetCookieHeaders(cookieStore sessions.SessionStore, rawjwt []byte) (m return hdrs, nil } -func (a *Authorize) getJWTClaimHeaders(options config.Options, signedJWT string) (map[string]string, error) { +func (a *Authorize) getJWTClaimHeaders(options *config.Options, signedJWT string) (map[string]string, error) { if len(signedJWT) == 0 { return make(map[string]string), nil } diff --git a/authorize/session_test.go b/authorize/session_test.go index dff7eb61f..3c133d0c5 100644 --- a/authorize/session_test.go +++ b/authorize/session_test.go @@ -14,7 +14,7 @@ import ( ) func TestLoadSession(t *testing.T) { - opts := *config.NewDefaultOptions() + opts := config.NewDefaultOptions() encoder, err := jws.NewHS256Signer(nil, "example.com") if !assert.NoError(t, err) { return diff --git a/config/autocert.go b/config/autocert.go index 4cfebce7d..dea614ba6 100644 --- a/config/autocert.go +++ b/config/autocert.go @@ -1,16 +1,5 @@ package config -import ( - "context" - "fmt" - "net/http" - "sync" - - "github.com/caddyserver/certmagic" - - "github.com/pomerium/pomerium/internal/log" -) - // AutocertOptions contains the options to control the behavior of autocert. type AutocertOptions struct { // Enable enables fully automated certificate management including issuance @@ -35,94 +24,3 @@ type AutocertOptions struct { // defaults to $XDG_DATA_HOME/pomerium Folder string `mapstructure:"autocert_dir" yaml:"autocert_dir,omitempty"` } - -// AutocertManager manages Let's Encrypt certificates based on configuration options. -var AutocertManager = newAutocertManager() - -type autocertManager struct { - mu sync.RWMutex - certmagic *certmagic.Config - acmeMgr *certmagic.ACMEManager -} - -func newAutocertManager() *autocertManager { - mgr := &autocertManager{} - return mgr -} - -func (mgr *autocertManager) getConfig(options *Options) (*certmagic.Config, error) { - mgr.mu.Lock() - defer mgr.mu.Unlock() - - cm := mgr.certmagic - if cm == nil { - cm = certmagic.NewDefault() - cm.MustStaple = options.AutocertOptions.MustStaple - } - - cm.OnDemand = nil // disable on-demand - cm.Storage = &certmagic.FileStorage{Path: options.AutocertOptions.Folder} - // add existing certs to the cache, and staple OCSP - for _, cert := range options.Certificates { - if err := cm.CacheUnmanagedTLSCertificate(cert, nil); err != nil { - return nil, fmt.Errorf("config: failed caching cert: %w", err) - } - } - acmeMgr := certmagic.NewACMEManager(cm, certmagic.DefaultACME) - acmeMgr.Agreed = true - if options.AutocertOptions.UseStaging { - acmeMgr.CA = certmagic.LetsEncryptStagingCA - } - acmeMgr.DisableTLSALPNChallenge = true - cm.Issuer = acmeMgr - mgr.acmeMgr = acmeMgr - - return cm, nil -} - -func (mgr *autocertManager) update(options *Options) error { - if !options.AutocertOptions.Enable { - return nil - } - - cm, err := mgr.getConfig(options) - if err != nil { - return err - } - - for _, domain := range options.sourceHostnames() { - cert, err := cm.CacheManagedCertificate(domain) - if err != nil { - log.Info().Str("domain", domain).Msg("obtaining certificate") - err = cm.ObtainCert(context.Background(), domain, false) - if err != nil { - return fmt.Errorf("config: failed to obtain client certificate: %w", err) - } - cert, err = cm.CacheManagedCertificate(domain) - } - if err == nil && cert.NeedsRenewal(cm) { - log.Info().Str("domain", domain).Msg("renewing certificate") - err = cm.RenewCert(context.Background(), domain, false) - if err != nil { - return fmt.Errorf("config: failed to renew client certificate: %w", err) - } - cert, err = cm.CacheManagedCertificate(domain) - } - if err == nil { - options.Certificates = append(options.Certificates, cert.Certificate) - } else { - log.Error().Err(err).Msg("config: failed to obtain client certificate") - } - } - return nil -} - -func (mgr *autocertManager) HandleHTTPChallenge(w http.ResponseWriter, r *http.Request) bool { - mgr.mu.RLock() - acmeMgr := mgr.acmeMgr - mgr.mu.RUnlock() - if acmeMgr == nil { - return false - } - return acmeMgr.HandleHTTPChallenge(w, r) -} diff --git a/config/config_source.go b/config/config_source.go new file mode 100644 index 000000000..a01874672 --- /dev/null +++ b/config/config_source.go @@ -0,0 +1,103 @@ +package config + +import ( + "reflect" + "sync" + + "github.com/fsnotify/fsnotify" + "github.com/mitchellh/copystructure" + "github.com/spf13/viper" +) + +// Config holds pomerium configuration options. +type Config struct { + Options *Options +} + +// Clone creates a deep clone of the config. +func (cfg *Config) Clone() *Config { + return copystructure.Must(copystructure.Config{ + Copiers: map[reflect.Type]copystructure.CopierFunc{ + reflect.TypeOf((*viper.Viper)(nil)): func(i interface{}) (interface{}, error) { + return i, nil + }, + }, + }.Copy(cfg)).(*Config) +} + +// A ChangeListener is called when configuration changes. +type ChangeListener = func(*Config) + +// A ChangeDispatcher manages listeners on config changes. +type ChangeDispatcher struct { + sync.Mutex + onConfigChangeListeners []ChangeListener +} + +// Trigger triggers a change. +func (dispatcher *ChangeDispatcher) Trigger(cfg *Config) { + dispatcher.Lock() + defer dispatcher.Unlock() + + for _, li := range dispatcher.onConfigChangeListeners { + li(cfg) + } +} + +// OnConfigChange adds a listener. +func (dispatcher *ChangeDispatcher) OnConfigChange(li ChangeListener) { + dispatcher.Lock() + defer dispatcher.Unlock() + dispatcher.onConfigChangeListeners = append(dispatcher.onConfigChangeListeners, li) +} + +// A Source gets configuration. +type Source interface { + GetConfig() *Config + OnConfigChange(ChangeListener) +} + +// A FileOrEnvironmentSource retrieves config options from a file or the environment. +type FileOrEnvironmentSource struct { + configFile string + + mu sync.RWMutex + config *Config + + ChangeDispatcher +} + +// NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource. +func NewFileOrEnvironmentSource(configFile string) (*FileOrEnvironmentSource, error) { + options, err := newOptionsFromConfig(configFile) + if err != nil { + return nil, err + } + + src := &FileOrEnvironmentSource{ + configFile: configFile, + config: &Config{Options: options}, + } + options.viper.OnConfigChange(src.onConfigChange) + go options.viper.WatchConfig() + + return src, nil +} + +func (src *FileOrEnvironmentSource) onConfigChange(evt fsnotify.Event) { + src.mu.Lock() + newOptions := handleConfigUpdate(src.configFile, src.config.Options) + cfg := &Config{Options: newOptions} + src.config = cfg + src.mu.Unlock() + + src.Trigger(cfg) +} + +// GetConfig gets the config. +func (src *FileOrEnvironmentSource) GetConfig() *Config { + src.mu.RLock() + defer src.mu.RUnlock() + + return src.config +} diff --git a/config/options.go b/config/options.go index 213c31a12..b1a7ddf6f 100644 --- a/config/options.go +++ b/config/options.go @@ -15,7 +15,6 @@ import ( "time" "github.com/cespare/xxhash/v2" - "github.com/fsnotify/fsnotify" "github.com/mitchellh/hashstructure" "github.com/spf13/viper" "gopkg.in/yaml.v2" @@ -285,9 +284,9 @@ func NewDefaultOptions() *Options { return &newOpts } -// NewOptionsFromConfig builds the main binary's configuration options by parsing +// newOptionsFromConfig builds the main binary's configuration options by parsing // environmental variables and config file -func NewOptionsFromConfig(configFile string) (*Options, error) { +func newOptionsFromConfig(configFile string) (*Options, error) { o, err := optionsFromViper(configFile) if err != nil { return nil, fmt.Errorf("config: options from config file %w", err) @@ -366,13 +365,6 @@ func (o *Options) parsePolicy() error { return nil } -// OnConfigChange starts a go routine and watches for any changes. If any are -// detected, via an fsnotify event the provided function is run. -func (o *Options) OnConfigChange(run func(in fsnotify.Event)) { - go o.viper.WatchConfig() - o.viper.OnConfigChange(run) -} - func (o *Options) viperUnmarshalKey(key string, rawVal interface{}) error { return o.viper.UnmarshalKey(key, &rawVal) } @@ -457,8 +449,6 @@ func bindEnvs(o *Options, v *viper.Viper) error { // Validate ensures the Options fields are valid, and hydrated. func (o *Options) Validate() error { - var err error - if !IsValidService(o.Services) { return fmt.Errorf("config: %s is an invalid service type", o.Services) } @@ -605,47 +595,18 @@ func (o *Options) Validate() error { // strip quotes from redirect address (#811) o.HTTPRedirectAddr = strings.Trim(o.HTTPRedirectAddr, `"'`) - RedirectAndAutocertServer.update(o) - - err = AutocertManager.update(o) - if err != nil { - return fmt.Errorf("config: failed to setup autocert: %w", err) - } - // sort the certificates so we get a consistent hash sort.Slice(o.Certificates, func(i, j int) bool { return compareByteSliceSlice(o.Certificates[i].Certificate, o.Certificates[j].Certificate) < 0 }) - if !o.InsecureServer && len(o.Certificates) == 0 { + if !o.InsecureServer && len(o.Certificates) == 0 && !o.AutocertOptions.Enable { return fmt.Errorf("config: server must be run with `autocert`, " + "`insecure_server` or manually provided certificates to start") } return nil } -func (o *Options) sourceHostnames() []string { - if len(o.Policies) == 0 { - return nil - } - - dedupe := map[string]struct{}{} - for _, p := range o.Policies { - dedupe[p.Source.Hostname()] = struct{}{} - } - if o.AuthenticateURL != nil { - dedupe[o.AuthenticateURL.Hostname()] = struct{}{} - } - - var h []string - for k := range dedupe { - h = append(h, k) - } - sort.Strings(h) - - return h -} - // GetAuthenticateURL returns the AuthenticateURL in the options or localhost. func (o *Options) GetAuthenticateURL() *url.URL { if o != nil && o.AuthenticateURL != nil { @@ -697,11 +658,6 @@ func (o *Options) GetOauthOptions() oauth.Options { } } -// OptionsUpdater updates local state based on an Options struct -type OptionsUpdater interface { - UpdateOptions(Options) error -} - // Checksum returns the checksum of the current options struct func (o *Options) Checksum() uint64 { hash, err := hashstructure.Hash(o, &hashstructure.HashOptions{Hasher: xxhash.New()}) @@ -712,40 +668,13 @@ func (o *Options) Checksum() uint64 { return hash } -// WatchChanges takes a configuration file, an existing options struct, and -// updates each service in the services slice OptionsUpdater with a new set -// of options if any change is detected. It also periodically rechecks if -// any computed properties have changed. -func WatchChanges(configFile string, opt *Options, services []OptionsUpdater) { - onchange := make(chan struct{}, 1) - ticker := time.NewTicker(10 * time.Minute) // force check every 10 minutes - defer ticker.Stop() - - opt.OnConfigChange(func(fs fsnotify.Event) { - log.Info().Str("file", fs.Name).Msg("config: file changed") - select { - case onchange <- struct{}{}: - default: - } - }) - - for { - select { - case <-onchange: - case <-ticker.C: - } - - opt = handleConfigUpdate(configFile, opt, services) - } -} - // handleConfigUpdate takes configuration file, an existing options struct, and // updates each service in the services slice OptionsUpdater with a new set of // options if any change is detected. -func handleConfigUpdate(configFile string, opt *Options, services []OptionsUpdater) *Options { +func handleConfigUpdate(configFile string, opt *Options) *Options { serviceName := telemetry.ServiceName(opt.Services) - newOpt, err := NewOptionsFromConfig(configFile) + newOpt, err := newOptionsFromConfig(configFile) if err != nil { log.Error().Err(err).Msg("config: could not reload configuration") metrics.SetConfigInfo(serviceName, false) @@ -761,19 +690,6 @@ func handleConfigUpdate(configFile string, opt *Options, services []OptionsUpdat return opt } - var updateFailed bool - for _, service := range services { - if err := service.UpdateOptions(*newOpt); err != nil { - log.Error().Err(err).Msg("config: could not update options") - updateFailed = true - metrics.SetConfigInfo(serviceName, false) - } - } - - if !updateFailed { - metrics.SetConfigInfo(serviceName, true) - metrics.SetConfigChecksum(serviceName, newOptChecksum) - } return newOpt } diff --git a/config/options_test.go b/config/options_test.go index 8317c17b9..282a26cef 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -265,7 +265,7 @@ func TestOptionsFromViper(t *testing.T) { return } if diff := cmp.Diff(got, tt.want, opts...); diff != "" { - t.Errorf("NewOptionsFromConfig() = %s", diff) + t.Errorf("newOptionsFromConfig() = %s", diff) } }) } @@ -305,9 +305,9 @@ func Test_NewOptionsFromConfigEnvVar(t *testing.T) { os.Setenv(k, v) defer os.Unsetenv(k) } - _, err := NewOptionsFromConfig("") + _, err := newOptionsFromConfig("") if (err != nil) != tt.wantErr { - t.Errorf("NewOptionsFromConfig() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("newOptionsFromConfig() error = %v, wantErr %v", err, tt.wantErr) return } }) @@ -327,7 +327,7 @@ func Test_AutoCertOptionsFromEnvVar(t *testing.T) { defer os.Unsetenv(k) } - o, err := NewOptionsFromConfig("") + o, err := newOptionsFromConfig("") if err != nil { t.Fatal(err) } @@ -343,160 +343,6 @@ func Test_AutoCertOptionsFromEnvVar(t *testing.T) { } -type mockService struct { - fail bool - Updated bool -} - -func (m *mockService) UpdateOptions(o Options) error { - - m.Updated = true - if m.fail { - return fmt.Errorf("failed") - } - return nil -} - -func Test_HandleConfigUpdate(t *testing.T) { - tests := []struct { - name string - oldEnvKeyPairs map[string]string - newEnvKeyPairs map[string]string - service *mockService - wantUpdate bool - }{ - {"good", - map[string]string{ - "INSECURE_SERVER": "true", - "AUTHENTICATE_SERVICE_URL": "https://authenticate.example", - "AUTHORIZE_SERVICE_URL": "https://authorize.example"}, - map[string]string{ - "INSECURE_SERVER": "true", - "AUTHENTICATE_SERVICE_URL": "https://authenticate.example", - "AUTHORIZE_SERVICE_URL": "https://authorize.example"}, - &mockService{fail: false}, - true}, - {"good set debug", - map[string]string{ - "INSECURE_SERVER": "true", - "AUTHENTICATE_SERVICE_URL": "https://authenticate.example", - "AUTHORIZE_SERVICE_URL": "https://authorize.example"}, - map[string]string{ - "POMERIUM_DEBUG": "true", - "INSECURE_SERVER": "true", - "AUTHENTICATE_SERVICE_URL": "https://authenticate.example", - "AUTHORIZE_SERVICE_URL": "https://authorize.example"}, - &mockService{fail: false}, - true}, - {"bad", - map[string]string{ - "INSECURE_SERVER": "true", - "AUTHENTICATE_SERVICE_URL": "https://authenticate.example", - "AUTHORIZE_SERVICE_URL": "https://authorize.example"}, - map[string]string{ - "INSECURE_SERVER": "true", - "AUTHENTICATE_SERVICE_URL": "https://authenticate.example", - "AUTHORIZE_SERVICE_URL": "https://authorize.example"}, - &mockService{fail: true}, - true}, - {"bad policy file unmarshal error", - map[string]string{ - "INSECURE_SERVER": "true", - "AUTHENTICATE_SERVICE_URL": "https://authenticate.example", - "AUTHORIZE_SERVICE_URL": "https://authorize.example"}, - map[string]string{ - "POLICY": base64.StdEncoding.EncodeToString([]byte("{json:}")), - "INSECURE_SERVER": "true", - "AUTHENTICATE_SERVICE_URL": "https://authenticate.example", - "AUTHORIZE_SERVICE_URL": "https://authorize.example"}, - &mockService{fail: false}, - false}, - {"bad header key", - map[string]string{ - "INSECURE_SERVER": "true", - "AUTHENTICATE_SERVICE_URL": "https://authenticate.example", - "AUTHORIZE_SERVICE_URL": "https://authorize.example"}, - map[string]string{ - "SERVICES": "error", - "INSECURE_SERVER": "true", - "AUTHENTICATE_SERVICE_URL": "https://authenticate.example", - "AUTHORIZE_SERVICE_URL": "https://authorize.example"}, - &mockService{fail: false}, - false}, - {"bad header header value", - map[string]string{ - "INSECURE_SERVER": "true", - "AUTHENTICATE_SERVICE_URL": "https://authenticate.example", - "AUTHORIZE_SERVICE_URL": "https://authorize.example"}, - map[string]string{ - "HEADERS": "x;y;z", - "INSECURE_SERVER": "true", - "AUTHENTICATE_SERVICE_URL": "https://authenticate.example", - "AUTHORIZE_SERVICE_URL": "https://authorize.example"}, - &mockService{fail: false}, - false}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - for k, v := range tt.oldEnvKeyPairs { - os.Setenv(k, v) - } - oldOpts, err := NewOptionsFromConfig("") - if err != nil { - t.Fatal(err) - } - for k := range tt.oldEnvKeyPairs { - os.Unsetenv(k) - } - for k, v := range tt.newEnvKeyPairs { - os.Setenv(k, v) - defer os.Unsetenv(k) - } - handleConfigUpdate("", oldOpts, []OptionsUpdater{tt.service}) - if tt.service.Updated != tt.wantUpdate { - t.Errorf("Failed to update config on service") - } - }) - } -} - -func TestOptions_sourceHostnames(t *testing.T) { - t.Parallel() - testOptions := func() *Options { - o := NewDefaultOptions() - o.SharedKey = "test" - o.Services = "all" - o.InsecureServer = true - return o - } - tests := []struct { - name string - policies []Policy - authenticateURL string - want []string - }{ - {"empty", []Policy{}, "", nil}, - {"good no authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "", []string{"from.example"}}, - {"good with authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "https://authn.example.com", []string{"authn.example.com", "from.example"}}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - o := testOptions() - o.Policies = tt.policies - o.AuthenticateURLString = tt.authenticateURL - err := o.Validate() - if err != nil { - t.Fatal(err) - } - got := o.sourceHostnames() - if diff := cmp.Diff(got, tt.want); diff != "" { - t.Errorf("Options.sourceHostnames() = %v", diff) - } - }) - } -} - func TestHTTPRedirectAddressStripQuotes(t *testing.T) { o := NewDefaultOptions() o.InsecureServer = true diff --git a/config/redirect.go b/config/redirect.go deleted file mode 100644 index 34afe4399..000000000 --- a/config/redirect.go +++ /dev/null @@ -1,60 +0,0 @@ -package config - -import ( - "net/http" - "sync" - - "github.com/pomerium/pomerium/internal/httputil" - "github.com/pomerium/pomerium/internal/log" -) - -// RedirectAndAutocertServer is an HTTP server which handles redirecting to HTTPS and autocerts. -var RedirectAndAutocertServer = newRedirectAndAutoCertServer() - -type redirectAndAutoCertServer struct { - mu sync.Mutex - srv *http.Server -} - -func newRedirectAndAutoCertServer() *redirectAndAutoCertServer { - return &redirectAndAutoCertServer{} -} - -func (srv *redirectAndAutoCertServer) update(options *Options) { - srv.mu.Lock() - defer srv.mu.Unlock() - - if srv.srv != nil { - // nothing to do if the address hasn't changed - if srv.srv.Addr == options.HTTPRedirectAddr { - return - } - // close immediately, don't care about the error - _ = srv.srv.Close() - srv.srv = nil - } - - if options.HTTPRedirectAddr == "" { - return - } - - redirect := httputil.RedirectHandler() - - hsrv := &http.Server{ - Addr: options.HTTPRedirectAddr, - Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if AutocertManager.HandleHTTPChallenge(w, r) { - return - } - redirect.ServeHTTP(w, r) - }), - } - go func() { - log.Info().Str("addr", hsrv.Addr).Msg("starting http redirect server") - err := hsrv.ListenAndServe() - if err != nil { - log.Error().Err(err).Msg("failed to run http redirect server") - } - }() - srv.srv = hsrv -} diff --git a/go.mod b/go.mod index 365fc1637..05fafd243 100644 --- a/go.mod +++ b/go.mod @@ -28,6 +28,7 @@ require ( github.com/hashicorp/memberlist v0.2.2 github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect github.com/lithammer/shortuuid/v3 v3.0.4 + github.com/mitchellh/copystructure v1.0.0 github.com/mitchellh/hashstructure v1.0.0 github.com/natefinch/atomic v0.0.0-20200526193002-18c0533a5b09 github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce diff --git a/go.sum b/go.sum index c7cfe8383..5ab55ad63 100644 --- a/go.sum +++ b/go.sum @@ -342,6 +342,8 @@ github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKju github.com/miekg/dns v1.1.27 h1:aEH/kqUzUxGJ/UHcEKdJY+ugH6WEzsEBBSPa8zuy1aM= github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= +github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ= +github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= @@ -353,6 +355,8 @@ github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0Qu github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= +github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/IfikLNY= +github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= diff --git a/internal/autocert/manager.go b/internal/autocert/manager.go new file mode 100644 index 000000000..88670b153 --- /dev/null +++ b/internal/autocert/manager.go @@ -0,0 +1,198 @@ +// Package autocert implements automatic management of TLS certificates. +package autocert + +import ( + "context" + "fmt" + "net/http" + "sort" + "sync" + + "github.com/caddyserver/certmagic" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/log" +) + +// Manager manages TLS certificates. +type Manager struct { + src config.Source + + mu sync.RWMutex + config *config.Config + certmagic *certmagic.Config + acmeMgr *certmagic.ACMEManager + srv *http.Server + + config.ChangeDispatcher +} + +// New creates a new autocert manager. +func New(src config.Source) (*Manager, error) { + mgr := &Manager{ + src: src, + certmagic: certmagic.NewDefault(), + } + err := mgr.update(src.GetConfig()) + if err != nil { + return nil, err + } + mgr.src.OnConfigChange(func(cfg *config.Config) { + err := mgr.update(cfg) + if err != nil { + log.Error().Err(err).Msg("autocert: error updating config") + return + } + + mgr.Trigger(mgr.GetConfig()) + }) + return mgr, nil +} + +func (mgr *Manager) getCertMagicConfig(options *config.Options) (*certmagic.Config, error) { + mgr.certmagic.MustStaple = options.AutocertOptions.MustStaple + mgr.certmagic.OnDemand = nil // disable on-demand + mgr.certmagic.Storage = &certmagic.FileStorage{Path: options.AutocertOptions.Folder} + // add existing certs to the cache, and staple OCSP + for _, cert := range options.Certificates { + if err := mgr.certmagic.CacheUnmanagedTLSCertificate(cert, nil); err != nil { + return nil, fmt.Errorf("config: failed caching cert: %w", err) + } + } + acmeMgr := certmagic.NewACMEManager(mgr.certmagic, certmagic.DefaultACME) + acmeMgr.Agreed = true + if options.AutocertOptions.UseStaging { + acmeMgr.CA = certmagic.LetsEncryptStagingCA + } + acmeMgr.DisableTLSALPNChallenge = true + mgr.certmagic.Issuer = acmeMgr + mgr.acmeMgr = acmeMgr + + return mgr.certmagic, nil +} + +func (mgr *Manager) update(cfg *config.Config) error { + cfg = cfg.Clone() + + mgr.mu.Lock() + defer mgr.mu.Unlock() + defer func() { mgr.config = cfg }() + + mgr.updateServer(cfg) + return mgr.updateAutocert(cfg) +} + +func (mgr *Manager) updateAutocert(cfg *config.Config) error { + if !cfg.Options.AutocertOptions.Enable { + return nil + } + + cm, err := mgr.getCertMagicConfig(cfg.Options) + if err != nil { + return err + } + + for _, domain := range sourceHostnames(cfg) { + cert, err := cm.CacheManagedCertificate(domain) + if err != nil { + log.Info().Str("domain", domain).Msg("obtaining certificate") + err = cm.ObtainCert(context.Background(), domain, false) + if err != nil { + return fmt.Errorf("autocert: failed to obtain client certificate: %w", err) + } + cert, err = cm.CacheManagedCertificate(domain) + } + if err == nil && cert.NeedsRenewal(cm) { + log.Info().Str("domain", domain).Msg("renewing certificate") + err = cm.RenewCert(context.Background(), domain, false) + if err != nil { + return fmt.Errorf("autocert: failed to renew client certificate: %w", err) + } + cert, err = cm.CacheManagedCertificate(domain) + } + if err == nil { + cfg.Options.Certificates = append(cfg.Options.Certificates, cert.Certificate) + } else { + log.Error().Err(err).Msg("autocert: failed to obtain client certificate") + } + } + + return nil +} + +func (mgr *Manager) updateServer(cfg *config.Config) { + if mgr.srv != nil { + // nothing to do if the address hasn't changed + if mgr.srv.Addr == cfg.Options.HTTPRedirectAddr { + return + } + // close immediately, don't care about the error + _ = mgr.srv.Close() + mgr.srv = nil + } + + if cfg.Options.HTTPRedirectAddr == "" { + return + } + + redirect := httputil.RedirectHandler() + + hsrv := &http.Server{ + Addr: cfg.Options.HTTPRedirectAddr, + Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if mgr.handleHTTPChallenge(w, r) { + return + } + redirect.ServeHTTP(w, r) + }), + } + go func() { + log.Info().Str("addr", hsrv.Addr).Msg("starting http redirect server") + err := hsrv.ListenAndServe() + if err != nil { + log.Error().Err(err).Msg("failed to run http redirect server") + } + }() + mgr.srv = hsrv +} + +func (mgr *Manager) handleHTTPChallenge(w http.ResponseWriter, r *http.Request) bool { + mgr.mu.RLock() + acmeMgr := mgr.acmeMgr + mgr.mu.RUnlock() + if acmeMgr == nil { + return false + } + return acmeMgr.HandleHTTPChallenge(w, r) +} + +// GetConfig gets the config. +func (mgr *Manager) GetConfig() *config.Config { + mgr.mu.RLock() + defer mgr.mu.RUnlock() + + return mgr.config +} + +func sourceHostnames(cfg *config.Config) []string { + if len(cfg.Options.Policies) == 0 { + return nil + } + + dedupe := map[string]struct{}{} + for _, p := range cfg.Options.Policies { + dedupe[p.Source.Hostname()] = struct{}{} + } + if cfg.Options.AuthenticateURL != nil { + dedupe[cfg.Options.AuthenticateURL.Hostname()] = struct{}{} + } + + var h []string + for k := range dedupe { + h = append(h, k) + } + sort.Strings(h) + + return h +} diff --git a/internal/cmd/pomerium/pomerium.go b/internal/cmd/pomerium/pomerium.go index e3251ef19..bc916c751 100644 --- a/internal/cmd/pomerium/pomerium.go +++ b/internal/cmd/pomerium/pomerium.go @@ -11,8 +11,6 @@ import ( "sync" "syscall" - "github.com/pomerium/pomerium/internal/telemetry" - envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2" "golang.org/x/sync/errgroup" @@ -20,10 +18,12 @@ import ( "github.com/pomerium/pomerium/authorize" "github.com/pomerium/pomerium/cache" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/autocert" "github.com/pomerium/pomerium/internal/controlplane" "github.com/pomerium/pomerium/internal/envoy" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" + "github.com/pomerium/pomerium/internal/telemetry" "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" @@ -33,31 +33,36 @@ import ( // Run runs the main pomerium application. func Run(ctx context.Context, configFile string) error { - opt, err := config.NewOptionsFromConfig(configFile) + var src config.Source + + src, err := config.NewFileOrEnvironmentSource(configFile) if err != nil { return err } - var optionsUpdaters []config.OptionsUpdater + + src, err = autocert.New(src) + if err != nil { + return err + } + + cfg := src.GetConfig() log.Info().Str("version", version.FullVersion()).Msg("cmd/pomerium") - if err := setupMetrics(ctx, opt); err != nil { + if err := setupMetrics(ctx, cfg.Options); err != nil { return err } - if err := setupTracing(ctx, opt); err != nil { + if err := setupTracing(ctx, cfg.Options); err != nil { return err } // setup the control plane - controlPlane, err := controlplane.NewServer(opt.Services) + controlPlane, err := controlplane.NewServer(cfg.Options.Services) if err != nil { return fmt.Errorf("error creating control plane: %w", err) } - optionsUpdaters = append(optionsUpdaters, controlPlane) - err = controlPlane.UpdateOptions(*opt) - if err != nil { - return fmt.Errorf("error updating control plane options: %w", err) - } + src.OnConfigChange(controlPlane.OnConfigChange) + controlPlane.OnConfigChange(cfg) _, grpcPort, _ := net.SplitHostPort(controlPlane.GRPCListener.Addr().String()) _, httpPort, _ := net.SplitHostPort(controlPlane.HTTPListener.Addr().String()) @@ -66,36 +71,33 @@ func Run(ctx context.Context, configFile string) error { log.Info().Str("port", httpPort).Msg("HTTP server started") // create envoy server - envoyServer, err := envoy.NewServer(opt, grpcPort, httpPort) + envoyServer, err := envoy.NewServer(cfg.Options, grpcPort, httpPort) if err != nil { return fmt.Errorf("error creating envoy server: %w", err) } // add services - if err := setupAuthenticate(opt, controlPlane, &optionsUpdaters); err != nil { + if err := setupAuthenticate(src, cfg, controlPlane); err != nil { return err } var authorizeServer *authorize.Authorize - if config.IsAuthorize(opt.Services) { - authorizeServer, err = setupAuthorize(opt, controlPlane, &optionsUpdaters) + if config.IsAuthorize(cfg.Options.Services) { + authorizeServer, err = setupAuthorize(src, cfg, controlPlane) if err != nil { return err } } var cacheServer *cache.Cache - if config.IsCache(opt.Services) { - cacheServer, err = setupCache(opt, controlPlane) + if config.IsCache(cfg.Options.Services) { + cacheServer, err = setupCache(cfg.Options, controlPlane) if err != nil { return err } } - if err := setupProxy(opt, controlPlane); err != nil { + if err := setupProxy(cfg.Options, controlPlane); err != nil { return err } - // start the config change listener - go config.WatchChanges(configFile, opt, optionsUpdaters) - ctx, cancel := context.WithCancel(ctx) go func(ctx context.Context) { ch := make(chan os.Signal, 2) @@ -132,21 +134,21 @@ func Run(ctx context.Context, configFile string) error { return eg.Wait() } -func setupAuthenticate(opt *config.Options, controlPlane *controlplane.Server, optionsUpdaters *[]config.OptionsUpdater) error { - if !config.IsAuthenticate(opt.Services) { +func setupAuthenticate(src config.Source, cfg *config.Config, controlPlane *controlplane.Server) error { + if !config.IsAuthenticate(cfg.Options.Services) { return nil } - svc, err := authenticate.New(*opt) + svc, err := authenticate.New(cfg.Options) if err != nil { return fmt.Errorf("error creating authenticate service: %w", err) } - *optionsUpdaters = append(*optionsUpdaters, svc) - err = svc.UpdateOptions(*opt) + src.OnConfigChange(svc.OnConfigChange) + svc.OnConfigChange(cfg) if err != nil { return fmt.Errorf("error updating authenticate options: %w", err) } - host := urlutil.StripPort(opt.GetAuthenticateURL().Host) + host := urlutil.StripPort(cfg.Options.GetAuthenticateURL().Host) sr := controlPlane.HTTPRouter.Host(host).Subrouter() svc.Mount(sr) log.Info().Str("host", host).Msg("enabled authenticate service") @@ -154,20 +156,16 @@ func setupAuthenticate(opt *config.Options, controlPlane *controlplane.Server, o return nil } -func setupAuthorize(opt *config.Options, controlPlane *controlplane.Server, optionsUpdaters *[]config.OptionsUpdater) (*authorize.Authorize, error) { - svc, err := authorize.New(*opt) +func setupAuthorize(src config.Source, cfg *config.Config, controlPlane *controlplane.Server) (*authorize.Authorize, error) { + svc, err := authorize.New(cfg.Options) if err != nil { return nil, fmt.Errorf("error creating authorize service: %w", err) } envoy_service_auth_v2.RegisterAuthorizationServer(controlPlane.GRPCServer, svc) log.Info().Msg("enabled authorize service") - - *optionsUpdaters = append(*optionsUpdaters, svc) - err = svc.UpdateOptions(*opt) - if err != nil { - return nil, fmt.Errorf("error updating authorize options: %w", err) - } + src.OnConfigChange(svc.OnConfigChange) + svc.OnConfigChange(cfg) return svc, nil } diff --git a/internal/controlplane/server.go b/internal/controlplane/server.go index 6919ae297..8eb1f47eb 100644 --- a/internal/controlplane/server.go +++ b/internal/controlplane/server.go @@ -141,17 +141,16 @@ func (srv *Server) Run(ctx context.Context) error { return eg.Wait() } -// UpdateOptions updates the pomerium config options. -func (srv *Server) UpdateOptions(options config.Options) error { +// OnConfigChange updates the pomerium config options. +func (srv *Server) OnConfigChange(cfg *config.Config) { select { case <-srv.configUpdated: default: } prev := srv.currentConfig.Load() srv.currentConfig.Store(versionedOptions{ - Options: options, + Options: *cfg.Options, version: prev.version + 1, }) srv.configUpdated <- struct{}{} - return nil }