config: use config.Config instead of config.Options everywhere

This commit is contained in:
Caleb Doxsey 2022-08-26 15:39:50 -06:00
parent 5f51510e91
commit 1b80e8a6c2
40 changed files with 484 additions and 412 deletions

View file

@ -39,18 +39,18 @@ func ValidateOptions(o *config.Options) error {
// Authenticate contains data required to run the authenticate service.
type Authenticate struct {
cfg *authenticateConfig
options *atomicutil.Value[*config.Options]
state *atomicutil.Value[*authenticateState]
webauthn *webauthn.Handler
cfg *authenticateConfig
currentConfig *atomicutil.Value[*config.Config]
state *atomicutil.Value[*authenticateState]
webauthn *webauthn.Handler
}
// New validates and creates a new authenticate service from a set of Options.
func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
a := &Authenticate{
cfg: getAuthenticateConfig(options...),
options: config.NewAtomicOptions(),
state: atomicutil.NewValue(newAuthenticateState()),
cfg: getAuthenticateConfig(options...),
currentConfig: atomicutil.NewValue(cfg),
state: atomicutil.NewValue(newAuthenticateState()),
}
a.webauthn = webauthn.New(a.getWebauthnState)
@ -69,7 +69,7 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
return
}
a.options.Store(cfg.Options)
a.currentConfig.Store(cfg)
if state, err := newAuthenticateStateFromConfig(cfg); err != nil {
log.Error(ctx).Err(err).Msg("authenticate: failed to update state")
} else {

View file

@ -113,7 +113,7 @@ func TestNew(t *testing.T) {
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
_, err := New(&config.Config{Options: tt.opts})
_, err := New(config.New(tt.opts))
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return

View file

@ -6,7 +6,7 @@ import (
)
type authenticateConfig struct {
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)
getIdentityProvider func(cfg *config.Config, idpID string) (identity.Authenticator, error)
}
// An Option customizes the Authenticate config.
@ -22,7 +22,7 @@ func getAuthenticateConfig(options ...Option) *authenticateConfig {
}
// WithGetIdentityProvider sets the getIdentityProvider function in the config.
func WithGetIdentityProvider(getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)) Option {
func WithGetIdentityProvider(getIdentityProvider func(cfg *config.Config, idpID string) (identity.Authenticator, error)) Option {
return func(cfg *authenticateConfig) {
cfg.getIdentityProvider = getIdentityProvider
}

View file

@ -47,12 +47,12 @@ func (a *Authenticate) Mount(r *mux.Router) {
r.StrictSlash(true)
r.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
r.Use(func(h http.Handler) http.Handler {
options := a.options.Load()
cfg := a.currentConfig.Load()
state := a.state.Load()
csrfKey := fmt.Sprintf("%s_csrf", options.CookieName)
csrfKey := fmt.Sprintf("%s_csrf", cfg.Options.CookieName)
return csrf.Protect(
state.cookieSecret,
csrf.Secure(options.CookieSecure),
csrf.Secure(cfg.Options.CookieSecure),
csrf.Path("/"),
csrf.UnsafePaths(
[]string{
@ -256,7 +256,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
// check for an HMAC'd URL. If none is found, show a confirmation page.
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
if err != nil {
authenticateURL, err := a.options.Load().GetAuthenticateURL()
authenticateURL, err := a.currentConfig.Load().Options.GetAuthenticateURL()
if err != nil {
return err
}
@ -275,9 +275,9 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
defer span.End()
options := a.options.Load()
cfg := a.currentConfig.Load()
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
@ -285,7 +285,7 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
rawIDToken := a.revokeSession(ctx, w, r)
redirectString := ""
signOutURL, err := a.options.Load().GetSignOutRedirectURL()
signOutURL, err := cfg.Options.GetSignOutRedirectURL()
if err != nil {
return err
}
@ -330,10 +330,10 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
return httputil.NewError(http.StatusUnauthorized, err)
}
options := a.options.Load()
cfg := a.currentConfig.Load()
state := a.state.Load()
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
@ -381,7 +381,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
defer span.End()
options := a.options.Load()
cfg := a.currentConfig.Load()
state := a.state.Load()
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
@ -430,7 +430,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
}
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)
idp, err := a.cfg.getIdentityProvider(options, idpID)
idp, err := a.cfg.getIdentityProvider(cfg, idpID)
if err != nil {
return nil, err
}
@ -513,7 +513,7 @@ func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error {
func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData, error) {
state := a.state.Load()
authenticateURL, err := a.options.Load().GetAuthenticateURL()
authenticateURL, err := a.currentConfig.Load().Options.GetAuthenticateURL()
if err != nil {
return handlers.UserInfoData{}, err
}
@ -569,7 +569,7 @@ func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData,
WebAuthnRequestOptions: requestOptions,
WebAuthnURL: urlutil.WebAuthnURL(r, authenticateURL, state.sharedKey, r.URL.Query()),
BrandingOptions: a.options.Load().BrandingOptions,
BrandingOptions: a.currentConfig.Load().Options.BrandingOptions,
}, nil
}
@ -581,14 +581,14 @@ func (a *Authenticate) saveSessionToDataBroker(
accessToken *oauth2.Token,
) error {
state := a.state.Load()
options := a.options.Load()
cfg := a.currentConfig.Load()
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return err
}
sessionExpiry := timestamppb.New(time.Now().Add(options.CookieExpire))
sessionExpiry := timestamppb.New(time.Now().Add(cfg.Options.CookieExpire))
idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time())
s := &session.Session{
@ -648,13 +648,13 @@ func (a *Authenticate) saveSessionToDataBroker(
// databroker. If successful, it returns the original `id_token` of the session, if failed, returns
// and empty string.
func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string {
options := a.options.Load()
cfg := a.currentConfig.Load()
state := a.state.Load()
// clear the user's local session no matter what
defer state.sessionStore.ClearSession(w, r)
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil {
return ""
}
@ -708,6 +708,7 @@ func (a *Authenticate) getDirectoryUser(ctx context.Context, userID string) (*di
func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, error) {
state := a.state.Load()
cfg := a.currentConfig.Load()
s, _, err := a.getCurrentSession(ctx)
if err != nil {
@ -719,17 +720,17 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
return nil, err
}
authenticateURL, err := a.options.Load().GetAuthenticateURL()
authenticateURL, err := cfg.Options.GetAuthenticateURL()
if err != nil {
return nil, err
}
internalAuthenticateURL, err := a.options.Load().GetInternalAuthenticateURL()
internalAuthenticateURL, err := cfg.Options.GetInternalAuthenticateURL()
if err != nil {
return nil, err
}
pomeriumDomains, err := a.options.Load().GetAllRouteableHTTPDomains()
pomeriumDomains, err := cfg.Options.GetAllRouteableHTTPDomains()
if err != nil {
return nil, err
}
@ -744,7 +745,7 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
SessionState: ss,
SessionStore: state.sessionStore,
RelyingParty: state.webauthnRelyingParty,
BrandingOptions: a.options.Load().BrandingOptions,
BrandingOptions: cfg.Options.BrandingOptions,
}, nil
}

View file

@ -49,10 +49,9 @@ func testAuthenticate() *Authenticate {
redirectURL: redirectURL,
cookieSecret: cryptutil.NewKey(),
})
auth.options = config.NewAtomicOptions()
auth.options.Store(&config.Options{
auth.currentConfig = atomicutil.NewValue(config.New(&config.Options{
SharedKey: cryptutil.NewBase64Key(),
})
}))
return &auth
}
@ -148,7 +147,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
return tt.provider, nil
})),
state: atomicutil.NewValue(&authenticateState{
@ -168,10 +167,10 @@ func TestAuthenticate_SignIn(t *testing.T) {
},
directoryClient: new(mockDirectoryServiceClient),
}),
options: config.NewAtomicOptions(),
currentConfig: atomicutil.NewValue(config.New(&config.Options{
SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey()),
})),
}
a.options.Store(&config.Options{SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey())})
uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
queryString := uri.Query()
@ -304,7 +303,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
return tt.provider, nil
})),
state: atomicutil.NewValue(&authenticateState{
@ -325,12 +324,9 @@ func TestAuthenticate_SignOut(t *testing.T) {
},
directoryClient: new(mockDirectoryServiceClient),
}),
options: config.NewAtomicOptions(),
}
if tt.signoutRedirectURL != "" {
opts := a.options.Load()
opts.SignOutRedirectURLString = tt.signoutRedirectURL
a.options.Store(opts)
currentConfig: atomicutil.NewValue(config.New(&config.Options{
SignOutRedirectURLString: tt.signoutRedirectURL,
})),
}
u, _ := url.Parse("/sign_out")
params, _ := url.ParseQuery(u.RawQuery)
@ -417,7 +413,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
}
authURL, _ := url.Parse(tt.authenticateURL)
a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
return tt.provider, nil
})),
state: atomicutil.NewValue(&authenticateState{
@ -435,7 +431,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
cookieCipher: aead,
encryptedEncoder: signer,
}),
options: config.NewAtomicOptions(),
currentConfig: atomicutil.NewValue(config.New(nil)),
}
u, _ := url.Parse("/oauthGet")
params, _ := url.ParseQuery(u.RawQuery)
@ -552,7 +548,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
t.Fatal(err)
}
a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
return tt.provider, nil
})),
state: atomicutil.NewValue(&authenticateState{
@ -573,7 +569,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
},
directoryClient: new(mockDirectoryServiceClient),
}),
options: config.NewAtomicOptions(),
currentConfig: atomicutil.NewValue(config.New(nil)),
}
r := httptest.NewRequest("GET", "/", nil)
state, err := tt.session.LoadSession(r)
@ -604,7 +600,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
func TestJwksEndpoint(t *testing.T) {
o := newTestOptions(t)
o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
auth, err := New(&config.Config{Options: o})
auth, err := New(config.New(o))
if err != nil {
t.Fatal(err)
return
@ -632,12 +628,11 @@ func TestAuthenticate_userInfo(t *testing.T) {
a.state = atomicutil.NewValue(&authenticateState{
cookieSecret: cryptutil.NewKey(),
})
a.options = config.NewAtomicOptions()
a.options.Store(&config.Options{
a.currentConfig = atomicutil.NewValue(config.New(&config.Options{
SharedKey: cryptutil.NewBase64Key(),
AuthenticateURLString: "https://authenticate.example.com",
AuthenticateInternalURLString: "https://authenticate.service.cluster.local",
})
}))
err := a.userInfo(w, r)
assert.NoError(t, err)
assert.Equal(t, http.StatusFound, w.Code)
@ -687,13 +682,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
if err != nil {
t.Fatal(err)
}
o := config.NewAtomicOptions()
o.Store(&config.Options{
AuthenticateURLString: "https://authenticate.localhost.pomerium.io",
SharedKey: "SHARED KEY",
})
a := &Authenticate{
options: o,
state: atomicutil.NewValue(&authenticateState{
sessionStore: tt.sessionStore,
encryptedEncoder: signer,
@ -711,6 +700,10 @@ func TestAuthenticate_userInfo(t *testing.T) {
},
directoryClient: new(mockDirectoryServiceClient),
}),
currentConfig: atomicutil.NewValue(config.New(&config.Options{
AuthenticateURLString: "https://authenticate.localhost.pomerium.io",
SharedKey: "SHARED KEY",
})),
}
a.webauthn = webauthn.New(a.getWebauthnState)
r := httptest.NewRequest(tt.method, tt.url.String(), nil)

View file

@ -7,8 +7,8 @@ import (
"github.com/pomerium/pomerium/internal/urlutil"
)
func defaultGetIdentityProvider(options *config.Options, idpID string) (identity.Authenticator, error) {
authenticateURL, err := options.GetAuthenticateURL()
func defaultGetIdentityProvider(cfg *config.Config, idpID string) (identity.Authenticator, error) {
authenticateURL, err := cfg.Options.GetAuthenticateURL()
if err != nil {
return nil, err
}
@ -17,9 +17,9 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity
if err != nil {
return nil, err
}
redirectURL.Path = options.AuthenticateCallbackPath
redirectURL.Path = cfg.Options.AuthenticateCallbackPath
idp, err := options.GetIdentityProviderForID(idpID)
idp, err := cfg.Options.GetIdentityProviderForID(idpID)
if err != nil {
return nil, err
}

View file

@ -34,14 +34,14 @@ func (a *Authenticate) requireValidSignature(next httputil.HandlerFunc) http.Han
}
func (a *Authenticate) getExternalRequest(r *http.Request) *http.Request {
options := a.options.Load()
cfg := a.currentConfig.Load()
externalURL, err := options.GetAuthenticateURL()
externalURL, err := cfg.Options.GetAuthenticateURL()
if err != nil {
return r
}
internalURL, err := options.GetInternalAuthenticateURL()
internalURL, err := cfg.Options.GetInternalAuthenticateURL()
if err != nil {
return r
}

View file

@ -25,11 +25,11 @@ import (
// Authorize struct holds
type Authorize struct {
state *atomicutil.Value[*authorizeState]
store *store.Store
currentOptions *atomicutil.Value[*config.Options]
accessTracker *AccessTracker
globalCache storage.Cache
state *atomicutil.Value[*authorizeState]
store *store.Store
currentConfig *atomicutil.Value[*config.Config]
accessTracker *AccessTracker
globalCache storage.Cache
// The stateLock prevents updating the evaluator store simultaneously with an evaluation.
// This should provide a consistent view of the data at a given server/record version and
@ -40,9 +40,9 @@ type Authorize struct {
// New validates and creates a new Authorize service from a set of config options.
func New(cfg *config.Config) (*Authorize, error) {
a := &Authorize{
currentOptions: config.NewAtomicOptions(),
store: store.New(),
globalCache: storage.NewGlobalCache(time.Minute),
currentConfig: atomicutil.NewValue(cfg),
store: store.New(),
globalCache: storage.NewGlobalCache(time.Minute),
}
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
@ -86,42 +86,42 @@ func validateOptions(o *config.Options) error {
}
// newPolicyEvaluator returns an policy evaluator.
func newPolicyEvaluator(opts *config.Options, store *store.Store) (*evaluator.Evaluator, error) {
func newPolicyEvaluator(cfg *config.Config, store *store.Store) (*evaluator.Evaluator, error) {
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
return int64(len(opts.GetAllPolicies()))
return int64(len(cfg.Options.GetAllPolicies()))
})
ctx := context.Background()
_, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator")
defer span.End()
clientCA, err := opts.GetClientCA()
clientCA, err := cfg.Options.GetClientCA()
if err != nil {
return nil, fmt.Errorf("authorize: invalid client CA: %w", err)
}
authenticateURL, err := opts.GetInternalAuthenticateURL()
authenticateURL, err := cfg.Options.GetInternalAuthenticateURL()
if err != nil {
return nil, fmt.Errorf("authorize: invalid authenticate url: %w", err)
}
signingKey, err := opts.GetSigningKey()
signingKey, err := cfg.Options.GetSigningKey()
if err != nil {
return nil, fmt.Errorf("authorize: invalid signing key: %w", err)
}
return evaluator.New(ctx, store,
evaluator.WithPolicies(opts.GetAllPolicies()),
evaluator.WithPolicies(cfg.Options.GetAllPolicies()),
evaluator.WithClientCA(clientCA),
evaluator.WithSigningKey(signingKey),
evaluator.WithAuthenticateURL(authenticateURL.String()),
evaluator.WithGoogleCloudServerlessAuthenticationServiceAccount(opts.GetGoogleCloudServerlessAuthenticationServiceAccount()),
evaluator.WithJWTClaimsHeaders(opts.JWTClaimsHeaders),
evaluator.WithGoogleCloudServerlessAuthenticationServiceAccount(cfg.Options.GetGoogleCloudServerlessAuthenticationServiceAccount()),
evaluator.WithJWTClaimsHeaders(cfg.Options.JWTClaimsHeaders),
)
}
// OnConfigChange updates internal structures based on config.Options
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
a.currentOptions.Store(cfg.Options)
a.currentConfig.Store(cfg)
if state, err := newAuthorizeStateFromConfig(cfg, a.store); err != nil {
log.Error(ctx).Err(err).Msg("authorize: error updating state")
} else {

View file

@ -74,7 +74,7 @@ func TestNew(t *testing.T) {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
_, err := New(&config.Config{Options: &tt.config})
_, err := New(config.New(&tt.config))
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
@ -105,12 +105,12 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
SharedKey: tc.SharedKey,
Policies: tc.Policies,
}
a, err := New(&config.Config{Options: o})
a, err := New(config.New(o))
require.NoError(t, err)
require.NotNil(t, a)
oldPe := a.state.Load().evaluator
cfg := &config.Config{Options: o}
cfg := config.New(o)
assertFunc := assert.True
o.SigningKey = "bad-share-key"
if tc.expectedChange {

View file

@ -125,7 +125,7 @@ func (a *Authorize) deniedResponse(
respBody := []byte(reason)
respHeader := []*envoy_config_core_v3.HeaderValueOption{}
forwardAuthURL, _ := a.currentOptions.Load().GetForwardAuthURL()
forwardAuthURL, _ := a.currentConfig.Load().Options.GetForwardAuthURL()
if forwardAuthURL == nil {
// create a http response writer recorder
w := httptest.NewRecorder()
@ -140,7 +140,7 @@ func (a *Authorize) deniedResponse(
Err: errors.New(reason),
DebugURL: debugEndpoint,
RequestID: requestid.FromContext(ctx),
BrandingOptions: a.currentOptions.Load().BrandingOptions,
BrandingOptions: a.currentConfig.Load().Options.BrandingOptions,
}
httpErr.ErrorResponse(ctx, w, r)
@ -184,7 +184,7 @@ func (a *Authorize) requireLoginResponse(
request *evaluator.Request,
isForwardAuthVerify bool,
) (*envoy_service_auth_v3.CheckResponse, error) {
opts := a.currentOptions.Load()
opts := a.currentConfig.Load().Options
state := a.state.Load()
authenticateURL, err := opts.GetAuthenticateURL()
if err != nil {
@ -225,7 +225,7 @@ func (a *Authorize) requireWebAuthnResponse(
result *evaluator.Result,
isForwardAuthVerify bool,
) (*envoy_service_auth_v3.CheckResponse, error) {
opts := a.currentOptions.Load()
opts := a.currentConfig.Load().Options
state := a.state.Load()
authenticateURL, err := opts.GetAuthenticateURL()
if err != nil {
@ -295,7 +295,7 @@ func toEnvoyHeaders(headers http.Header) []*envoy_config_core_v3.HeaderValueOpti
// userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's
// session that lives on the authenticate service.
func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) {
opts := a.currentOptions.Load()
opts := a.currentConfig.Load().Options
authenticateURL, err := opts.GetAuthenticateURL()
if err != nil {
return nil, err

View file

@ -29,7 +29,7 @@ func TestAuthorize_handleResult(t *testing.T) {
opt.AuthenticateURLString = "https://authenticate.example.com"
opt.DataBrokerURLString = "https://databroker.example.com"
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
a, err := New(&config.Config{Options: opt})
a, err := New(config.New(opt))
require.NoError(t, err)
t.Run("user-unauthenticated", func(t *testing.T) {
@ -67,12 +67,12 @@ func TestAuthorize_okResponse(t *testing.T) {
}},
JWTClaimsHeaders: config.NewJWTClaimHeaders("email"),
}
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
cfg := config.New(opt)
a := &Authorize{currentConfig: atomicutil.NewValue(cfg), state: atomicutil.NewValue(new(authorizeState))}
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder
a.currentOptions.Store(opt)
a.store = store.New()
pe, err := newPolicyEvaluator(opt, a.store)
pe, err := newPolicyEvaluator(cfg, a.store)
require.NoError(t, err)
a.state.Load().evaluator = pe
@ -123,17 +123,17 @@ func TestAuthorize_okResponse(t *testing.T) {
}
func TestAuthorize_deniedResponse(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder
a.currentOptions.Store(&config.Options{
a.currentConfig.Store(config.New(&config.Options{
Policies: []config.Policy{{
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
}},
})
}))
tests := []struct {
name string
@ -190,7 +190,7 @@ func TestRequireLogin(t *testing.T) {
opt.AuthenticateURLString = "https://authenticate.example.com"
opt.DataBrokerURLString = "https://databroker.example.com"
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
a, err := New(&config.Config{Options: opt})
a, err := New(config.New(opt))
require.NoError(t, err)
t.Run("accept empty", func(t *testing.T) {

View file

@ -55,7 +55,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
}
}
rawJWT, _ := loadRawSession(hreq, a.currentOptions.Load(), state.encoder)
rawJWT, _ := loadRawSession(hreq, a.currentConfig.Load(), state.encoder)
sessionState, _ := loadSession(state.encoder, rawJWT)
var s sessionOrServiceAccount
@ -100,9 +100,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
// isForwardAuth returns if the current request is a forward auth route.
func (a *Authorize) isForwardAuth(req *envoy_service_auth_v3.CheckRequest) bool {
opts := a.currentOptions.Load()
cfg := a.currentConfig.Load()
forwardAuthURL, err := opts.GetForwardAuthURL()
forwardAuthURL, err := cfg.Options.GetForwardAuthURL()
if err != nil || forwardAuthURL == nil {
return false
}
@ -136,9 +136,9 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
}
func (a *Authorize) getMatchingPolicy(requestURL url.URL) *config.Policy {
options := a.currentOptions.Load()
cfg := a.currentConfig.Load()
for _, p := range options.GetAllPolicies() {
for _, p := range cfg.Options.GetAllPolicies() {
if p.Matches(requestURL) {
return &p
}

View file

@ -47,17 +47,17 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
-----END CERTIFICATE-----`
func Test_getEvaluatorRequest(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder
a.currentOptions.Store(&config.Options{
a.currentConfig.Store(config.New(&config.Options{
Policies: []config.Policy{{
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
}},
})
}))
actual, err := a.getEvaluatorRequestFromCheckRequest(
&envoy_service_auth_v3.CheckRequest{
@ -87,7 +87,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
)
require.NoError(t, err)
expect := &evaluator.Request{
Policy: &a.currentOptions.Load().Policies[0],
Policy: &a.currentConfig.Load().Options.Policies[0],
Session: evaluator.RequestSession{
ID: "SESSION_ID",
},
@ -248,8 +248,10 @@ func Test_handleForwardAuth(t *testing.T) {
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(&config.Options{ForwardAuthURLString: tc.forwardAuthURL})
a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
a.currentConfig.Store(config.New(&config.Options{
ForwardAuthURLString: tc.forwardAuthURL,
}))
got := a.isForwardAuth(tc.checkReq)
@ -261,17 +263,17 @@ func Test_handleForwardAuth(t *testing.T) {
}
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder
a.currentOptions.Store(&config.Options{
a.currentConfig.Store(config.New(&config.Options{
Policies: []config.Policy{{
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"},
}},
}},
})
}))
actual, err := a.getEvaluatorRequestFromCheckRequest(&envoy_service_auth_v3.CheckRequest{
Attributes: &envoy_service_auth_v3.AttributeContext{
@ -296,7 +298,7 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
}, nil)
require.NoError(t, err)
expect := &evaluator.Request{
Policy: &a.currentOptions.Load().Policies[0],
Policy: &a.currentConfig.Load().Options.Policies[0],
Session: evaluator.RequestSession{},
HTTP: evaluator.NewRequestHTTP(
"GET",
@ -332,11 +334,13 @@ func TestAuthorize_Check(t *testing.T) {
opt.AuthenticateURLString = "https://authenticate.example.com"
opt.DataBrokerURLString = "https://databroker.example.com"
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
a, err := New(&config.Config{Options: opt})
a, err := New(config.New(opt))
if err != nil {
t.Fatal(err)
}
a.currentOptions.Store(&config.Options{ForwardAuthURLString: "https://forward-auth.example.com"})
a.currentConfig.Store(config.New(&config.Options{
ForwardAuthURLString: "https://forward-auth.example.com",
}))
cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}),

View file

@ -13,9 +13,13 @@ import (
"github.com/pomerium/pomerium/internal/urlutil"
)
func loadRawSession(req *http.Request, options *config.Options, encoder encoding.MarshalUnmarshaler) ([]byte, error) {
func loadRawSession(
req *http.Request,
cfg *config.Config,
encoder encoding.MarshalUnmarshaler,
) ([]byte, error) {
var loaders []sessions.SessionLoader
cookieStore, err := getCookieStore(options, encoder)
cookieStore, err := getCookieStore(cfg, encoder)
if err != nil {
return nil, err
}
@ -37,7 +41,10 @@ func loadRawSession(req *http.Request, options *config.Options, encoder encoding
return nil, sessions.ErrNoSessionFound
}
func loadSession(encoder encoding.MarshalUnmarshaler, rawJWT []byte) (*sessions.State, error) {
func loadSession(
encoder encoding.MarshalUnmarshaler,
rawJWT []byte,
) (*sessions.State, error) {
var s sessions.State
err := encoder.Unmarshal(rawJWT, &s)
if err != nil {
@ -46,14 +53,17 @@ func loadSession(encoder encoding.MarshalUnmarshaler, rawJWT []byte) (*sessions.
return &s, nil
}
func getCookieStore(options *config.Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) {
func getCookieStore(
cfg *config.Config,
encoder encoding.MarshalUnmarshaler,
) (sessions.SessionStore, error) {
cookieStore, err := cookie.NewStore(func() cookie.Options {
return cookie.Options{
Name: options.CookieName,
Domain: options.CookieDomain,
Secure: options.CookieSecure,
HTTPOnly: options.CookieHTTPOnly,
Expire: options.CookieExpire,
Name: cfg.Options.CookieName,
Domain: cfg.Options.CookieDomain,
Secure: cfg.Options.CookieSecure,
HTTPOnly: cfg.Options.CookieHTTPOnly,
Expire: cfg.Options.CookieExpire,
}
}, encoder)
if err != nil {

View file

@ -32,7 +32,7 @@ func TestLoadSession(t *testing.T) {
},
},
})
raw, err := loadRawSession(req, opts, encoder)
raw, err := loadRawSession(req, config.New(opts), encoder)
if err != nil {
return nil, err
}

View file

@ -36,7 +36,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*autho
var err error
state.evaluator, err = newPolicyEvaluator(cfg.Options, store)
state.evaluator, err = newPolicyEvaluator(cfg, store)
if err != nil {
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
}

View file

@ -31,6 +31,16 @@ type Config struct {
MetricsScrapeEndpoints []MetricsScrapeEndpoint
}
// New creates a new Config.
func New(options *Options) *Config {
if options == nil {
options = NewDefaultOptions()
}
return &Config{
Options: options,
}
}
// Clone creates a clone of the config.
func (cfg *Config) Clone() *Config {
newOptions := new(Options)

View file

@ -13,11 +13,9 @@ import (
func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
t.Run("valid", func(t *testing.T) {
adminCfg, err := b.BuildBootstrapAdmin(&config.Config{
Options: &config.Options{
EnvoyAdminAddress: "localhost:9901",
},
})
adminCfg, err := b.BuildBootstrapAdmin(config.New(&config.Options{
EnvoyAdminAddress: "localhost:9901",
}))
assert.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `
{
@ -31,11 +29,9 @@ func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
`, adminCfg)
})
t.Run("bad address", func(t *testing.T) {
_, err := b.BuildBootstrapAdmin(&config.Config{
Options: &config.Options{
EnvoyAdminAddress: "xyz1234:zyx4321",
},
})
_, err := b.BuildBootstrapAdmin(config.New(&config.Options{
EnvoyAdminAddress: "xyz1234:zyx4321",
}))
assert.Error(t, err)
})
}
@ -111,11 +107,9 @@ func TestBuilder_BuildBootstrapStaticResources(t *testing.T) {
func TestBuilder_BuildBootstrapStatsConfig(t *testing.T) {
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
t.Run("valid", func(t *testing.T) {
statsCfg, err := b.BuildBootstrapStatsConfig(&config.Config{
Options: &config.Options{
Services: "all",
},
})
statsCfg, err := b.BuildBootstrapStatsConfig(config.New(&config.Options{
Services: "all",
}))
assert.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `
{

View file

@ -46,22 +46,22 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
return nil, err
}
controlGRPC, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, upstreamProtocolHTTP2)
controlGRPC, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, upstreamProtocolHTTP2)
if err != nil {
return nil, err
}
controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
controlHTTP, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
if err != nil {
return nil, err
}
controlMetrics, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto)
controlMetrics, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto)
if err != nil {
return nil, err
}
authorizeCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
authorizeCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
if err != nil {
return nil, err
}
@ -70,7 +70,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
authorizeCluster.OutlierDetection = grpcAuthorizeOutlierDetection()
}
databrokerCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
if err != nil {
return nil, err
}
@ -87,7 +87,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
databrokerCluster,
}
tracingCluster, err := buildTracingCluster(cfg.Options)
tracingCluster, err := buildTracingCluster(cfg)
if err != nil {
return nil, err
} else if tracingCluster != nil {
@ -101,7 +101,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
policy.EnvoyOpts = newDefaultEnvoyClusterConfig()
}
if len(policy.To) > 0 {
cluster, err := b.buildPolicyCluster(ctx, cfg.Options, &policy)
cluster, err := b.buildPolicyCluster(ctx, cfg, &policy)
if err != nil {
return nil, fmt.Errorf("policy #%d: %w", i, err)
}
@ -119,16 +119,16 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
func (b *Builder) buildInternalCluster(
ctx context.Context,
options *config.Options,
cfg *config.Config,
name string,
dsts []*url.URL,
upstreamProtocol upstreamProtocolConfig,
) (*envoy_config_cluster_v3.Cluster, error) {
cluster := newDefaultEnvoyClusterConfig()
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily)
var endpoints []Endpoint
for _, dst := range dsts {
ts, err := b.buildInternalTransportSocket(ctx, options, dst)
ts, err := b.buildInternalTransportSocket(ctx, cfg, dst)
if err != nil {
return nil, err
}
@ -141,18 +141,22 @@ func (b *Builder) buildInternalCluster(
return cluster, nil
}
func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
func (b *Builder) buildPolicyCluster(
ctx context.Context,
cfg *config.Config,
policy *config.Policy,
) (*envoy_config_cluster_v3.Cluster, error) {
cluster := new(envoy_config_cluster_v3.Cluster)
proto.Merge(cluster, policy.EnvoyOpts)
if options.EnvoyBindConfigFreebind.IsSet() || options.EnvoyBindConfigSourceAddress != "" {
if cfg.Options.EnvoyBindConfigFreebind.IsSet() || cfg.Options.EnvoyBindConfigSourceAddress != "" {
cluster.UpstreamBindConfig = new(envoy_config_core_v3.BindConfig)
if options.EnvoyBindConfigFreebind.IsSet() {
cluster.UpstreamBindConfig.Freebind = wrapperspb.Bool(options.EnvoyBindConfigFreebind.Bool)
if cfg.Options.EnvoyBindConfigFreebind.IsSet() {
cluster.UpstreamBindConfig.Freebind = wrapperspb.Bool(cfg.Options.EnvoyBindConfigFreebind.Bool)
}
if options.EnvoyBindConfigSourceAddress != "" {
if cfg.Options.EnvoyBindConfigSourceAddress != "" {
cluster.UpstreamBindConfig.SourceAddress = &envoy_config_core_v3.SocketAddress{
Address: options.EnvoyBindConfigSourceAddress,
Address: cfg.Options.EnvoyBindConfigSourceAddress,
PortSpecifier: &envoy_config_core_v3.SocketAddress_PortValue{
PortValue: 0,
},
@ -171,13 +175,13 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
name := getClusterID(policy)
endpoints, err := b.buildPolicyEndpoints(ctx, options, policy)
endpoints, err := b.buildPolicyEndpoints(ctx, cfg, policy)
if err != nil {
return nil, err
}
if cluster.DnsLookupFamily == envoy_config_cluster_v3.Cluster_AUTO {
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily)
}
if policy.EnableGoogleCloudServerlessAuthentication {
@ -193,12 +197,12 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
func (b *Builder) buildPolicyEndpoints(
ctx context.Context,
options *config.Options,
cfg *config.Config,
policy *config.Policy,
) ([]Endpoint, error) {
var endpoints []Endpoint
for _, dst := range policy.To {
ts, err := b.buildPolicyTransportSocket(ctx, options, policy, dst.URL)
ts, err := b.buildPolicyTransportSocket(ctx, cfg, policy, dst.URL)
if err != nil {
return nil, err
}
@ -209,7 +213,7 @@ func (b *Builder) buildPolicyEndpoints(
func (b *Builder) buildInternalTransportSocket(
ctx context.Context,
options *config.Options,
cfg *config.Config,
endpoint *url.URL,
) (*envoy_config_core_v3.TransportSocket, error) {
if endpoint.Scheme != "https" {
@ -218,10 +222,10 @@ func (b *Builder) buildInternalTransportSocket(
validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{
MatchTypedSubjectAltNames: []*envoy_extensions_transport_sockets_tls_v3.SubjectAltNameMatcher{
b.buildSubjectAltNameMatcher(endpoint, options.OverrideCertificateName),
b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName),
},
}
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile)
bs, err := getCombinedCertificateAuthority(cfg.Options.CA, cfg.Options.CAFile)
if err != nil {
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else {
@ -234,7 +238,7 @@ func (b *Builder) buildInternalTransportSocket(
ValidationContext: validationContext,
},
},
Sni: b.buildSubjectNameIndication(endpoint, options.OverrideCertificateName),
Sni: b.buildSubjectNameIndication(endpoint, cfg.Options.OverrideCertificateName),
}
tlsConfig := marshalAny(tlsContext)
return &envoy_config_core_v3.TransportSocket{
@ -247,7 +251,7 @@ func (b *Builder) buildInternalTransportSocket(
func (b *Builder) buildPolicyTransportSocket(
ctx context.Context,
options *config.Options,
cfg *config.Config,
policy *config.Policy,
dst url.URL,
) (*envoy_config_core_v3.TransportSocket, error) {
@ -257,7 +261,7 @@ func (b *Builder) buildPolicyTransportSocket(
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
vc, err := b.buildPolicyValidationContext(ctx, options, policy, dst)
vc, err := b.buildPolicyValidationContext(ctx, cfg, policy, dst)
if err != nil {
return nil, err
}
@ -318,7 +322,7 @@ func (b *Builder) buildPolicyTransportSocket(
func (b *Builder) buildPolicyValidationContext(
ctx context.Context,
options *config.Options,
cfg *config.Config,
policy *config.Policy,
dst url.URL,
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
@ -343,7 +347,7 @@ func (b *Builder) buildPolicyValidationContext(
}
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
} else {
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile)
bs, err := getCombinedCertificateAuthority(cfg.Options.CA, cfg.Options.CAFile)
if err != nil {
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else {

View file

@ -37,14 +37,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename()
t.Run("insecure", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com"),
}, *mustParseURL(t, "http://example.com"))
require.NoError(t, err)
assert.Nil(t, ts)
})
t.Run("host as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
}, *mustParseURL(t, "https://example.com"))
require.NoError(t, err)
@ -97,7 +97,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("tls_server_name as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
TLSServerName: "use-this-name.example.com",
}, *mustParseURL(t, "https://example.com"))
@ -151,7 +151,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("tls_upstream_server_name as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
TLSUpstreamServerName: "use-this-name.example.com",
}, *mustParseURL(t, "https://example.com"))
@ -205,7 +205,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("tls_skip_verify", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
TLSSkipVerify: true,
}, *mustParseURL(t, "https://example.com"))
@ -260,7 +260,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("custom ca", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
}, *mustParseURL(t, "https://example.com"))
@ -314,7 +314,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("options custom ca", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o2, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o2), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
}, *mustParseURL(t, "https://example.com"))
require.NoError(t, err)
@ -368,7 +368,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
})
t.Run("client certificate", func(t *testing.T) {
clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey)
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
ClientCertificate: clientCert,
}, *mustParseURL(t, "https://example.com"))
@ -438,7 +438,7 @@ func Test_buildCluster(t *testing.T) {
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
o1 := config.NewDefaultOptions()
t.Run("insecure", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"),
})
require.NoError(t, err)
@ -495,7 +495,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("secure", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t,
"https://example.com",
"https://example.com",
@ -663,7 +663,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("ip addresses", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"),
})
require.NoError(t, err)
@ -718,7 +718,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("weights", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"),
})
require.NoError(t, err)
@ -775,7 +775,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("localhost", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "http://localhost"),
})
require.NoError(t, err)
@ -821,7 +821,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("outlier", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com"),
})
require.NoError(t, err)
@ -904,7 +904,7 @@ func Test_bindConfig(t *testing.T) {
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
t.Run("no bind config", func(t *testing.T) {
cluster, err := b.buildPolicyCluster(ctx, &config.Options{}, &config.Policy{
cluster, err := b.buildPolicyCluster(ctx, config.New(nil), &config.Policy{
From: "https://from.example.com",
To: mustParseWeightedURLs(t, "https://to.example.com"),
})
@ -912,9 +912,9 @@ func Test_bindConfig(t *testing.T) {
assert.Nil(t, cluster.UpstreamBindConfig)
})
t.Run("freebind", func(t *testing.T) {
cluster, err := b.buildPolicyCluster(ctx, &config.Options{
cluster, err := b.buildPolicyCluster(ctx, config.New(&config.Options{
EnvoyBindConfigFreebind: null.BoolFrom(true),
}, &config.Policy{
}), &config.Policy{
From: "https://from.example.com",
To: mustParseWeightedURLs(t, "https://to.example.com"),
})
@ -930,9 +930,9 @@ func Test_bindConfig(t *testing.T) {
`, cluster.UpstreamBindConfig)
})
t.Run("source address", func(t *testing.T) {
cluster, err := b.buildPolicyCluster(ctx, &config.Options{
cluster, err := b.buildPolicyCluster(ctx, config.New(&config.Options{
EnvoyBindConfigSourceAddress: "192.168.0.1",
}, &config.Policy{
}), &config.Policy{
From: "https://from.example.com",
To: mustParseWeightedURLs(t, "https://to.example.com"),
})

View file

@ -76,10 +76,10 @@ func newDefaultEnvoyClusterConfig() *envoy_config_cluster_v3.Cluster {
}
}
func buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog {
lvl := options.ProxyLogLevel
func buildAccessLogs(cfg *config.Config) []*envoy_config_accesslog_v3.AccessLog {
lvl := cfg.Options.ProxyLogLevel
if lvl == "" {
lvl = options.LogLevel
lvl = cfg.Options.LogLevel
}
if lvl == "" {
lvl = "debug"

View file

@ -10,7 +10,7 @@ import (
)
func (b *Builder) buildVirtualHost(
options *config.Options,
cfg *config.Config,
name string,
domain string,
) (*envoy_config_route_v3.VirtualHost, error) {
@ -20,15 +20,15 @@ func (b *Builder) buildVirtualHost(
}
// these routes match /.pomerium/... and similar paths
rs, err := b.buildPomeriumHTTPRoutes(options, domain)
rs, err := b.buildPomeriumHTTPRoutes(cfg, domain)
if err != nil {
return nil, err
}
vh.Routes = append(vh.Routes, rs...)
// if we're the proxy or authenticate service, add our global headers
if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) {
vh.ResponseHeadersToAdd = toEnvoyHeaders(options.GetSetResponseHeaders())
if config.IsProxy(cfg.Options.Services) || config.IsAuthenticate(cfg.Options.Services) {
vh.ResponseHeadersToAdd = toEnvoyHeaders(cfg.Options.GetSetResponseHeaders())
}
return vh, nil
@ -37,13 +37,13 @@ func (b *Builder) buildVirtualHost(
// buildLocalReplyConfig builds the local reply config: the config used to modify "local" replies, that is replies
// coming directly from envoy
func (b *Builder) buildLocalReplyConfig(
options *config.Options,
cfg *config.Config,
) *envoy_http_connection_manager.LocalReplyConfig {
// add global headers for HSTS headers (#2110)
var headers []*envoy_config_core_v3.HeaderValueOption
// if we're the proxy or authenticate service, add our global headers
if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) {
headers = toEnvoyHeaders(options.GetSetResponseHeaders())
if config.IsProxy(cfg.Options.Services) || config.IsAuthenticate(cfg.Options.Services) {
headers = toEnvoyHeaders(cfg.Options.GetSetResponseHeaders())
}
return &envoy_http_connection_manager.LocalReplyConfig{

View file

@ -53,7 +53,10 @@ func init() {
}
// BuildListeners builds envoy listeners from the given config.
func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*envoy_config_listener_v3.Listener, error) {
func (b *Builder) BuildListeners(
ctx context.Context,
cfg *config.Config,
) ([]*envoy_config_listener_v3.Listener, error) {
var listeners []*envoy_config_listener_v3.Listener
if config.IsAuthenticate(cfg.Options.Services) || config.IsProxy(cfg.Options.Services) {
@ -89,19 +92,22 @@ func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*en
return listeners, nil
}
func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
func (b *Builder) buildMainListener(
ctx context.Context,
cfg *config.Config,
) (*envoy_config_listener_v3.Listener, error) {
listenerFilters := []*envoy_config_listener_v3.ListenerFilter{}
if cfg.Options.UseProxyProtocol {
listenerFilters = append(listenerFilters, ProxyProtocolFilter())
}
if cfg.Options.InsecureServer {
allDomains, err := getAllRouteableDomains(cfg.Options, cfg.Options.Addr)
allDomains, err := getAllRouteableDomains(cfg, cfg.Options.Addr)
if err != nil {
return nil, err
}
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, allDomains, "")
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg, allDomains, "")
if err != nil {
return nil, err
}
@ -118,9 +124,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
}
listenerFilters = append(listenerFilters, TLSInspectorFilter())
chains, err := b.buildFilterChains(cfg.Options, cfg.Options.Addr,
chains, err := b.buildFilterChains(cfg, cfg.Options.Addr,
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain)
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg, httpDomains, tlsDomain)
if err != nil {
return nil, err
}
@ -155,7 +161,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
return li, nil
}
func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
func (b *Builder) buildMetricsListener(
cfg *config.Config,
) (*envoy_config_listener_v3.Listener, error) {
filter, err := b.buildMetricsHTTPConnectionManagerFilter()
if err != nil {
return nil, err
@ -235,22 +243,23 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen
}
func (b *Builder) buildFilterChains(
options *config.Options, addr string,
cfg *config.Config,
addr string,
callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error),
) ([]*envoy_config_listener_v3.FilterChain, error) {
allDomains, err := getAllRouteableDomains(options, addr)
allDomains, err := getAllRouteableDomains(cfg, addr)
if err != nil {
return nil, err
}
tlsDomains, err := getAllTLSDomains(options, addr)
tlsDomains, err := getAllTLSDomains(cfg, addr)
if err != nil {
return nil, err
}
var chains []*envoy_config_listener_v3.FilterChain
for _, domain := range tlsDomains {
routeableDomains, err := getRouteableDomainsForTLSServerName(options, addr, domain)
routeableDomains, err := getRouteableDomainsForTLSServerName(cfg, addr, domain)
if err != nil {
return nil, err
}
@ -273,31 +282,31 @@ func (b *Builder) buildFilterChains(
}
func (b *Builder) buildMainHTTPConnectionManagerFilter(
options *config.Options,
cfg *config.Config,
domains []string,
tlsDomain string,
) (*envoy_config_listener_v3.Filter, error) {
authorizeURLs, err := options.GetInternalAuthorizeURLs()
authorizeURLs, err := cfg.Options.GetInternalAuthorizeURLs()
if err != nil {
return nil, err
}
dataBrokerURLs, err := options.GetInternalDataBrokerURLs()
dataBrokerURLs, err := cfg.Options.GetInternalDataBrokerURLs()
if err != nil {
return nil, err
}
var virtualHosts []*envoy_config_route_v3.VirtualHost
for _, domain := range domains {
vh, err := b.buildVirtualHost(options, domain, domain)
vh, err := b.buildVirtualHost(cfg, domain, domain)
if err != nil {
return nil, err
}
if options.Addr == options.GetGRPCAddr() {
if cfg.Options.Addr == cfg.Options.GetGRPCAddr() {
// if this is a gRPC service domain and we're supposed to handle that, add those routes
if (config.IsAuthorize(options.Services) && hostsMatchDomain(authorizeURLs, domain)) ||
(config.IsDataBroker(options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) {
if (config.IsAuthorize(cfg.Options.Services) && hostsMatchDomain(authorizeURLs, domain)) ||
(config.IsDataBroker(cfg.Options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) {
rs, err := b.buildGRPCRoutes()
if err != nil {
return nil, err
@ -307,8 +316,8 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
}
// if we're the proxy, add all the policy routes
if config.IsProxy(options.Services) {
rs, err := b.buildPolicyRoutes(options, domain)
if config.IsProxy(cfg.Options.Services) {
rs, err := b.buildPolicyRoutes(cfg, domain)
if err != nil {
return nil, err
}
@ -320,15 +329,15 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
}
}
vh, err := b.buildVirtualHost(options, "catch-all", "*")
vh, err := b.buildVirtualHost(cfg, "catch-all", "*")
if err != nil {
return nil, err
}
virtualHosts = append(virtualHosts, vh)
var grpcClientTimeout *durationpb.Duration
if options.GRPCClientTimeout != 0 {
grpcClientTimeout = durationpb.New(options.GRPCClientTimeout)
if cfg.Options.GRPCClientTimeout != 0 {
grpcClientTimeout = durationpb.New(cfg.Options.GRPCClientTimeout)
} else {
grpcClientTimeout = durationpb.New(30 * time.Second)
}
@ -346,15 +355,15 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
filters = append(filters, HTTPRouterFilter())
var maxStreamDuration *durationpb.Duration
if options.WriteTimeout > 0 {
maxStreamDuration = durationpb.New(options.WriteTimeout)
if cfg.Options.WriteTimeout > 0 {
maxStreamDuration = durationpb.New(cfg.Options.WriteTimeout)
}
rc, err := b.buildRouteConfiguration("main", virtualHosts)
if err != nil {
return nil, err
}
tracingProvider, err := buildTracingHTTP(options)
tracingProvider, err := buildTracingHTTP(cfg)
if err != nil {
return nil, err
}
@ -362,27 +371,27 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{
AlwaysSetRequestIdInResponse: true,
CodecType: options.GetCodecType().ToEnvoy(),
CodecType: cfg.Options.GetCodecType().ToEnvoy(),
StatPrefix: "ingress",
RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{
RouteConfig: rc,
},
HttpFilters: filters,
AccessLog: buildAccessLogs(options),
AccessLog: buildAccessLogs(cfg),
CommonHttpProtocolOptions: &envoy_config_core_v3.HttpProtocolOptions{
IdleTimeout: durationpb.New(options.IdleTimeout),
IdleTimeout: durationpb.New(cfg.Options.IdleTimeout),
MaxStreamDuration: maxStreamDuration,
},
RequestTimeout: durationpb.New(options.ReadTimeout),
RequestTimeout: durationpb.New(cfg.Options.ReadTimeout),
Tracing: &envoy_http_connection_manager.HttpConnectionManager_Tracing{
RandomSampling: &envoy_type_v3.Percent{Value: options.TracingSampleRate * 100},
RandomSampling: &envoy_type_v3.Percent{Value: cfg.Options.TracingSampleRate * 100},
Provider: tracingProvider,
},
// See https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_conn_man/headers#x-forwarded-for
UseRemoteAddress: &wrappers.BoolValue{Value: true},
SkipXffAppend: options.SkipXffAppend,
XffNumTrustedHops: options.XffNumTrustedHops,
LocalReplyConfig: b.buildLocalReplyConfig(options),
SkipXffAppend: cfg.Options.SkipXffAppend,
XffNumTrustedHops: cfg.Options.XffNumTrustedHops,
LocalReplyConfig: b.buildLocalReplyConfig(cfg),
}), nil
}
@ -420,7 +429,10 @@ func (b *Builder) buildMetricsHTTPConnectionManagerFilter() (*envoy_config_liste
}), nil
}
func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
func (b *Builder) buildGRPCListener(
ctx context.Context,
cfg *config.Config,
) (*envoy_config_listener_v3.Listener, error) {
filter, err := b.buildGRPCHTTPConnectionManagerFilter()
if err != nil {
return nil, err
@ -437,7 +449,7 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e
return li, nil
}
chains, err := b.buildFilterChains(cfg.Options, cfg.Options.GRPCAddr,
chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr,
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
filterChain := &envoy_config_listener_v3.FilterChain{
Filters: []*envoy_config_listener_v3.Filter{filter},
@ -518,7 +530,10 @@ func (b *Builder) buildGRPCHTTPConnectionManagerFilter() (*envoy_config_listener
}), nil
}
func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3.VirtualHost) (*envoy_config_route_v3.RouteConfiguration, error) {
func (b *Builder) buildRouteConfiguration(
name string,
virtualHosts []*envoy_config_route_v3.VirtualHost,
) (*envoy_config_route_v3.RouteConfiguration, error) {
return &envoy_config_route_v3.RouteConfiguration{
Name: name,
VirtualHosts: virtualHosts,
@ -527,7 +542,8 @@ func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_con
}, nil
}
func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
func (b *Builder) buildDownstreamTLSContext(
ctx context.Context,
cfg *config.Config,
domain string,
) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
@ -570,7 +586,8 @@ func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
}
}
func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
func (b *Builder) buildDownstreamValidationContext(
ctx context.Context,
cfg *config.Config,
domain string,
) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext {
@ -580,7 +597,7 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
needsClientCert = true
}
if !needsClientCert {
for _, p := range getPoliciesForDomain(cfg.Options, domain) {
for _, p := range getPoliciesForDomain(cfg, domain) {
if p.TLSDownstreamClientCA != "" {
needsClientCert = true
break
@ -613,19 +630,23 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
return vc
}
func getRouteableDomainsForTLSServerName(options *config.Options, addr string, tlsServerName string) ([]string, error) {
func getRouteableDomainsForTLSServerName(
cfg *config.Config,
addr string,
tlsServerName string,
) ([]string, error) {
allDomains := sets.NewSorted[string]()
if addr == options.Addr {
domains, err := options.GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName)
if addr == cfg.Options.Addr {
domains, err := cfg.Options.GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName)
if err != nil {
return nil, err
}
allDomains.Add(domains...)
}
if addr == options.GetGRPCAddr() {
domains, err := options.GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName)
if addr == cfg.Options.GetGRPCAddr() {
domains, err := cfg.Options.GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName)
if err != nil {
return nil, err
}
@ -635,19 +656,22 @@ func getRouteableDomainsForTLSServerName(options *config.Options, addr string, t
return allDomains.ToSlice(), nil
}
func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) {
func getAllRouteableDomains(
cfg *config.Config,
addr string,
) ([]string, error) {
allDomains := sets.NewSorted[string]()
if addr == options.Addr {
domains, err := options.GetAllRouteableHTTPDomains()
if addr == cfg.Options.Addr {
domains, err := cfg.Options.GetAllRouteableHTTPDomains()
if err != nil {
return nil, err
}
allDomains.Add(domains...)
}
if addr == options.GetGRPCAddr() {
domains, err := options.GetAllRouteableGRPCDomains()
if addr == cfg.Options.GetGRPCAddr() {
domains, err := cfg.Options.GetAllRouteableGRPCDomains()
if err != nil {
return nil, err
}
@ -657,8 +681,11 @@ func getAllRouteableDomains(options *config.Options, addr string) ([]string, err
return allDomains.ToSlice(), nil
}
func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {
allDomains, err := getAllRouteableDomains(options, addr)
func getAllTLSDomains(
cfg *config.Config,
addr string,
) ([]string, error) {
allDomains, err := getAllRouteableDomains(cfg, addr)
if err != nil {
return nil, err
}
@ -711,9 +738,12 @@ func hostMatchesDomain(u *url.URL, host string) bool {
return h1 == h2 && p1 == p2
}
func getPoliciesForDomain(options *config.Options, domain string) []config.Policy {
func getPoliciesForDomain(
cfg *config.Config,
domain string,
) []config.Policy {
var policies []config.Policy
for _, p := range options.GetAllPolicies() {
for _, p := range cfg.Options.GetAllPolicies() {
if p.Source != nil && p.Source.URL.Hostname() == domain {
policies = append(policies, p)
}

View file

@ -26,13 +26,11 @@ func Test_buildMetricsHTTPConnectionManagerFilter(t *testing.T) {
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
li, err := b.buildMetricsListener(&config.Config{
Options: &config.Options{
MetricsAddr: "127.0.0.1:9902",
MetricsCertificate: aExampleComCert,
MetricsCertificateKey: aExampleComKey,
},
})
li, err := b.buildMetricsListener(config.New(&config.Options{
MetricsAddr: "127.0.0.1:9902",
MetricsCertificate: aExampleComCert,
MetricsCertificateKey: aExampleComKey,
}))
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `
{
@ -115,7 +113,7 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
options := config.NewDefaultOptions()
options.SkipXffAppend = true
options.XffNumTrustedHops = 1
filter, err := b.buildMainHTTPConnectionManagerFilter(options, []string{"example.com"}, "*")
filter, err := b.buildMainHTTPConnectionManagerFilter(config.New(options), []string{"example.com"}, "*")
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `{
"name": "envoy.filters.network.http_connection_manager",
@ -544,10 +542,10 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")
t.Run("no-validation", func(t *testing.T) {
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
Cert: aExampleComCert,
Key: aExampleComKey,
}}, "a.example.com")
}), "a.example.com")
testutil.AssertProtoJSONEqual(t, `{
"commonTlsContext": {
@ -577,11 +575,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
}`, downstreamTLSContext)
})
t.Run("client-ca", func(t *testing.T) {
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
Cert: aExampleComCert,
Key: aExampleComKey,
ClientCA: "TEST",
}}, "a.example.com")
}), "a.example.com")
testutil.AssertProtoJSONEqual(t, `{
"commonTlsContext": {
@ -614,7 +612,7 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
}`, downstreamTLSContext)
})
t.Run("policy-client-ca", func(t *testing.T) {
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
Cert: aExampleComCert,
Key: aExampleComKey,
Policies: []config.Policy{
@ -623,7 +621,7 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
TLSDownstreamClientCA: "TEST",
},
},
}}, "a.example.com")
}), "a.example.com")
testutil.AssertProtoJSONEqual(t, `{
"commonTlsContext": {
@ -656,11 +654,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
}`, downstreamTLSContext)
})
t.Run("http1", func(t *testing.T) {
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
Cert: aExampleComCert,
Key: aExampleComKey,
CodecType: config.CodecTypeHTTP1,
}}, "a.example.com")
}), "a.example.com")
testutil.AssertProtoJSONEqual(t, `{
"commonTlsContext": {
@ -690,11 +688,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
}`, downstreamTLSContext)
})
t.Run("http2", func(t *testing.T) {
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
Cert: aExampleComCert,
Key: aExampleComKey,
CodecType: config.CodecTypeHTTP2,
}}, "a.example.com")
}), "a.example.com")
testutil.AssertProtoJSONEqual(t, `{
"commonTlsContext": {
@ -739,9 +737,10 @@ func Test_getAllDomains(t *testing.T) {
{Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}},
},
}
cfg := config.New(options)
t.Run("routable", func(t *testing.T) {
t.Run("http", func(t *testing.T) {
actual, err := getAllRouteableDomains(options, "127.0.0.1:9000")
actual, err := getAllRouteableDomains(cfg, "127.0.0.1:9000")
require.NoError(t, err)
expect := []string{
"a.example.com",
@ -756,7 +755,7 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual)
})
t.Run("grpc", func(t *testing.T) {
actual, err := getAllRouteableDomains(options, "127.0.0.1:9001")
actual, err := getAllRouteableDomains(cfg, "127.0.0.1:9001")
require.NoError(t, err)
expect := []string{
"authorize.example.com:9001",
@ -765,9 +764,9 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual)
})
t.Run("both", func(t *testing.T) {
newOptions := *options
newOptions.GRPCAddr = newOptions.Addr
actual, err := getAllRouteableDomains(&newOptions, "127.0.0.1:9000")
newCfg := cfg.Clone()
newCfg.Options.GRPCAddr = cfg.Options.Addr
actual, err := getAllRouteableDomains(newCfg, "127.0.0.1:9000")
require.NoError(t, err)
expect := []string{
"a.example.com",
@ -786,7 +785,7 @@ func Test_getAllDomains(t *testing.T) {
})
t.Run("tls", func(t *testing.T) {
t.Run("http", func(t *testing.T) {
actual, err := getAllTLSDomains(options, "127.0.0.1:9000")
actual, err := getAllTLSDomains(cfg, "127.0.0.1:9000")
require.NoError(t, err)
expect := []string{
"a.example.com",
@ -797,7 +796,7 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual)
})
t.Run("grpc", func(t *testing.T) {
actual, err := getAllTLSDomains(options, "127.0.0.1:9001")
actual, err := getAllTLSDomains(cfg, "127.0.0.1:9001")
require.NoError(t, err)
expect := []string{
"authorize.example.com",
@ -831,10 +830,10 @@ func Test_buildRouteConfiguration(t *testing.T) {
func Test_requireProxyProtocol(t *testing.T) {
b := New("local-grpc", "local-http", "local-metrics", nil, nil)
t.Run("required", func(t *testing.T) {
li, err := b.buildMainListener(context.Background(), &config.Config{Options: &config.Options{
li, err := b.buildMainListener(context.Background(), config.New(&config.Options{
UseProxyProtocol: true,
InsecureServer: true,
}})
}))
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[
{
@ -846,10 +845,10 @@ func Test_requireProxyProtocol(t *testing.T) {
]`, li.GetListenerFilters())
})
t.Run("not required", func(t *testing.T) {
li, err := b.buildMainListener(context.Background(), &config.Config{Options: &config.Options{
li, err := b.buildMainListener(context.Background(), config.New(&config.Options{
UseProxyProtocol: false,
InsecureServer: true,
}})
}))
require.NoError(t, err)
assert.Len(t, li.GetListenerFilters(), 0)
})

View file

@ -14,7 +14,9 @@ import (
"github.com/pomerium/pomerium/config"
)
func (b *Builder) buildOutboundListener(cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
func (b *Builder) buildOutboundListener(
cfg *config.Config,
) (*envoy_config_listener_v3.Listener, error) {
outboundPort, err := strconv.Atoi(cfg.OutboundPort)
if err != nil {
return nil, fmt.Errorf("invalid outbound port %v: %w", cfg.OutboundPort, err)

View file

@ -47,12 +47,15 @@ func (b *Builder) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) {
}}, nil
}
func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) {
func (b *Builder) buildPomeriumHTTPRoutes(
cfg *config.Config,
domain string,
) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route
// if this is the pomerium proxy in front of the the authenticate service, don't add
// these routes since they will be handled by authenticate
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain)
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(cfg, domain)
if err != nil {
return nil, err
}
@ -69,27 +72,27 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string
b.buildControlPlanePrefixRoute("/.well-known/pomerium/", false),
)
// per #837, only add robots.txt if there are no unauthenticated routes
if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) {
if !hasPublicPolicyMatchingURL(cfg, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) {
routes = append(routes, b.buildControlPlanePathRoute("/robots.txt", false))
}
}
// if we're handling authentication, add the oauth2 callback url
authenticateURL, err := options.GetInternalAuthenticateURL()
authenticateURL, err := cfg.Options.GetInternalAuthenticateURL()
if err != nil {
return nil, err
}
if config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) {
if config.IsAuthenticate(cfg.Options.Services) && hostMatchesDomain(authenticateURL, domain) {
routes = append(routes,
b.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false),
b.buildControlPlanePathRoute(cfg.Options.AuthenticateCallbackPath, false),
b.buildControlPlanePathRoute("/", false),
)
}
// if we're the proxy and this is the forward-auth url
forwardAuthURL, err := options.GetForwardAuthURL()
forwardAuthURL, err := cfg.Options.GetForwardAuthURL()
if err != nil {
return nil, err
}
if config.IsProxy(options.Services) && hostMatchesDomain(forwardAuthURL, domain) {
if config.IsProxy(cfg.Options.Services) && hostMatchesDomain(forwardAuthURL, domain) {
// disable ext_authz and pass request to proxy handlers that enable authN flow
r, err := b.buildControlPlanePathAndQueryRoute("/verify", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI})
if err != nil {
@ -227,10 +230,13 @@ func getClusterStatsName(policy *config.Policy) string {
return ""
}
func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) {
func (b *Builder) buildPolicyRoutes(
cfg *config.Config,
domain string,
) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route
for i, p := range options.GetAllPolicies() {
for i, p := range cfg.Options.GetAllPolicies() {
policy := p
if !hostMatchesDomain(policy.Source.URL, domain) {
continue
@ -242,7 +248,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
Match: match,
Metadata: &envoy_config_core_v3.Metadata{},
RequestHeadersToAdd: toEnvoyHeaders(policy.SetRequestHeaders),
RequestHeadersToRemove: getRequestHeadersToRemove(options, &policy),
RequestHeadersToRemove: getRequestHeadersToRemove(cfg, &policy),
ResponseHeadersToAdd: toEnvoyHeaders(policy.SetResponseHeaders),
}
if policy.Redirect != nil {
@ -252,7 +258,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
}
envoyRoute.Action = &envoy_config_route_v3.Route_Redirect{Redirect: action}
} else {
action, err := b.buildPolicyRouteRouteAction(options, &policy)
action, err := b.buildPolicyRouteRouteAction(cfg, &policy)
if err != nil {
return nil, err
}
@ -264,7 +270,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
}
// disable authentication entirely when the proxy is fronting authenticate
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain)
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(cfg, domain)
if err != nil {
return nil, err
}
@ -275,7 +281,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
} else {
luaMetadata["remove_pomerium_cookie"] = &structpb.Value{
Kind: &structpb.Value_StringValue{
StringValue: options.CookieName,
StringValue: cfg.Options.CookieName,
},
}
luaMetadata["remove_pomerium_authorization"] = &structpb.Value{
@ -350,13 +356,16 @@ func (b *Builder) buildPolicyRouteRedirectAction(r *config.PolicyRedirect) (*env
return action, nil
}
func (b *Builder) buildPolicyRouteRouteAction(options *config.Options, policy *config.Policy) (*envoy_config_route_v3.RouteAction, error) {
func (b *Builder) buildPolicyRouteRouteAction(
cfg *config.Config,
policy *config.Policy,
) (*envoy_config_route_v3.RouteAction, error) {
clusterName := getClusterID(policy)
// kubernetes requests are sent to the http control plane to be reproxied
if policy.IsForKubernetes() {
clusterName = httpCluster
}
routeTimeout := getRouteTimeout(options, policy)
routeTimeout := getRouteTimeout(cfg, policy)
idleTimeout := getRouteIdleTimeout(policy)
prefixRewrite, regexRewrite := getRewriteOptions(policy)
upgradeConfigs := []*envoy_config_route_v3.RouteAction_UpgradeConfig{
@ -464,13 +473,16 @@ func mkRouteMatch(policy *config.Policy) *envoy_config_route_v3.RouteMatch {
return match
}
func getRequestHeadersToRemove(options *config.Options, policy *config.Policy) []string {
func getRequestHeadersToRemove(
cfg *config.Config,
policy *config.Policy,
) []string {
requestHeadersToRemove := policy.RemoveRequestHeaders
if !policy.PassIdentityHeaders {
requestHeadersToRemove = append(requestHeadersToRemove,
httputil.HeaderPomeriumJWTAssertion,
httputil.HeaderPomeriumJWTAssertionFor)
for headerName := range options.JWTClaimsHeaders {
for headerName := range cfg.Options.JWTClaimsHeaders {
requestHeadersToRemove = append(requestHeadersToRemove, headerName)
}
}
@ -482,7 +494,10 @@ func getRequestHeadersToRemove(options *config.Options, policy *config.Policy) [
return requestHeadersToRemove
}
func getRouteTimeout(options *config.Options, policy *config.Policy) *durationpb.Duration {
func getRouteTimeout(
cfg *config.Config,
policy *config.Policy,
) *durationpb.Duration {
var routeTimeout *durationpb.Duration
if policy.UpstreamTimeout != nil {
routeTimeout = durationpb.New(*policy.UpstreamTimeout)
@ -490,7 +505,7 @@ func getRouteTimeout(options *config.Options, policy *config.Policy) *durationpb
// a non-zero value would conflict with idleTimeout and/or websocket / tcp calls
routeTimeout = durationpb.New(0)
} else {
routeTimeout = durationpb.New(options.DefaultUpstreamTimeout)
routeTimeout = durationpb.New(cfg.Options.DefaultUpstreamTimeout)
}
return routeTimeout
}
@ -564,8 +579,11 @@ func setHostRewriteOptions(policy *config.Policy, action *envoy_config_route_v3.
}
}
func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) bool {
for _, policy := range options.GetAllPolicies() {
func hasPublicPolicyMatchingURL(
cfg *config.Config,
requestURL url.URL,
) bool {
for _, policy := range cfg.Options.GetAllPolicies() {
if policy.AllowPublicUnauthenticatedAccess && policy.Matches(requestURL) {
return true
}
@ -573,13 +591,16 @@ func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) boo
return false
}
func isProxyFrontingAuthenticate(options *config.Options, domain string) (bool, error) {
authenticateURL, err := options.GetAuthenticateURL()
func isProxyFrontingAuthenticate(
cfg *config.Config,
domain string,
) (bool, error) {
authenticateURL, err := cfg.Options.GetAuthenticateURL()
if err != nil {
return false, err
}
if !config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) {
if !config.IsAuthenticate(cfg.Options.Services) && hostMatchesDomain(authenticateURL, domain) {
return true, nil
}

View file

@ -84,7 +84,8 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
AuthenticateCallbackPath: "/oauth2/callback",
ForwardAuthURLString: "https://forward-auth.example.com",
}
routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com")
cfg := config.New(options)
routes, err := b.buildPomeriumHTTPRoutes(cfg, "authenticate.example.com")
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[
@ -106,7 +107,8 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
AuthenticateURLString: "https://authenticate.example.com",
AuthenticateCallbackPath: "/oauth2/callback",
}
routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com")
cfg := config.New(options)
routes, err := b.buildPomeriumHTTPRoutes(cfg, "authenticate.example.com")
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, "null", routes)
})
@ -122,8 +124,9 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
To: mustParseWeightedURLs(t, "https://to.example.com"),
}},
}
cfg := config.New(options)
_ = options.Policies[0].Validate()
routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com")
routes, err := b.buildPomeriumHTTPRoutes(cfg, "from.example.com")
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[
@ -150,8 +153,9 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
AllowPublicUnauthenticatedAccess: true,
}},
}
cfg := config.New(options)
_ = options.Policies[0].Validate()
routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com")
routes, err := b.buildPomeriumHTTPRoutes(cfg, "from.example.com")
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[
@ -242,7 +246,7 @@ func TestTimeouts(t *testing.T) {
for _, tc := range testCases {
b := &Builder{filemgr: filemgr.NewManager()}
routes, err := b.buildPolicyRoutes(&config.Options{
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
CookieName: "pomerium",
DefaultUpstreamTimeout: time.Second * 3,
Policies: []config.Policy{
@ -253,7 +257,7 @@ func TestTimeouts(t *testing.T) {
IdleTimeout: getDuration(tc.idle),
AllowWebsockets: tc.allowWebsockets,
}},
}, "example.com")
}), "example.com")
if !assert.NoError(t, err, "%v", tc) || !assert.Len(t, routes, 1, tc) || !assert.NotNil(t, routes[0].GetRoute(), "%v", tc) {
continue
}
@ -295,7 +299,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
ten := time.Second * 10
b := &Builder{filemgr: filemgr.NewManager()}
routes, err := b.buildPolicyRoutes(&config.Options{
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
CookieName: "pomerium",
DefaultUpstreamTimeout: time.Second * 3,
Policies: []config.Policy{
@ -357,7 +361,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
UpstreamTimeout: &ten,
},
},
}, "example.com")
}), "example.com")
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `
@ -724,7 +728,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
`, routes)
t.Run("fronting-authenticate", func(t *testing.T) {
routes, err := b.buildPolicyRoutes(&config.Options{
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
AuthenticateURLString: "https://authenticate.example.com",
Services: "proxy",
CookieName: "pomerium",
@ -735,7 +739,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
PassIdentityHeaders: true,
},
},
}, "authenticate.example.com")
}), "authenticate.example.com")
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `
@ -791,7 +795,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
})
t.Run("tcp", func(t *testing.T) {
routes, err := b.buildPolicyRoutes(&config.Options{
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
CookieName: "pomerium",
DefaultUpstreamTimeout: time.Second * 3,
Policies: []config.Policy{
@ -805,7 +809,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
UpstreamTimeout: &ten,
},
},
}, "example.com:22")
}), "example.com:22")
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `
@ -905,7 +909,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
})
t.Run("remove-pomerium-headers", func(t *testing.T) {
routes, err := b.buildPolicyRoutes(&config.Options{
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
AuthenticateURLString: "https://authenticate.example.com",
Services: "proxy",
CookieName: "pomerium",
@ -918,7 +922,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
Source: &config.StringURL{URL: mustParseURL(t, "https://from.example.com")},
},
},
}, "from.example.com")
}), "from.example.com")
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `
@ -980,7 +984,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) {
}(getClusterID)
getClusterID = policyNameFunc()
b := &Builder{filemgr: filemgr.NewManager()}
routes, err := b.buildPolicyRoutes(&config.Options{
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
CookieName: "pomerium",
DefaultUpstreamTimeout: time.Second * 3,
Policies: []config.Policy{
@ -1022,7 +1026,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) {
HostPathRegexRewriteSubstitution: "\\1",
},
},
}, "example.com")
}), "example.com")
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `

View file

@ -14,8 +14,10 @@ import (
"github.com/pomerium/pomerium/pkg/protoutil"
)
func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Cluster, error) {
tracingOptions, err := config.NewTracingOptions(options)
func buildTracingCluster(
cfg *config.Config,
) (*envoy_config_cluster_v3.Cluster, error) {
tracingOptions, err := config.NewTracingOptions(cfg)
if err != nil {
return nil, fmt.Errorf("envoyconfig: invalid tracing config: %w", err)
}
@ -24,8 +26,8 @@ func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Clus
case trace.DatadogTracingProviderName:
addr, _ := parseAddress("127.0.0.1:8126")
if options.TracingDatadogAddress != "" {
addr, err = parseAddress(options.TracingDatadogAddress)
if cfg.Options.TracingDatadogAddress != "" {
addr, err = parseAddress(cfg.Options.TracingDatadogAddress)
if err != nil {
return nil, fmt.Errorf("envoyconfig: invalid tracing datadog address: %w", err)
}
@ -94,8 +96,10 @@ func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Clus
}
}
func buildTracingHTTP(options *config.Options) (*envoy_config_trace_v3.Tracing_Http, error) {
tracingOptions, err := config.NewTracingOptions(options)
func buildTracingHTTP(
cfg *config.Config,
) (*envoy_config_trace_v3.Tracing_Http, error) {
tracingOptions, err := config.NewTracingOptions(cfg)
if err != nil {
return nil, fmt.Errorf("invalid tracing config: %w", err)
}

View file

@ -11,9 +11,9 @@ import (
func TestBuildTracingCluster(t *testing.T) {
t.Run("datadog", func(t *testing.T) {
c, err := buildTracingCluster(&config.Options{
c, err := buildTracingCluster(config.New(&config.Options{
TracingProvider: "datadog",
})
}))
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `
{
@ -38,10 +38,10 @@ func TestBuildTracingCluster(t *testing.T) {
}
`, c)
c, err = buildTracingCluster(&config.Options{
c, err = buildTracingCluster(config.New(&config.Options{
TracingProvider: "datadog",
TracingDatadogAddress: "example.com:8126",
})
}))
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `
{
@ -67,10 +67,10 @@ func TestBuildTracingCluster(t *testing.T) {
`, c)
})
t.Run("zipkin", func(t *testing.T) {
c, err := buildTracingCluster(&config.Options{
c, err := buildTracingCluster(config.New(&config.Options{
TracingProvider: "zipkin",
ZipkinEndpoint: "https://example.com/api/v2/spans",
})
}))
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `
{
@ -99,9 +99,9 @@ func TestBuildTracingCluster(t *testing.T) {
func TestBuildTracingHTTP(t *testing.T) {
t.Run("datadog", func(t *testing.T) {
h, err := buildTracingHTTP(&config.Options{
h, err := buildTracingHTTP(config.New(&config.Options{
TracingProvider: "datadog",
})
}))
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `
{
@ -115,10 +115,10 @@ func TestBuildTracingHTTP(t *testing.T) {
`, h)
})
t.Run("zipkin", func(t *testing.T) {
h, err := buildTracingHTTP(&config.Options{
h, err := buildTracingHTTP(config.New(&config.Options{
TracingProvider: "zipkin",
ZipkinEndpoint: "https://example.com/api/v2/spans",
})
}))
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `
{

View file

@ -18,28 +18,28 @@ import (
type TracingOptions = trace.TracingOptions
// NewTracingOptions builds a new TracingOptions from core Options
func NewTracingOptions(o *Options) (*TracingOptions, error) {
func NewTracingOptions(cfg *Config) (*TracingOptions, error) {
tracingOpts := TracingOptions{
Provider: o.TracingProvider,
Service: telemetry.ServiceName(o.Services),
JaegerAgentEndpoint: o.TracingJaegerAgentEndpoint,
SampleRate: o.TracingSampleRate,
Provider: cfg.Options.TracingProvider,
Service: telemetry.ServiceName(cfg.Options.Services),
JaegerAgentEndpoint: cfg.Options.TracingJaegerAgentEndpoint,
SampleRate: cfg.Options.TracingSampleRate,
}
switch o.TracingProvider {
switch cfg.Options.TracingProvider {
case trace.DatadogTracingProviderName:
tracingOpts.DatadogAddress = o.TracingDatadogAddress
tracingOpts.DatadogAddress = cfg.Options.TracingDatadogAddress
case trace.JaegerTracingProviderName:
if o.TracingJaegerCollectorEndpoint != "" {
jaegerCollectorEndpoint, err := urlutil.ParseAndValidateURL(o.TracingJaegerCollectorEndpoint)
if cfg.Options.TracingJaegerCollectorEndpoint != "" {
jaegerCollectorEndpoint, err := urlutil.ParseAndValidateURL(cfg.Options.TracingJaegerCollectorEndpoint)
if err != nil {
return nil, fmt.Errorf("config: invalid jaeger endpoint url: %w", err)
}
tracingOpts.JaegerCollectorEndpoint = jaegerCollectorEndpoint
tracingOpts.JaegerAgentEndpoint = o.TracingJaegerAgentEndpoint
tracingOpts.JaegerAgentEndpoint = cfg.Options.TracingJaegerAgentEndpoint
}
case trace.ZipkinTracingProviderName:
zipkinEndpoint, err := urlutil.ParseAndValidateURL(o.ZipkinEndpoint)
zipkinEndpoint, err := urlutil.ParseAndValidateURL(cfg.Options.ZipkinEndpoint)
if err != nil {
return nil, fmt.Errorf("config: invalid zipkin endpoint url: %w", err)
}
@ -47,7 +47,7 @@ func NewTracingOptions(o *Options) (*TracingOptions, error) {
case "":
return &TracingOptions{}, nil
default:
return nil, fmt.Errorf("config: provider %s unknown", o.TracingProvider)
return nil, fmt.Errorf("config: provider %s unknown", cfg.Options.TracingProvider)
}
return &tracingOpts, nil
@ -88,7 +88,7 @@ func (mgr *TraceManager) OnConfigChange(ctx context.Context, cfg *Config) {
mgr.mu.Lock()
defer mgr.mu.Unlock()
traceOpts, err := NewTracingOptions(cfg.Options)
traceOpts, err := NewTracingOptions(cfg)
if err != nil {
log.Error(ctx).Err(err).Msg("trace: failed to build tracing options")
return

View file

@ -68,7 +68,7 @@ func Test_NewTracingOptions(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewTracingOptions(tt.opts)
got, err := NewTracingOptions(&Config{Options: tt.opts})
assert.NotEqual(t, err == nil, tt.wantErr, "unexpected error value")
assert.Empty(t, cmp.Diff(tt.want, got))
})

View file

@ -28,7 +28,7 @@ func TestNew(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.opts.Provider = "google"
_, err := New(&config.Config{Options: &tt.opts}, events.New())
_, err := New(config.New(&tt.opts), events.New())
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return

View file

@ -236,8 +236,8 @@ func TestConfig(t *testing.T) {
}
_ = p1.Validate()
mgr, err := newManager(ctx, config.NewStaticSource(&config.Config{
Options: &config.Options{
mgr, err := newManager(ctx,
config.NewStaticSource(config.New(&config.Options{
AutocertOptions: config.AutocertOptions{
Enable: true,
UseStaging: true,
@ -247,11 +247,11 @@ func TestConfig(t *testing.T) {
},
HTTPRedirectAddr: addr,
Policies: []config.Policy{p1},
},
}), certmagic.ACMEIssuer{
CA: srv.URL + "/acme/directory",
TestCA: srv.URL + "/acme/directory",
}, time.Millisecond*100)
})),
certmagic.ACMEIssuer{
CA: srv.URL + "/acme/directory",
TestCA: srv.URL + "/acme/directory",
}, time.Millisecond*100)
if !assert.NoError(t, err) {
return
}
@ -304,16 +304,14 @@ func TestRedirect(t *testing.T) {
addr := li.Addr().String()
_ = li.Close()
src := config.NewStaticSource(&config.Config{
Options: &config.Options{
HTTPRedirectAddr: addr,
SetResponseHeaders: map[string]string{
"X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
},
src := config.NewStaticSource(config.New(&config.Options{
HTTPRedirectAddr: addr,
SetResponseHeaders: map[string]string{
"X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
},
})
}))
_, err = New(src)
if !assert.NoError(t, err) {
return

View file

@ -28,7 +28,7 @@ import (
type Handler struct {
mu sync.RWMutex
key []byte
options *config.Options
cfg *config.Config
policies map[uint64]config.Policy
}
@ -83,7 +83,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler {
}
h.mu.RLock()
options := h.options
options := h.cfg.Options
policy, ok := h.policies[policyID]
h.mu.RUnlock()
@ -132,7 +132,7 @@ func (h *Handler) Update(ctx context.Context, cfg *config.Config) {
defer h.mu.Unlock()
h.key, _ = cfg.Options.GetSharedKey()
h.options = cfg.Options
h.cfg = cfg
h.policies = make(map[uint64]config.Policy)
for _, p := range cfg.Options.GetAllPolicies() {
id, err := p.RouteID()

View file

@ -49,15 +49,13 @@ func TestMiddleware(t *testing.T) {
srv2 := httptest.NewServer(h.Middleware(next))
defer srv2.Close()
cfg := &config.Config{
Options: &config.Options{
SharedKey: cryptutil.NewBase64Key(),
Policies: []config.Policy{{
To: config.WeightedURLs{{URL: *u}},
KubernetesServiceAccountToken: "ABCD",
}},
},
}
cfg := config.New(&config.Options{
SharedKey: cryptutil.NewBase64Key(),
Policies: []config.Policy{{
To: config.WeightedURLs{{URL: *u}},
KubernetesServiceAccountToken: "ABCD",
}},
})
h.Update(context.Background(), cfg)
policyID, _ := cfg.Options.Policies[0].RouteID()

View file

@ -80,11 +80,11 @@ func TestProxy_ForwardAuth(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options})
p, err := New(config.New(tt.options))
if err != nil {
t.Fatal(err)
}
p.OnConfigChange(ctx, &config.Config{Options: tt.options})
p.OnConfigChange(ctx, config.New(tt.options))
state := p.state.Load()
state.sessionStore = tt.sessionStore
signer, err := jws.NewHS256Signer(nil)

View file

@ -58,7 +58,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) error {
state := p.state.Load()
var redirectURL *url.URL
signOutURL, err := p.currentOptions.Load().GetSignOutRedirectURL()
signOutURL, err := p.currentConfig.Load().Options.GetSignOutRedirectURL()
if err != nil {
return httputil.NewError(http.StatusInternalServerError, err)
}

View file

@ -47,7 +47,7 @@ func TestProxy_Signout(t *testing.T) {
if err != nil {
t.Fatal(err)
}
proxy, err := New(&config.Config{Options: opts})
proxy, err := New(config.New(opts))
if err != nil {
t.Fatal(err)
}
@ -70,7 +70,7 @@ func TestProxy_userInfo(t *testing.T) {
if err != nil {
t.Fatal(err)
}
proxy, err := New(&config.Config{Options: opts})
proxy, err := New(config.New(opts))
if err != nil {
t.Fatal(err)
}
@ -102,7 +102,7 @@ func TestProxy_SignOut(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
opts := testOptions(t)
p, err := New(&config.Config{Options: opts})
p, err := New(config.New(opts))
if err != nil {
t.Fatal(err)
}
@ -236,11 +236,11 @@ func TestProxy_Callback(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options})
p, err := New(config.New(tt.options))
if err != nil {
t.Fatal(err)
}
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
p.OnConfigChange(context.Background(), config.New(tt.options))
state := p.state.Load()
state.encoder = tt.cipher
state.sessionStore = tt.sessionStore
@ -350,7 +350,7 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options})
p, err := New(config.New(tt.options))
if err != nil {
t.Fatal(err)
}
@ -482,11 +482,11 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options})
p, err := New(config.New(tt.options))
if err != nil {
t.Fatal(err)
}
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
p.OnConfigChange(context.Background(), config.New(tt.options))
state := p.state.Load()
state.encoder = tt.cipher
state.sessionStore = tt.sessionStore

View file

@ -51,9 +51,9 @@ func ValidateOptions(o *config.Options) error {
// Proxy stores all the information associated with proxying a request.
type Proxy struct {
state *atomicutil.Value[*proxyState]
currentOptions *atomicutil.Value[*config.Options]
currentRouter *atomicutil.Value[*mux.Router]
state *atomicutil.Value[*proxyState]
currentConfig *atomicutil.Value[*config.Config]
currentRouter *atomicutil.Value[*mux.Router]
}
// New takes a Proxy service from options and a validation function.
@ -65,13 +65,13 @@ func New(cfg *config.Config) (*Proxy, error) {
}
p := &Proxy{
state: atomicutil.NewValue(state),
currentOptions: config.NewAtomicOptions(),
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
state: atomicutil.NewValue(state),
currentConfig: atomicutil.NewValue(cfg),
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
}
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
return int64(len(p.currentOptions.Load().GetAllPolicies()))
return int64(len(p.currentConfig.Load().Options.GetAllPolicies()))
})
return p, nil
@ -88,8 +88,8 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
return
}
p.currentOptions.Store(cfg.Options)
if err := p.setHandlers(cfg.Options); err != nil {
p.currentConfig.Store(cfg)
if err := p.setHandlers(cfg); err != nil {
log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
}
if state, err := newProxyStateFromConfig(cfg); err != nil {
@ -99,8 +99,8 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
}
}
func (p *Proxy) setHandlers(opts *config.Options) error {
if len(opts.GetAllPolicies()) == 0 {
func (p *Proxy) setHandlers(cfg *config.Config) error {
if len(cfg.Options.GetAllPolicies()) == 0 {
log.Warn(context.TODO()).Msg("proxy: configuration has no policies")
}
r := httputil.NewRouter()
@ -113,7 +113,7 @@ func (p *Proxy) setHandlers(opts *config.Options) error {
// dashboard handlers are registered to all routes
r = p.registerDashboardHandlers(r)
forwardAuthURL, err := opts.GetForwardAuthURL()
forwardAuthURL, err := cfg.Options.GetForwardAuthURL()
if err != nil {
return err
}

View file

@ -99,7 +99,7 @@ func TestNew(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := New(&config.Config{Options: tt.opts})
got, err := New(config.New(tt.opts))
if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return
@ -194,12 +194,12 @@ func Test_UpdateOptions(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.originalOptions})
p, err := New(config.New(tt.originalOptions))
if err != nil {
t.Fatal(err)
}
p.OnConfigChange(context.Background(), &config.Config{Options: tt.updatedOptions})
p.OnConfigChange(context.Background(), config.New(tt.updatedOptions))
r := httptest.NewRequest("GET", tt.host, nil)
w := httptest.NewRecorder()
p.ServeHTTP(w, r)
@ -212,5 +212,5 @@ func Test_UpdateOptions(t *testing.T) {
// Test nil
var p *Proxy
p.OnConfigChange(context.Background(), &config.Config{})
p.OnConfigChange(context.Background(), config.New(nil))
}