diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 7c343f0c4..7f57fdee6 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -110,6 +110,8 @@ type Authenticate struct { jwk *jose.JSONWebKeySet templates *template.Template + + options *config.AtomicOptions } // New validates and creates a new authenticate service from a set of Options. @@ -138,11 +140,6 @@ func New(opts *config.Options) (*Authenticate, error) { Expire: opts.CookieExpire, } - cookieStore, err := cookie.NewStore(cookieOptions, sharedEncoder) - if err != nil { - return nil, err - } - dataBrokerConn, err := grpc.NewGRPCClientConn( &grpc.Options{ Addr: opts.DataBrokerURL, @@ -192,9 +189,7 @@ func New(opts *config.Options) (*Authenticate, error) { cookieSecret: decodedCookieSecret, cookieCipher: cookieCipher, cookieOptions: cookieOptions, - sessionStore: cookieStore, encryptedEncoder: encryptedEncoder, - sessionLoaders: []sessions.SessionLoader{qpStore, headerStore, cookieStore}, // IdP provider: provider, providerName: opts.Provider, @@ -202,8 +197,26 @@ func New(opts *config.Options) (*Authenticate, error) { dataBrokerClient: dataBrokerClient, jwk: &jose.JSONWebKeySet{}, templates: template.Must(frontend.NewTemplates()), + options: config.NewAtomicOptions(), } + cookieStore, err := cookie.NewStore(func() cookie.Options { + opts := a.options.Load() + return cookie.Options{ + Name: opts.CookieName, + Domain: opts.CookieDomain, + Secure: opts.CookieSecure, + HTTPOnly: opts.CookieHTTPOnly, + Expire: opts.CookieExpire, + } + }, sharedEncoder) + if err != nil { + return nil, err + } + + a.sessionStore = cookieStore + a.sessionLoaders = []sessions.SessionLoader{qpStore, headerStore, cookieStore} + if opts.SigningKey != "" { decodedCert, err := base64.StdEncoding.DecodeString(opts.SigningKey) if err != nil { @@ -236,5 +249,6 @@ func (a *Authenticate) OnConfigChange(cfg *config.Config) { } log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("authenticate: updating options") + a.options.Store(cfg.Options) a.setAdminUsers(cfg.Options) } diff --git a/authenticate/authenticate_test.go b/authenticate/authenticate_test.go index 0af8de710..1e77a4d34 100644 --- a/authenticate/authenticate_test.go +++ b/authenticate/authenticate_test.go @@ -86,9 +86,6 @@ func TestNew(t *testing.T) { badRedirectURL.AuthenticateURL = nil badRedirectURL.CookieName = "B" - badCookieName := newTestOptions(t) - badCookieName.CookieName = "" - badProvider := newTestOptions(t) badProvider.Provider = "" badProvider.CookieName = "C" @@ -118,7 +115,6 @@ func TestNew(t *testing.T) { {"good", good, false}, {"empty opts", &config.Options{}, true}, {"fails to validate", badRedirectURL, true}, - {"bad cookie name", badCookieName, true}, {"bad provider", badProvider, true}, {"bad cache url", badGRPCConn, true}, {"empty provider url", emptyProviderURL, true}, diff --git a/authorize/authorize.go b/authorize/authorize.go index bfa4b9505..f1e4a4ac1 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -23,18 +23,6 @@ import ( "github.com/pomerium/pomerium/pkg/grpc/databroker" ) -type atomicOptions struct { - value atomic.Value -} - -func (a *atomicOptions) Load() *config.Options { - return a.value.Load().(*config.Options) -} - -func (a *atomicOptions) Store(options *config.Options) { - a.value.Store(options) -} - type atomicMarshalUnmarshaler struct { value atomic.Value } @@ -52,7 +40,7 @@ type Authorize struct { pe *evaluator.Evaluator store *evaluator.Store - currentOptions atomicOptions + currentOptions *config.AtomicOptions currentEncoder atomicMarshalUnmarshaler templates *template.Template @@ -84,6 +72,7 @@ func New(opts *config.Options) (*Authorize, error) { } a := Authorize{ + currentOptions: config.NewAtomicOptions(), store: evaluator.NewStore(), templates: template.Must(frontend.NewTemplates()), dataBrokerClient: databroker.NewDataBrokerServiceClient(dataBrokerConn), @@ -99,7 +88,6 @@ func New(opts *config.Options) (*Authorize, error) { return nil, err } a.currentEncoder.Store(encoder) - a.currentOptions.Store(new(config.Options)) return &a, nil } diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index cddeca576..773d0bdb2 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -35,7 +35,7 @@ func TestAuthorize_okResponse(t *testing.T) { }}, JWTClaimsHeaders: []string{"email"}, } - a := new(Authorize) + a := &Authorize{currentOptions: config.NewAtomicOptions()} encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}, "") a.currentEncoder.Store(encoder) a.currentOptions.Store(opt) @@ -204,7 +204,7 @@ func TestAuthorize_okResponse(t *testing.T) { } func TestAuthorize_deniedResponse(t *testing.T) { - a := new(Authorize) + a := &Authorize{currentOptions: config.NewAtomicOptions()} encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}, "") a.currentEncoder.Store(encoder) a.currentOptions.Store(&config.Options{ diff --git a/authorize/grpc_test.go b/authorize/grpc_test.go index c061c0939..1be7ea1f3 100644 --- a/authorize/grpc_test.go +++ b/authorize/grpc_test.go @@ -47,7 +47,7 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA== -----END CERTIFICATE-----` func Test_getEvaluatorRequest(t *testing.T) { - a := new(Authorize) + a := &Authorize{currentOptions: config.NewAtomicOptions()} encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}, "") a.currentEncoder.Store(encoder) a.currentOptions.Store(&config.Options{ @@ -273,7 +273,7 @@ func Test_handleForwardAuth(t *testing.T) { for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - a := new(Authorize) + a := &Authorize{currentOptions: config.NewAtomicOptions()} var fau *url.URL if tc.forwardAuthURL != "" { fau = mustParseURL(tc.forwardAuthURL) @@ -288,7 +288,7 @@ func Test_handleForwardAuth(t *testing.T) { } func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) { - a := new(Authorize) + a := &Authorize{currentOptions: config.NewAtomicOptions()} encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}, "") a.currentEncoder.Store(encoder) a.currentOptions.Store(&config.Options{ diff --git a/authorize/session.go b/authorize/session.go index 8aa1e2eb5..931a62bb1 100644 --- a/authorize/session.go +++ b/authorize/session.go @@ -52,14 +52,15 @@ func loadSession(encoder encoding.MarshalUnmarshaler, rawJWT []byte) (*sessions. } func getCookieStore(options *config.Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) { - cookieOptions := &cookie.Options{ - Name: options.CookieName, - Domain: options.CookieDomain, - Secure: options.CookieSecure, - HTTPOnly: options.CookieHTTPOnly, - Expire: options.CookieExpire, - } - cookieStore, err := cookie.NewStore(cookieOptions, encoder) + cookieStore, err := cookie.NewStore(func() cookie.Options { + return cookie.Options{ + Name: options.CookieName, + Domain: options.CookieDomain, + Secure: options.CookieSecure, + HTTPOnly: options.CookieHTTPOnly, + Expire: options.CookieExpire, + } + }, encoder) if err != nil { return nil, err } diff --git a/authorize/session_test.go b/authorize/session_test.go index 4cb655e63..6c8856cd6 100644 --- a/authorize/session_test.go +++ b/authorize/session_test.go @@ -116,7 +116,7 @@ func TestAuthorize_getJWTClaimHeaders(t *testing.T) { }}, }}, } - a := new(Authorize) + a := &Authorize{currentOptions: config.NewAtomicOptions()} encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}, "") a.currentEncoder.Store(encoder) a.currentOptions.Store(opt) diff --git a/config/options.go b/config/options.go index 4f0200d31..9e0b7d04a 100644 --- a/config/options.go +++ b/config/options.go @@ -12,6 +12,7 @@ import ( "reflect" "sort" "strings" + "sync/atomic" "time" "github.com/cespare/xxhash/v2" @@ -986,3 +987,25 @@ func min(x, y int) int { } return y } + +// AtomicOptions are Options that can be access atomically. +type AtomicOptions struct { + value atomic.Value +} + +// NewAtomicOptions creates a new AtomicOptions. +func NewAtomicOptions() *AtomicOptions { + ao := new(AtomicOptions) + ao.Store(new(Options)) + return ao +} + +// Load loads the options. +func (a *AtomicOptions) Load() *Options { + return a.value.Load().(*Options) +} + +// Store stores the options. +func (a *AtomicOptions) Store(options *Options) { + a.value.Store(options) +} diff --git a/internal/cmd/pomerium/pomerium.go b/internal/cmd/pomerium/pomerium.go index f5bae8ed6..1a5f5d12d 100644 --- a/internal/cmd/pomerium/pomerium.go +++ b/internal/cmd/pomerium/pomerium.go @@ -93,7 +93,7 @@ func Run(ctx context.Context, configFile string) error { return err } } - if err := setupProxy(cfg.Options, controlPlane); err != nil { + if err := setupProxy(src, cfg, controlPlane); err != nil { return err } @@ -172,15 +172,20 @@ func setupCache(opt *config.Options, controlPlane *controlplane.Server) (*cache. return svc, nil } -func setupProxy(opt *config.Options, controlPlane *controlplane.Server) error { - if !config.IsProxy(opt.Services) { +func setupProxy(src config.Source, cfg *config.Config, controlPlane *controlplane.Server) error { + if !config.IsProxy(cfg.Options.Services) { return nil } - svc, err := proxy.New(*opt) + svc, err := proxy.New(cfg.Options) if err != nil { return fmt.Errorf("error creating proxy service: %w", err) } controlPlane.HTTPRouter.PathPrefix("/").Handler(svc) + + log.Info().Msg("enabled proxy service") + src.OnConfigChange(svc.OnConfigChange) + svc.OnConfigChange(cfg) + return nil } diff --git a/internal/sessions/cookie/cookie_store.go b/internal/sessions/cookie/cookie_store.go index 1703a7a2d..dc13ca905 100644 --- a/internal/sessions/cookie/cookie_store.go +++ b/internal/sessions/cookie/cookie_store.go @@ -33,18 +33,6 @@ const ( MaxNumChunks = 5 ) -// Store implements the session store interface for session cookies. -type Store struct { - Name string - Domain string - Expire time.Duration - HTTPOnly bool - Secure bool - - encoder encoding.Marshaler - decoder encoding.Unmarshaler -} - // Options holds options for Store type Options struct { Name string @@ -54,10 +42,20 @@ type Options struct { Secure bool } +// A GetOptionsFunc is a getter for cookie options. +type GetOptionsFunc func() Options + +// Store implements the session store interface for session cookies. +type Store struct { + getOptions GetOptionsFunc + encoder encoding.Marshaler + decoder encoding.Unmarshaler +} + // NewStore returns a new store that implements the SessionStore interface // using http cookies. -func NewStore(opts *Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) { - cs, err := NewCookieLoader(opts, encoder) +func NewStore(getOptions GetOptionsFunc, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) { + cs, err := NewCookieLoader(getOptions, encoder) if err != nil { return nil, err } @@ -67,41 +65,31 @@ func NewStore(opts *Options, encoder encoding.MarshalUnmarshaler) (sessions.Sess // NewCookieLoader returns a new store that implements the SessionLoader // interface using http cookies. -func NewCookieLoader(opts *Options, dencoder encoding.Unmarshaler) (*Store, error) { +func NewCookieLoader(getOptions GetOptionsFunc, dencoder encoding.Unmarshaler) (*Store, error) { if dencoder == nil { return nil, fmt.Errorf("internal/sessions: dencoder cannot be nil") } - cs, err := newStore(opts) - if err != nil { - return nil, err - } + cs := newStore(getOptions) cs.decoder = dencoder return cs, nil } -func newStore(opts *Options) (*Store, error) { - if opts.Name == "" { - return nil, fmt.Errorf("internal/sessions: cookie name cannot be empty") - } - +func newStore(getOptions GetOptionsFunc) *Store { return &Store{ - Name: opts.Name, - Secure: opts.Secure, - HTTPOnly: opts.HTTPOnly, - Domain: opts.Domain, - Expire: opts.Expire, - }, nil + getOptions: getOptions, + } } func (cs *Store) makeCookie(value string) *http.Cookie { + opts := cs.getOptions() return &http.Cookie{ - Name: cs.Name, + Name: opts.Name, Value: value, Path: "/", - Domain: cs.Domain, - HttpOnly: cs.HTTPOnly, - Secure: cs.Secure, - Expires: timeNow().Add(cs.Expire), + Domain: opts.Domain, + HttpOnly: opts.HTTPOnly, + Secure: opts.Secure, + Expires: timeNow().Add(opts.Expire), } } @@ -126,7 +114,8 @@ func getCookies(r *http.Request, name string) []*http.Cookie { // LoadSession returns a State from the cookie in the request. func (cs *Store) LoadSession(r *http.Request) (string, error) { - cookies := getCookies(r, cs.Name) + opts := cs.getOptions() + cookies := getCookies(r, opts.Name) if len(cookies) == 0 { return "", sessions.ErrNoSessionFound } diff --git a/internal/sessions/cookie/cookie_store_test.go b/internal/sessions/cookie/cookie_store_test.go index d84859a1a..a4a4fb75d 100644 --- a/internal/sessions/cookie/cookie_store_test.go +++ b/internal/sessions/cookie/cookie_store_test.go @@ -32,13 +32,16 @@ func TestNewStore(t *testing.T) { want sessions.SessionStore wantErr bool }{ - {"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false}, - {"missing name", &Options{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true}, + {"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{getOptions: func() Options { + return Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second} + }}, false}, {"missing encoder", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewStore(tt.opts, tt.encoder) + got, err := NewStore(func() Options { + return *tt.opts + }, tt.encoder) if (err != nil) != tt.wantErr { t.Errorf("NewStore() error = %v, wantErr %v", err, tt.wantErr) return @@ -66,13 +69,16 @@ func TestNewCookieLoader(t *testing.T) { want *Store wantErr bool }{ - {"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, false}, - {"missing name", &Options{Name: "", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, nil, true}, + {"good", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, encoder, &Store{getOptions: func() Options { + return Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second} + }}, false}, {"missing encoder", &Options{Name: "_cookie", Secure: true, HTTPOnly: true, Domain: "pomerium.io", Expire: 10 * time.Second}, nil, nil, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewCookieLoader(tt.opts, tt.encoder) + got, err := NewCookieLoader(func() Options { + return *tt.opts + }, tt.encoder) if (err != nil) != tt.wantErr { t.Errorf("NewCookieLoader() error = %v, wantErr %v", err, tt.wantErr) return @@ -117,13 +123,17 @@ func TestStore_SaveSession(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := &Store{ - Name: "_pomerium", - Secure: true, - HTTPOnly: true, - Domain: "pomerium.io", - Expire: 10 * time.Second, - encoder: tt.encoder, - decoder: tt.decoder, + getOptions: func() Options { + return Options{ + Name: "_pomerium", + Secure: true, + HTTPOnly: true, + Domain: "pomerium.io", + Expire: 10 * time.Second, + } + }, + encoder: tt.encoder, + decoder: tt.decoder, } r := httptest.NewRequest("GET", "/", nil) diff --git a/internal/sessions/cookie/middleware_test.go b/internal/sessions/cookie/middleware_test.go index 2b49bd680..90bba2f42 100644 --- a/internal/sessions/cookie/middleware_test.go +++ b/internal/sessions/cookie/middleware_test.go @@ -65,8 +65,10 @@ func TestVerifier(t *testing.T) { encSession = append(encSession, cryptutil.NewKey()...) } - cs, err := NewStore(&Options{ - Name: "_pomerium", + cs, err := NewStore(func() Options { + return Options{ + Name: "_pomerium", + } }, encoder) if err != nil { t.Fatal(err) diff --git a/proxy/forward_auth_test.go b/proxy/forward_auth_test.go index 305b869a6..86e0ae974 100644 --- a/proxy/forward_auth_test.go +++ b/proxy/forward_auth_test.go @@ -47,7 +47,7 @@ func TestProxy_ForwardAuth(t *testing.T) { opts := testOptions(t) tests := []struct { name string - options config.Options + options *config.Options ctxError error method string @@ -94,7 +94,7 @@ func TestProxy_ForwardAuth(t *testing.T) { t.Fatal(err) } p.encoder = signer - p.UpdateOptions(tt.options) + p.OnConfigChange(&config.Config{Options: tt.options}) uri, err := url.Parse(tt.requestURI) if err != nil { t.Fatal(err) diff --git a/proxy/handlers.go b/proxy/handlers.go index 27f0a8e09..a116e858e 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -25,12 +25,15 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router { // 2. AuthN - Verify the user is authenticated. Set email, group, & id headers h.Use(p.AuthenticateSession) // 3. Enforce CSRF protections for any non-idempotent http method - h.Use(csrf.Protect( - p.cookieSecret, - csrf.Secure(p.cookieOptions.Secure), - csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieOptions.Name)), - csrf.ErrorHandler(httputil.HandlerFunc(httputil.CSRFFailureHandler)), - )) + h.Use(func(h http.Handler) http.Handler { + opts := p.currentOptions.Load() + return csrf.Protect( + p.cookieSecret, + csrf.Secure(opts.CookieSecure), + csrf.CookieName(fmt.Sprintf("%s_csrf", opts.CookieName)), + csrf.ErrorHandler(httputil.HandlerFunc(httputil.CSRFFailureHandler)), + )(h) + }) // dashboard endpoints can be used by user's to view, or modify their session h.Path("/").HandlerFunc(p.UserDashboard).Methods(http.MethodGet) h.Path("/sign_out").HandlerFunc(p.SignOut).Methods(http.MethodGet, http.MethodPost) diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 8a20fb22e..d58df1bf4 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -118,7 +118,7 @@ func TestProxy_Callback(t *testing.T) { opts := testOptions(t) tests := []struct { name string - options config.Options + options *config.Options method string @@ -227,7 +227,7 @@ func TestProxy_Callback(t *testing.T) { } p.encoder = tt.cipher p.sessionStore = tt.sessionStore - p.UpdateOptions(tt.options) + p.OnConfigChange(&config.Config{Options: tt.options}) redirectURI := &url.URL{Scheme: tt.scheme, Host: tt.host, Path: tt.path} queryString := redirectURI.Query() for k, v := range tt.qp { @@ -276,7 +276,7 @@ func TestProxy_ProgrammaticLogin(t *testing.T) { opts := testOptions(t) tests := []struct { name string - options config.Options + options *config.Options method string @@ -337,7 +337,7 @@ func TestProxy_ProgrammaticCallback(t *testing.T) { opts := testOptions(t) tests := []struct { name string - options config.Options + options *config.Options method string @@ -434,7 +434,7 @@ func TestProxy_ProgrammaticCallback(t *testing.T) { } p.encoder = tt.cipher p.sessionStore = tt.sessionStore - p.UpdateOptions(tt.options) + p.OnConfigChange(&config.Config{Options: tt.options}) redirectURI, _ := url.Parse(tt.redirectURI) queryString := redirectURI.Query() for k, v := range tt.qp { diff --git a/proxy/proxy.go b/proxy/proxy.go index 3a1fa3c59..add53e860 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -43,7 +43,7 @@ const ( // ValidateOptions checks that proper configuration settings are set to create // a proper Proxy instance -func ValidateOptions(o config.Options) error { +func ValidateOptions(o *config.Options) error { if _, err := cryptutil.NewAEADCipherFromBase64(o.SharedKey); err != nil { return fmt.Errorf("proxy: invalid 'SHARED_SECRET': %w", err) } @@ -76,7 +76,6 @@ type Proxy struct { authenticateRefreshURL *url.URL encoder encoding.Unmarshaler - cookieOptions *cookie.Options cookieSecret []byte refreshCooldown time.Duration sessionStore sessions.SessionStore @@ -85,12 +84,13 @@ type Proxy struct { jwtClaimHeaders []string authzClient envoy_service_auth_v2.AuthorizationClient - currentRouter atomic.Value + currentOptions *config.AtomicOptions + currentRouter atomic.Value } // New takes a Proxy service from options and a validation function. // Function returns an error if options fail to validate. -func New(opts config.Options) (*Proxy, error) { +func New(opts *config.Options) (*Proxy, error) { if err := ValidateOptions(opts); err != nil { return nil, err } @@ -104,34 +104,16 @@ func New(opts config.Options) (*Proxy, error) { return nil, err } - cookieOptions := &cookie.Options{ - Name: opts.CookieName, - Domain: opts.CookieDomain, - Secure: opts.CookieSecure, - HTTPOnly: opts.CookieHTTPOnly, - Expire: opts.CookieExpire, - } - - cookieStore, err := cookie.NewStore(cookieOptions, encoder) - if err != nil { - return nil, err - } - p := &Proxy{ SharedKey: opts.SharedKey, sharedCipher: sharedCipher, encoder: encoder, cookieSecret: decodedCookieSecret, - cookieOptions: cookieOptions, refreshCooldown: opts.RefreshCooldown, - sessionStore: cookieStore, - sessionLoaders: []sessions.SessionLoader{ - cookieStore, - header.NewStore(encoder, httputil.AuthorizationTypePomerium), - queryparam.NewStore(encoder, "pomerium_session")}, templates: template.Must(frontend.NewTemplates()), jwtClaimHeaders: opts.JWTClaimsHeaders, + currentOptions: config.NewAtomicOptions(), } p.currentRouter.Store(httputil.NewRouter()) // errors checked in ValidateOptions @@ -142,6 +124,25 @@ func New(opts config.Options) (*Proxy, error) { p.authenticateSignoutURL = p.authenticateURL.ResolveReference(&url.URL{Path: signoutURL}) p.authenticateRefreshURL = p.authenticateURL.ResolveReference(&url.URL{Path: refreshURL}) + cookieStore, err := cookie.NewStore(func() cookie.Options { + opts := p.currentOptions.Load() + return cookie.Options{ + Name: opts.CookieName, + Domain: opts.CookieDomain, + Secure: opts.CookieSecure, + HTTPOnly: opts.CookieHTTPOnly, + Expire: opts.CookieExpire, + } + }, encoder) + if err != nil { + return nil, err + } + p.sessionStore = cookieStore + p.sessionLoaders = []sessions.SessionLoader{ + cookieStore, + header.NewStore(encoder, httputil.AuthorizationTypePomerium), + queryparam.NewStore(encoder, "pomerium_session")} + authzConn, err := grpc.NewGRPCClientConn(&grpc.Options{ Addr: p.authorizeURL, OverrideCertificateName: opts.OverrideCertificateName, @@ -157,11 +158,6 @@ func New(opts config.Options) (*Proxy, error) { } p.authzClient = envoy_service_auth_v2.NewAuthorizationClient(authzConn) - err = p.UpdateOptions(opts) - if err != nil { - return nil, err - } - metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { return int64(len(opts.Policies)) }) @@ -169,14 +165,15 @@ func New(opts config.Options) (*Proxy, error) { return p, nil } -// UpdateOptions updates internal structures based on config.Options -func (p *Proxy) UpdateOptions(o config.Options) error { +// OnConfigChange updates internal structures based on config.Options +func (p *Proxy) OnConfigChange(cfg *config.Config) { if p == nil { - return nil + return } - log.Info().Str("checksum", fmt.Sprintf("%x", o.Checksum())).Msg("proxy: updating options") - p.setHandlers(&o) - return nil + + log.Info().Str("checksum", fmt.Sprintf("%x", cfg.Options.Checksum())).Msg("proxy: updating options") + p.currentOptions.Store(cfg.Options) + p.setHandlers(cfg.Options) } func (p *Proxy) setHandlers(opts *config.Options) { diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index ac65ae86c..57109a963 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -10,7 +10,7 @@ import ( "github.com/pomerium/pomerium/config" ) -func testOptions(t *testing.T) config.Options { +func testOptions(t *testing.T) *config.Options { opts := config.NewDefaultOptions() opts.AuthenticateURLString = "https://authenticate.example" opts.AuthorizeURLString = "https://authorize.example" @@ -26,7 +26,7 @@ func testOptions(t *testing.T) config.Options { if err != nil { t.Fatal(err) } - return *opts + return opts } func TestOptions_Validate(t *testing.T) { @@ -57,11 +57,11 @@ func TestOptions_Validate(t *testing.T) { tests := []struct { name string - o config.Options + o *config.Options wantErr bool }{ {"good - minimum options", good, false}, - {"nil options", config.Options{}, true}, + {"nil options", &config.Options{}, true}, {"authenticate service url", badAuthURL, true}, {"authenticate service url no scheme", authenticateBadScheme, true}, {"authorize service url no scheme", authorizeBadSCheme, true}, @@ -93,14 +93,13 @@ func TestNew(t *testing.T) { tests := []struct { name string - opts config.Options + opts *config.Options wantProxy bool wantErr bool }{ {"good", good, true, false}, - {"empty options", config.Options{}, false, true}, + {"empty options", &config.Options{}, false, true}, {"short secret/validate sanity check", shortCookieLength, false, true}, - {"invalid cookie name, empty", badCookie, false, true}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -169,8 +168,8 @@ func Test_UpdateOptions(t *testing.T) { tests := []struct { name string - originalOptions config.Options - updatedOptions config.Options + originalOptions *config.Options + updatedOptions *config.Options host string wantErr bool wantRoute bool @@ -198,26 +197,18 @@ func Test_UpdateOptions(t *testing.T) { t.Fatal(err) } - err = p.UpdateOptions(tt.updatedOptions) - if (err != nil) != tt.wantErr { - t.Errorf("UpdateOptions: err = %v, wantErr = %v", err, tt.wantErr) + p.OnConfigChange(&config.Config{Options: tt.updatedOptions}) + r := httptest.NewRequest("GET", tt.host, nil) + w := httptest.NewRecorder() + p.ServeHTTP(w, r) + if tt.wantRoute && w.Code != http.StatusNotFound { + t.Errorf("Failed to find route handler") return } - - // This is only safe if we actually can load policies - if err == nil { - r := httptest.NewRequest("GET", tt.host, nil) - w := httptest.NewRecorder() - p.ServeHTTP(w, r) - if tt.wantRoute && w.Code != http.StatusNotFound { - t.Errorf("Failed to find route handler") - return - } - } }) } // Test nil var p *Proxy - p.UpdateOptions(config.Options{}) + p.OnConfigChange(&config.Config{}) }