diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 206c0c3b6..60869adf6 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -39,18 +39,18 @@ func ValidateOptions(o *config.Options) error { // Authenticate contains data required to run the authenticate service. type Authenticate struct { - cfg *authenticateConfig - options *atomicutil.Value[*config.Options] - state *atomicutil.Value[*authenticateState] - webauthn *webauthn.Handler + cfg *authenticateConfig + currentConfig *atomicutil.Value[*config.Config] + state *atomicutil.Value[*authenticateState] + webauthn *webauthn.Handler } // New validates and creates a new authenticate service from a set of Options. func New(cfg *config.Config, options ...Option) (*Authenticate, error) { a := &Authenticate{ - cfg: getAuthenticateConfig(options...), - options: config.NewAtomicOptions(), - state: atomicutil.NewValue(newAuthenticateState()), + cfg: getAuthenticateConfig(options...), + currentConfig: atomicutil.NewValue(cfg), + state: atomicutil.NewValue(newAuthenticateState()), } a.webauthn = webauthn.New(a.getWebauthnState) @@ -69,7 +69,7 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) { return } - a.options.Store(cfg.Options) + a.currentConfig.Store(cfg) if state, err := newAuthenticateStateFromConfig(cfg); err != nil { log.Error(ctx).Err(err).Msg("authenticate: failed to update state") } else { diff --git a/authenticate/authenticate_test.go b/authenticate/authenticate_test.go index 255d6725a..0e254f62f 100644 --- a/authenticate/authenticate_test.go +++ b/authenticate/authenticate_test.go @@ -113,7 +113,7 @@ func TestNew(t *testing.T) { for _, tt := range tests { tt := tt t.Run(tt.name, func(t *testing.T) { - _, err := New(&config.Config{Options: tt.opts}) + _, err := New(config.New(tt.opts)) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/authenticate/config.go b/authenticate/config.go index 73f0af1d6..a135dedbb 100644 --- a/authenticate/config.go +++ b/authenticate/config.go @@ -6,7 +6,7 @@ import ( ) type authenticateConfig struct { - getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error) + getIdentityProvider func(cfg *config.Config, idpID string) (identity.Authenticator, error) } // An Option customizes the Authenticate config. @@ -22,7 +22,7 @@ func getAuthenticateConfig(options ...Option) *authenticateConfig { } // WithGetIdentityProvider sets the getIdentityProvider function in the config. -func WithGetIdentityProvider(getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)) Option { +func WithGetIdentityProvider(getIdentityProvider func(cfg *config.Config, idpID string) (identity.Authenticator, error)) Option { return func(cfg *authenticateConfig) { cfg.getIdentityProvider = getIdentityProvider } diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 0102107b5..e535bde04 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -47,12 +47,12 @@ func (a *Authenticate) Mount(r *mux.Router) { r.StrictSlash(true) r.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy)) r.Use(func(h http.Handler) http.Handler { - options := a.options.Load() + cfg := a.currentConfig.Load() state := a.state.Load() - csrfKey := fmt.Sprintf("%s_csrf", options.CookieName) + csrfKey := fmt.Sprintf("%s_csrf", cfg.Options.CookieName) return csrf.Protect( state.cookieSecret, - csrf.Secure(options.CookieSecure), + csrf.Secure(cfg.Options.CookieSecure), csrf.Path("/"), csrf.UnsafePaths( []string{ @@ -256,7 +256,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error { // check for an HMAC'd URL. If none is found, show a confirmation page. err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey) if err != nil { - authenticateURL, err := a.options.Load().GetAuthenticateURL() + authenticateURL, err := a.currentConfig.Load().Options.GetAuthenticateURL() if err != nil { return err } @@ -275,9 +275,9 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut") defer span.End() - options := a.options.Load() + cfg := a.currentConfig.Load() - idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) + idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID)) if err != nil { return err } @@ -285,7 +285,7 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e rawIDToken := a.revokeSession(ctx, w, r) redirectString := "" - signOutURL, err := a.options.Load().GetSignOutRedirectURL() + signOutURL, err := cfg.Options.GetSignOutRedirectURL() if err != nil { return err } @@ -330,10 +330,10 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque return httputil.NewError(http.StatusUnauthorized, err) } - options := a.options.Load() + cfg := a.currentConfig.Load() state := a.state.Load() - idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) + idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID)) if err != nil { return err } @@ -381,7 +381,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback") defer span.End() - options := a.options.Load() + cfg := a.currentConfig.Load() state := a.state.Load() // Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6 @@ -430,7 +430,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request) } idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID) - idp, err := a.cfg.getIdentityProvider(options, idpID) + idp, err := a.cfg.getIdentityProvider(cfg, idpID) if err != nil { return nil, err } @@ -513,7 +513,7 @@ func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error { func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData, error) { state := a.state.Load() - authenticateURL, err := a.options.Load().GetAuthenticateURL() + authenticateURL, err := a.currentConfig.Load().Options.GetAuthenticateURL() if err != nil { return handlers.UserInfoData{}, err } @@ -569,7 +569,7 @@ func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData, WebAuthnRequestOptions: requestOptions, WebAuthnURL: urlutil.WebAuthnURL(r, authenticateURL, state.sharedKey, r.URL.Query()), - BrandingOptions: a.options.Load().BrandingOptions, + BrandingOptions: a.currentConfig.Load().Options.BrandingOptions, }, nil } @@ -581,14 +581,14 @@ func (a *Authenticate) saveSessionToDataBroker( accessToken *oauth2.Token, ) error { state := a.state.Load() - options := a.options.Load() + cfg := a.currentConfig.Load() - idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) + idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID)) if err != nil { return err } - sessionExpiry := timestamppb.New(time.Now().Add(options.CookieExpire)) + sessionExpiry := timestamppb.New(time.Now().Add(cfg.Options.CookieExpire)) idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time()) s := &session.Session{ @@ -648,13 +648,13 @@ func (a *Authenticate) saveSessionToDataBroker( // databroker. If successful, it returns the original `id_token` of the session, if failed, returns // and empty string. func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string { - options := a.options.Load() + cfg := a.currentConfig.Load() state := a.state.Load() // clear the user's local session no matter what defer state.sessionStore.ClearSession(w, r) - idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) + idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID)) if err != nil { return "" } @@ -708,6 +708,7 @@ func (a *Authenticate) getDirectoryUser(ctx context.Context, userID string) (*di func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, error) { state := a.state.Load() + cfg := a.currentConfig.Load() s, _, err := a.getCurrentSession(ctx) if err != nil { @@ -719,17 +720,17 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e return nil, err } - authenticateURL, err := a.options.Load().GetAuthenticateURL() + authenticateURL, err := cfg.Options.GetAuthenticateURL() if err != nil { return nil, err } - internalAuthenticateURL, err := a.options.Load().GetInternalAuthenticateURL() + internalAuthenticateURL, err := cfg.Options.GetInternalAuthenticateURL() if err != nil { return nil, err } - pomeriumDomains, err := a.options.Load().GetAllRouteableHTTPDomains() + pomeriumDomains, err := cfg.Options.GetAllRouteableHTTPDomains() if err != nil { return nil, err } @@ -744,7 +745,7 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e SessionState: ss, SessionStore: state.sessionStore, RelyingParty: state.webauthnRelyingParty, - BrandingOptions: a.options.Load().BrandingOptions, + BrandingOptions: cfg.Options.BrandingOptions, }, nil } diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index e39abe0db..533924dff 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -49,10 +49,9 @@ func testAuthenticate() *Authenticate { redirectURL: redirectURL, cookieSecret: cryptutil.NewKey(), }) - auth.options = config.NewAtomicOptions() - auth.options.Store(&config.Options{ + auth.currentConfig = atomicutil.NewValue(config.New(&config.Options{ SharedKey: cryptutil.NewBase64Key(), - }) + })) return &auth } @@ -148,7 +147,7 @@ func TestAuthenticate_SignIn(t *testing.T) { sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) a := &Authenticate{ - cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { + cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) { return tt.provider, nil })), state: atomicutil.NewValue(&authenticateState{ @@ -168,10 +167,10 @@ func TestAuthenticate_SignIn(t *testing.T) { }, directoryClient: new(mockDirectoryServiceClient), }), - - options: config.NewAtomicOptions(), + currentConfig: atomicutil.NewValue(config.New(&config.Options{ + SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey()), + })), } - a.options.Store(&config.Options{SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey())}) uri := &url.URL{Scheme: tt.scheme, Host: tt.host} queryString := uri.Query() @@ -304,7 +303,7 @@ func TestAuthenticate_SignOut(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() a := &Authenticate{ - cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { + cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) { return tt.provider, nil })), state: atomicutil.NewValue(&authenticateState{ @@ -325,12 +324,9 @@ func TestAuthenticate_SignOut(t *testing.T) { }, directoryClient: new(mockDirectoryServiceClient), }), - options: config.NewAtomicOptions(), - } - if tt.signoutRedirectURL != "" { - opts := a.options.Load() - opts.SignOutRedirectURLString = tt.signoutRedirectURL - a.options.Store(opts) + currentConfig: atomicutil.NewValue(config.New(&config.Options{ + SignOutRedirectURLString: tt.signoutRedirectURL, + })), } u, _ := url.Parse("/sign_out") params, _ := url.ParseQuery(u.RawQuery) @@ -417,7 +413,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { } authURL, _ := url.Parse(tt.authenticateURL) a := &Authenticate{ - cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { + cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) { return tt.provider, nil })), state: atomicutil.NewValue(&authenticateState{ @@ -435,7 +431,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) { cookieCipher: aead, encryptedEncoder: signer, }), - options: config.NewAtomicOptions(), + currentConfig: atomicutil.NewValue(config.New(nil)), } u, _ := url.Parse("/oauthGet") params, _ := url.ParseQuery(u.RawQuery) @@ -552,7 +548,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { t.Fatal(err) } a := &Authenticate{ - cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { + cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) { return tt.provider, nil })), state: atomicutil.NewValue(&authenticateState{ @@ -573,7 +569,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { }, directoryClient: new(mockDirectoryServiceClient), }), - options: config.NewAtomicOptions(), + currentConfig: atomicutil.NewValue(config.New(nil)), } r := httptest.NewRequest("GET", "/", nil) state, err := tt.session.LoadSession(r) @@ -604,7 +600,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { func TestJwksEndpoint(t *testing.T) { o := newTestOptions(t) o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=" - auth, err := New(&config.Config{Options: o}) + auth, err := New(config.New(o)) if err != nil { t.Fatal(err) return @@ -632,12 +628,11 @@ func TestAuthenticate_userInfo(t *testing.T) { a.state = atomicutil.NewValue(&authenticateState{ cookieSecret: cryptutil.NewKey(), }) - a.options = config.NewAtomicOptions() - a.options.Store(&config.Options{ + a.currentConfig = atomicutil.NewValue(config.New(&config.Options{ SharedKey: cryptutil.NewBase64Key(), AuthenticateURLString: "https://authenticate.example.com", AuthenticateInternalURLString: "https://authenticate.service.cluster.local", - }) + })) err := a.userInfo(w, r) assert.NoError(t, err) assert.Equal(t, http.StatusFound, w.Code) @@ -687,13 +682,7 @@ func TestAuthenticate_userInfo(t *testing.T) { if err != nil { t.Fatal(err) } - o := config.NewAtomicOptions() - o.Store(&config.Options{ - AuthenticateURLString: "https://authenticate.localhost.pomerium.io", - SharedKey: "SHARED KEY", - }) a := &Authenticate{ - options: o, state: atomicutil.NewValue(&authenticateState{ sessionStore: tt.sessionStore, encryptedEncoder: signer, @@ -711,6 +700,10 @@ func TestAuthenticate_userInfo(t *testing.T) { }, directoryClient: new(mockDirectoryServiceClient), }), + currentConfig: atomicutil.NewValue(config.New(&config.Options{ + AuthenticateURLString: "https://authenticate.localhost.pomerium.io", + SharedKey: "SHARED KEY", + })), } a.webauthn = webauthn.New(a.getWebauthnState) r := httptest.NewRequest(tt.method, tt.url.String(), nil) diff --git a/authenticate/identity.go b/authenticate/identity.go index 55a123779..5067666fc 100644 --- a/authenticate/identity.go +++ b/authenticate/identity.go @@ -7,8 +7,8 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" ) -func defaultGetIdentityProvider(options *config.Options, idpID string) (identity.Authenticator, error) { - authenticateURL, err := options.GetAuthenticateURL() +func defaultGetIdentityProvider(cfg *config.Config, idpID string) (identity.Authenticator, error) { + authenticateURL, err := cfg.Options.GetAuthenticateURL() if err != nil { return nil, err } @@ -17,9 +17,9 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity if err != nil { return nil, err } - redirectURL.Path = options.AuthenticateCallbackPath + redirectURL.Path = cfg.Options.AuthenticateCallbackPath - idp, err := options.GetIdentityProviderForID(idpID) + idp, err := cfg.Options.GetIdentityProviderForID(idpID) if err != nil { return nil, err } diff --git a/authenticate/middleware.go b/authenticate/middleware.go index 86cc3cbdb..884e0459f 100644 --- a/authenticate/middleware.go +++ b/authenticate/middleware.go @@ -34,14 +34,14 @@ func (a *Authenticate) requireValidSignature(next httputil.HandlerFunc) http.Han } func (a *Authenticate) getExternalRequest(r *http.Request) *http.Request { - options := a.options.Load() + cfg := a.currentConfig.Load() - externalURL, err := options.GetAuthenticateURL() + externalURL, err := cfg.Options.GetAuthenticateURL() if err != nil { return r } - internalURL, err := options.GetInternalAuthenticateURL() + internalURL, err := cfg.Options.GetInternalAuthenticateURL() if err != nil { return r } diff --git a/authorize/authorize.go b/authorize/authorize.go index e0d152cf1..e20f92472 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -25,11 +25,11 @@ import ( // Authorize struct holds type Authorize struct { - state *atomicutil.Value[*authorizeState] - store *store.Store - currentOptions *atomicutil.Value[*config.Options] - accessTracker *AccessTracker - globalCache storage.Cache + state *atomicutil.Value[*authorizeState] + store *store.Store + currentConfig *atomicutil.Value[*config.Config] + accessTracker *AccessTracker + globalCache storage.Cache // The stateLock prevents updating the evaluator store simultaneously with an evaluation. // This should provide a consistent view of the data at a given server/record version and @@ -40,9 +40,9 @@ type Authorize struct { // New validates and creates a new Authorize service from a set of config options. func New(cfg *config.Config) (*Authorize, error) { a := &Authorize{ - currentOptions: config.NewAtomicOptions(), - store: store.New(), - globalCache: storage.NewGlobalCache(time.Minute), + currentConfig: atomicutil.NewValue(cfg), + store: store.New(), + globalCache: storage.NewGlobalCache(time.Minute), } a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod) @@ -86,42 +86,42 @@ func validateOptions(o *config.Options) error { } // newPolicyEvaluator returns an policy evaluator. -func newPolicyEvaluator(opts *config.Options, store *store.Store) (*evaluator.Evaluator, error) { +func newPolicyEvaluator(cfg *config.Config, store *store.Store) (*evaluator.Evaluator, error) { metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 { - return int64(len(opts.GetAllPolicies())) + return int64(len(cfg.Options.GetAllPolicies())) }) ctx := context.Background() _, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator") defer span.End() - clientCA, err := opts.GetClientCA() + clientCA, err := cfg.Options.GetClientCA() if err != nil { return nil, fmt.Errorf("authorize: invalid client CA: %w", err) } - authenticateURL, err := opts.GetInternalAuthenticateURL() + authenticateURL, err := cfg.Options.GetInternalAuthenticateURL() if err != nil { return nil, fmt.Errorf("authorize: invalid authenticate url: %w", err) } - signingKey, err := opts.GetSigningKey() + signingKey, err := cfg.Options.GetSigningKey() if err != nil { return nil, fmt.Errorf("authorize: invalid signing key: %w", err) } return evaluator.New(ctx, store, - evaluator.WithPolicies(opts.GetAllPolicies()), + evaluator.WithPolicies(cfg.Options.GetAllPolicies()), evaluator.WithClientCA(clientCA), evaluator.WithSigningKey(signingKey), evaluator.WithAuthenticateURL(authenticateURL.String()), - evaluator.WithGoogleCloudServerlessAuthenticationServiceAccount(opts.GetGoogleCloudServerlessAuthenticationServiceAccount()), - evaluator.WithJWTClaimsHeaders(opts.JWTClaimsHeaders), + evaluator.WithGoogleCloudServerlessAuthenticationServiceAccount(cfg.Options.GetGoogleCloudServerlessAuthenticationServiceAccount()), + evaluator.WithJWTClaimsHeaders(cfg.Options.JWTClaimsHeaders), ) } // OnConfigChange updates internal structures based on config.Options func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) { - a.currentOptions.Store(cfg.Options) + a.currentConfig.Store(cfg) if state, err := newAuthorizeStateFromConfig(cfg, a.store); err != nil { log.Error(ctx).Err(err).Msg("authorize: error updating state") } else { diff --git a/authorize/authorize_test.go b/authorize/authorize_test.go index d2dbf3114..c9c427343 100644 --- a/authorize/authorize_test.go +++ b/authorize/authorize_test.go @@ -74,7 +74,7 @@ func TestNew(t *testing.T) { tt := tt t.Run(tt.name, func(t *testing.T) { t.Parallel() - _, err := New(&config.Config{Options: &tt.config}) + _, err := New(config.New(&tt.config)) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return @@ -105,12 +105,12 @@ func TestAuthorize_OnConfigChange(t *testing.T) { SharedKey: tc.SharedKey, Policies: tc.Policies, } - a, err := New(&config.Config{Options: o}) + a, err := New(config.New(o)) require.NoError(t, err) require.NotNil(t, a) oldPe := a.state.Load().evaluator - cfg := &config.Config{Options: o} + cfg := config.New(o) assertFunc := assert.True o.SigningKey = "bad-share-key" if tc.expectedChange { diff --git a/authorize/check_response.go b/authorize/check_response.go index d006337fe..41ca0a27c 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -125,7 +125,7 @@ func (a *Authorize) deniedResponse( respBody := []byte(reason) respHeader := []*envoy_config_core_v3.HeaderValueOption{} - forwardAuthURL, _ := a.currentOptions.Load().GetForwardAuthURL() + forwardAuthURL, _ := a.currentConfig.Load().Options.GetForwardAuthURL() if forwardAuthURL == nil { // create a http response writer recorder w := httptest.NewRecorder() @@ -140,7 +140,7 @@ func (a *Authorize) deniedResponse( Err: errors.New(reason), DebugURL: debugEndpoint, RequestID: requestid.FromContext(ctx), - BrandingOptions: a.currentOptions.Load().BrandingOptions, + BrandingOptions: a.currentConfig.Load().Options.BrandingOptions, } httpErr.ErrorResponse(ctx, w, r) @@ -184,7 +184,7 @@ func (a *Authorize) requireLoginResponse( request *evaluator.Request, isForwardAuthVerify bool, ) (*envoy_service_auth_v3.CheckResponse, error) { - opts := a.currentOptions.Load() + opts := a.currentConfig.Load().Options state := a.state.Load() authenticateURL, err := opts.GetAuthenticateURL() if err != nil { @@ -225,7 +225,7 @@ func (a *Authorize) requireWebAuthnResponse( result *evaluator.Result, isForwardAuthVerify bool, ) (*envoy_service_auth_v3.CheckResponse, error) { - opts := a.currentOptions.Load() + opts := a.currentConfig.Load().Options state := a.state.Load() authenticateURL, err := opts.GetAuthenticateURL() if err != nil { @@ -295,7 +295,7 @@ func toEnvoyHeaders(headers http.Header) []*envoy_config_core_v3.HeaderValueOpti // userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's // session that lives on the authenticate service. func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) { - opts := a.currentOptions.Load() + opts := a.currentConfig.Load().Options authenticateURL, err := opts.GetAuthenticateURL() if err != nil { return nil, err diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index 1dbf50c9b..e94d6956d 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -29,7 +29,7 @@ func TestAuthorize_handleResult(t *testing.T) { opt.AuthenticateURLString = "https://authenticate.example.com" opt.DataBrokerURLString = "https://databroker.example.com" opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" - a, err := New(&config.Config{Options: opt}) + a, err := New(config.New(opt)) require.NoError(t, err) t.Run("user-unauthenticated", func(t *testing.T) { @@ -67,12 +67,12 @@ func TestAuthorize_okResponse(t *testing.T) { }}, JWTClaimsHeaders: config.NewJWTClaimHeaders("email"), } - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} + cfg := config.New(opt) + a := &Authorize{currentConfig: atomicutil.NewValue(cfg), state: atomicutil.NewValue(new(authorizeState))} encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) a.state.Load().encoder = encoder - a.currentOptions.Store(opt) a.store = store.New() - pe, err := newPolicyEvaluator(opt, a.store) + pe, err := newPolicyEvaluator(cfg, a.store) require.NoError(t, err) a.state.Load().evaluator = pe @@ -123,17 +123,17 @@ func TestAuthorize_okResponse(t *testing.T) { } func TestAuthorize_deniedResponse(t *testing.T) { - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} + a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))} encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) a.state.Load().encoder = encoder - a.currentOptions.Store(&config.Options{ + a.currentConfig.Store(config.New(&config.Options{ Policies: []config.Policy{{ Source: &config.StringURL{URL: &url.URL{Host: "example.com"}}, SubPolicies: []config.SubPolicy{{ Rego: []string{"allow = true"}, }}, }}, - }) + })) tests := []struct { name string @@ -190,7 +190,7 @@ func TestRequireLogin(t *testing.T) { opt.AuthenticateURLString = "https://authenticate.example.com" opt.DataBrokerURLString = "https://databroker.example.com" opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" - a, err := New(&config.Config{Options: opt}) + a, err := New(config.New(opt)) require.NoError(t, err) t.Run("accept empty", func(t *testing.T) { diff --git a/authorize/grpc.go b/authorize/grpc.go index 4816fe0d2..30adac339 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -55,7 +55,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe } } - rawJWT, _ := loadRawSession(hreq, a.currentOptions.Load(), state.encoder) + rawJWT, _ := loadRawSession(hreq, a.currentConfig.Load(), state.encoder) sessionState, _ := loadSession(state.encoder, rawJWT) var s sessionOrServiceAccount @@ -100,9 +100,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe // isForwardAuth returns if the current request is a forward auth route. func (a *Authorize) isForwardAuth(req *envoy_service_auth_v3.CheckRequest) bool { - opts := a.currentOptions.Load() + cfg := a.currentConfig.Load() - forwardAuthURL, err := opts.GetForwardAuthURL() + forwardAuthURL, err := cfg.Options.GetForwardAuthURL() if err != nil || forwardAuthURL == nil { return false } @@ -136,9 +136,9 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest( } func (a *Authorize) getMatchingPolicy(requestURL url.URL) *config.Policy { - options := a.currentOptions.Load() + cfg := a.currentConfig.Load() - for _, p := range options.GetAllPolicies() { + for _, p := range cfg.Options.GetAllPolicies() { if p.Matches(requestURL) { return &p } diff --git a/authorize/grpc_test.go b/authorize/grpc_test.go index 28a8cdbbb..ac552a899 100644 --- a/authorize/grpc_test.go +++ b/authorize/grpc_test.go @@ -47,17 +47,17 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA== -----END CERTIFICATE-----` func Test_getEvaluatorRequest(t *testing.T) { - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} + a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))} encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) a.state.Load().encoder = encoder - a.currentOptions.Store(&config.Options{ + a.currentConfig.Store(config.New(&config.Options{ Policies: []config.Policy{{ Source: &config.StringURL{URL: &url.URL{Host: "example.com"}}, SubPolicies: []config.SubPolicy{{ Rego: []string{"allow = true"}, }}, }}, - }) + })) actual, err := a.getEvaluatorRequestFromCheckRequest( &envoy_service_auth_v3.CheckRequest{ @@ -87,7 +87,7 @@ func Test_getEvaluatorRequest(t *testing.T) { ) require.NoError(t, err) expect := &evaluator.Request{ - Policy: &a.currentOptions.Load().Policies[0], + Policy: &a.currentConfig.Load().Options.Policies[0], Session: evaluator.RequestSession{ ID: "SESSION_ID", }, @@ -248,8 +248,10 @@ func Test_handleForwardAuth(t *testing.T) { for _, tc := range tests { tc := tc t.Run(tc.name, func(t *testing.T) { - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} - a.currentOptions.Store(&config.Options{ForwardAuthURLString: tc.forwardAuthURL}) + a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))} + a.currentConfig.Store(config.New(&config.Options{ + ForwardAuthURLString: tc.forwardAuthURL, + })) got := a.isForwardAuth(tc.checkReq) @@ -261,17 +263,17 @@ func Test_handleForwardAuth(t *testing.T) { } func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) { - a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} + a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))} encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) a.state.Load().encoder = encoder - a.currentOptions.Store(&config.Options{ + a.currentConfig.Store(config.New(&config.Options{ Policies: []config.Policy{{ Source: &config.StringURL{URL: &url.URL{Host: "example.com"}}, SubPolicies: []config.SubPolicy{{ Rego: []string{"allow = true"}, }}, }}, - }) + })) actual, err := a.getEvaluatorRequestFromCheckRequest(&envoy_service_auth_v3.CheckRequest{ Attributes: &envoy_service_auth_v3.AttributeContext{ @@ -296,7 +298,7 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) { }, nil) require.NoError(t, err) expect := &evaluator.Request{ - Policy: &a.currentOptions.Load().Policies[0], + Policy: &a.currentConfig.Load().Options.Policies[0], Session: evaluator.RequestSession{}, HTTP: evaluator.NewRequestHTTP( "GET", @@ -332,11 +334,13 @@ func TestAuthorize_Check(t *testing.T) { opt.AuthenticateURLString = "https://authenticate.example.com" opt.DataBrokerURLString = "https://databroker.example.com" opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" - a, err := New(&config.Config{Options: opt}) + a, err := New(config.New(opt)) if err != nil { t.Fatal(err) } - a.currentOptions.Store(&config.Options{ForwardAuthURLString: "https://forward-auth.example.com"}) + a.currentConfig.Store(config.New(&config.Options{ + ForwardAuthURLString: "https://forward-auth.example.com", + })) cmpOpts := []cmp.Option{ cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}), diff --git a/authorize/session.go b/authorize/session.go index c40b7add6..3fbde437e 100644 --- a/authorize/session.go +++ b/authorize/session.go @@ -13,9 +13,13 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" ) -func loadRawSession(req *http.Request, options *config.Options, encoder encoding.MarshalUnmarshaler) ([]byte, error) { +func loadRawSession( + req *http.Request, + cfg *config.Config, + encoder encoding.MarshalUnmarshaler, +) ([]byte, error) { var loaders []sessions.SessionLoader - cookieStore, err := getCookieStore(options, encoder) + cookieStore, err := getCookieStore(cfg, encoder) if err != nil { return nil, err } @@ -37,7 +41,10 @@ func loadRawSession(req *http.Request, options *config.Options, encoder encoding return nil, sessions.ErrNoSessionFound } -func loadSession(encoder encoding.MarshalUnmarshaler, rawJWT []byte) (*sessions.State, error) { +func loadSession( + encoder encoding.MarshalUnmarshaler, + rawJWT []byte, +) (*sessions.State, error) { var s sessions.State err := encoder.Unmarshal(rawJWT, &s) if err != nil { @@ -46,14 +53,17 @@ func loadSession(encoder encoding.MarshalUnmarshaler, rawJWT []byte) (*sessions. return &s, nil } -func getCookieStore(options *config.Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) { +func getCookieStore( + cfg *config.Config, + encoder encoding.MarshalUnmarshaler, +) (sessions.SessionStore, error) { cookieStore, err := cookie.NewStore(func() cookie.Options { return cookie.Options{ - Name: options.CookieName, - Domain: options.CookieDomain, - Secure: options.CookieSecure, - HTTPOnly: options.CookieHTTPOnly, - Expire: options.CookieExpire, + Name: cfg.Options.CookieName, + Domain: cfg.Options.CookieDomain, + Secure: cfg.Options.CookieSecure, + HTTPOnly: cfg.Options.CookieHTTPOnly, + Expire: cfg.Options.CookieExpire, } }, encoder) if err != nil { diff --git a/authorize/session_test.go b/authorize/session_test.go index 53b509e80..b2042118b 100644 --- a/authorize/session_test.go +++ b/authorize/session_test.go @@ -32,7 +32,7 @@ func TestLoadSession(t *testing.T) { }, }, }) - raw, err := loadRawSession(req, opts, encoder) + raw, err := loadRawSession(req, config.New(opts), encoder) if err != nil { return nil, err } diff --git a/authorize/state.go b/authorize/state.go index 6440c4a10..ad4b777fc 100644 --- a/authorize/state.go +++ b/authorize/state.go @@ -36,7 +36,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*autho var err error - state.evaluator, err = newPolicyEvaluator(cfg.Options, store) + state.evaluator, err = newPolicyEvaluator(cfg, store) if err != nil { return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err) } diff --git a/config/config.go b/config/config.go index 42716b45a..7b42862d0 100644 --- a/config/config.go +++ b/config/config.go @@ -31,6 +31,16 @@ type Config struct { MetricsScrapeEndpoints []MetricsScrapeEndpoint } +// New creates a new Config. +func New(options *Options) *Config { + if options == nil { + options = NewDefaultOptions() + } + return &Config{ + Options: options, + } +} + // Clone creates a clone of the config. func (cfg *Config) Clone() *Config { newOptions := new(Options) diff --git a/config/envoyconfig/bootstrap_test.go b/config/envoyconfig/bootstrap_test.go index 4b18056c7..fc7585d17 100644 --- a/config/envoyconfig/bootstrap_test.go +++ b/config/envoyconfig/bootstrap_test.go @@ -13,11 +13,9 @@ import ( func TestBuilder_BuildBootstrapAdmin(t *testing.T) { b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) t.Run("valid", func(t *testing.T) { - adminCfg, err := b.BuildBootstrapAdmin(&config.Config{ - Options: &config.Options{ - EnvoyAdminAddress: "localhost:9901", - }, - }) + adminCfg, err := b.BuildBootstrapAdmin(config.New(&config.Options{ + EnvoyAdminAddress: "localhost:9901", + })) assert.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { @@ -31,11 +29,9 @@ func TestBuilder_BuildBootstrapAdmin(t *testing.T) { `, adminCfg) }) t.Run("bad address", func(t *testing.T) { - _, err := b.BuildBootstrapAdmin(&config.Config{ - Options: &config.Options{ - EnvoyAdminAddress: "xyz1234:zyx4321", - }, - }) + _, err := b.BuildBootstrapAdmin(config.New(&config.Options{ + EnvoyAdminAddress: "xyz1234:zyx4321", + })) assert.Error(t, err) }) } @@ -111,11 +107,9 @@ func TestBuilder_BuildBootstrapStaticResources(t *testing.T) { func TestBuilder_BuildBootstrapStatsConfig(t *testing.T) { b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) t.Run("valid", func(t *testing.T) { - statsCfg, err := b.BuildBootstrapStatsConfig(&config.Config{ - Options: &config.Options{ - Services: "all", - }, - }) + statsCfg, err := b.BuildBootstrapStatsConfig(config.New(&config.Options{ + Services: "all", + })) assert.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { diff --git a/config/envoyconfig/clusters.go b/config/envoyconfig/clusters.go index b23efcf11..50d381a92 100644 --- a/config/envoyconfig/clusters.go +++ b/config/envoyconfig/clusters.go @@ -46,22 +46,22 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env return nil, err } - controlGRPC, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, upstreamProtocolHTTP2) + controlGRPC, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, upstreamProtocolHTTP2) if err != nil { return nil, err } - controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto) + controlHTTP, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto) if err != nil { return nil, err } - controlMetrics, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto) + controlMetrics, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto) if err != nil { return nil, err } - authorizeCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2) + authorizeCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2) if err != nil { return nil, err } @@ -70,7 +70,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env authorizeCluster.OutlierDetection = grpcAuthorizeOutlierDetection() } - databrokerCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2) + databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2) if err != nil { return nil, err } @@ -87,7 +87,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env databrokerCluster, } - tracingCluster, err := buildTracingCluster(cfg.Options) + tracingCluster, err := buildTracingCluster(cfg) if err != nil { return nil, err } else if tracingCluster != nil { @@ -101,7 +101,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env policy.EnvoyOpts = newDefaultEnvoyClusterConfig() } if len(policy.To) > 0 { - cluster, err := b.buildPolicyCluster(ctx, cfg.Options, &policy) + cluster, err := b.buildPolicyCluster(ctx, cfg, &policy) if err != nil { return nil, fmt.Errorf("policy #%d: %w", i, err) } @@ -119,16 +119,16 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env func (b *Builder) buildInternalCluster( ctx context.Context, - options *config.Options, + cfg *config.Config, name string, dsts []*url.URL, upstreamProtocol upstreamProtocolConfig, ) (*envoy_config_cluster_v3.Cluster, error) { cluster := newDefaultEnvoyClusterConfig() - cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily) + cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily) var endpoints []Endpoint for _, dst := range dsts { - ts, err := b.buildInternalTransportSocket(ctx, options, dst) + ts, err := b.buildInternalTransportSocket(ctx, cfg, dst) if err != nil { return nil, err } @@ -141,18 +141,22 @@ func (b *Builder) buildInternalCluster( return cluster, nil } -func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) { +func (b *Builder) buildPolicyCluster( + ctx context.Context, + cfg *config.Config, + policy *config.Policy, +) (*envoy_config_cluster_v3.Cluster, error) { cluster := new(envoy_config_cluster_v3.Cluster) proto.Merge(cluster, policy.EnvoyOpts) - if options.EnvoyBindConfigFreebind.IsSet() || options.EnvoyBindConfigSourceAddress != "" { + if cfg.Options.EnvoyBindConfigFreebind.IsSet() || cfg.Options.EnvoyBindConfigSourceAddress != "" { cluster.UpstreamBindConfig = new(envoy_config_core_v3.BindConfig) - if options.EnvoyBindConfigFreebind.IsSet() { - cluster.UpstreamBindConfig.Freebind = wrapperspb.Bool(options.EnvoyBindConfigFreebind.Bool) + if cfg.Options.EnvoyBindConfigFreebind.IsSet() { + cluster.UpstreamBindConfig.Freebind = wrapperspb.Bool(cfg.Options.EnvoyBindConfigFreebind.Bool) } - if options.EnvoyBindConfigSourceAddress != "" { + if cfg.Options.EnvoyBindConfigSourceAddress != "" { cluster.UpstreamBindConfig.SourceAddress = &envoy_config_core_v3.SocketAddress{ - Address: options.EnvoyBindConfigSourceAddress, + Address: cfg.Options.EnvoyBindConfigSourceAddress, PortSpecifier: &envoy_config_core_v3.SocketAddress_PortValue{ PortValue: 0, }, @@ -171,13 +175,13 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy) name := getClusterID(policy) - endpoints, err := b.buildPolicyEndpoints(ctx, options, policy) + endpoints, err := b.buildPolicyEndpoints(ctx, cfg, policy) if err != nil { return nil, err } if cluster.DnsLookupFamily == envoy_config_cluster_v3.Cluster_AUTO { - cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily) + cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily) } if policy.EnableGoogleCloudServerlessAuthentication { @@ -193,12 +197,12 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option func (b *Builder) buildPolicyEndpoints( ctx context.Context, - options *config.Options, + cfg *config.Config, policy *config.Policy, ) ([]Endpoint, error) { var endpoints []Endpoint for _, dst := range policy.To { - ts, err := b.buildPolicyTransportSocket(ctx, options, policy, dst.URL) + ts, err := b.buildPolicyTransportSocket(ctx, cfg, policy, dst.URL) if err != nil { return nil, err } @@ -209,7 +213,7 @@ func (b *Builder) buildPolicyEndpoints( func (b *Builder) buildInternalTransportSocket( ctx context.Context, - options *config.Options, + cfg *config.Config, endpoint *url.URL, ) (*envoy_config_core_v3.TransportSocket, error) { if endpoint.Scheme != "https" { @@ -218,10 +222,10 @@ func (b *Builder) buildInternalTransportSocket( validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{ MatchTypedSubjectAltNames: []*envoy_extensions_transport_sockets_tls_v3.SubjectAltNameMatcher{ - b.buildSubjectAltNameMatcher(endpoint, options.OverrideCertificateName), + b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName), }, } - bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile) + bs, err := getCombinedCertificateAuthority(cfg.Options.CA, cfg.Options.CAFile) if err != nil { log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found") } else { @@ -234,7 +238,7 @@ func (b *Builder) buildInternalTransportSocket( ValidationContext: validationContext, }, }, - Sni: b.buildSubjectNameIndication(endpoint, options.OverrideCertificateName), + Sni: b.buildSubjectNameIndication(endpoint, cfg.Options.OverrideCertificateName), } tlsConfig := marshalAny(tlsContext) return &envoy_config_core_v3.TransportSocket{ @@ -247,7 +251,7 @@ func (b *Builder) buildInternalTransportSocket( func (b *Builder) buildPolicyTransportSocket( ctx context.Context, - options *config.Options, + cfg *config.Config, policy *config.Policy, dst url.URL, ) (*envoy_config_core_v3.TransportSocket, error) { @@ -257,7 +261,7 @@ func (b *Builder) buildPolicyTransportSocket( upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy) - vc, err := b.buildPolicyValidationContext(ctx, options, policy, dst) + vc, err := b.buildPolicyValidationContext(ctx, cfg, policy, dst) if err != nil { return nil, err } @@ -318,7 +322,7 @@ func (b *Builder) buildPolicyTransportSocket( func (b *Builder) buildPolicyValidationContext( ctx context.Context, - options *config.Options, + cfg *config.Config, policy *config.Policy, dst url.URL, ) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) { @@ -343,7 +347,7 @@ func (b *Builder) buildPolicyValidationContext( } validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs) } else { - bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile) + bs, err := getCombinedCertificateAuthority(cfg.Options.CA, cfg.Options.CAFile) if err != nil { log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found") } else { diff --git a/config/envoyconfig/clusters_test.go b/config/envoyconfig/clusters_test.go index a433f3bb3..9204b112d 100644 --- a/config/envoyconfig/clusters_test.go +++ b/config/envoyconfig/clusters_test.go @@ -37,14 +37,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) { combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename() t.Run("insecure", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{ To: mustParseWeightedURLs(t, "http://example.com"), }, *mustParseURL(t, "http://example.com")) require.NoError(t, err) assert.Nil(t, ts) }) t.Run("host as sni", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), }, *mustParseURL(t, "https://example.com")) require.NoError(t, err) @@ -97,7 +97,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { `, ts) }) t.Run("tls_server_name as sni", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), TLSServerName: "use-this-name.example.com", }, *mustParseURL(t, "https://example.com")) @@ -151,7 +151,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { `, ts) }) t.Run("tls_upstream_server_name as sni", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), TLSUpstreamServerName: "use-this-name.example.com", }, *mustParseURL(t, "https://example.com")) @@ -205,7 +205,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { `, ts) }) t.Run("tls_skip_verify", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), TLSSkipVerify: true, }, *mustParseURL(t, "https://example.com")) @@ -260,7 +260,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { `, ts) }) t.Run("custom ca", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}), }, *mustParseURL(t, "https://example.com")) @@ -314,7 +314,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { `, ts) }) t.Run("options custom ca", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o2, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, config.New(o2), &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), }, *mustParseURL(t, "https://example.com")) require.NoError(t, err) @@ -368,7 +368,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { }) t.Run("client certificate", func(t *testing.T) { clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey) - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), ClientCertificate: clientCert, }, *mustParseURL(t, "https://example.com")) @@ -438,7 +438,7 @@ func Test_buildCluster(t *testing.T) { rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename() o1 := config.NewDefaultOptions() t.Run("insecure", func(t *testing.T) { - endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ + endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{ To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"), }) require.NoError(t, err) @@ -495,7 +495,7 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("secure", func(t *testing.T) { - endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ + endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com", "https://example.com", @@ -663,7 +663,7 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("ip addresses", func(t *testing.T) { - endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ + endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{ To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"), }) require.NoError(t, err) @@ -718,7 +718,7 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("weights", func(t *testing.T) { - endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ + endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{ To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"), }) require.NoError(t, err) @@ -775,7 +775,7 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("localhost", func(t *testing.T) { - endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ + endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{ To: mustParseWeightedURLs(t, "http://localhost"), }) require.NoError(t, err) @@ -821,7 +821,7 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("outlier", func(t *testing.T) { - endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ + endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{ To: mustParseWeightedURLs(t, "http://example.com"), }) require.NoError(t, err) @@ -904,7 +904,7 @@ func Test_bindConfig(t *testing.T) { b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) t.Run("no bind config", func(t *testing.T) { - cluster, err := b.buildPolicyCluster(ctx, &config.Options{}, &config.Policy{ + cluster, err := b.buildPolicyCluster(ctx, config.New(nil), &config.Policy{ From: "https://from.example.com", To: mustParseWeightedURLs(t, "https://to.example.com"), }) @@ -912,9 +912,9 @@ func Test_bindConfig(t *testing.T) { assert.Nil(t, cluster.UpstreamBindConfig) }) t.Run("freebind", func(t *testing.T) { - cluster, err := b.buildPolicyCluster(ctx, &config.Options{ + cluster, err := b.buildPolicyCluster(ctx, config.New(&config.Options{ EnvoyBindConfigFreebind: null.BoolFrom(true), - }, &config.Policy{ + }), &config.Policy{ From: "https://from.example.com", To: mustParseWeightedURLs(t, "https://to.example.com"), }) @@ -930,9 +930,9 @@ func Test_bindConfig(t *testing.T) { `, cluster.UpstreamBindConfig) }) t.Run("source address", func(t *testing.T) { - cluster, err := b.buildPolicyCluster(ctx, &config.Options{ + cluster, err := b.buildPolicyCluster(ctx, config.New(&config.Options{ EnvoyBindConfigSourceAddress: "192.168.0.1", - }, &config.Policy{ + }), &config.Policy{ From: "https://from.example.com", To: mustParseWeightedURLs(t, "https://to.example.com"), }) diff --git a/config/envoyconfig/envoyconfig.go b/config/envoyconfig/envoyconfig.go index 0edcfbb08..5c1a9a719 100644 --- a/config/envoyconfig/envoyconfig.go +++ b/config/envoyconfig/envoyconfig.go @@ -76,10 +76,10 @@ func newDefaultEnvoyClusterConfig() *envoy_config_cluster_v3.Cluster { } } -func buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog { - lvl := options.ProxyLogLevel +func buildAccessLogs(cfg *config.Config) []*envoy_config_accesslog_v3.AccessLog { + lvl := cfg.Options.ProxyLogLevel if lvl == "" { - lvl = options.LogLevel + lvl = cfg.Options.LogLevel } if lvl == "" { lvl = "debug" diff --git a/config/envoyconfig/http_connection_manager.go b/config/envoyconfig/http_connection_manager.go index b9e30c988..34b45e673 100644 --- a/config/envoyconfig/http_connection_manager.go +++ b/config/envoyconfig/http_connection_manager.go @@ -10,7 +10,7 @@ import ( ) func (b *Builder) buildVirtualHost( - options *config.Options, + cfg *config.Config, name string, domain string, ) (*envoy_config_route_v3.VirtualHost, error) { @@ -20,15 +20,15 @@ func (b *Builder) buildVirtualHost( } // these routes match /.pomerium/... and similar paths - rs, err := b.buildPomeriumHTTPRoutes(options, domain) + rs, err := b.buildPomeriumHTTPRoutes(cfg, domain) if err != nil { return nil, err } vh.Routes = append(vh.Routes, rs...) // if we're the proxy or authenticate service, add our global headers - if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) { - vh.ResponseHeadersToAdd = toEnvoyHeaders(options.GetSetResponseHeaders()) + if config.IsProxy(cfg.Options.Services) || config.IsAuthenticate(cfg.Options.Services) { + vh.ResponseHeadersToAdd = toEnvoyHeaders(cfg.Options.GetSetResponseHeaders()) } return vh, nil @@ -37,13 +37,13 @@ func (b *Builder) buildVirtualHost( // buildLocalReplyConfig builds the local reply config: the config used to modify "local" replies, that is replies // coming directly from envoy func (b *Builder) buildLocalReplyConfig( - options *config.Options, + cfg *config.Config, ) *envoy_http_connection_manager.LocalReplyConfig { // add global headers for HSTS headers (#2110) var headers []*envoy_config_core_v3.HeaderValueOption // if we're the proxy or authenticate service, add our global headers - if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) { - headers = toEnvoyHeaders(options.GetSetResponseHeaders()) + if config.IsProxy(cfg.Options.Services) || config.IsAuthenticate(cfg.Options.Services) { + headers = toEnvoyHeaders(cfg.Options.GetSetResponseHeaders()) } return &envoy_http_connection_manager.LocalReplyConfig{ diff --git a/config/envoyconfig/listeners.go b/config/envoyconfig/listeners.go index b5da6ece2..e92d12329 100644 --- a/config/envoyconfig/listeners.go +++ b/config/envoyconfig/listeners.go @@ -53,7 +53,10 @@ func init() { } // BuildListeners builds envoy listeners from the given config. -func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*envoy_config_listener_v3.Listener, error) { +func (b *Builder) BuildListeners( + ctx context.Context, + cfg *config.Config, +) ([]*envoy_config_listener_v3.Listener, error) { var listeners []*envoy_config_listener_v3.Listener if config.IsAuthenticate(cfg.Options.Services) || config.IsProxy(cfg.Options.Services) { @@ -89,19 +92,22 @@ func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*en return listeners, nil } -func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) { +func (b *Builder) buildMainListener( + ctx context.Context, + cfg *config.Config, +) (*envoy_config_listener_v3.Listener, error) { listenerFilters := []*envoy_config_listener_v3.ListenerFilter{} if cfg.Options.UseProxyProtocol { listenerFilters = append(listenerFilters, ProxyProtocolFilter()) } if cfg.Options.InsecureServer { - allDomains, err := getAllRouteableDomains(cfg.Options, cfg.Options.Addr) + allDomains, err := getAllRouteableDomains(cfg, cfg.Options.Addr) if err != nil { return nil, err } - filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, allDomains, "") + filter, err := b.buildMainHTTPConnectionManagerFilter(cfg, allDomains, "") if err != nil { return nil, err } @@ -118,9 +124,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e } listenerFilters = append(listenerFilters, TLSInspectorFilter()) - chains, err := b.buildFilterChains(cfg.Options, cfg.Options.Addr, + chains, err := b.buildFilterChains(cfg, cfg.Options.Addr, func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { - filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain) + filter, err := b.buildMainHTTPConnectionManagerFilter(cfg, httpDomains, tlsDomain) if err != nil { return nil, err } @@ -155,7 +161,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e return li, nil } -func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listener_v3.Listener, error) { +func (b *Builder) buildMetricsListener( + cfg *config.Config, +) (*envoy_config_listener_v3.Listener, error) { filter, err := b.buildMetricsHTTPConnectionManagerFilter() if err != nil { return nil, err @@ -235,22 +243,23 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen } func (b *Builder) buildFilterChains( - options *config.Options, addr string, + cfg *config.Config, + addr string, callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error), ) ([]*envoy_config_listener_v3.FilterChain, error) { - allDomains, err := getAllRouteableDomains(options, addr) + allDomains, err := getAllRouteableDomains(cfg, addr) if err != nil { return nil, err } - tlsDomains, err := getAllTLSDomains(options, addr) + tlsDomains, err := getAllTLSDomains(cfg, addr) if err != nil { return nil, err } var chains []*envoy_config_listener_v3.FilterChain for _, domain := range tlsDomains { - routeableDomains, err := getRouteableDomainsForTLSServerName(options, addr, domain) + routeableDomains, err := getRouteableDomainsForTLSServerName(cfg, addr, domain) if err != nil { return nil, err } @@ -273,31 +282,31 @@ func (b *Builder) buildFilterChains( } func (b *Builder) buildMainHTTPConnectionManagerFilter( - options *config.Options, + cfg *config.Config, domains []string, tlsDomain string, ) (*envoy_config_listener_v3.Filter, error) { - authorizeURLs, err := options.GetInternalAuthorizeURLs() + authorizeURLs, err := cfg.Options.GetInternalAuthorizeURLs() if err != nil { return nil, err } - dataBrokerURLs, err := options.GetInternalDataBrokerURLs() + dataBrokerURLs, err := cfg.Options.GetInternalDataBrokerURLs() if err != nil { return nil, err } var virtualHosts []*envoy_config_route_v3.VirtualHost for _, domain := range domains { - vh, err := b.buildVirtualHost(options, domain, domain) + vh, err := b.buildVirtualHost(cfg, domain, domain) if err != nil { return nil, err } - if options.Addr == options.GetGRPCAddr() { + if cfg.Options.Addr == cfg.Options.GetGRPCAddr() { // if this is a gRPC service domain and we're supposed to handle that, add those routes - if (config.IsAuthorize(options.Services) && hostsMatchDomain(authorizeURLs, domain)) || - (config.IsDataBroker(options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) { + if (config.IsAuthorize(cfg.Options.Services) && hostsMatchDomain(authorizeURLs, domain)) || + (config.IsDataBroker(cfg.Options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) { rs, err := b.buildGRPCRoutes() if err != nil { return nil, err @@ -307,8 +316,8 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter( } // if we're the proxy, add all the policy routes - if config.IsProxy(options.Services) { - rs, err := b.buildPolicyRoutes(options, domain) + if config.IsProxy(cfg.Options.Services) { + rs, err := b.buildPolicyRoutes(cfg, domain) if err != nil { return nil, err } @@ -320,15 +329,15 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter( } } - vh, err := b.buildVirtualHost(options, "catch-all", "*") + vh, err := b.buildVirtualHost(cfg, "catch-all", "*") if err != nil { return nil, err } virtualHosts = append(virtualHosts, vh) var grpcClientTimeout *durationpb.Duration - if options.GRPCClientTimeout != 0 { - grpcClientTimeout = durationpb.New(options.GRPCClientTimeout) + if cfg.Options.GRPCClientTimeout != 0 { + grpcClientTimeout = durationpb.New(cfg.Options.GRPCClientTimeout) } else { grpcClientTimeout = durationpb.New(30 * time.Second) } @@ -346,15 +355,15 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter( filters = append(filters, HTTPRouterFilter()) var maxStreamDuration *durationpb.Duration - if options.WriteTimeout > 0 { - maxStreamDuration = durationpb.New(options.WriteTimeout) + if cfg.Options.WriteTimeout > 0 { + maxStreamDuration = durationpb.New(cfg.Options.WriteTimeout) } rc, err := b.buildRouteConfiguration("main", virtualHosts) if err != nil { return nil, err } - tracingProvider, err := buildTracingHTTP(options) + tracingProvider, err := buildTracingHTTP(cfg) if err != nil { return nil, err } @@ -362,27 +371,27 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter( return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{ AlwaysSetRequestIdInResponse: true, - CodecType: options.GetCodecType().ToEnvoy(), + CodecType: cfg.Options.GetCodecType().ToEnvoy(), StatPrefix: "ingress", RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{ RouteConfig: rc, }, HttpFilters: filters, - AccessLog: buildAccessLogs(options), + AccessLog: buildAccessLogs(cfg), CommonHttpProtocolOptions: &envoy_config_core_v3.HttpProtocolOptions{ - IdleTimeout: durationpb.New(options.IdleTimeout), + IdleTimeout: durationpb.New(cfg.Options.IdleTimeout), MaxStreamDuration: maxStreamDuration, }, - RequestTimeout: durationpb.New(options.ReadTimeout), + RequestTimeout: durationpb.New(cfg.Options.ReadTimeout), Tracing: &envoy_http_connection_manager.HttpConnectionManager_Tracing{ - RandomSampling: &envoy_type_v3.Percent{Value: options.TracingSampleRate * 100}, + RandomSampling: &envoy_type_v3.Percent{Value: cfg.Options.TracingSampleRate * 100}, Provider: tracingProvider, }, // See https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_conn_man/headers#x-forwarded-for UseRemoteAddress: &wrappers.BoolValue{Value: true}, - SkipXffAppend: options.SkipXffAppend, - XffNumTrustedHops: options.XffNumTrustedHops, - LocalReplyConfig: b.buildLocalReplyConfig(options), + SkipXffAppend: cfg.Options.SkipXffAppend, + XffNumTrustedHops: cfg.Options.XffNumTrustedHops, + LocalReplyConfig: b.buildLocalReplyConfig(cfg), }), nil } @@ -420,7 +429,10 @@ func (b *Builder) buildMetricsHTTPConnectionManagerFilter() (*envoy_config_liste }), nil } -func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) { +func (b *Builder) buildGRPCListener( + ctx context.Context, + cfg *config.Config, +) (*envoy_config_listener_v3.Listener, error) { filter, err := b.buildGRPCHTTPConnectionManagerFilter() if err != nil { return nil, err @@ -437,7 +449,7 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e return li, nil } - chains, err := b.buildFilterChains(cfg.Options, cfg.Options.GRPCAddr, + chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr, func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { filterChain := &envoy_config_listener_v3.FilterChain{ Filters: []*envoy_config_listener_v3.Filter{filter}, @@ -518,7 +530,10 @@ func (b *Builder) buildGRPCHTTPConnectionManagerFilter() (*envoy_config_listener }), nil } -func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3.VirtualHost) (*envoy_config_route_v3.RouteConfiguration, error) { +func (b *Builder) buildRouteConfiguration( + name string, + virtualHosts []*envoy_config_route_v3.VirtualHost, +) (*envoy_config_route_v3.RouteConfiguration, error) { return &envoy_config_route_v3.RouteConfiguration{ Name: name, VirtualHosts: virtualHosts, @@ -527,7 +542,8 @@ func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_con }, nil } -func (b *Builder) buildDownstreamTLSContext(ctx context.Context, +func (b *Builder) buildDownstreamTLSContext( + ctx context.Context, cfg *config.Config, domain string, ) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { @@ -570,7 +586,8 @@ func (b *Builder) buildDownstreamTLSContext(ctx context.Context, } } -func (b *Builder) buildDownstreamValidationContext(ctx context.Context, +func (b *Builder) buildDownstreamValidationContext( + ctx context.Context, cfg *config.Config, domain string, ) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext { @@ -580,7 +597,7 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context, needsClientCert = true } if !needsClientCert { - for _, p := range getPoliciesForDomain(cfg.Options, domain) { + for _, p := range getPoliciesForDomain(cfg, domain) { if p.TLSDownstreamClientCA != "" { needsClientCert = true break @@ -613,19 +630,23 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context, return vc } -func getRouteableDomainsForTLSServerName(options *config.Options, addr string, tlsServerName string) ([]string, error) { +func getRouteableDomainsForTLSServerName( + cfg *config.Config, + addr string, + tlsServerName string, +) ([]string, error) { allDomains := sets.NewSorted[string]() - if addr == options.Addr { - domains, err := options.GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName) + if addr == cfg.Options.Addr { + domains, err := cfg.Options.GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName) if err != nil { return nil, err } allDomains.Add(domains...) } - if addr == options.GetGRPCAddr() { - domains, err := options.GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName) + if addr == cfg.Options.GetGRPCAddr() { + domains, err := cfg.Options.GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName) if err != nil { return nil, err } @@ -635,19 +656,22 @@ func getRouteableDomainsForTLSServerName(options *config.Options, addr string, t return allDomains.ToSlice(), nil } -func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) { +func getAllRouteableDomains( + cfg *config.Config, + addr string, +) ([]string, error) { allDomains := sets.NewSorted[string]() - if addr == options.Addr { - domains, err := options.GetAllRouteableHTTPDomains() + if addr == cfg.Options.Addr { + domains, err := cfg.Options.GetAllRouteableHTTPDomains() if err != nil { return nil, err } allDomains.Add(domains...) } - if addr == options.GetGRPCAddr() { - domains, err := options.GetAllRouteableGRPCDomains() + if addr == cfg.Options.GetGRPCAddr() { + domains, err := cfg.Options.GetAllRouteableGRPCDomains() if err != nil { return nil, err } @@ -657,8 +681,11 @@ func getAllRouteableDomains(options *config.Options, addr string) ([]string, err return allDomains.ToSlice(), nil } -func getAllTLSDomains(options *config.Options, addr string) ([]string, error) { - allDomains, err := getAllRouteableDomains(options, addr) +func getAllTLSDomains( + cfg *config.Config, + addr string, +) ([]string, error) { + allDomains, err := getAllRouteableDomains(cfg, addr) if err != nil { return nil, err } @@ -711,9 +738,12 @@ func hostMatchesDomain(u *url.URL, host string) bool { return h1 == h2 && p1 == p2 } -func getPoliciesForDomain(options *config.Options, domain string) []config.Policy { +func getPoliciesForDomain( + cfg *config.Config, + domain string, +) []config.Policy { var policies []config.Policy - for _, p := range options.GetAllPolicies() { + for _, p := range cfg.Options.GetAllPolicies() { if p.Source != nil && p.Source.URL.Hostname() == domain { policies = append(policies, p) } diff --git a/config/envoyconfig/listeners_test.go b/config/envoyconfig/listeners_test.go index 861dcac08..7a80b044b 100644 --- a/config/envoyconfig/listeners_test.go +++ b/config/envoyconfig/listeners_test.go @@ -26,13 +26,11 @@ func Test_buildMetricsHTTPConnectionManagerFilter(t *testing.T) { keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem") b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) - li, err := b.buildMetricsListener(&config.Config{ - Options: &config.Options{ - MetricsAddr: "127.0.0.1:9902", - MetricsCertificate: aExampleComCert, - MetricsCertificateKey: aExampleComKey, - }, - }) + li, err := b.buildMetricsListener(config.New(&config.Options{ + MetricsAddr: "127.0.0.1:9902", + MetricsCertificate: aExampleComCert, + MetricsCertificateKey: aExampleComKey, + })) require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { @@ -115,7 +113,7 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) { options := config.NewDefaultOptions() options.SkipXffAppend = true options.XffNumTrustedHops = 1 - filter, err := b.buildMainHTTPConnectionManagerFilter(options, []string{"example.com"}, "*") + filter, err := b.buildMainHTTPConnectionManagerFilter(config.New(options), []string{"example.com"}, "*") require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `{ "name": "envoy.filters.network.http_connection_manager", @@ -544,10 +542,10 @@ func Test_buildDownstreamTLSContext(t *testing.T) { keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem") t.Run("no-validation", func(t *testing.T) { - downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ + downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{ Cert: aExampleComCert, Key: aExampleComKey, - }}, "a.example.com") + }), "a.example.com") testutil.AssertProtoJSONEqual(t, `{ "commonTlsContext": { @@ -577,11 +575,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) { }`, downstreamTLSContext) }) t.Run("client-ca", func(t *testing.T) { - downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ + downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{ Cert: aExampleComCert, Key: aExampleComKey, ClientCA: "TEST", - }}, "a.example.com") + }), "a.example.com") testutil.AssertProtoJSONEqual(t, `{ "commonTlsContext": { @@ -614,7 +612,7 @@ func Test_buildDownstreamTLSContext(t *testing.T) { }`, downstreamTLSContext) }) t.Run("policy-client-ca", func(t *testing.T) { - downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ + downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{ Cert: aExampleComCert, Key: aExampleComKey, Policies: []config.Policy{ @@ -623,7 +621,7 @@ func Test_buildDownstreamTLSContext(t *testing.T) { TLSDownstreamClientCA: "TEST", }, }, - }}, "a.example.com") + }), "a.example.com") testutil.AssertProtoJSONEqual(t, `{ "commonTlsContext": { @@ -656,11 +654,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) { }`, downstreamTLSContext) }) t.Run("http1", func(t *testing.T) { - downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ + downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{ Cert: aExampleComCert, Key: aExampleComKey, CodecType: config.CodecTypeHTTP1, - }}, "a.example.com") + }), "a.example.com") testutil.AssertProtoJSONEqual(t, `{ "commonTlsContext": { @@ -690,11 +688,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) { }`, downstreamTLSContext) }) t.Run("http2", func(t *testing.T) { - downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ + downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{ Cert: aExampleComCert, Key: aExampleComKey, CodecType: config.CodecTypeHTTP2, - }}, "a.example.com") + }), "a.example.com") testutil.AssertProtoJSONEqual(t, `{ "commonTlsContext": { @@ -739,9 +737,10 @@ func Test_getAllDomains(t *testing.T) { {Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}}, }, } + cfg := config.New(options) t.Run("routable", func(t *testing.T) { t.Run("http", func(t *testing.T) { - actual, err := getAllRouteableDomains(options, "127.0.0.1:9000") + actual, err := getAllRouteableDomains(cfg, "127.0.0.1:9000") require.NoError(t, err) expect := []string{ "a.example.com", @@ -756,7 +755,7 @@ func Test_getAllDomains(t *testing.T) { assert.Equal(t, expect, actual) }) t.Run("grpc", func(t *testing.T) { - actual, err := getAllRouteableDomains(options, "127.0.0.1:9001") + actual, err := getAllRouteableDomains(cfg, "127.0.0.1:9001") require.NoError(t, err) expect := []string{ "authorize.example.com:9001", @@ -765,9 +764,9 @@ func Test_getAllDomains(t *testing.T) { assert.Equal(t, expect, actual) }) t.Run("both", func(t *testing.T) { - newOptions := *options - newOptions.GRPCAddr = newOptions.Addr - actual, err := getAllRouteableDomains(&newOptions, "127.0.0.1:9000") + newCfg := cfg.Clone() + newCfg.Options.GRPCAddr = cfg.Options.Addr + actual, err := getAllRouteableDomains(newCfg, "127.0.0.1:9000") require.NoError(t, err) expect := []string{ "a.example.com", @@ -786,7 +785,7 @@ func Test_getAllDomains(t *testing.T) { }) t.Run("tls", func(t *testing.T) { t.Run("http", func(t *testing.T) { - actual, err := getAllTLSDomains(options, "127.0.0.1:9000") + actual, err := getAllTLSDomains(cfg, "127.0.0.1:9000") require.NoError(t, err) expect := []string{ "a.example.com", @@ -797,7 +796,7 @@ func Test_getAllDomains(t *testing.T) { assert.Equal(t, expect, actual) }) t.Run("grpc", func(t *testing.T) { - actual, err := getAllTLSDomains(options, "127.0.0.1:9001") + actual, err := getAllTLSDomains(cfg, "127.0.0.1:9001") require.NoError(t, err) expect := []string{ "authorize.example.com", @@ -831,10 +830,10 @@ func Test_buildRouteConfiguration(t *testing.T) { func Test_requireProxyProtocol(t *testing.T) { b := New("local-grpc", "local-http", "local-metrics", nil, nil) t.Run("required", func(t *testing.T) { - li, err := b.buildMainListener(context.Background(), &config.Config{Options: &config.Options{ + li, err := b.buildMainListener(context.Background(), config.New(&config.Options{ UseProxyProtocol: true, InsecureServer: true, - }}) + })) require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ { @@ -846,10 +845,10 @@ func Test_requireProxyProtocol(t *testing.T) { ]`, li.GetListenerFilters()) }) t.Run("not required", func(t *testing.T) { - li, err := b.buildMainListener(context.Background(), &config.Config{Options: &config.Options{ + li, err := b.buildMainListener(context.Background(), config.New(&config.Options{ UseProxyProtocol: false, InsecureServer: true, - }}) + })) require.NoError(t, err) assert.Len(t, li.GetListenerFilters(), 0) }) diff --git a/config/envoyconfig/outbound.go b/config/envoyconfig/outbound.go index ab52b1b44..816425de8 100644 --- a/config/envoyconfig/outbound.go +++ b/config/envoyconfig/outbound.go @@ -14,7 +14,9 @@ import ( "github.com/pomerium/pomerium/config" ) -func (b *Builder) buildOutboundListener(cfg *config.Config) (*envoy_config_listener_v3.Listener, error) { +func (b *Builder) buildOutboundListener( + cfg *config.Config, +) (*envoy_config_listener_v3.Listener, error) { outboundPort, err := strconv.Atoi(cfg.OutboundPort) if err != nil { return nil, fmt.Errorf("invalid outbound port %v: %w", cfg.OutboundPort, err) diff --git a/config/envoyconfig/routes.go b/config/envoyconfig/routes.go index be702eb74..f6985aa13 100644 --- a/config/envoyconfig/routes.go +++ b/config/envoyconfig/routes.go @@ -47,12 +47,15 @@ func (b *Builder) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) { }}, nil } -func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) { +func (b *Builder) buildPomeriumHTTPRoutes( + cfg *config.Config, + domain string, +) ([]*envoy_config_route_v3.Route, error) { var routes []*envoy_config_route_v3.Route // if this is the pomerium proxy in front of the the authenticate service, don't add // these routes since they will be handled by authenticate - isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain) + isFrontingAuthenticate, err := isProxyFrontingAuthenticate(cfg, domain) if err != nil { return nil, err } @@ -69,27 +72,27 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string b.buildControlPlanePrefixRoute("/.well-known/pomerium/", false), ) // per #837, only add robots.txt if there are no unauthenticated routes - if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) { + if !hasPublicPolicyMatchingURL(cfg, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) { routes = append(routes, b.buildControlPlanePathRoute("/robots.txt", false)) } } // if we're handling authentication, add the oauth2 callback url - authenticateURL, err := options.GetInternalAuthenticateURL() + authenticateURL, err := cfg.Options.GetInternalAuthenticateURL() if err != nil { return nil, err } - if config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) { + if config.IsAuthenticate(cfg.Options.Services) && hostMatchesDomain(authenticateURL, domain) { routes = append(routes, - b.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false), + b.buildControlPlanePathRoute(cfg.Options.AuthenticateCallbackPath, false), b.buildControlPlanePathRoute("/", false), ) } // if we're the proxy and this is the forward-auth url - forwardAuthURL, err := options.GetForwardAuthURL() + forwardAuthURL, err := cfg.Options.GetForwardAuthURL() if err != nil { return nil, err } - if config.IsProxy(options.Services) && hostMatchesDomain(forwardAuthURL, domain) { + if config.IsProxy(cfg.Options.Services) && hostMatchesDomain(forwardAuthURL, domain) { // disable ext_authz and pass request to proxy handlers that enable authN flow r, err := b.buildControlPlanePathAndQueryRoute("/verify", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI}) if err != nil { @@ -227,10 +230,13 @@ func getClusterStatsName(policy *config.Policy) string { return "" } -func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) { +func (b *Builder) buildPolicyRoutes( + cfg *config.Config, + domain string, +) ([]*envoy_config_route_v3.Route, error) { var routes []*envoy_config_route_v3.Route - for i, p := range options.GetAllPolicies() { + for i, p := range cfg.Options.GetAllPolicies() { policy := p if !hostMatchesDomain(policy.Source.URL, domain) { continue @@ -242,7 +248,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]* Match: match, Metadata: &envoy_config_core_v3.Metadata{}, RequestHeadersToAdd: toEnvoyHeaders(policy.SetRequestHeaders), - RequestHeadersToRemove: getRequestHeadersToRemove(options, &policy), + RequestHeadersToRemove: getRequestHeadersToRemove(cfg, &policy), ResponseHeadersToAdd: toEnvoyHeaders(policy.SetResponseHeaders), } if policy.Redirect != nil { @@ -252,7 +258,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]* } envoyRoute.Action = &envoy_config_route_v3.Route_Redirect{Redirect: action} } else { - action, err := b.buildPolicyRouteRouteAction(options, &policy) + action, err := b.buildPolicyRouteRouteAction(cfg, &policy) if err != nil { return nil, err } @@ -264,7 +270,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]* } // disable authentication entirely when the proxy is fronting authenticate - isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain) + isFrontingAuthenticate, err := isProxyFrontingAuthenticate(cfg, domain) if err != nil { return nil, err } @@ -275,7 +281,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]* } else { luaMetadata["remove_pomerium_cookie"] = &structpb.Value{ Kind: &structpb.Value_StringValue{ - StringValue: options.CookieName, + StringValue: cfg.Options.CookieName, }, } luaMetadata["remove_pomerium_authorization"] = &structpb.Value{ @@ -350,13 +356,16 @@ func (b *Builder) buildPolicyRouteRedirectAction(r *config.PolicyRedirect) (*env return action, nil } -func (b *Builder) buildPolicyRouteRouteAction(options *config.Options, policy *config.Policy) (*envoy_config_route_v3.RouteAction, error) { +func (b *Builder) buildPolicyRouteRouteAction( + cfg *config.Config, + policy *config.Policy, +) (*envoy_config_route_v3.RouteAction, error) { clusterName := getClusterID(policy) // kubernetes requests are sent to the http control plane to be reproxied if policy.IsForKubernetes() { clusterName = httpCluster } - routeTimeout := getRouteTimeout(options, policy) + routeTimeout := getRouteTimeout(cfg, policy) idleTimeout := getRouteIdleTimeout(policy) prefixRewrite, regexRewrite := getRewriteOptions(policy) upgradeConfigs := []*envoy_config_route_v3.RouteAction_UpgradeConfig{ @@ -464,13 +473,16 @@ func mkRouteMatch(policy *config.Policy) *envoy_config_route_v3.RouteMatch { return match } -func getRequestHeadersToRemove(options *config.Options, policy *config.Policy) []string { +func getRequestHeadersToRemove( + cfg *config.Config, + policy *config.Policy, +) []string { requestHeadersToRemove := policy.RemoveRequestHeaders if !policy.PassIdentityHeaders { requestHeadersToRemove = append(requestHeadersToRemove, httputil.HeaderPomeriumJWTAssertion, httputil.HeaderPomeriumJWTAssertionFor) - for headerName := range options.JWTClaimsHeaders { + for headerName := range cfg.Options.JWTClaimsHeaders { requestHeadersToRemove = append(requestHeadersToRemove, headerName) } } @@ -482,7 +494,10 @@ func getRequestHeadersToRemove(options *config.Options, policy *config.Policy) [ return requestHeadersToRemove } -func getRouteTimeout(options *config.Options, policy *config.Policy) *durationpb.Duration { +func getRouteTimeout( + cfg *config.Config, + policy *config.Policy, +) *durationpb.Duration { var routeTimeout *durationpb.Duration if policy.UpstreamTimeout != nil { routeTimeout = durationpb.New(*policy.UpstreamTimeout) @@ -490,7 +505,7 @@ func getRouteTimeout(options *config.Options, policy *config.Policy) *durationpb // a non-zero value would conflict with idleTimeout and/or websocket / tcp calls routeTimeout = durationpb.New(0) } else { - routeTimeout = durationpb.New(options.DefaultUpstreamTimeout) + routeTimeout = durationpb.New(cfg.Options.DefaultUpstreamTimeout) } return routeTimeout } @@ -564,8 +579,11 @@ func setHostRewriteOptions(policy *config.Policy, action *envoy_config_route_v3. } } -func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) bool { - for _, policy := range options.GetAllPolicies() { +func hasPublicPolicyMatchingURL( + cfg *config.Config, + requestURL url.URL, +) bool { + for _, policy := range cfg.Options.GetAllPolicies() { if policy.AllowPublicUnauthenticatedAccess && policy.Matches(requestURL) { return true } @@ -573,13 +591,16 @@ func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) boo return false } -func isProxyFrontingAuthenticate(options *config.Options, domain string) (bool, error) { - authenticateURL, err := options.GetAuthenticateURL() +func isProxyFrontingAuthenticate( + cfg *config.Config, + domain string, +) (bool, error) { + authenticateURL, err := cfg.Options.GetAuthenticateURL() if err != nil { return false, err } - if !config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) { + if !config.IsAuthenticate(cfg.Options.Services) && hostMatchesDomain(authenticateURL, domain) { return true, nil } diff --git a/config/envoyconfig/routes_test.go b/config/envoyconfig/routes_test.go index 3830ac66e..423e093ec 100644 --- a/config/envoyconfig/routes_test.go +++ b/config/envoyconfig/routes_test.go @@ -84,7 +84,8 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) { AuthenticateCallbackPath: "/oauth2/callback", ForwardAuthURLString: "https://forward-auth.example.com", } - routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com") + cfg := config.New(options) + routes, err := b.buildPomeriumHTTPRoutes(cfg, "authenticate.example.com") require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ @@ -106,7 +107,8 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) { AuthenticateURLString: "https://authenticate.example.com", AuthenticateCallbackPath: "/oauth2/callback", } - routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com") + cfg := config.New(options) + routes, err := b.buildPomeriumHTTPRoutes(cfg, "authenticate.example.com") require.NoError(t, err) testutil.AssertProtoJSONEqual(t, "null", routes) }) @@ -122,8 +124,9 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) { To: mustParseWeightedURLs(t, "https://to.example.com"), }}, } + cfg := config.New(options) _ = options.Policies[0].Validate() - routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com") + routes, err := b.buildPomeriumHTTPRoutes(cfg, "from.example.com") require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ @@ -150,8 +153,9 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) { AllowPublicUnauthenticatedAccess: true, }}, } + cfg := config.New(options) _ = options.Policies[0].Validate() - routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com") + routes, err := b.buildPomeriumHTTPRoutes(cfg, "from.example.com") require.NoError(t, err) testutil.AssertProtoJSONEqual(t, `[ @@ -242,7 +246,7 @@ func TestTimeouts(t *testing.T) { for _, tc := range testCases { b := &Builder{filemgr: filemgr.NewManager()} - routes, err := b.buildPolicyRoutes(&config.Options{ + routes, err := b.buildPolicyRoutes(config.New(&config.Options{ CookieName: "pomerium", DefaultUpstreamTimeout: time.Second * 3, Policies: []config.Policy{ @@ -253,7 +257,7 @@ func TestTimeouts(t *testing.T) { IdleTimeout: getDuration(tc.idle), AllowWebsockets: tc.allowWebsockets, }}, - }, "example.com") + }), "example.com") if !assert.NoError(t, err, "%v", tc) || !assert.Len(t, routes, 1, tc) || !assert.NotNil(t, routes[0].GetRoute(), "%v", tc) { continue } @@ -295,7 +299,7 @@ func Test_buildPolicyRoutes(t *testing.T) { ten := time.Second * 10 b := &Builder{filemgr: filemgr.NewManager()} - routes, err := b.buildPolicyRoutes(&config.Options{ + routes, err := b.buildPolicyRoutes(config.New(&config.Options{ CookieName: "pomerium", DefaultUpstreamTimeout: time.Second * 3, Policies: []config.Policy{ @@ -357,7 +361,7 @@ func Test_buildPolicyRoutes(t *testing.T) { UpstreamTimeout: &ten, }, }, - }, "example.com") + }), "example.com") require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` @@ -724,7 +728,7 @@ func Test_buildPolicyRoutes(t *testing.T) { `, routes) t.Run("fronting-authenticate", func(t *testing.T) { - routes, err := b.buildPolicyRoutes(&config.Options{ + routes, err := b.buildPolicyRoutes(config.New(&config.Options{ AuthenticateURLString: "https://authenticate.example.com", Services: "proxy", CookieName: "pomerium", @@ -735,7 +739,7 @@ func Test_buildPolicyRoutes(t *testing.T) { PassIdentityHeaders: true, }, }, - }, "authenticate.example.com") + }), "authenticate.example.com") require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` @@ -791,7 +795,7 @@ func Test_buildPolicyRoutes(t *testing.T) { }) t.Run("tcp", func(t *testing.T) { - routes, err := b.buildPolicyRoutes(&config.Options{ + routes, err := b.buildPolicyRoutes(config.New(&config.Options{ CookieName: "pomerium", DefaultUpstreamTimeout: time.Second * 3, Policies: []config.Policy{ @@ -805,7 +809,7 @@ func Test_buildPolicyRoutes(t *testing.T) { UpstreamTimeout: &ten, }, }, - }, "example.com:22") + }), "example.com:22") require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` @@ -905,7 +909,7 @@ func Test_buildPolicyRoutes(t *testing.T) { }) t.Run("remove-pomerium-headers", func(t *testing.T) { - routes, err := b.buildPolicyRoutes(&config.Options{ + routes, err := b.buildPolicyRoutes(config.New(&config.Options{ AuthenticateURLString: "https://authenticate.example.com", Services: "proxy", CookieName: "pomerium", @@ -918,7 +922,7 @@ func Test_buildPolicyRoutes(t *testing.T) { Source: &config.StringURL{URL: mustParseURL(t, "https://from.example.com")}, }, }, - }, "from.example.com") + }), "from.example.com") require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` @@ -980,7 +984,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) { }(getClusterID) getClusterID = policyNameFunc() b := &Builder{filemgr: filemgr.NewManager()} - routes, err := b.buildPolicyRoutes(&config.Options{ + routes, err := b.buildPolicyRoutes(config.New(&config.Options{ CookieName: "pomerium", DefaultUpstreamTimeout: time.Second * 3, Policies: []config.Policy{ @@ -1022,7 +1026,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) { HostPathRegexRewriteSubstitution: "\\1", }, }, - }, "example.com") + }), "example.com") require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` diff --git a/config/envoyconfig/tracing.go b/config/envoyconfig/tracing.go index 60eb69f9d..0ec0e1a24 100644 --- a/config/envoyconfig/tracing.go +++ b/config/envoyconfig/tracing.go @@ -14,8 +14,10 @@ import ( "github.com/pomerium/pomerium/pkg/protoutil" ) -func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Cluster, error) { - tracingOptions, err := config.NewTracingOptions(options) +func buildTracingCluster( + cfg *config.Config, +) (*envoy_config_cluster_v3.Cluster, error) { + tracingOptions, err := config.NewTracingOptions(cfg) if err != nil { return nil, fmt.Errorf("envoyconfig: invalid tracing config: %w", err) } @@ -24,8 +26,8 @@ func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Clus case trace.DatadogTracingProviderName: addr, _ := parseAddress("127.0.0.1:8126") - if options.TracingDatadogAddress != "" { - addr, err = parseAddress(options.TracingDatadogAddress) + if cfg.Options.TracingDatadogAddress != "" { + addr, err = parseAddress(cfg.Options.TracingDatadogAddress) if err != nil { return nil, fmt.Errorf("envoyconfig: invalid tracing datadog address: %w", err) } @@ -94,8 +96,10 @@ func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Clus } } -func buildTracingHTTP(options *config.Options) (*envoy_config_trace_v3.Tracing_Http, error) { - tracingOptions, err := config.NewTracingOptions(options) +func buildTracingHTTP( + cfg *config.Config, +) (*envoy_config_trace_v3.Tracing_Http, error) { + tracingOptions, err := config.NewTracingOptions(cfg) if err != nil { return nil, fmt.Errorf("invalid tracing config: %w", err) } diff --git a/config/envoyconfig/tracing_test.go b/config/envoyconfig/tracing_test.go index d9eadec4a..5dc897884 100644 --- a/config/envoyconfig/tracing_test.go +++ b/config/envoyconfig/tracing_test.go @@ -11,9 +11,9 @@ import ( func TestBuildTracingCluster(t *testing.T) { t.Run("datadog", func(t *testing.T) { - c, err := buildTracingCluster(&config.Options{ + c, err := buildTracingCluster(config.New(&config.Options{ TracingProvider: "datadog", - }) + })) require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { @@ -38,10 +38,10 @@ func TestBuildTracingCluster(t *testing.T) { } `, c) - c, err = buildTracingCluster(&config.Options{ + c, err = buildTracingCluster(config.New(&config.Options{ TracingProvider: "datadog", TracingDatadogAddress: "example.com:8126", - }) + })) require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { @@ -67,10 +67,10 @@ func TestBuildTracingCluster(t *testing.T) { `, c) }) t.Run("zipkin", func(t *testing.T) { - c, err := buildTracingCluster(&config.Options{ + c, err := buildTracingCluster(config.New(&config.Options{ TracingProvider: "zipkin", ZipkinEndpoint: "https://example.com/api/v2/spans", - }) + })) require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { @@ -99,9 +99,9 @@ func TestBuildTracingCluster(t *testing.T) { func TestBuildTracingHTTP(t *testing.T) { t.Run("datadog", func(t *testing.T) { - h, err := buildTracingHTTP(&config.Options{ + h, err := buildTracingHTTP(config.New(&config.Options{ TracingProvider: "datadog", - }) + })) require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { @@ -115,10 +115,10 @@ func TestBuildTracingHTTP(t *testing.T) { `, h) }) t.Run("zipkin", func(t *testing.T) { - h, err := buildTracingHTTP(&config.Options{ + h, err := buildTracingHTTP(config.New(&config.Options{ TracingProvider: "zipkin", ZipkinEndpoint: "https://example.com/api/v2/spans", - }) + })) require.NoError(t, err) testutil.AssertProtoJSONEqual(t, ` { diff --git a/config/trace.go b/config/trace.go index e17389cd6..d99e1d348 100644 --- a/config/trace.go +++ b/config/trace.go @@ -18,28 +18,28 @@ import ( type TracingOptions = trace.TracingOptions // NewTracingOptions builds a new TracingOptions from core Options -func NewTracingOptions(o *Options) (*TracingOptions, error) { +func NewTracingOptions(cfg *Config) (*TracingOptions, error) { tracingOpts := TracingOptions{ - Provider: o.TracingProvider, - Service: telemetry.ServiceName(o.Services), - JaegerAgentEndpoint: o.TracingJaegerAgentEndpoint, - SampleRate: o.TracingSampleRate, + Provider: cfg.Options.TracingProvider, + Service: telemetry.ServiceName(cfg.Options.Services), + JaegerAgentEndpoint: cfg.Options.TracingJaegerAgentEndpoint, + SampleRate: cfg.Options.TracingSampleRate, } - switch o.TracingProvider { + switch cfg.Options.TracingProvider { case trace.DatadogTracingProviderName: - tracingOpts.DatadogAddress = o.TracingDatadogAddress + tracingOpts.DatadogAddress = cfg.Options.TracingDatadogAddress case trace.JaegerTracingProviderName: - if o.TracingJaegerCollectorEndpoint != "" { - jaegerCollectorEndpoint, err := urlutil.ParseAndValidateURL(o.TracingJaegerCollectorEndpoint) + if cfg.Options.TracingJaegerCollectorEndpoint != "" { + jaegerCollectorEndpoint, err := urlutil.ParseAndValidateURL(cfg.Options.TracingJaegerCollectorEndpoint) if err != nil { return nil, fmt.Errorf("config: invalid jaeger endpoint url: %w", err) } tracingOpts.JaegerCollectorEndpoint = jaegerCollectorEndpoint - tracingOpts.JaegerAgentEndpoint = o.TracingJaegerAgentEndpoint + tracingOpts.JaegerAgentEndpoint = cfg.Options.TracingJaegerAgentEndpoint } case trace.ZipkinTracingProviderName: - zipkinEndpoint, err := urlutil.ParseAndValidateURL(o.ZipkinEndpoint) + zipkinEndpoint, err := urlutil.ParseAndValidateURL(cfg.Options.ZipkinEndpoint) if err != nil { return nil, fmt.Errorf("config: invalid zipkin endpoint url: %w", err) } @@ -47,7 +47,7 @@ func NewTracingOptions(o *Options) (*TracingOptions, error) { case "": return &TracingOptions{}, nil default: - return nil, fmt.Errorf("config: provider %s unknown", o.TracingProvider) + return nil, fmt.Errorf("config: provider %s unknown", cfg.Options.TracingProvider) } return &tracingOpts, nil @@ -88,7 +88,7 @@ func (mgr *TraceManager) OnConfigChange(ctx context.Context, cfg *Config) { mgr.mu.Lock() defer mgr.mu.Unlock() - traceOpts, err := NewTracingOptions(cfg.Options) + traceOpts, err := NewTracingOptions(cfg) if err != nil { log.Error(ctx).Err(err).Msg("trace: failed to build tracing options") return diff --git a/config/trace_test.go b/config/trace_test.go index d20c0a420..21e5d9aed 100644 --- a/config/trace_test.go +++ b/config/trace_test.go @@ -68,7 +68,7 @@ func Test_NewTracingOptions(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NewTracingOptions(tt.opts) + got, err := NewTracingOptions(&Config{Options: tt.opts}) assert.NotEqual(t, err == nil, tt.wantErr, "unexpected error value") assert.Empty(t, cmp.Diff(tt.want, got)) }) diff --git a/databroker/cache_test.go b/databroker/cache_test.go index 308ebde12..ae26f2c33 100644 --- a/databroker/cache_test.go +++ b/databroker/cache_test.go @@ -28,7 +28,7 @@ func TestNew(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tt.opts.Provider = "google" - _, err := New(&config.Config{Options: &tt.opts}, events.New()) + _, err := New(config.New(&tt.opts), events.New()) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/internal/autocert/manager_test.go b/internal/autocert/manager_test.go index e92876b14..f169b7d58 100644 --- a/internal/autocert/manager_test.go +++ b/internal/autocert/manager_test.go @@ -236,8 +236,8 @@ func TestConfig(t *testing.T) { } _ = p1.Validate() - mgr, err := newManager(ctx, config.NewStaticSource(&config.Config{ - Options: &config.Options{ + mgr, err := newManager(ctx, + config.NewStaticSource(config.New(&config.Options{ AutocertOptions: config.AutocertOptions{ Enable: true, UseStaging: true, @@ -247,11 +247,11 @@ func TestConfig(t *testing.T) { }, HTTPRedirectAddr: addr, Policies: []config.Policy{p1}, - }, - }), certmagic.ACMEIssuer{ - CA: srv.URL + "/acme/directory", - TestCA: srv.URL + "/acme/directory", - }, time.Millisecond*100) + })), + certmagic.ACMEIssuer{ + CA: srv.URL + "/acme/directory", + TestCA: srv.URL + "/acme/directory", + }, time.Millisecond*100) if !assert.NoError(t, err) { return } @@ -304,16 +304,14 @@ func TestRedirect(t *testing.T) { addr := li.Addr().String() _ = li.Close() - src := config.NewStaticSource(&config.Config{ - Options: &config.Options{ - HTTPRedirectAddr: addr, - SetResponseHeaders: map[string]string{ - "X-Frame-Options": "SAMEORIGIN", - "X-XSS-Protection": "1; mode=block", - "Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload", - }, + src := config.NewStaticSource(config.New(&config.Options{ + HTTPRedirectAddr: addr, + SetResponseHeaders: map[string]string{ + "X-Frame-Options": "SAMEORIGIN", + "X-XSS-Protection": "1; mode=block", + "Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload", }, - }) + })) _, err = New(src) if !assert.NoError(t, err) { return diff --git a/internal/httputil/reproxy/reproxy.go b/internal/httputil/reproxy/reproxy.go index 60ac5edec..ffd694373 100644 --- a/internal/httputil/reproxy/reproxy.go +++ b/internal/httputil/reproxy/reproxy.go @@ -28,7 +28,7 @@ import ( type Handler struct { mu sync.RWMutex key []byte - options *config.Options + cfg *config.Config policies map[uint64]config.Policy } @@ -83,7 +83,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler { } h.mu.RLock() - options := h.options + options := h.cfg.Options policy, ok := h.policies[policyID] h.mu.RUnlock() @@ -132,7 +132,7 @@ func (h *Handler) Update(ctx context.Context, cfg *config.Config) { defer h.mu.Unlock() h.key, _ = cfg.Options.GetSharedKey() - h.options = cfg.Options + h.cfg = cfg h.policies = make(map[uint64]config.Policy) for _, p := range cfg.Options.GetAllPolicies() { id, err := p.RouteID() diff --git a/internal/httputil/reproxy/reproxy_test.go b/internal/httputil/reproxy/reproxy_test.go index 3728d9d56..5d320885e 100644 --- a/internal/httputil/reproxy/reproxy_test.go +++ b/internal/httputil/reproxy/reproxy_test.go @@ -49,15 +49,13 @@ func TestMiddleware(t *testing.T) { srv2 := httptest.NewServer(h.Middleware(next)) defer srv2.Close() - cfg := &config.Config{ - Options: &config.Options{ - SharedKey: cryptutil.NewBase64Key(), - Policies: []config.Policy{{ - To: config.WeightedURLs{{URL: *u}}, - KubernetesServiceAccountToken: "ABCD", - }}, - }, - } + cfg := config.New(&config.Options{ + SharedKey: cryptutil.NewBase64Key(), + Policies: []config.Policy{{ + To: config.WeightedURLs{{URL: *u}}, + KubernetesServiceAccountToken: "ABCD", + }}, + }) h.Update(context.Background(), cfg) policyID, _ := cfg.Options.Policies[0].RouteID() diff --git a/proxy/forward_auth_test.go b/proxy/forward_auth_test.go index a732a9856..565533682 100644 --- a/proxy/forward_auth_test.go +++ b/proxy/forward_auth_test.go @@ -80,11 +80,11 @@ func TestProxy_ForwardAuth(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := New(&config.Config{Options: tt.options}) + p, err := New(config.New(tt.options)) if err != nil { t.Fatal(err) } - p.OnConfigChange(ctx, &config.Config{Options: tt.options}) + p.OnConfigChange(ctx, config.New(tt.options)) state := p.state.Load() state.sessionStore = tt.sessionStore signer, err := jws.NewHS256Signer(nil) diff --git a/proxy/handlers.go b/proxy/handlers.go index e22980887..d75f98de9 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -58,7 +58,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) error { state := p.state.Load() var redirectURL *url.URL - signOutURL, err := p.currentOptions.Load().GetSignOutRedirectURL() + signOutURL, err := p.currentConfig.Load().Options.GetSignOutRedirectURL() if err != nil { return httputil.NewError(http.StatusInternalServerError, err) } diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index f842e7615..2a1a5a77e 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -47,7 +47,7 @@ func TestProxy_Signout(t *testing.T) { if err != nil { t.Fatal(err) } - proxy, err := New(&config.Config{Options: opts}) + proxy, err := New(config.New(opts)) if err != nil { t.Fatal(err) } @@ -70,7 +70,7 @@ func TestProxy_userInfo(t *testing.T) { if err != nil { t.Fatal(err) } - proxy, err := New(&config.Config{Options: opts}) + proxy, err := New(config.New(opts)) if err != nil { t.Fatal(err) } @@ -102,7 +102,7 @@ func TestProxy_SignOut(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { opts := testOptions(t) - p, err := New(&config.Config{Options: opts}) + p, err := New(config.New(opts)) if err != nil { t.Fatal(err) } @@ -236,11 +236,11 @@ func TestProxy_Callback(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := New(&config.Config{Options: tt.options}) + p, err := New(config.New(tt.options)) if err != nil { t.Fatal(err) } - p.OnConfigChange(context.Background(), &config.Config{Options: tt.options}) + p.OnConfigChange(context.Background(), config.New(tt.options)) state := p.state.Load() state.encoder = tt.cipher state.sessionStore = tt.sessionStore @@ -350,7 +350,7 @@ func TestProxy_ProgrammaticLogin(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := New(&config.Config{Options: tt.options}) + p, err := New(config.New(tt.options)) if err != nil { t.Fatal(err) } @@ -482,11 +482,11 @@ func TestProxy_ProgrammaticCallback(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := New(&config.Config{Options: tt.options}) + p, err := New(config.New(tt.options)) if err != nil { t.Fatal(err) } - p.OnConfigChange(context.Background(), &config.Config{Options: tt.options}) + p.OnConfigChange(context.Background(), config.New(tt.options)) state := p.state.Load() state.encoder = tt.cipher state.sessionStore = tt.sessionStore diff --git a/proxy/proxy.go b/proxy/proxy.go index 85f382be4..b3cf7655b 100644 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -51,9 +51,9 @@ func ValidateOptions(o *config.Options) error { // Proxy stores all the information associated with proxying a request. type Proxy struct { - state *atomicutil.Value[*proxyState] - currentOptions *atomicutil.Value[*config.Options] - currentRouter *atomicutil.Value[*mux.Router] + state *atomicutil.Value[*proxyState] + currentConfig *atomicutil.Value[*config.Config] + currentRouter *atomicutil.Value[*mux.Router] } // New takes a Proxy service from options and a validation function. @@ -65,13 +65,13 @@ func New(cfg *config.Config) (*Proxy, error) { } p := &Proxy{ - state: atomicutil.NewValue(state), - currentOptions: config.NewAtomicOptions(), - currentRouter: atomicutil.NewValue(httputil.NewRouter()), + state: atomicutil.NewValue(state), + currentConfig: atomicutil.NewValue(cfg), + currentRouter: atomicutil.NewValue(httputil.NewRouter()), } metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { - return int64(len(p.currentOptions.Load().GetAllPolicies())) + return int64(len(p.currentConfig.Load().Options.GetAllPolicies())) }) return p, nil @@ -88,8 +88,8 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) { return } - p.currentOptions.Store(cfg.Options) - if err := p.setHandlers(cfg.Options); err != nil { + p.currentConfig.Store(cfg) + if err := p.setHandlers(cfg); err != nil { log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings") } if state, err := newProxyStateFromConfig(cfg); err != nil { @@ -99,8 +99,8 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) { } } -func (p *Proxy) setHandlers(opts *config.Options) error { - if len(opts.GetAllPolicies()) == 0 { +func (p *Proxy) setHandlers(cfg *config.Config) error { + if len(cfg.Options.GetAllPolicies()) == 0 { log.Warn(context.TODO()).Msg("proxy: configuration has no policies") } r := httputil.NewRouter() @@ -113,7 +113,7 @@ func (p *Proxy) setHandlers(opts *config.Options) error { // dashboard handlers are registered to all routes r = p.registerDashboardHandlers(r) - forwardAuthURL, err := opts.GetForwardAuthURL() + forwardAuthURL, err := cfg.Options.GetForwardAuthURL() if err != nil { return err } diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 44d3c0a11..bef5b69bd 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -99,7 +99,7 @@ func TestNew(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := New(&config.Config{Options: tt.opts}) + got, err := New(config.New(tt.opts)) if (err != nil) != tt.wantErr { t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) return @@ -194,12 +194,12 @@ func Test_UpdateOptions(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p, err := New(&config.Config{Options: tt.originalOptions}) + p, err := New(config.New(tt.originalOptions)) if err != nil { t.Fatal(err) } - p.OnConfigChange(context.Background(), &config.Config{Options: tt.updatedOptions}) + p.OnConfigChange(context.Background(), config.New(tt.updatedOptions)) r := httptest.NewRequest("GET", tt.host, nil) w := httptest.NewRecorder() p.ServeHTTP(w, r) @@ -212,5 +212,5 @@ func Test_UpdateOptions(t *testing.T) { // Test nil var p *Proxy - p.OnConfigChange(context.Background(), &config.Config{}) + p.OnConfigChange(context.Background(), config.New(nil)) }