mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-30 19:06:33 +02:00
options refactor (#1088)
* refactor config loading * wip * move autocert to its own config source * refactor options updaters * fix stuttering * fix autocert validate check
This commit is contained in:
parent
eef4c6f2c0
commit
d3a7ee38be
18 changed files with 385 additions and 489 deletions
|
@ -36,7 +36,7 @@ import (
|
||||||
|
|
||||||
// ValidateOptions checks that configuration are complete and valid.
|
// ValidateOptions checks that configuration are complete and valid.
|
||||||
// Returns on first error found.
|
// Returns on first error found.
|
||||||
func ValidateOptions(o config.Options) error {
|
func ValidateOptions(o *config.Options) error {
|
||||||
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
|
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
|
||||||
return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %w", err)
|
return fmt.Errorf("authenticate: 'SHARED_SECRET' invalid: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -118,7 +118,7 @@ type Authenticate struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New validates and creates a new authenticate service from a set of Options.
|
// New validates and creates a new authenticate service from a set of Options.
|
||||||
func New(opts config.Options) (*Authenticate, error) {
|
func New(opts *config.Options) (*Authenticate, error) {
|
||||||
if err := ValidateOptions(opts); err != nil {
|
if err := ValidateOptions(opts); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -238,15 +238,13 @@ func (a *Authenticate) setAdminUsers(opts *config.Options) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateOptions implements the OptionsUpdater interface and updates internal
|
// OnConfigChange implements the OptionsUpdater interface and updates internal
|
||||||
// structures based on config.Options
|
// structures based on config.Options
|
||||||
func (a *Authenticate) UpdateOptions(opts config.Options) error {
|
func (a *Authenticate) OnConfigChange(cfg *config.Config) {
|
||||||
if a == nil {
|
if a == nil {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Str("checksum", fmt.Sprintf("%x", opts.Checksum())).Msg("authenticate: updating options")
|
log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("authenticate: updating options")
|
||||||
a.setAdminUsers(&opts)
|
a.setAdminUsers(cfg.Options)
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -71,7 +71,7 @@ func TestOptions_Validate(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
if err := ValidateOptions(*tt.o); (err != nil) != tt.wantErr {
|
if err := ValidateOptions(tt.o); (err != nil) != tt.wantErr {
|
||||||
t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("Options.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -128,7 +128,7 @@ func TestNew(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
_, err := New(*tt.opts)
|
_, err := New(tt.opts)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -155,8 +155,8 @@ func TestIsAdmin(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
opts := newTestOptions(t)
|
opts := newTestOptions(t)
|
||||||
opts.Administrators = tc.admins
|
opts.Administrators = tc.admins
|
||||||
a, err := New(*opts)
|
a, err := New(opts)
|
||||||
assert.NoError(t, a.UpdateOptions(*opts))
|
a.OnConfigChange(&config.Config{Options: opts})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, a.isAdmin(tc.user) == tc.isAdmin)
|
assert.True(t, a.isAdmin(tc.user) == tc.isAdmin)
|
||||||
})
|
})
|
||||||
|
|
|
@ -511,7 +511,7 @@ func TestWellKnownEndpoint(t *testing.T) {
|
||||||
func TestJwksEndpoint(t *testing.T) {
|
func TestJwksEndpoint(t *testing.T) {
|
||||||
o := newTestOptions(t)
|
o := newTestOptions(t)
|
||||||
o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
|
o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
|
||||||
auth, err := New(*o)
|
auth, err := New(o)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -27,11 +27,11 @@ type atomicOptions struct {
|
||||||
value atomic.Value
|
value atomic.Value
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *atomicOptions) Load() config.Options {
|
func (a *atomicOptions) Load() *config.Options {
|
||||||
return a.value.Load().(config.Options)
|
return a.value.Load().(*config.Options)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *atomicOptions) Store(options config.Options) {
|
func (a *atomicOptions) Store(options *config.Options) {
|
||||||
a.value.Store(options)
|
a.value.Store(options)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -63,7 +63,7 @@ type Authorize struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
// New validates and creates a new Authorize service from a set of config options.
|
// New validates and creates a new Authorize service from a set of config options.
|
||||||
func New(opts config.Options) (*Authorize, error) {
|
func New(opts *config.Options) (*Authorize, error) {
|
||||||
if err := validateOptions(opts); err != nil {
|
if err := validateOptions(opts); err != nil {
|
||||||
return nil, fmt.Errorf("authorize: bad options: %w", err)
|
return nil, fmt.Errorf("authorize: bad options: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -98,16 +98,11 @@ func New(opts config.Options) (*Authorize, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
a.currentEncoder.Store(encoder)
|
a.currentEncoder.Store(encoder)
|
||||||
|
a.currentOptions.Store(new(config.Options))
|
||||||
a.currentOptions.Store(config.Options{})
|
|
||||||
err = a.UpdateOptions(opts)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &a, nil
|
return &a, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateOptions(o config.Options) error {
|
func validateOptions(o *config.Options) error {
|
||||||
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
|
if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil {
|
||||||
return fmt.Errorf("bad shared_secret: %w", err)
|
return fmt.Errorf("bad shared_secret: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -128,19 +123,19 @@ func newPolicyEvaluator(opts *config.Options) (*evaluator.Evaluator, error) {
|
||||||
return evaluator.New(opts)
|
return evaluator.New(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateOptions implements the OptionsUpdater interface and updates internal
|
// OnConfigChange implements the OptionsUpdater interface and updates internal
|
||||||
// structures based on config.Options
|
// structures based on config.Options
|
||||||
func (a *Authorize) UpdateOptions(opts config.Options) error {
|
func (a *Authorize) OnConfigChange(cfg *config.Config) {
|
||||||
if a == nil {
|
if a == nil {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info().Str("checksum", fmt.Sprintf("%x", opts.Checksum())).Msg("authorize: updating options")
|
log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("authorize: updating options")
|
||||||
a.currentOptions.Store(opts)
|
a.currentOptions.Store(cfg.Options)
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
if a.pe, err = newPolicyEvaluator(&opts); err != nil {
|
if a.pe, err = newPolicyEvaluator(cfg.Options); err != nil {
|
||||||
return err
|
log.Error().Err(err).Msg("authorize: failed to update policy with options")
|
||||||
|
return
|
||||||
}
|
}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,13 +24,13 @@ func TestNew(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
o := config.Options{
|
o := &config.Options{
|
||||||
AuthenticateURL: mustParseURL("https://authN.example.com"),
|
AuthenticateURL: mustParseURL("https://authN.example.com"),
|
||||||
DataBrokerURL: mustParseURL("https://cache.example.com"),
|
DataBrokerURL: mustParseURL("https://cache.example.com"),
|
||||||
SharedKey: tt.SharedKey,
|
SharedKey: tt.SharedKey,
|
||||||
Policies: tt.Policies}
|
Policies: tt.Policies}
|
||||||
if tt.name == "empty options" {
|
if tt.name == "empty options" {
|
||||||
o = config.Options{}
|
o = &config.Options{}
|
||||||
}
|
}
|
||||||
_, err := New(o)
|
_, err := New(o)
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
|
|
|
@ -240,7 +240,7 @@ func Test_handleForwardAuth(t *testing.T) {
|
||||||
if tc.forwardAuthURL != "" {
|
if tc.forwardAuthURL != "" {
|
||||||
fau = mustParseURL(tc.forwardAuthURL)
|
fau = mustParseURL(tc.forwardAuthURL)
|
||||||
}
|
}
|
||||||
a.currentOptions.Store(config.Options{ForwardAuthURL: fau})
|
a.currentOptions.Store(&config.Options{ForwardAuthURL: fau})
|
||||||
assert.Equal(t, tc.isForwardAuth, a.handleForwardAuth(tc.checkReq))
|
assert.Equal(t, tc.isForwardAuth, a.handleForwardAuth(tc.checkReq))
|
||||||
if tc.attrCtxHTTPReq != nil {
|
if tc.attrCtxHTTPReq != nil {
|
||||||
assert.Equal(t, tc.attrCtxHTTPReq, tc.checkReq.Attributes.Request.Http)
|
assert.Equal(t, tc.attrCtxHTTPReq, tc.checkReq.Attributes.Request.Http)
|
||||||
|
|
|
@ -18,7 +18,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadRawSession(req *http.Request, options config.Options, encoder encoding.MarshalUnmarshaler) ([]byte, error) {
|
func loadRawSession(req *http.Request, options *config.Options, encoder encoding.MarshalUnmarshaler) ([]byte, error) {
|
||||||
var loaders []sessions.SessionLoader
|
var loaders []sessions.SessionLoader
|
||||||
cookieStore, err := getCookieStore(options, encoder)
|
cookieStore, err := getCookieStore(options, encoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -51,7 +51,7 @@ func loadSession(encoder encoding.MarshalUnmarshaler, rawJWT []byte) (*sessions.
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCookieStore(options config.Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) {
|
func getCookieStore(options *config.Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) {
|
||||||
cookieOptions := &cookie.Options{
|
cookieOptions := &cookie.Options{
|
||||||
Name: options.CookieName,
|
Name: options.CookieName,
|
||||||
Domain: options.CookieDomain,
|
Domain: options.CookieDomain,
|
||||||
|
@ -85,7 +85,7 @@ func getJWTSetCookieHeaders(cookieStore sessions.SessionStore, rawjwt []byte) (m
|
||||||
return hdrs, nil
|
return hdrs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authorize) getJWTClaimHeaders(options config.Options, signedJWT string) (map[string]string, error) {
|
func (a *Authorize) getJWTClaimHeaders(options *config.Options, signedJWT string) (map[string]string, error) {
|
||||||
if len(signedJWT) == 0 {
|
if len(signedJWT) == 0 {
|
||||||
return make(map[string]string), nil
|
return make(map[string]string), nil
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestLoadSession(t *testing.T) {
|
func TestLoadSession(t *testing.T) {
|
||||||
opts := *config.NewDefaultOptions()
|
opts := config.NewDefaultOptions()
|
||||||
encoder, err := jws.NewHS256Signer(nil, "example.com")
|
encoder, err := jws.NewHS256Signer(nil, "example.com")
|
||||||
if !assert.NoError(t, err) {
|
if !assert.NoError(t, err) {
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,16 +1,5 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/caddyserver/certmagic"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
// AutocertOptions contains the options to control the behavior of autocert.
|
// AutocertOptions contains the options to control the behavior of autocert.
|
||||||
type AutocertOptions struct {
|
type AutocertOptions struct {
|
||||||
// Enable enables fully automated certificate management including issuance
|
// Enable enables fully automated certificate management including issuance
|
||||||
|
@ -35,94 +24,3 @@ type AutocertOptions struct {
|
||||||
// defaults to $XDG_DATA_HOME/pomerium
|
// defaults to $XDG_DATA_HOME/pomerium
|
||||||
Folder string `mapstructure:"autocert_dir" yaml:"autocert_dir,omitempty"`
|
Folder string `mapstructure:"autocert_dir" yaml:"autocert_dir,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// AutocertManager manages Let's Encrypt certificates based on configuration options.
|
|
||||||
var AutocertManager = newAutocertManager()
|
|
||||||
|
|
||||||
type autocertManager struct {
|
|
||||||
mu sync.RWMutex
|
|
||||||
certmagic *certmagic.Config
|
|
||||||
acmeMgr *certmagic.ACMEManager
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAutocertManager() *autocertManager {
|
|
||||||
mgr := &autocertManager{}
|
|
||||||
return mgr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgr *autocertManager) getConfig(options *Options) (*certmagic.Config, error) {
|
|
||||||
mgr.mu.Lock()
|
|
||||||
defer mgr.mu.Unlock()
|
|
||||||
|
|
||||||
cm := mgr.certmagic
|
|
||||||
if cm == nil {
|
|
||||||
cm = certmagic.NewDefault()
|
|
||||||
cm.MustStaple = options.AutocertOptions.MustStaple
|
|
||||||
}
|
|
||||||
|
|
||||||
cm.OnDemand = nil // disable on-demand
|
|
||||||
cm.Storage = &certmagic.FileStorage{Path: options.AutocertOptions.Folder}
|
|
||||||
// add existing certs to the cache, and staple OCSP
|
|
||||||
for _, cert := range options.Certificates {
|
|
||||||
if err := cm.CacheUnmanagedTLSCertificate(cert, nil); err != nil {
|
|
||||||
return nil, fmt.Errorf("config: failed caching cert: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
acmeMgr := certmagic.NewACMEManager(cm, certmagic.DefaultACME)
|
|
||||||
acmeMgr.Agreed = true
|
|
||||||
if options.AutocertOptions.UseStaging {
|
|
||||||
acmeMgr.CA = certmagic.LetsEncryptStagingCA
|
|
||||||
}
|
|
||||||
acmeMgr.DisableTLSALPNChallenge = true
|
|
||||||
cm.Issuer = acmeMgr
|
|
||||||
mgr.acmeMgr = acmeMgr
|
|
||||||
|
|
||||||
return cm, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgr *autocertManager) update(options *Options) error {
|
|
||||||
if !options.AutocertOptions.Enable {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
cm, err := mgr.getConfig(options)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, domain := range options.sourceHostnames() {
|
|
||||||
cert, err := cm.CacheManagedCertificate(domain)
|
|
||||||
if err != nil {
|
|
||||||
log.Info().Str("domain", domain).Msg("obtaining certificate")
|
|
||||||
err = cm.ObtainCert(context.Background(), domain, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("config: failed to obtain client certificate: %w", err)
|
|
||||||
}
|
|
||||||
cert, err = cm.CacheManagedCertificate(domain)
|
|
||||||
}
|
|
||||||
if err == nil && cert.NeedsRenewal(cm) {
|
|
||||||
log.Info().Str("domain", domain).Msg("renewing certificate")
|
|
||||||
err = cm.RenewCert(context.Background(), domain, false)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("config: failed to renew client certificate: %w", err)
|
|
||||||
}
|
|
||||||
cert, err = cm.CacheManagedCertificate(domain)
|
|
||||||
}
|
|
||||||
if err == nil {
|
|
||||||
options.Certificates = append(options.Certificates, cert.Certificate)
|
|
||||||
} else {
|
|
||||||
log.Error().Err(err).Msg("config: failed to obtain client certificate")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (mgr *autocertManager) HandleHTTPChallenge(w http.ResponseWriter, r *http.Request) bool {
|
|
||||||
mgr.mu.RLock()
|
|
||||||
acmeMgr := mgr.acmeMgr
|
|
||||||
mgr.mu.RUnlock()
|
|
||||||
if acmeMgr == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return acmeMgr.HandleHTTPChallenge(w, r)
|
|
||||||
}
|
|
||||||
|
|
103
config/config_source.go
Normal file
103
config/config_source.go
Normal file
|
@ -0,0 +1,103 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/fsnotify/fsnotify"
|
||||||
|
"github.com/mitchellh/copystructure"
|
||||||
|
"github.com/spf13/viper"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Config holds pomerium configuration options.
|
||||||
|
type Config struct {
|
||||||
|
Options *Options
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clone creates a deep clone of the config.
|
||||||
|
func (cfg *Config) Clone() *Config {
|
||||||
|
return copystructure.Must(copystructure.Config{
|
||||||
|
Copiers: map[reflect.Type]copystructure.CopierFunc{
|
||||||
|
reflect.TypeOf((*viper.Viper)(nil)): func(i interface{}) (interface{}, error) {
|
||||||
|
return i, nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}.Copy(cfg)).(*Config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A ChangeListener is called when configuration changes.
|
||||||
|
type ChangeListener = func(*Config)
|
||||||
|
|
||||||
|
// A ChangeDispatcher manages listeners on config changes.
|
||||||
|
type ChangeDispatcher struct {
|
||||||
|
sync.Mutex
|
||||||
|
onConfigChangeListeners []ChangeListener
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trigger triggers a change.
|
||||||
|
func (dispatcher *ChangeDispatcher) Trigger(cfg *Config) {
|
||||||
|
dispatcher.Lock()
|
||||||
|
defer dispatcher.Unlock()
|
||||||
|
|
||||||
|
for _, li := range dispatcher.onConfigChangeListeners {
|
||||||
|
li(cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OnConfigChange adds a listener.
|
||||||
|
func (dispatcher *ChangeDispatcher) OnConfigChange(li ChangeListener) {
|
||||||
|
dispatcher.Lock()
|
||||||
|
defer dispatcher.Unlock()
|
||||||
|
dispatcher.onConfigChangeListeners = append(dispatcher.onConfigChangeListeners, li)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Source gets configuration.
|
||||||
|
type Source interface {
|
||||||
|
GetConfig() *Config
|
||||||
|
OnConfigChange(ChangeListener)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A FileOrEnvironmentSource retrieves config options from a file or the environment.
|
||||||
|
type FileOrEnvironmentSource struct {
|
||||||
|
configFile string
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
config *Config
|
||||||
|
|
||||||
|
ChangeDispatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileOrEnvironmentSource creates a new FileOrEnvironmentSource.
|
||||||
|
func NewFileOrEnvironmentSource(configFile string) (*FileOrEnvironmentSource, error) {
|
||||||
|
options, err := newOptionsFromConfig(configFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
src := &FileOrEnvironmentSource{
|
||||||
|
configFile: configFile,
|
||||||
|
config: &Config{Options: options},
|
||||||
|
}
|
||||||
|
options.viper.OnConfigChange(src.onConfigChange)
|
||||||
|
go options.viper.WatchConfig()
|
||||||
|
|
||||||
|
return src, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *FileOrEnvironmentSource) onConfigChange(evt fsnotify.Event) {
|
||||||
|
src.mu.Lock()
|
||||||
|
newOptions := handleConfigUpdate(src.configFile, src.config.Options)
|
||||||
|
cfg := &Config{Options: newOptions}
|
||||||
|
src.config = cfg
|
||||||
|
src.mu.Unlock()
|
||||||
|
|
||||||
|
src.Trigger(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig gets the config.
|
||||||
|
func (src *FileOrEnvironmentSource) GetConfig() *Config {
|
||||||
|
src.mu.RLock()
|
||||||
|
defer src.mu.RUnlock()
|
||||||
|
|
||||||
|
return src.config
|
||||||
|
}
|
|
@ -15,7 +15,6 @@ import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cespare/xxhash/v2"
|
"github.com/cespare/xxhash/v2"
|
||||||
"github.com/fsnotify/fsnotify"
|
|
||||||
"github.com/mitchellh/hashstructure"
|
"github.com/mitchellh/hashstructure"
|
||||||
"github.com/spf13/viper"
|
"github.com/spf13/viper"
|
||||||
"gopkg.in/yaml.v2"
|
"gopkg.in/yaml.v2"
|
||||||
|
@ -285,9 +284,9 @@ func NewDefaultOptions() *Options {
|
||||||
return &newOpts
|
return &newOpts
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewOptionsFromConfig builds the main binary's configuration options by parsing
|
// newOptionsFromConfig builds the main binary's configuration options by parsing
|
||||||
// environmental variables and config file
|
// environmental variables and config file
|
||||||
func NewOptionsFromConfig(configFile string) (*Options, error) {
|
func newOptionsFromConfig(configFile string) (*Options, error) {
|
||||||
o, err := optionsFromViper(configFile)
|
o, err := optionsFromViper(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("config: options from config file %w", err)
|
return nil, fmt.Errorf("config: options from config file %w", err)
|
||||||
|
@ -366,13 +365,6 @@ func (o *Options) parsePolicy() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnConfigChange starts a go routine and watches for any changes. If any are
|
|
||||||
// detected, via an fsnotify event the provided function is run.
|
|
||||||
func (o *Options) OnConfigChange(run func(in fsnotify.Event)) {
|
|
||||||
go o.viper.WatchConfig()
|
|
||||||
o.viper.OnConfigChange(run)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o *Options) viperUnmarshalKey(key string, rawVal interface{}) error {
|
func (o *Options) viperUnmarshalKey(key string, rawVal interface{}) error {
|
||||||
return o.viper.UnmarshalKey(key, &rawVal)
|
return o.viper.UnmarshalKey(key, &rawVal)
|
||||||
}
|
}
|
||||||
|
@ -457,8 +449,6 @@ func bindEnvs(o *Options, v *viper.Viper) error {
|
||||||
|
|
||||||
// Validate ensures the Options fields are valid, and hydrated.
|
// Validate ensures the Options fields are valid, and hydrated.
|
||||||
func (o *Options) Validate() error {
|
func (o *Options) Validate() error {
|
||||||
var err error
|
|
||||||
|
|
||||||
if !IsValidService(o.Services) {
|
if !IsValidService(o.Services) {
|
||||||
return fmt.Errorf("config: %s is an invalid service type", o.Services)
|
return fmt.Errorf("config: %s is an invalid service type", o.Services)
|
||||||
}
|
}
|
||||||
|
@ -605,47 +595,18 @@ func (o *Options) Validate() error {
|
||||||
// strip quotes from redirect address (#811)
|
// strip quotes from redirect address (#811)
|
||||||
o.HTTPRedirectAddr = strings.Trim(o.HTTPRedirectAddr, `"'`)
|
o.HTTPRedirectAddr = strings.Trim(o.HTTPRedirectAddr, `"'`)
|
||||||
|
|
||||||
RedirectAndAutocertServer.update(o)
|
|
||||||
|
|
||||||
err = AutocertManager.update(o)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("config: failed to setup autocert: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// sort the certificates so we get a consistent hash
|
// sort the certificates so we get a consistent hash
|
||||||
sort.Slice(o.Certificates, func(i, j int) bool {
|
sort.Slice(o.Certificates, func(i, j int) bool {
|
||||||
return compareByteSliceSlice(o.Certificates[i].Certificate, o.Certificates[j].Certificate) < 0
|
return compareByteSliceSlice(o.Certificates[i].Certificate, o.Certificates[j].Certificate) < 0
|
||||||
})
|
})
|
||||||
|
|
||||||
if !o.InsecureServer && len(o.Certificates) == 0 {
|
if !o.InsecureServer && len(o.Certificates) == 0 && !o.AutocertOptions.Enable {
|
||||||
return fmt.Errorf("config: server must be run with `autocert`, " +
|
return fmt.Errorf("config: server must be run with `autocert`, " +
|
||||||
"`insecure_server` or manually provided certificates to start")
|
"`insecure_server` or manually provided certificates to start")
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (o *Options) sourceHostnames() []string {
|
|
||||||
if len(o.Policies) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
dedupe := map[string]struct{}{}
|
|
||||||
for _, p := range o.Policies {
|
|
||||||
dedupe[p.Source.Hostname()] = struct{}{}
|
|
||||||
}
|
|
||||||
if o.AuthenticateURL != nil {
|
|
||||||
dedupe[o.AuthenticateURL.Hostname()] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
var h []string
|
|
||||||
for k := range dedupe {
|
|
||||||
h = append(h, k)
|
|
||||||
}
|
|
||||||
sort.Strings(h)
|
|
||||||
|
|
||||||
return h
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAuthenticateURL returns the AuthenticateURL in the options or localhost.
|
// GetAuthenticateURL returns the AuthenticateURL in the options or localhost.
|
||||||
func (o *Options) GetAuthenticateURL() *url.URL {
|
func (o *Options) GetAuthenticateURL() *url.URL {
|
||||||
if o != nil && o.AuthenticateURL != nil {
|
if o != nil && o.AuthenticateURL != nil {
|
||||||
|
@ -697,11 +658,6 @@ func (o *Options) GetOauthOptions() oauth.Options {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// OptionsUpdater updates local state based on an Options struct
|
|
||||||
type OptionsUpdater interface {
|
|
||||||
UpdateOptions(Options) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checksum returns the checksum of the current options struct
|
// Checksum returns the checksum of the current options struct
|
||||||
func (o *Options) Checksum() uint64 {
|
func (o *Options) Checksum() uint64 {
|
||||||
hash, err := hashstructure.Hash(o, &hashstructure.HashOptions{Hasher: xxhash.New()})
|
hash, err := hashstructure.Hash(o, &hashstructure.HashOptions{Hasher: xxhash.New()})
|
||||||
|
@ -712,40 +668,13 @@ func (o *Options) Checksum() uint64 {
|
||||||
return hash
|
return hash
|
||||||
}
|
}
|
||||||
|
|
||||||
// WatchChanges takes a configuration file, an existing options struct, and
|
|
||||||
// updates each service in the services slice OptionsUpdater with a new set
|
|
||||||
// of options if any change is detected. It also periodically rechecks if
|
|
||||||
// any computed properties have changed.
|
|
||||||
func WatchChanges(configFile string, opt *Options, services []OptionsUpdater) {
|
|
||||||
onchange := make(chan struct{}, 1)
|
|
||||||
ticker := time.NewTicker(10 * time.Minute) // force check every 10 minutes
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
opt.OnConfigChange(func(fs fsnotify.Event) {
|
|
||||||
log.Info().Str("file", fs.Name).Msg("config: file changed")
|
|
||||||
select {
|
|
||||||
case onchange <- struct{}{}:
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-onchange:
|
|
||||||
case <-ticker.C:
|
|
||||||
}
|
|
||||||
|
|
||||||
opt = handleConfigUpdate(configFile, opt, services)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// handleConfigUpdate takes configuration file, an existing options struct, and
|
// handleConfigUpdate takes configuration file, an existing options struct, and
|
||||||
// updates each service in the services slice OptionsUpdater with a new set of
|
// updates each service in the services slice OptionsUpdater with a new set of
|
||||||
// options if any change is detected.
|
// options if any change is detected.
|
||||||
func handleConfigUpdate(configFile string, opt *Options, services []OptionsUpdater) *Options {
|
func handleConfigUpdate(configFile string, opt *Options) *Options {
|
||||||
serviceName := telemetry.ServiceName(opt.Services)
|
serviceName := telemetry.ServiceName(opt.Services)
|
||||||
|
|
||||||
newOpt, err := NewOptionsFromConfig(configFile)
|
newOpt, err := newOptionsFromConfig(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error().Err(err).Msg("config: could not reload configuration")
|
log.Error().Err(err).Msg("config: could not reload configuration")
|
||||||
metrics.SetConfigInfo(serviceName, false)
|
metrics.SetConfigInfo(serviceName, false)
|
||||||
|
@ -761,19 +690,6 @@ func handleConfigUpdate(configFile string, opt *Options, services []OptionsUpdat
|
||||||
return opt
|
return opt
|
||||||
}
|
}
|
||||||
|
|
||||||
var updateFailed bool
|
|
||||||
for _, service := range services {
|
|
||||||
if err := service.UpdateOptions(*newOpt); err != nil {
|
|
||||||
log.Error().Err(err).Msg("config: could not update options")
|
|
||||||
updateFailed = true
|
|
||||||
metrics.SetConfigInfo(serviceName, false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !updateFailed {
|
|
||||||
metrics.SetConfigInfo(serviceName, true)
|
|
||||||
metrics.SetConfigChecksum(serviceName, newOptChecksum)
|
|
||||||
}
|
|
||||||
return newOpt
|
return newOpt
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -265,7 +265,7 @@ func TestOptionsFromViper(t *testing.T) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if diff := cmp.Diff(got, tt.want, opts...); diff != "" {
|
if diff := cmp.Diff(got, tt.want, opts...); diff != "" {
|
||||||
t.Errorf("NewOptionsFromConfig() = %s", diff)
|
t.Errorf("newOptionsFromConfig() = %s", diff)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -305,9 +305,9 @@ func Test_NewOptionsFromConfigEnvVar(t *testing.T) {
|
||||||
os.Setenv(k, v)
|
os.Setenv(k, v)
|
||||||
defer os.Unsetenv(k)
|
defer os.Unsetenv(k)
|
||||||
}
|
}
|
||||||
_, err := NewOptionsFromConfig("")
|
_, err := newOptionsFromConfig("")
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("NewOptionsFromConfig() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("newOptionsFromConfig() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
@ -327,7 +327,7 @@ func Test_AutoCertOptionsFromEnvVar(t *testing.T) {
|
||||||
defer os.Unsetenv(k)
|
defer os.Unsetenv(k)
|
||||||
}
|
}
|
||||||
|
|
||||||
o, err := NewOptionsFromConfig("")
|
o, err := newOptionsFromConfig("")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -343,160 +343,6 @@ func Test_AutoCertOptionsFromEnvVar(t *testing.T) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type mockService struct {
|
|
||||||
fail bool
|
|
||||||
Updated bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockService) UpdateOptions(o Options) error {
|
|
||||||
|
|
||||||
m.Updated = true
|
|
||||||
if m.fail {
|
|
||||||
return fmt.Errorf("failed")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_HandleConfigUpdate(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
oldEnvKeyPairs map[string]string
|
|
||||||
newEnvKeyPairs map[string]string
|
|
||||||
service *mockService
|
|
||||||
wantUpdate bool
|
|
||||||
}{
|
|
||||||
{"good",
|
|
||||||
map[string]string{
|
|
||||||
"INSECURE_SERVER": "true",
|
|
||||||
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
|
|
||||||
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
|
|
||||||
map[string]string{
|
|
||||||
"INSECURE_SERVER": "true",
|
|
||||||
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
|
|
||||||
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
|
|
||||||
&mockService{fail: false},
|
|
||||||
true},
|
|
||||||
{"good set debug",
|
|
||||||
map[string]string{
|
|
||||||
"INSECURE_SERVER": "true",
|
|
||||||
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
|
|
||||||
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
|
|
||||||
map[string]string{
|
|
||||||
"POMERIUM_DEBUG": "true",
|
|
||||||
"INSECURE_SERVER": "true",
|
|
||||||
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
|
|
||||||
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
|
|
||||||
&mockService{fail: false},
|
|
||||||
true},
|
|
||||||
{"bad",
|
|
||||||
map[string]string{
|
|
||||||
"INSECURE_SERVER": "true",
|
|
||||||
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
|
|
||||||
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
|
|
||||||
map[string]string{
|
|
||||||
"INSECURE_SERVER": "true",
|
|
||||||
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
|
|
||||||
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
|
|
||||||
&mockService{fail: true},
|
|
||||||
true},
|
|
||||||
{"bad policy file unmarshal error",
|
|
||||||
map[string]string{
|
|
||||||
"INSECURE_SERVER": "true",
|
|
||||||
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
|
|
||||||
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
|
|
||||||
map[string]string{
|
|
||||||
"POLICY": base64.StdEncoding.EncodeToString([]byte("{json:}")),
|
|
||||||
"INSECURE_SERVER": "true",
|
|
||||||
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
|
|
||||||
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
|
|
||||||
&mockService{fail: false},
|
|
||||||
false},
|
|
||||||
{"bad header key",
|
|
||||||
map[string]string{
|
|
||||||
"INSECURE_SERVER": "true",
|
|
||||||
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
|
|
||||||
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
|
|
||||||
map[string]string{
|
|
||||||
"SERVICES": "error",
|
|
||||||
"INSECURE_SERVER": "true",
|
|
||||||
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
|
|
||||||
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
|
|
||||||
&mockService{fail: false},
|
|
||||||
false},
|
|
||||||
{"bad header header value",
|
|
||||||
map[string]string{
|
|
||||||
"INSECURE_SERVER": "true",
|
|
||||||
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
|
|
||||||
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
|
|
||||||
map[string]string{
|
|
||||||
"HEADERS": "x;y;z",
|
|
||||||
"INSECURE_SERVER": "true",
|
|
||||||
"AUTHENTICATE_SERVICE_URL": "https://authenticate.example",
|
|
||||||
"AUTHORIZE_SERVICE_URL": "https://authorize.example"},
|
|
||||||
&mockService{fail: false},
|
|
||||||
false},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
for k, v := range tt.oldEnvKeyPairs {
|
|
||||||
os.Setenv(k, v)
|
|
||||||
}
|
|
||||||
oldOpts, err := NewOptionsFromConfig("")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
for k := range tt.oldEnvKeyPairs {
|
|
||||||
os.Unsetenv(k)
|
|
||||||
}
|
|
||||||
for k, v := range tt.newEnvKeyPairs {
|
|
||||||
os.Setenv(k, v)
|
|
||||||
defer os.Unsetenv(k)
|
|
||||||
}
|
|
||||||
handleConfigUpdate("", oldOpts, []OptionsUpdater{tt.service})
|
|
||||||
if tt.service.Updated != tt.wantUpdate {
|
|
||||||
t.Errorf("Failed to update config on service")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestOptions_sourceHostnames(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
testOptions := func() *Options {
|
|
||||||
o := NewDefaultOptions()
|
|
||||||
o.SharedKey = "test"
|
|
||||||
o.Services = "all"
|
|
||||||
o.InsecureServer = true
|
|
||||||
return o
|
|
||||||
}
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
policies []Policy
|
|
||||||
authenticateURL string
|
|
||||||
want []string
|
|
||||||
}{
|
|
||||||
{"empty", []Policy{}, "", nil},
|
|
||||||
{"good no authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "", []string{"from.example"}},
|
|
||||||
{"good with authN", []Policy{{From: "https://from.example", To: "https://to.example"}}, "https://authn.example.com", []string{"authn.example.com", "from.example"}},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
o := testOptions()
|
|
||||||
o.Policies = tt.policies
|
|
||||||
o.AuthenticateURLString = tt.authenticateURL
|
|
||||||
err := o.Validate()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
got := o.sourceHostnames()
|
|
||||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
|
||||||
t.Errorf("Options.sourceHostnames() = %v", diff)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHTTPRedirectAddressStripQuotes(t *testing.T) {
|
func TestHTTPRedirectAddressStripQuotes(t *testing.T) {
|
||||||
o := NewDefaultOptions()
|
o := NewDefaultOptions()
|
||||||
o.InsecureServer = true
|
o.InsecureServer = true
|
||||||
|
|
|
@ -1,60 +0,0 @@
|
||||||
package config
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
|
||||||
)
|
|
||||||
|
|
||||||
// RedirectAndAutocertServer is an HTTP server which handles redirecting to HTTPS and autocerts.
|
|
||||||
var RedirectAndAutocertServer = newRedirectAndAutoCertServer()
|
|
||||||
|
|
||||||
type redirectAndAutoCertServer struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
srv *http.Server
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRedirectAndAutoCertServer() *redirectAndAutoCertServer {
|
|
||||||
return &redirectAndAutoCertServer{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (srv *redirectAndAutoCertServer) update(options *Options) {
|
|
||||||
srv.mu.Lock()
|
|
||||||
defer srv.mu.Unlock()
|
|
||||||
|
|
||||||
if srv.srv != nil {
|
|
||||||
// nothing to do if the address hasn't changed
|
|
||||||
if srv.srv.Addr == options.HTTPRedirectAddr {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// close immediately, don't care about the error
|
|
||||||
_ = srv.srv.Close()
|
|
||||||
srv.srv = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if options.HTTPRedirectAddr == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
redirect := httputil.RedirectHandler()
|
|
||||||
|
|
||||||
hsrv := &http.Server{
|
|
||||||
Addr: options.HTTPRedirectAddr,
|
|
||||||
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if AutocertManager.HandleHTTPChallenge(w, r) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
redirect.ServeHTTP(w, r)
|
|
||||||
}),
|
|
||||||
}
|
|
||||||
go func() {
|
|
||||||
log.Info().Str("addr", hsrv.Addr).Msg("starting http redirect server")
|
|
||||||
err := hsrv.ListenAndServe()
|
|
||||||
if err != nil {
|
|
||||||
log.Error().Err(err).Msg("failed to run http redirect server")
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
srv.srv = hsrv
|
|
||||||
}
|
|
1
go.mod
1
go.mod
|
@ -28,6 +28,7 @@ require (
|
||||||
github.com/hashicorp/memberlist v0.2.2
|
github.com/hashicorp/memberlist v0.2.2
|
||||||
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
|
github.com/kardianos/osext v0.0.0-20190222173326-2bc1f35cddc0 // indirect
|
||||||
github.com/lithammer/shortuuid/v3 v3.0.4
|
github.com/lithammer/shortuuid/v3 v3.0.4
|
||||||
|
github.com/mitchellh/copystructure v1.0.0
|
||||||
github.com/mitchellh/hashstructure v1.0.0
|
github.com/mitchellh/hashstructure v1.0.0
|
||||||
github.com/natefinch/atomic v0.0.0-20200526193002-18c0533a5b09
|
github.com/natefinch/atomic v0.0.0-20200526193002-18c0533a5b09
|
||||||
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
|
github.com/nsf/jsondiff v0.0.0-20200515183724-f29ed568f4ce
|
||||||
|
|
4
go.sum
4
go.sum
|
@ -342,6 +342,8 @@ github.com/miekg/dns v1.1.26/go.mod h1:bPDLeHnStXmXAq1m/Ch/hvfNHr14JKNPMBo3VZKju
|
||||||
github.com/miekg/dns v1.1.27 h1:aEH/kqUzUxGJ/UHcEKdJY+ugH6WEzsEBBSPa8zuy1aM=
|
github.com/miekg/dns v1.1.27 h1:aEH/kqUzUxGJ/UHcEKdJY+ugH6WEzsEBBSPa8zuy1aM=
|
||||||
github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
|
github.com/miekg/dns v1.1.27/go.mod h1:KNUDUusw/aVsxyTYZM1oqvCicbwhgbNgztCETuNZ7xM=
|
||||||
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
|
github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc=
|
||||||
|
github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ=
|
||||||
|
github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw=
|
||||||
github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||||
github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI=
|
github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI=
|
||||||
|
@ -353,6 +355,8 @@ github.com/mitchellh/iochan v1.0.0/go.mod h1:JwYml1nuB7xOzsp52dPpHFffvOCDupsG0Qu
|
||||||
github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||||
github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE=
|
github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE=
|
||||||
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y=
|
||||||
|
github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/IfikLNY=
|
||||||
|
github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||||
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||||
|
|
198
internal/autocert/manager.go
Normal file
198
internal/autocert/manager.go
Normal file
|
@ -0,0 +1,198 @@
|
||||||
|
// Package autocert implements automatic management of TLS certificates.
|
||||||
|
package autocert
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"sort"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/caddyserver/certmagic"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager manages TLS certificates.
|
||||||
|
type Manager struct {
|
||||||
|
src config.Source
|
||||||
|
|
||||||
|
mu sync.RWMutex
|
||||||
|
config *config.Config
|
||||||
|
certmagic *certmagic.Config
|
||||||
|
acmeMgr *certmagic.ACMEManager
|
||||||
|
srv *http.Server
|
||||||
|
|
||||||
|
config.ChangeDispatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new autocert manager.
|
||||||
|
func New(src config.Source) (*Manager, error) {
|
||||||
|
mgr := &Manager{
|
||||||
|
src: src,
|
||||||
|
certmagic: certmagic.NewDefault(),
|
||||||
|
}
|
||||||
|
err := mgr.update(src.GetConfig())
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
mgr.src.OnConfigChange(func(cfg *config.Config) {
|
||||||
|
err := mgr.update(cfg)
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("autocert: error updating config")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr.Trigger(mgr.GetConfig())
|
||||||
|
})
|
||||||
|
return mgr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *Manager) getCertMagicConfig(options *config.Options) (*certmagic.Config, error) {
|
||||||
|
mgr.certmagic.MustStaple = options.AutocertOptions.MustStaple
|
||||||
|
mgr.certmagic.OnDemand = nil // disable on-demand
|
||||||
|
mgr.certmagic.Storage = &certmagic.FileStorage{Path: options.AutocertOptions.Folder}
|
||||||
|
// add existing certs to the cache, and staple OCSP
|
||||||
|
for _, cert := range options.Certificates {
|
||||||
|
if err := mgr.certmagic.CacheUnmanagedTLSCertificate(cert, nil); err != nil {
|
||||||
|
return nil, fmt.Errorf("config: failed caching cert: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
acmeMgr := certmagic.NewACMEManager(mgr.certmagic, certmagic.DefaultACME)
|
||||||
|
acmeMgr.Agreed = true
|
||||||
|
if options.AutocertOptions.UseStaging {
|
||||||
|
acmeMgr.CA = certmagic.LetsEncryptStagingCA
|
||||||
|
}
|
||||||
|
acmeMgr.DisableTLSALPNChallenge = true
|
||||||
|
mgr.certmagic.Issuer = acmeMgr
|
||||||
|
mgr.acmeMgr = acmeMgr
|
||||||
|
|
||||||
|
return mgr.certmagic, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *Manager) update(cfg *config.Config) error {
|
||||||
|
cfg = cfg.Clone()
|
||||||
|
|
||||||
|
mgr.mu.Lock()
|
||||||
|
defer mgr.mu.Unlock()
|
||||||
|
defer func() { mgr.config = cfg }()
|
||||||
|
|
||||||
|
mgr.updateServer(cfg)
|
||||||
|
return mgr.updateAutocert(cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *Manager) updateAutocert(cfg *config.Config) error {
|
||||||
|
if !cfg.Options.AutocertOptions.Enable {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cm, err := mgr.getCertMagicConfig(cfg.Options)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, domain := range sourceHostnames(cfg) {
|
||||||
|
cert, err := cm.CacheManagedCertificate(domain)
|
||||||
|
if err != nil {
|
||||||
|
log.Info().Str("domain", domain).Msg("obtaining certificate")
|
||||||
|
err = cm.ObtainCert(context.Background(), domain, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("autocert: failed to obtain client certificate: %w", err)
|
||||||
|
}
|
||||||
|
cert, err = cm.CacheManagedCertificate(domain)
|
||||||
|
}
|
||||||
|
if err == nil && cert.NeedsRenewal(cm) {
|
||||||
|
log.Info().Str("domain", domain).Msg("renewing certificate")
|
||||||
|
err = cm.RenewCert(context.Background(), domain, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("autocert: failed to renew client certificate: %w", err)
|
||||||
|
}
|
||||||
|
cert, err = cm.CacheManagedCertificate(domain)
|
||||||
|
}
|
||||||
|
if err == nil {
|
||||||
|
cfg.Options.Certificates = append(cfg.Options.Certificates, cert.Certificate)
|
||||||
|
} else {
|
||||||
|
log.Error().Err(err).Msg("autocert: failed to obtain client certificate")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *Manager) updateServer(cfg *config.Config) {
|
||||||
|
if mgr.srv != nil {
|
||||||
|
// nothing to do if the address hasn't changed
|
||||||
|
if mgr.srv.Addr == cfg.Options.HTTPRedirectAddr {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// close immediately, don't care about the error
|
||||||
|
_ = mgr.srv.Close()
|
||||||
|
mgr.srv = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Options.HTTPRedirectAddr == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirect := httputil.RedirectHandler()
|
||||||
|
|
||||||
|
hsrv := &http.Server{
|
||||||
|
Addr: cfg.Options.HTTPRedirectAddr,
|
||||||
|
Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
if mgr.handleHTTPChallenge(w, r) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
redirect.ServeHTTP(w, r)
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
log.Info().Str("addr", hsrv.Addr).Msg("starting http redirect server")
|
||||||
|
err := hsrv.ListenAndServe()
|
||||||
|
if err != nil {
|
||||||
|
log.Error().Err(err).Msg("failed to run http redirect server")
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
mgr.srv = hsrv
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mgr *Manager) handleHTTPChallenge(w http.ResponseWriter, r *http.Request) bool {
|
||||||
|
mgr.mu.RLock()
|
||||||
|
acmeMgr := mgr.acmeMgr
|
||||||
|
mgr.mu.RUnlock()
|
||||||
|
if acmeMgr == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return acmeMgr.HandleHTTPChallenge(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig gets the config.
|
||||||
|
func (mgr *Manager) GetConfig() *config.Config {
|
||||||
|
mgr.mu.RLock()
|
||||||
|
defer mgr.mu.RUnlock()
|
||||||
|
|
||||||
|
return mgr.config
|
||||||
|
}
|
||||||
|
|
||||||
|
func sourceHostnames(cfg *config.Config) []string {
|
||||||
|
if len(cfg.Options.Policies) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
dedupe := map[string]struct{}{}
|
||||||
|
for _, p := range cfg.Options.Policies {
|
||||||
|
dedupe[p.Source.Hostname()] = struct{}{}
|
||||||
|
}
|
||||||
|
if cfg.Options.AuthenticateURL != nil {
|
||||||
|
dedupe[cfg.Options.AuthenticateURL.Hostname()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
var h []string
|
||||||
|
for k := range dedupe {
|
||||||
|
h = append(h, k)
|
||||||
|
}
|
||||||
|
sort.Strings(h)
|
||||||
|
|
||||||
|
return h
|
||||||
|
}
|
|
@ -11,8 +11,6 @@ import (
|
||||||
"sync"
|
"sync"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/telemetry"
|
|
||||||
|
|
||||||
envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2"
|
envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2"
|
||||||
"golang.org/x/sync/errgroup"
|
"golang.org/x/sync/errgroup"
|
||||||
|
|
||||||
|
@ -20,10 +18,12 @@ import (
|
||||||
"github.com/pomerium/pomerium/authorize"
|
"github.com/pomerium/pomerium/authorize"
|
||||||
"github.com/pomerium/pomerium/cache"
|
"github.com/pomerium/pomerium/cache"
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/autocert"
|
||||||
"github.com/pomerium/pomerium/internal/controlplane"
|
"github.com/pomerium/pomerium/internal/controlplane"
|
||||||
"github.com/pomerium/pomerium/internal/envoy"
|
"github.com/pomerium/pomerium/internal/envoy"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
"github.com/pomerium/pomerium/internal/telemetry"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
@ -33,31 +33,36 @@ import (
|
||||||
|
|
||||||
// Run runs the main pomerium application.
|
// Run runs the main pomerium application.
|
||||||
func Run(ctx context.Context, configFile string) error {
|
func Run(ctx context.Context, configFile string) error {
|
||||||
opt, err := config.NewOptionsFromConfig(configFile)
|
var src config.Source
|
||||||
|
|
||||||
|
src, err := config.NewFileOrEnvironmentSource(configFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var optionsUpdaters []config.OptionsUpdater
|
|
||||||
|
src, err = autocert.New(src)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg := src.GetConfig()
|
||||||
|
|
||||||
log.Info().Str("version", version.FullVersion()).Msg("cmd/pomerium")
|
log.Info().Str("version", version.FullVersion()).Msg("cmd/pomerium")
|
||||||
|
|
||||||
if err := setupMetrics(ctx, opt); err != nil {
|
if err := setupMetrics(ctx, cfg.Options); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if err := setupTracing(ctx, opt); err != nil {
|
if err := setupTracing(ctx, cfg.Options); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// setup the control plane
|
// setup the control plane
|
||||||
controlPlane, err := controlplane.NewServer(opt.Services)
|
controlPlane, err := controlplane.NewServer(cfg.Options.Services)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating control plane: %w", err)
|
return fmt.Errorf("error creating control plane: %w", err)
|
||||||
}
|
}
|
||||||
optionsUpdaters = append(optionsUpdaters, controlPlane)
|
src.OnConfigChange(controlPlane.OnConfigChange)
|
||||||
err = controlPlane.UpdateOptions(*opt)
|
controlPlane.OnConfigChange(cfg)
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error updating control plane options: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
_, grpcPort, _ := net.SplitHostPort(controlPlane.GRPCListener.Addr().String())
|
_, grpcPort, _ := net.SplitHostPort(controlPlane.GRPCListener.Addr().String())
|
||||||
_, httpPort, _ := net.SplitHostPort(controlPlane.HTTPListener.Addr().String())
|
_, httpPort, _ := net.SplitHostPort(controlPlane.HTTPListener.Addr().String())
|
||||||
|
@ -66,36 +71,33 @@ func Run(ctx context.Context, configFile string) error {
|
||||||
log.Info().Str("port", httpPort).Msg("HTTP server started")
|
log.Info().Str("port", httpPort).Msg("HTTP server started")
|
||||||
|
|
||||||
// create envoy server
|
// create envoy server
|
||||||
envoyServer, err := envoy.NewServer(opt, grpcPort, httpPort)
|
envoyServer, err := envoy.NewServer(cfg.Options, grpcPort, httpPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating envoy server: %w", err)
|
return fmt.Errorf("error creating envoy server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// add services
|
// add services
|
||||||
if err := setupAuthenticate(opt, controlPlane, &optionsUpdaters); err != nil {
|
if err := setupAuthenticate(src, cfg, controlPlane); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
var authorizeServer *authorize.Authorize
|
var authorizeServer *authorize.Authorize
|
||||||
if config.IsAuthorize(opt.Services) {
|
if config.IsAuthorize(cfg.Options.Services) {
|
||||||
authorizeServer, err = setupAuthorize(opt, controlPlane, &optionsUpdaters)
|
authorizeServer, err = setupAuthorize(src, cfg, controlPlane)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var cacheServer *cache.Cache
|
var cacheServer *cache.Cache
|
||||||
if config.IsCache(opt.Services) {
|
if config.IsCache(cfg.Options.Services) {
|
||||||
cacheServer, err = setupCache(opt, controlPlane)
|
cacheServer, err = setupCache(cfg.Options, controlPlane)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := setupProxy(opt, controlPlane); err != nil {
|
if err := setupProxy(cfg.Options, controlPlane); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// start the config change listener
|
|
||||||
go config.WatchChanges(configFile, opt, optionsUpdaters)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
go func(ctx context.Context) {
|
go func(ctx context.Context) {
|
||||||
ch := make(chan os.Signal, 2)
|
ch := make(chan os.Signal, 2)
|
||||||
|
@ -132,21 +134,21 @@ func Run(ctx context.Context, configFile string) error {
|
||||||
return eg.Wait()
|
return eg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupAuthenticate(opt *config.Options, controlPlane *controlplane.Server, optionsUpdaters *[]config.OptionsUpdater) error {
|
func setupAuthenticate(src config.Source, cfg *config.Config, controlPlane *controlplane.Server) error {
|
||||||
if !config.IsAuthenticate(opt.Services) {
|
if !config.IsAuthenticate(cfg.Options.Services) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
svc, err := authenticate.New(*opt)
|
svc, err := authenticate.New(cfg.Options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating authenticate service: %w", err)
|
return fmt.Errorf("error creating authenticate service: %w", err)
|
||||||
}
|
}
|
||||||
*optionsUpdaters = append(*optionsUpdaters, svc)
|
src.OnConfigChange(svc.OnConfigChange)
|
||||||
err = svc.UpdateOptions(*opt)
|
svc.OnConfigChange(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error updating authenticate options: %w", err)
|
return fmt.Errorf("error updating authenticate options: %w", err)
|
||||||
}
|
}
|
||||||
host := urlutil.StripPort(opt.GetAuthenticateURL().Host)
|
host := urlutil.StripPort(cfg.Options.GetAuthenticateURL().Host)
|
||||||
sr := controlPlane.HTTPRouter.Host(host).Subrouter()
|
sr := controlPlane.HTTPRouter.Host(host).Subrouter()
|
||||||
svc.Mount(sr)
|
svc.Mount(sr)
|
||||||
log.Info().Str("host", host).Msg("enabled authenticate service")
|
log.Info().Str("host", host).Msg("enabled authenticate service")
|
||||||
|
@ -154,20 +156,16 @@ func setupAuthenticate(opt *config.Options, controlPlane *controlplane.Server, o
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func setupAuthorize(opt *config.Options, controlPlane *controlplane.Server, optionsUpdaters *[]config.OptionsUpdater) (*authorize.Authorize, error) {
|
func setupAuthorize(src config.Source, cfg *config.Config, controlPlane *controlplane.Server) (*authorize.Authorize, error) {
|
||||||
svc, err := authorize.New(*opt)
|
svc, err := authorize.New(cfg.Options)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error creating authorize service: %w", err)
|
return nil, fmt.Errorf("error creating authorize service: %w", err)
|
||||||
}
|
}
|
||||||
envoy_service_auth_v2.RegisterAuthorizationServer(controlPlane.GRPCServer, svc)
|
envoy_service_auth_v2.RegisterAuthorizationServer(controlPlane.GRPCServer, svc)
|
||||||
|
|
||||||
log.Info().Msg("enabled authorize service")
|
log.Info().Msg("enabled authorize service")
|
||||||
|
src.OnConfigChange(svc.OnConfigChange)
|
||||||
*optionsUpdaters = append(*optionsUpdaters, svc)
|
svc.OnConfigChange(cfg)
|
||||||
err = svc.UpdateOptions(*opt)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("error updating authorize options: %w", err)
|
|
||||||
}
|
|
||||||
return svc, nil
|
return svc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -141,17 +141,16 @@ func (srv *Server) Run(ctx context.Context) error {
|
||||||
return eg.Wait()
|
return eg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateOptions updates the pomerium config options.
|
// OnConfigChange updates the pomerium config options.
|
||||||
func (srv *Server) UpdateOptions(options config.Options) error {
|
func (srv *Server) OnConfigChange(cfg *config.Config) {
|
||||||
select {
|
select {
|
||||||
case <-srv.configUpdated:
|
case <-srv.configUpdated:
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
prev := srv.currentConfig.Load()
|
prev := srv.currentConfig.Load()
|
||||||
srv.currentConfig.Store(versionedOptions{
|
srv.currentConfig.Store(versionedOptions{
|
||||||
Options: options,
|
Options: *cfg.Options,
|
||||||
version: prev.version + 1,
|
version: prev.version + 1,
|
||||||
})
|
})
|
||||||
srv.configUpdated <- struct{}{}
|
srv.configUpdated <- struct{}{}
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Reference in a new issue