config: use config.Config instead of config.Options everywhere

This commit is contained in:
Caleb Doxsey 2022-08-26 15:39:50 -06:00
parent 5f51510e91
commit 1b80e8a6c2
40 changed files with 484 additions and 412 deletions

View file

@ -39,18 +39,18 @@ func ValidateOptions(o *config.Options) error {
// Authenticate contains data required to run the authenticate service. // Authenticate contains data required to run the authenticate service.
type Authenticate struct { type Authenticate struct {
cfg *authenticateConfig cfg *authenticateConfig
options *atomicutil.Value[*config.Options] currentConfig *atomicutil.Value[*config.Config]
state *atomicutil.Value[*authenticateState] state *atomicutil.Value[*authenticateState]
webauthn *webauthn.Handler webauthn *webauthn.Handler
} }
// New validates and creates a new authenticate service from a set of Options. // New validates and creates a new authenticate service from a set of Options.
func New(cfg *config.Config, options ...Option) (*Authenticate, error) { func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
a := &Authenticate{ a := &Authenticate{
cfg: getAuthenticateConfig(options...), cfg: getAuthenticateConfig(options...),
options: config.NewAtomicOptions(), currentConfig: atomicutil.NewValue(cfg),
state: atomicutil.NewValue(newAuthenticateState()), state: atomicutil.NewValue(newAuthenticateState()),
} }
a.webauthn = webauthn.New(a.getWebauthnState) a.webauthn = webauthn.New(a.getWebauthnState)
@ -69,7 +69,7 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
return return
} }
a.options.Store(cfg.Options) a.currentConfig.Store(cfg)
if state, err := newAuthenticateStateFromConfig(cfg); err != nil { if state, err := newAuthenticateStateFromConfig(cfg); err != nil {
log.Error(ctx).Err(err).Msg("authenticate: failed to update state") log.Error(ctx).Err(err).Msg("authenticate: failed to update state")
} else { } else {

View file

@ -113,7 +113,7 @@ func TestNew(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
_, err := New(&config.Config{Options: tt.opts}) _, err := New(config.New(tt.opts))
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return return

View file

@ -6,7 +6,7 @@ import (
) )
type authenticateConfig struct { type authenticateConfig struct {
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error) getIdentityProvider func(cfg *config.Config, idpID string) (identity.Authenticator, error)
} }
// An Option customizes the Authenticate config. // An Option customizes the Authenticate config.
@ -22,7 +22,7 @@ func getAuthenticateConfig(options ...Option) *authenticateConfig {
} }
// WithGetIdentityProvider sets the getIdentityProvider function in the config. // WithGetIdentityProvider sets the getIdentityProvider function in the config.
func WithGetIdentityProvider(getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)) Option { func WithGetIdentityProvider(getIdentityProvider func(cfg *config.Config, idpID string) (identity.Authenticator, error)) Option {
return func(cfg *authenticateConfig) { return func(cfg *authenticateConfig) {
cfg.getIdentityProvider = getIdentityProvider cfg.getIdentityProvider = getIdentityProvider
} }

View file

@ -47,12 +47,12 @@ func (a *Authenticate) Mount(r *mux.Router) {
r.StrictSlash(true) r.StrictSlash(true)
r.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy)) r.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
r.Use(func(h http.Handler) http.Handler { r.Use(func(h http.Handler) http.Handler {
options := a.options.Load() cfg := a.currentConfig.Load()
state := a.state.Load() state := a.state.Load()
csrfKey := fmt.Sprintf("%s_csrf", options.CookieName) csrfKey := fmt.Sprintf("%s_csrf", cfg.Options.CookieName)
return csrf.Protect( return csrf.Protect(
state.cookieSecret, state.cookieSecret,
csrf.Secure(options.CookieSecure), csrf.Secure(cfg.Options.CookieSecure),
csrf.Path("/"), csrf.Path("/"),
csrf.UnsafePaths( csrf.UnsafePaths(
[]string{ []string{
@ -256,7 +256,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
// check for an HMAC'd URL. If none is found, show a confirmation page. // check for an HMAC'd URL. If none is found, show a confirmation page.
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey) err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
if err != nil { if err != nil {
authenticateURL, err := a.options.Load().GetAuthenticateURL() authenticateURL, err := a.currentConfig.Load().Options.GetAuthenticateURL()
if err != nil { if err != nil {
return err return err
} }
@ -275,9 +275,9 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut") ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
defer span.End() defer span.End()
options := a.options.Load() cfg := a.currentConfig.Load()
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil { if err != nil {
return err return err
} }
@ -285,7 +285,7 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
rawIDToken := a.revokeSession(ctx, w, r) rawIDToken := a.revokeSession(ctx, w, r)
redirectString := "" redirectString := ""
signOutURL, err := a.options.Load().GetSignOutRedirectURL() signOutURL, err := cfg.Options.GetSignOutRedirectURL()
if err != nil { if err != nil {
return err return err
} }
@ -330,10 +330,10 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
return httputil.NewError(http.StatusUnauthorized, err) return httputil.NewError(http.StatusUnauthorized, err)
} }
options := a.options.Load() cfg := a.currentConfig.Load()
state := a.state.Load() state := a.state.Load()
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil { if err != nil {
return err return err
} }
@ -381,7 +381,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback") ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
defer span.End() defer span.End()
options := a.options.Load() cfg := a.currentConfig.Load()
state := a.state.Load() state := a.state.Load()
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6 // Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
@ -430,7 +430,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
} }
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID) idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)
idp, err := a.cfg.getIdentityProvider(options, idpID) idp, err := a.cfg.getIdentityProvider(cfg, idpID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -513,7 +513,7 @@ func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error {
func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData, error) { func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData, error) {
state := a.state.Load() state := a.state.Load()
authenticateURL, err := a.options.Load().GetAuthenticateURL() authenticateURL, err := a.currentConfig.Load().Options.GetAuthenticateURL()
if err != nil { if err != nil {
return handlers.UserInfoData{}, err return handlers.UserInfoData{}, err
} }
@ -569,7 +569,7 @@ func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData,
WebAuthnRequestOptions: requestOptions, WebAuthnRequestOptions: requestOptions,
WebAuthnURL: urlutil.WebAuthnURL(r, authenticateURL, state.sharedKey, r.URL.Query()), WebAuthnURL: urlutil.WebAuthnURL(r, authenticateURL, state.sharedKey, r.URL.Query()),
BrandingOptions: a.options.Load().BrandingOptions, BrandingOptions: a.currentConfig.Load().Options.BrandingOptions,
}, nil }, nil
} }
@ -581,14 +581,14 @@ func (a *Authenticate) saveSessionToDataBroker(
accessToken *oauth2.Token, accessToken *oauth2.Token,
) error { ) error {
state := a.state.Load() state := a.state.Load()
options := a.options.Load() cfg := a.currentConfig.Load()
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil { if err != nil {
return err return err
} }
sessionExpiry := timestamppb.New(time.Now().Add(options.CookieExpire)) sessionExpiry := timestamppb.New(time.Now().Add(cfg.Options.CookieExpire))
idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time()) idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time())
s := &session.Session{ s := &session.Session{
@ -648,13 +648,13 @@ func (a *Authenticate) saveSessionToDataBroker(
// databroker. If successful, it returns the original `id_token` of the session, if failed, returns // databroker. If successful, it returns the original `id_token` of the session, if failed, returns
// and empty string. // and empty string.
func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string { func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string {
options := a.options.Load() cfg := a.currentConfig.Load()
state := a.state.Load() state := a.state.Load()
// clear the user's local session no matter what // clear the user's local session no matter what
defer state.sessionStore.ClearSession(w, r) defer state.sessionStore.ClearSession(w, r)
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID)) idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
if err != nil { if err != nil {
return "" return ""
} }
@ -708,6 +708,7 @@ func (a *Authenticate) getDirectoryUser(ctx context.Context, userID string) (*di
func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, error) { func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, error) {
state := a.state.Load() state := a.state.Load()
cfg := a.currentConfig.Load()
s, _, err := a.getCurrentSession(ctx) s, _, err := a.getCurrentSession(ctx)
if err != nil { if err != nil {
@ -719,17 +720,17 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
return nil, err return nil, err
} }
authenticateURL, err := a.options.Load().GetAuthenticateURL() authenticateURL, err := cfg.Options.GetAuthenticateURL()
if err != nil { if err != nil {
return nil, err return nil, err
} }
internalAuthenticateURL, err := a.options.Load().GetInternalAuthenticateURL() internalAuthenticateURL, err := cfg.Options.GetInternalAuthenticateURL()
if err != nil { if err != nil {
return nil, err return nil, err
} }
pomeriumDomains, err := a.options.Load().GetAllRouteableHTTPDomains() pomeriumDomains, err := cfg.Options.GetAllRouteableHTTPDomains()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -744,7 +745,7 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
SessionState: ss, SessionState: ss,
SessionStore: state.sessionStore, SessionStore: state.sessionStore,
RelyingParty: state.webauthnRelyingParty, RelyingParty: state.webauthnRelyingParty,
BrandingOptions: a.options.Load().BrandingOptions, BrandingOptions: cfg.Options.BrandingOptions,
}, nil }, nil
} }

View file

@ -49,10 +49,9 @@ func testAuthenticate() *Authenticate {
redirectURL: redirectURL, redirectURL: redirectURL,
cookieSecret: cryptutil.NewKey(), cookieSecret: cryptutil.NewKey(),
}) })
auth.options = config.NewAtomicOptions() auth.currentConfig = atomicutil.NewValue(config.New(&config.Options{
auth.options.Store(&config.Options{
SharedKey: cryptutil.NewBase64Key(), SharedKey: cryptutil.NewBase64Key(),
}) }))
return &auth return &auth
} }
@ -148,7 +147,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key()) sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
a := &Authenticate{ a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
return tt.provider, nil return tt.provider, nil
})), })),
state: atomicutil.NewValue(&authenticateState{ state: atomicutil.NewValue(&authenticateState{
@ -168,10 +167,10 @@ func TestAuthenticate_SignIn(t *testing.T) {
}, },
directoryClient: new(mockDirectoryServiceClient), directoryClient: new(mockDirectoryServiceClient),
}), }),
currentConfig: atomicutil.NewValue(config.New(&config.Options{
options: config.NewAtomicOptions(), SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey()),
})),
} }
a.options.Store(&config.Options{SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey())})
uri := &url.URL{Scheme: tt.scheme, Host: tt.host} uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
queryString := uri.Query() queryString := uri.Query()
@ -304,7 +303,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
ctrl := gomock.NewController(t) ctrl := gomock.NewController(t)
defer ctrl.Finish() defer ctrl.Finish()
a := &Authenticate{ a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
return tt.provider, nil return tt.provider, nil
})), })),
state: atomicutil.NewValue(&authenticateState{ state: atomicutil.NewValue(&authenticateState{
@ -325,12 +324,9 @@ func TestAuthenticate_SignOut(t *testing.T) {
}, },
directoryClient: new(mockDirectoryServiceClient), directoryClient: new(mockDirectoryServiceClient),
}), }),
options: config.NewAtomicOptions(), currentConfig: atomicutil.NewValue(config.New(&config.Options{
} SignOutRedirectURLString: tt.signoutRedirectURL,
if tt.signoutRedirectURL != "" { })),
opts := a.options.Load()
opts.SignOutRedirectURLString = tt.signoutRedirectURL
a.options.Store(opts)
} }
u, _ := url.Parse("/sign_out") u, _ := url.Parse("/sign_out")
params, _ := url.ParseQuery(u.RawQuery) params, _ := url.ParseQuery(u.RawQuery)
@ -417,7 +413,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
} }
authURL, _ := url.Parse(tt.authenticateURL) authURL, _ := url.Parse(tt.authenticateURL)
a := &Authenticate{ a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
return tt.provider, nil return tt.provider, nil
})), })),
state: atomicutil.NewValue(&authenticateState{ state: atomicutil.NewValue(&authenticateState{
@ -435,7 +431,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
cookieCipher: aead, cookieCipher: aead,
encryptedEncoder: signer, encryptedEncoder: signer,
}), }),
options: config.NewAtomicOptions(), currentConfig: atomicutil.NewValue(config.New(nil)),
} }
u, _ := url.Parse("/oauthGet") u, _ := url.Parse("/oauthGet")
params, _ := url.ParseQuery(u.RawQuery) params, _ := url.ParseQuery(u.RawQuery)
@ -552,7 +548,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
a := &Authenticate{ a := &Authenticate{
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) { cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
return tt.provider, nil return tt.provider, nil
})), })),
state: atomicutil.NewValue(&authenticateState{ state: atomicutil.NewValue(&authenticateState{
@ -573,7 +569,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
}, },
directoryClient: new(mockDirectoryServiceClient), directoryClient: new(mockDirectoryServiceClient),
}), }),
options: config.NewAtomicOptions(), currentConfig: atomicutil.NewValue(config.New(nil)),
} }
r := httptest.NewRequest("GET", "/", nil) r := httptest.NewRequest("GET", "/", nil)
state, err := tt.session.LoadSession(r) state, err := tt.session.LoadSession(r)
@ -604,7 +600,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
func TestJwksEndpoint(t *testing.T) { func TestJwksEndpoint(t *testing.T) {
o := newTestOptions(t) o := newTestOptions(t)
o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo=" o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
auth, err := New(&config.Config{Options: o}) auth, err := New(config.New(o))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
return return
@ -632,12 +628,11 @@ func TestAuthenticate_userInfo(t *testing.T) {
a.state = atomicutil.NewValue(&authenticateState{ a.state = atomicutil.NewValue(&authenticateState{
cookieSecret: cryptutil.NewKey(), cookieSecret: cryptutil.NewKey(),
}) })
a.options = config.NewAtomicOptions() a.currentConfig = atomicutil.NewValue(config.New(&config.Options{
a.options.Store(&config.Options{
SharedKey: cryptutil.NewBase64Key(), SharedKey: cryptutil.NewBase64Key(),
AuthenticateURLString: "https://authenticate.example.com", AuthenticateURLString: "https://authenticate.example.com",
AuthenticateInternalURLString: "https://authenticate.service.cluster.local", AuthenticateInternalURLString: "https://authenticate.service.cluster.local",
}) }))
err := a.userInfo(w, r) err := a.userInfo(w, r)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, http.StatusFound, w.Code) assert.Equal(t, http.StatusFound, w.Code)
@ -687,13 +682,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
o := config.NewAtomicOptions()
o.Store(&config.Options{
AuthenticateURLString: "https://authenticate.localhost.pomerium.io",
SharedKey: "SHARED KEY",
})
a := &Authenticate{ a := &Authenticate{
options: o,
state: atomicutil.NewValue(&authenticateState{ state: atomicutil.NewValue(&authenticateState{
sessionStore: tt.sessionStore, sessionStore: tt.sessionStore,
encryptedEncoder: signer, encryptedEncoder: signer,
@ -711,6 +700,10 @@ func TestAuthenticate_userInfo(t *testing.T) {
}, },
directoryClient: new(mockDirectoryServiceClient), directoryClient: new(mockDirectoryServiceClient),
}), }),
currentConfig: atomicutil.NewValue(config.New(&config.Options{
AuthenticateURLString: "https://authenticate.localhost.pomerium.io",
SharedKey: "SHARED KEY",
})),
} }
a.webauthn = webauthn.New(a.getWebauthnState) a.webauthn = webauthn.New(a.getWebauthnState)
r := httptest.NewRequest(tt.method, tt.url.String(), nil) r := httptest.NewRequest(tt.method, tt.url.String(), nil)

View file

@ -7,8 +7,8 @@ import (
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
) )
func defaultGetIdentityProvider(options *config.Options, idpID string) (identity.Authenticator, error) { func defaultGetIdentityProvider(cfg *config.Config, idpID string) (identity.Authenticator, error) {
authenticateURL, err := options.GetAuthenticateURL() authenticateURL, err := cfg.Options.GetAuthenticateURL()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -17,9 +17,9 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity
if err != nil { if err != nil {
return nil, err return nil, err
} }
redirectURL.Path = options.AuthenticateCallbackPath redirectURL.Path = cfg.Options.AuthenticateCallbackPath
idp, err := options.GetIdentityProviderForID(idpID) idp, err := cfg.Options.GetIdentityProviderForID(idpID)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -34,14 +34,14 @@ func (a *Authenticate) requireValidSignature(next httputil.HandlerFunc) http.Han
} }
func (a *Authenticate) getExternalRequest(r *http.Request) *http.Request { func (a *Authenticate) getExternalRequest(r *http.Request) *http.Request {
options := a.options.Load() cfg := a.currentConfig.Load()
externalURL, err := options.GetAuthenticateURL() externalURL, err := cfg.Options.GetAuthenticateURL()
if err != nil { if err != nil {
return r return r
} }
internalURL, err := options.GetInternalAuthenticateURL() internalURL, err := cfg.Options.GetInternalAuthenticateURL()
if err != nil { if err != nil {
return r return r
} }

View file

@ -25,11 +25,11 @@ import (
// Authorize struct holds // Authorize struct holds
type Authorize struct { type Authorize struct {
state *atomicutil.Value[*authorizeState] state *atomicutil.Value[*authorizeState]
store *store.Store store *store.Store
currentOptions *atomicutil.Value[*config.Options] currentConfig *atomicutil.Value[*config.Config]
accessTracker *AccessTracker accessTracker *AccessTracker
globalCache storage.Cache globalCache storage.Cache
// The stateLock prevents updating the evaluator store simultaneously with an evaluation. // The stateLock prevents updating the evaluator store simultaneously with an evaluation.
// This should provide a consistent view of the data at a given server/record version and // This should provide a consistent view of the data at a given server/record version and
@ -40,9 +40,9 @@ type Authorize struct {
// New validates and creates a new Authorize service from a set of config options. // New validates and creates a new Authorize service from a set of config options.
func New(cfg *config.Config) (*Authorize, error) { func New(cfg *config.Config) (*Authorize, error) {
a := &Authorize{ a := &Authorize{
currentOptions: config.NewAtomicOptions(), currentConfig: atomicutil.NewValue(cfg),
store: store.New(), store: store.New(),
globalCache: storage.NewGlobalCache(time.Minute), globalCache: storage.NewGlobalCache(time.Minute),
} }
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod) a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
@ -86,42 +86,42 @@ func validateOptions(o *config.Options) error {
} }
// newPolicyEvaluator returns an policy evaluator. // newPolicyEvaluator returns an policy evaluator.
func newPolicyEvaluator(opts *config.Options, store *store.Store) (*evaluator.Evaluator, error) { func newPolicyEvaluator(cfg *config.Config, store *store.Store) (*evaluator.Evaluator, error) {
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 { metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
return int64(len(opts.GetAllPolicies())) return int64(len(cfg.Options.GetAllPolicies()))
}) })
ctx := context.Background() ctx := context.Background()
_, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator") _, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator")
defer span.End() defer span.End()
clientCA, err := opts.GetClientCA() clientCA, err := cfg.Options.GetClientCA()
if err != nil { if err != nil {
return nil, fmt.Errorf("authorize: invalid client CA: %w", err) return nil, fmt.Errorf("authorize: invalid client CA: %w", err)
} }
authenticateURL, err := opts.GetInternalAuthenticateURL() authenticateURL, err := cfg.Options.GetInternalAuthenticateURL()
if err != nil { if err != nil {
return nil, fmt.Errorf("authorize: invalid authenticate url: %w", err) return nil, fmt.Errorf("authorize: invalid authenticate url: %w", err)
} }
signingKey, err := opts.GetSigningKey() signingKey, err := cfg.Options.GetSigningKey()
if err != nil { if err != nil {
return nil, fmt.Errorf("authorize: invalid signing key: %w", err) return nil, fmt.Errorf("authorize: invalid signing key: %w", err)
} }
return evaluator.New(ctx, store, return evaluator.New(ctx, store,
evaluator.WithPolicies(opts.GetAllPolicies()), evaluator.WithPolicies(cfg.Options.GetAllPolicies()),
evaluator.WithClientCA(clientCA), evaluator.WithClientCA(clientCA),
evaluator.WithSigningKey(signingKey), evaluator.WithSigningKey(signingKey),
evaluator.WithAuthenticateURL(authenticateURL.String()), evaluator.WithAuthenticateURL(authenticateURL.String()),
evaluator.WithGoogleCloudServerlessAuthenticationServiceAccount(opts.GetGoogleCloudServerlessAuthenticationServiceAccount()), evaluator.WithGoogleCloudServerlessAuthenticationServiceAccount(cfg.Options.GetGoogleCloudServerlessAuthenticationServiceAccount()),
evaluator.WithJWTClaimsHeaders(opts.JWTClaimsHeaders), evaluator.WithJWTClaimsHeaders(cfg.Options.JWTClaimsHeaders),
) )
} }
// OnConfigChange updates internal structures based on config.Options // OnConfigChange updates internal structures based on config.Options
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) { func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
a.currentOptions.Store(cfg.Options) a.currentConfig.Store(cfg)
if state, err := newAuthorizeStateFromConfig(cfg, a.store); err != nil { if state, err := newAuthorizeStateFromConfig(cfg, a.store); err != nil {
log.Error(ctx).Err(err).Msg("authorize: error updating state") log.Error(ctx).Err(err).Msg("authorize: error updating state")
} else { } else {

View file

@ -74,7 +74,7 @@ func TestNew(t *testing.T) {
tt := tt tt := tt
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
t.Parallel() t.Parallel()
_, err := New(&config.Config{Options: &tt.config}) _, err := New(config.New(&tt.config))
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -105,12 +105,12 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
SharedKey: tc.SharedKey, SharedKey: tc.SharedKey,
Policies: tc.Policies, Policies: tc.Policies,
} }
a, err := New(&config.Config{Options: o}) a, err := New(config.New(o))
require.NoError(t, err) require.NoError(t, err)
require.NotNil(t, a) require.NotNil(t, a)
oldPe := a.state.Load().evaluator oldPe := a.state.Load().evaluator
cfg := &config.Config{Options: o} cfg := config.New(o)
assertFunc := assert.True assertFunc := assert.True
o.SigningKey = "bad-share-key" o.SigningKey = "bad-share-key"
if tc.expectedChange { if tc.expectedChange {

View file

@ -125,7 +125,7 @@ func (a *Authorize) deniedResponse(
respBody := []byte(reason) respBody := []byte(reason)
respHeader := []*envoy_config_core_v3.HeaderValueOption{} respHeader := []*envoy_config_core_v3.HeaderValueOption{}
forwardAuthURL, _ := a.currentOptions.Load().GetForwardAuthURL() forwardAuthURL, _ := a.currentConfig.Load().Options.GetForwardAuthURL()
if forwardAuthURL == nil { if forwardAuthURL == nil {
// create a http response writer recorder // create a http response writer recorder
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -140,7 +140,7 @@ func (a *Authorize) deniedResponse(
Err: errors.New(reason), Err: errors.New(reason),
DebugURL: debugEndpoint, DebugURL: debugEndpoint,
RequestID: requestid.FromContext(ctx), RequestID: requestid.FromContext(ctx),
BrandingOptions: a.currentOptions.Load().BrandingOptions, BrandingOptions: a.currentConfig.Load().Options.BrandingOptions,
} }
httpErr.ErrorResponse(ctx, w, r) httpErr.ErrorResponse(ctx, w, r)
@ -184,7 +184,7 @@ func (a *Authorize) requireLoginResponse(
request *evaluator.Request, request *evaluator.Request,
isForwardAuthVerify bool, isForwardAuthVerify bool,
) (*envoy_service_auth_v3.CheckResponse, error) { ) (*envoy_service_auth_v3.CheckResponse, error) {
opts := a.currentOptions.Load() opts := a.currentConfig.Load().Options
state := a.state.Load() state := a.state.Load()
authenticateURL, err := opts.GetAuthenticateURL() authenticateURL, err := opts.GetAuthenticateURL()
if err != nil { if err != nil {
@ -225,7 +225,7 @@ func (a *Authorize) requireWebAuthnResponse(
result *evaluator.Result, result *evaluator.Result,
isForwardAuthVerify bool, isForwardAuthVerify bool,
) (*envoy_service_auth_v3.CheckResponse, error) { ) (*envoy_service_auth_v3.CheckResponse, error) {
opts := a.currentOptions.Load() opts := a.currentConfig.Load().Options
state := a.state.Load() state := a.state.Load()
authenticateURL, err := opts.GetAuthenticateURL() authenticateURL, err := opts.GetAuthenticateURL()
if err != nil { if err != nil {
@ -295,7 +295,7 @@ func toEnvoyHeaders(headers http.Header) []*envoy_config_core_v3.HeaderValueOpti
// userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's // userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's
// session that lives on the authenticate service. // session that lives on the authenticate service.
func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) { func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) {
opts := a.currentOptions.Load() opts := a.currentConfig.Load().Options
authenticateURL, err := opts.GetAuthenticateURL() authenticateURL, err := opts.GetAuthenticateURL()
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -29,7 +29,7 @@ func TestAuthorize_handleResult(t *testing.T) {
opt.AuthenticateURLString = "https://authenticate.example.com" opt.AuthenticateURLString = "https://authenticate.example.com"
opt.DataBrokerURLString = "https://databroker.example.com" opt.DataBrokerURLString = "https://databroker.example.com"
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
a, err := New(&config.Config{Options: opt}) a, err := New(config.New(opt))
require.NoError(t, err) require.NoError(t, err)
t.Run("user-unauthenticated", func(t *testing.T) { t.Run("user-unauthenticated", func(t *testing.T) {
@ -67,12 +67,12 @@ func TestAuthorize_okResponse(t *testing.T) {
}}, }},
JWTClaimsHeaders: config.NewJWTClaimHeaders("email"), JWTClaimsHeaders: config.NewJWTClaimHeaders("email"),
} }
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} cfg := config.New(opt)
a := &Authorize{currentConfig: atomicutil.NewValue(cfg), state: atomicutil.NewValue(new(authorizeState))}
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder a.state.Load().encoder = encoder
a.currentOptions.Store(opt)
a.store = store.New() a.store = store.New()
pe, err := newPolicyEvaluator(opt, a.store) pe, err := newPolicyEvaluator(cfg, a.store)
require.NoError(t, err) require.NoError(t, err)
a.state.Load().evaluator = pe a.state.Load().evaluator = pe
@ -123,17 +123,17 @@ func TestAuthorize_okResponse(t *testing.T) {
} }
func TestAuthorize_deniedResponse(t *testing.T) { func TestAuthorize_deniedResponse(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder a.state.Load().encoder = encoder
a.currentOptions.Store(&config.Options{ a.currentConfig.Store(config.New(&config.Options{
Policies: []config.Policy{{ Policies: []config.Policy{{
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}}, Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
SubPolicies: []config.SubPolicy{{ SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"}, Rego: []string{"allow = true"},
}}, }},
}}, }},
}) }))
tests := []struct { tests := []struct {
name string name string
@ -190,7 +190,7 @@ func TestRequireLogin(t *testing.T) {
opt.AuthenticateURLString = "https://authenticate.example.com" opt.AuthenticateURLString = "https://authenticate.example.com"
opt.DataBrokerURLString = "https://databroker.example.com" opt.DataBrokerURLString = "https://databroker.example.com"
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
a, err := New(&config.Config{Options: opt}) a, err := New(config.New(opt))
require.NoError(t, err) require.NoError(t, err)
t.Run("accept empty", func(t *testing.T) { t.Run("accept empty", func(t *testing.T) {

View file

@ -55,7 +55,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
} }
} }
rawJWT, _ := loadRawSession(hreq, a.currentOptions.Load(), state.encoder) rawJWT, _ := loadRawSession(hreq, a.currentConfig.Load(), state.encoder)
sessionState, _ := loadSession(state.encoder, rawJWT) sessionState, _ := loadSession(state.encoder, rawJWT)
var s sessionOrServiceAccount var s sessionOrServiceAccount
@ -100,9 +100,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
// isForwardAuth returns if the current request is a forward auth route. // isForwardAuth returns if the current request is a forward auth route.
func (a *Authorize) isForwardAuth(req *envoy_service_auth_v3.CheckRequest) bool { func (a *Authorize) isForwardAuth(req *envoy_service_auth_v3.CheckRequest) bool {
opts := a.currentOptions.Load() cfg := a.currentConfig.Load()
forwardAuthURL, err := opts.GetForwardAuthURL() forwardAuthURL, err := cfg.Options.GetForwardAuthURL()
if err != nil || forwardAuthURL == nil { if err != nil || forwardAuthURL == nil {
return false return false
} }
@ -136,9 +136,9 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
} }
func (a *Authorize) getMatchingPolicy(requestURL url.URL) *config.Policy { func (a *Authorize) getMatchingPolicy(requestURL url.URL) *config.Policy {
options := a.currentOptions.Load() cfg := a.currentConfig.Load()
for _, p := range options.GetAllPolicies() { for _, p := range cfg.Options.GetAllPolicies() {
if p.Matches(requestURL) { if p.Matches(requestURL) {
return &p return &p
} }

View file

@ -47,17 +47,17 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
-----END CERTIFICATE-----` -----END CERTIFICATE-----`
func Test_getEvaluatorRequest(t *testing.T) { func Test_getEvaluatorRequest(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder a.state.Load().encoder = encoder
a.currentOptions.Store(&config.Options{ a.currentConfig.Store(config.New(&config.Options{
Policies: []config.Policy{{ Policies: []config.Policy{{
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}}, Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
SubPolicies: []config.SubPolicy{{ SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"}, Rego: []string{"allow = true"},
}}, }},
}}, }},
}) }))
actual, err := a.getEvaluatorRequestFromCheckRequest( actual, err := a.getEvaluatorRequestFromCheckRequest(
&envoy_service_auth_v3.CheckRequest{ &envoy_service_auth_v3.CheckRequest{
@ -87,7 +87,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
) )
require.NoError(t, err) require.NoError(t, err)
expect := &evaluator.Request{ expect := &evaluator.Request{
Policy: &a.currentOptions.Load().Policies[0], Policy: &a.currentConfig.Load().Options.Policies[0],
Session: evaluator.RequestSession{ Session: evaluator.RequestSession{
ID: "SESSION_ID", ID: "SESSION_ID",
}, },
@ -248,8 +248,10 @@ func Test_handleForwardAuth(t *testing.T) {
for _, tc := range tests { for _, tc := range tests {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
a.currentOptions.Store(&config.Options{ForwardAuthURLString: tc.forwardAuthURL}) a.currentConfig.Store(config.New(&config.Options{
ForwardAuthURLString: tc.forwardAuthURL,
}))
got := a.isForwardAuth(tc.checkReq) got := a.isForwardAuth(tc.checkReq)
@ -261,17 +263,17 @@ func Test_handleForwardAuth(t *testing.T) {
} }
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) { func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0}) encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
a.state.Load().encoder = encoder a.state.Load().encoder = encoder
a.currentOptions.Store(&config.Options{ a.currentConfig.Store(config.New(&config.Options{
Policies: []config.Policy{{ Policies: []config.Policy{{
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}}, Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
SubPolicies: []config.SubPolicy{{ SubPolicies: []config.SubPolicy{{
Rego: []string{"allow = true"}, Rego: []string{"allow = true"},
}}, }},
}}, }},
}) }))
actual, err := a.getEvaluatorRequestFromCheckRequest(&envoy_service_auth_v3.CheckRequest{ actual, err := a.getEvaluatorRequestFromCheckRequest(&envoy_service_auth_v3.CheckRequest{
Attributes: &envoy_service_auth_v3.AttributeContext{ Attributes: &envoy_service_auth_v3.AttributeContext{
@ -296,7 +298,7 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
}, nil) }, nil)
require.NoError(t, err) require.NoError(t, err)
expect := &evaluator.Request{ expect := &evaluator.Request{
Policy: &a.currentOptions.Load().Policies[0], Policy: &a.currentConfig.Load().Options.Policies[0],
Session: evaluator.RequestSession{}, Session: evaluator.RequestSession{},
HTTP: evaluator.NewRequestHTTP( HTTP: evaluator.NewRequestHTTP(
"GET", "GET",
@ -332,11 +334,13 @@ func TestAuthorize_Check(t *testing.T) {
opt.AuthenticateURLString = "https://authenticate.example.com" opt.AuthenticateURLString = "https://authenticate.example.com"
opt.DataBrokerURLString = "https://databroker.example.com" opt.DataBrokerURLString = "https://databroker.example.com"
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM=" opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
a, err := New(&config.Config{Options: opt}) a, err := New(config.New(opt))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
a.currentOptions.Store(&config.Options{ForwardAuthURLString: "https://forward-auth.example.com"}) a.currentConfig.Store(config.New(&config.Options{
ForwardAuthURLString: "https://forward-auth.example.com",
}))
cmpOpts := []cmp.Option{ cmpOpts := []cmp.Option{
cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}), cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}),

View file

@ -13,9 +13,13 @@ import (
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
) )
func loadRawSession(req *http.Request, options *config.Options, encoder encoding.MarshalUnmarshaler) ([]byte, error) { func loadRawSession(
req *http.Request,
cfg *config.Config,
encoder encoding.MarshalUnmarshaler,
) ([]byte, error) {
var loaders []sessions.SessionLoader var loaders []sessions.SessionLoader
cookieStore, err := getCookieStore(options, encoder) cookieStore, err := getCookieStore(cfg, encoder)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -37,7 +41,10 @@ func loadRawSession(req *http.Request, options *config.Options, encoder encoding
return nil, sessions.ErrNoSessionFound return nil, sessions.ErrNoSessionFound
} }
func loadSession(encoder encoding.MarshalUnmarshaler, rawJWT []byte) (*sessions.State, error) { func loadSession(
encoder encoding.MarshalUnmarshaler,
rawJWT []byte,
) (*sessions.State, error) {
var s sessions.State var s sessions.State
err := encoder.Unmarshal(rawJWT, &s) err := encoder.Unmarshal(rawJWT, &s)
if err != nil { if err != nil {
@ -46,14 +53,17 @@ func loadSession(encoder encoding.MarshalUnmarshaler, rawJWT []byte) (*sessions.
return &s, nil return &s, nil
} }
func getCookieStore(options *config.Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) { func getCookieStore(
cfg *config.Config,
encoder encoding.MarshalUnmarshaler,
) (sessions.SessionStore, error) {
cookieStore, err := cookie.NewStore(func() cookie.Options { cookieStore, err := cookie.NewStore(func() cookie.Options {
return cookie.Options{ return cookie.Options{
Name: options.CookieName, Name: cfg.Options.CookieName,
Domain: options.CookieDomain, Domain: cfg.Options.CookieDomain,
Secure: options.CookieSecure, Secure: cfg.Options.CookieSecure,
HTTPOnly: options.CookieHTTPOnly, HTTPOnly: cfg.Options.CookieHTTPOnly,
Expire: options.CookieExpire, Expire: cfg.Options.CookieExpire,
} }
}, encoder) }, encoder)
if err != nil { if err != nil {

View file

@ -32,7 +32,7 @@ func TestLoadSession(t *testing.T) {
}, },
}, },
}) })
raw, err := loadRawSession(req, opts, encoder) raw, err := loadRawSession(req, config.New(opts), encoder)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -36,7 +36,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*autho
var err error var err error
state.evaluator, err = newPolicyEvaluator(cfg.Options, store) state.evaluator, err = newPolicyEvaluator(cfg, store)
if err != nil { if err != nil {
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err) return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
} }

View file

@ -31,6 +31,16 @@ type Config struct {
MetricsScrapeEndpoints []MetricsScrapeEndpoint MetricsScrapeEndpoints []MetricsScrapeEndpoint
} }
// New creates a new Config.
func New(options *Options) *Config {
if options == nil {
options = NewDefaultOptions()
}
return &Config{
Options: options,
}
}
// Clone creates a clone of the config. // Clone creates a clone of the config.
func (cfg *Config) Clone() *Config { func (cfg *Config) Clone() *Config {
newOptions := new(Options) newOptions := new(Options)

View file

@ -13,11 +13,9 @@ import (
func TestBuilder_BuildBootstrapAdmin(t *testing.T) { func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
adminCfg, err := b.BuildBootstrapAdmin(&config.Config{ adminCfg, err := b.BuildBootstrapAdmin(config.New(&config.Options{
Options: &config.Options{ EnvoyAdminAddress: "localhost:9901",
EnvoyAdminAddress: "localhost:9901", }))
},
})
assert.NoError(t, err) assert.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
{ {
@ -31,11 +29,9 @@ func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
`, adminCfg) `, adminCfg)
}) })
t.Run("bad address", func(t *testing.T) { t.Run("bad address", func(t *testing.T) {
_, err := b.BuildBootstrapAdmin(&config.Config{ _, err := b.BuildBootstrapAdmin(config.New(&config.Options{
Options: &config.Options{ EnvoyAdminAddress: "xyz1234:zyx4321",
EnvoyAdminAddress: "xyz1234:zyx4321", }))
},
})
assert.Error(t, err) assert.Error(t, err)
}) })
} }
@ -111,11 +107,9 @@ func TestBuilder_BuildBootstrapStaticResources(t *testing.T) {
func TestBuilder_BuildBootstrapStatsConfig(t *testing.T) { func TestBuilder_BuildBootstrapStatsConfig(t *testing.T) {
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
t.Run("valid", func(t *testing.T) { t.Run("valid", func(t *testing.T) {
statsCfg, err := b.BuildBootstrapStatsConfig(&config.Config{ statsCfg, err := b.BuildBootstrapStatsConfig(config.New(&config.Options{
Options: &config.Options{ Services: "all",
Services: "all", }))
},
})
assert.NoError(t, err) assert.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
{ {

View file

@ -46,22 +46,22 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
return nil, err return nil, err
} }
controlGRPC, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, upstreamProtocolHTTP2) controlGRPC, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, upstreamProtocolHTTP2)
if err != nil { if err != nil {
return nil, err return nil, err
} }
controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto) controlHTTP, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
if err != nil { if err != nil {
return nil, err return nil, err
} }
controlMetrics, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto) controlMetrics, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto)
if err != nil { if err != nil {
return nil, err return nil, err
} }
authorizeCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2) authorizeCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -70,7 +70,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
authorizeCluster.OutlierDetection = grpcAuthorizeOutlierDetection() authorizeCluster.OutlierDetection = grpcAuthorizeOutlierDetection()
} }
databrokerCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2) databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -87,7 +87,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
databrokerCluster, databrokerCluster,
} }
tracingCluster, err := buildTracingCluster(cfg.Options) tracingCluster, err := buildTracingCluster(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} else if tracingCluster != nil { } else if tracingCluster != nil {
@ -101,7 +101,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
policy.EnvoyOpts = newDefaultEnvoyClusterConfig() policy.EnvoyOpts = newDefaultEnvoyClusterConfig()
} }
if len(policy.To) > 0 { if len(policy.To) > 0 {
cluster, err := b.buildPolicyCluster(ctx, cfg.Options, &policy) cluster, err := b.buildPolicyCluster(ctx, cfg, &policy)
if err != nil { if err != nil {
return nil, fmt.Errorf("policy #%d: %w", i, err) return nil, fmt.Errorf("policy #%d: %w", i, err)
} }
@ -119,16 +119,16 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
func (b *Builder) buildInternalCluster( func (b *Builder) buildInternalCluster(
ctx context.Context, ctx context.Context,
options *config.Options, cfg *config.Config,
name string, name string,
dsts []*url.URL, dsts []*url.URL,
upstreamProtocol upstreamProtocolConfig, upstreamProtocol upstreamProtocolConfig,
) (*envoy_config_cluster_v3.Cluster, error) { ) (*envoy_config_cluster_v3.Cluster, error) {
cluster := newDefaultEnvoyClusterConfig() cluster := newDefaultEnvoyClusterConfig()
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily) cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily)
var endpoints []Endpoint var endpoints []Endpoint
for _, dst := range dsts { for _, dst := range dsts {
ts, err := b.buildInternalTransportSocket(ctx, options, dst) ts, err := b.buildInternalTransportSocket(ctx, cfg, dst)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -141,18 +141,22 @@ func (b *Builder) buildInternalCluster(
return cluster, nil return cluster, nil
} }
func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) { func (b *Builder) buildPolicyCluster(
ctx context.Context,
cfg *config.Config,
policy *config.Policy,
) (*envoy_config_cluster_v3.Cluster, error) {
cluster := new(envoy_config_cluster_v3.Cluster) cluster := new(envoy_config_cluster_v3.Cluster)
proto.Merge(cluster, policy.EnvoyOpts) proto.Merge(cluster, policy.EnvoyOpts)
if options.EnvoyBindConfigFreebind.IsSet() || options.EnvoyBindConfigSourceAddress != "" { if cfg.Options.EnvoyBindConfigFreebind.IsSet() || cfg.Options.EnvoyBindConfigSourceAddress != "" {
cluster.UpstreamBindConfig = new(envoy_config_core_v3.BindConfig) cluster.UpstreamBindConfig = new(envoy_config_core_v3.BindConfig)
if options.EnvoyBindConfigFreebind.IsSet() { if cfg.Options.EnvoyBindConfigFreebind.IsSet() {
cluster.UpstreamBindConfig.Freebind = wrapperspb.Bool(options.EnvoyBindConfigFreebind.Bool) cluster.UpstreamBindConfig.Freebind = wrapperspb.Bool(cfg.Options.EnvoyBindConfigFreebind.Bool)
} }
if options.EnvoyBindConfigSourceAddress != "" { if cfg.Options.EnvoyBindConfigSourceAddress != "" {
cluster.UpstreamBindConfig.SourceAddress = &envoy_config_core_v3.SocketAddress{ cluster.UpstreamBindConfig.SourceAddress = &envoy_config_core_v3.SocketAddress{
Address: options.EnvoyBindConfigSourceAddress, Address: cfg.Options.EnvoyBindConfigSourceAddress,
PortSpecifier: &envoy_config_core_v3.SocketAddress_PortValue{ PortSpecifier: &envoy_config_core_v3.SocketAddress_PortValue{
PortValue: 0, PortValue: 0,
}, },
@ -171,13 +175,13 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy) upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
name := getClusterID(policy) name := getClusterID(policy)
endpoints, err := b.buildPolicyEndpoints(ctx, options, policy) endpoints, err := b.buildPolicyEndpoints(ctx, cfg, policy)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if cluster.DnsLookupFamily == envoy_config_cluster_v3.Cluster_AUTO { if cluster.DnsLookupFamily == envoy_config_cluster_v3.Cluster_AUTO {
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily) cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily)
} }
if policy.EnableGoogleCloudServerlessAuthentication { if policy.EnableGoogleCloudServerlessAuthentication {
@ -193,12 +197,12 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
func (b *Builder) buildPolicyEndpoints( func (b *Builder) buildPolicyEndpoints(
ctx context.Context, ctx context.Context,
options *config.Options, cfg *config.Config,
policy *config.Policy, policy *config.Policy,
) ([]Endpoint, error) { ) ([]Endpoint, error) {
var endpoints []Endpoint var endpoints []Endpoint
for _, dst := range policy.To { for _, dst := range policy.To {
ts, err := b.buildPolicyTransportSocket(ctx, options, policy, dst.URL) ts, err := b.buildPolicyTransportSocket(ctx, cfg, policy, dst.URL)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -209,7 +213,7 @@ func (b *Builder) buildPolicyEndpoints(
func (b *Builder) buildInternalTransportSocket( func (b *Builder) buildInternalTransportSocket(
ctx context.Context, ctx context.Context,
options *config.Options, cfg *config.Config,
endpoint *url.URL, endpoint *url.URL,
) (*envoy_config_core_v3.TransportSocket, error) { ) (*envoy_config_core_v3.TransportSocket, error) {
if endpoint.Scheme != "https" { if endpoint.Scheme != "https" {
@ -218,10 +222,10 @@ func (b *Builder) buildInternalTransportSocket(
validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{ validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{
MatchTypedSubjectAltNames: []*envoy_extensions_transport_sockets_tls_v3.SubjectAltNameMatcher{ MatchTypedSubjectAltNames: []*envoy_extensions_transport_sockets_tls_v3.SubjectAltNameMatcher{
b.buildSubjectAltNameMatcher(endpoint, options.OverrideCertificateName), b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName),
}, },
} }
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile) bs, err := getCombinedCertificateAuthority(cfg.Options.CA, cfg.Options.CAFile)
if err != nil { if err != nil {
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found") log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else { } else {
@ -234,7 +238,7 @@ func (b *Builder) buildInternalTransportSocket(
ValidationContext: validationContext, ValidationContext: validationContext,
}, },
}, },
Sni: b.buildSubjectNameIndication(endpoint, options.OverrideCertificateName), Sni: b.buildSubjectNameIndication(endpoint, cfg.Options.OverrideCertificateName),
} }
tlsConfig := marshalAny(tlsContext) tlsConfig := marshalAny(tlsContext)
return &envoy_config_core_v3.TransportSocket{ return &envoy_config_core_v3.TransportSocket{
@ -247,7 +251,7 @@ func (b *Builder) buildInternalTransportSocket(
func (b *Builder) buildPolicyTransportSocket( func (b *Builder) buildPolicyTransportSocket(
ctx context.Context, ctx context.Context,
options *config.Options, cfg *config.Config,
policy *config.Policy, policy *config.Policy,
dst url.URL, dst url.URL,
) (*envoy_config_core_v3.TransportSocket, error) { ) (*envoy_config_core_v3.TransportSocket, error) {
@ -257,7 +261,7 @@ func (b *Builder) buildPolicyTransportSocket(
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy) upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
vc, err := b.buildPolicyValidationContext(ctx, options, policy, dst) vc, err := b.buildPolicyValidationContext(ctx, cfg, policy, dst)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -318,7 +322,7 @@ func (b *Builder) buildPolicyTransportSocket(
func (b *Builder) buildPolicyValidationContext( func (b *Builder) buildPolicyValidationContext(
ctx context.Context, ctx context.Context,
options *config.Options, cfg *config.Config,
policy *config.Policy, policy *config.Policy,
dst url.URL, dst url.URL,
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) { ) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
@ -343,7 +347,7 @@ func (b *Builder) buildPolicyValidationContext(
} }
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs) validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
} else { } else {
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile) bs, err := getCombinedCertificateAuthority(cfg.Options.CA, cfg.Options.CAFile)
if err != nil { if err != nil {
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found") log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else { } else {

View file

@ -37,14 +37,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename() combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename()
t.Run("insecure", func(t *testing.T) { t.Run("insecure", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com"), To: mustParseWeightedURLs(t, "http://example.com"),
}, *mustParseURL(t, "http://example.com")) }, *mustParseURL(t, "http://example.com"))
require.NoError(t, err) require.NoError(t, err)
assert.Nil(t, ts) assert.Nil(t, ts)
}) })
t.Run("host as sni", func(t *testing.T) { t.Run("host as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"), To: mustParseWeightedURLs(t, "https://example.com"),
}, *mustParseURL(t, "https://example.com")) }, *mustParseURL(t, "https://example.com"))
require.NoError(t, err) require.NoError(t, err)
@ -97,7 +97,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts) `, ts)
}) })
t.Run("tls_server_name as sni", func(t *testing.T) { t.Run("tls_server_name as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"), To: mustParseWeightedURLs(t, "https://example.com"),
TLSServerName: "use-this-name.example.com", TLSServerName: "use-this-name.example.com",
}, *mustParseURL(t, "https://example.com")) }, *mustParseURL(t, "https://example.com"))
@ -151,7 +151,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts) `, ts)
}) })
t.Run("tls_upstream_server_name as sni", func(t *testing.T) { t.Run("tls_upstream_server_name as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"), To: mustParseWeightedURLs(t, "https://example.com"),
TLSUpstreamServerName: "use-this-name.example.com", TLSUpstreamServerName: "use-this-name.example.com",
}, *mustParseURL(t, "https://example.com")) }, *mustParseURL(t, "https://example.com"))
@ -205,7 +205,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts) `, ts)
}) })
t.Run("tls_skip_verify", func(t *testing.T) { t.Run("tls_skip_verify", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"), To: mustParseWeightedURLs(t, "https://example.com"),
TLSSkipVerify: true, TLSSkipVerify: true,
}, *mustParseURL(t, "https://example.com")) }, *mustParseURL(t, "https://example.com"))
@ -260,7 +260,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts) `, ts)
}) })
t.Run("custom ca", func(t *testing.T) { t.Run("custom ca", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"), To: mustParseWeightedURLs(t, "https://example.com"),
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}), TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
}, *mustParseURL(t, "https://example.com")) }, *mustParseURL(t, "https://example.com"))
@ -314,7 +314,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts) `, ts)
}) })
t.Run("options custom ca", func(t *testing.T) { t.Run("options custom ca", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o2, &config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, config.New(o2), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"), To: mustParseWeightedURLs(t, "https://example.com"),
}, *mustParseURL(t, "https://example.com")) }, *mustParseURL(t, "https://example.com"))
require.NoError(t, err) require.NoError(t, err)
@ -368,7 +368,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
}) })
t.Run("client certificate", func(t *testing.T) { t.Run("client certificate", func(t *testing.T) {
clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey) clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey)
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"), To: mustParseWeightedURLs(t, "https://example.com"),
ClientCertificate: clientCert, ClientCertificate: clientCert,
}, *mustParseURL(t, "https://example.com")) }, *mustParseURL(t, "https://example.com"))
@ -438,7 +438,7 @@ func Test_buildCluster(t *testing.T) {
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename() rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
o1 := config.NewDefaultOptions() o1 := config.NewDefaultOptions()
t.Run("insecure", func(t *testing.T) { t.Run("insecure", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"), To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -495,7 +495,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster) `, cluster)
}) })
t.Run("secure", func(t *testing.T) { t.Run("secure", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, To: mustParseWeightedURLs(t,
"https://example.com", "https://example.com",
"https://example.com", "https://example.com",
@ -663,7 +663,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster) `, cluster)
}) })
t.Run("ip addresses", func(t *testing.T) { t.Run("ip addresses", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"), To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -718,7 +718,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster) `, cluster)
}) })
t.Run("weights", func(t *testing.T) { t.Run("weights", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"), To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -775,7 +775,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster) `, cluster)
}) })
t.Run("localhost", func(t *testing.T) { t.Run("localhost", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "http://localhost"), To: mustParseWeightedURLs(t, "http://localhost"),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -821,7 +821,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster) `, cluster)
}) })
t.Run("outlier", func(t *testing.T) { t.Run("outlier", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com"), To: mustParseWeightedURLs(t, "http://example.com"),
}) })
require.NoError(t, err) require.NoError(t, err)
@ -904,7 +904,7 @@ func Test_bindConfig(t *testing.T) {
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
t.Run("no bind config", func(t *testing.T) { t.Run("no bind config", func(t *testing.T) {
cluster, err := b.buildPolicyCluster(ctx, &config.Options{}, &config.Policy{ cluster, err := b.buildPolicyCluster(ctx, config.New(nil), &config.Policy{
From: "https://from.example.com", From: "https://from.example.com",
To: mustParseWeightedURLs(t, "https://to.example.com"), To: mustParseWeightedURLs(t, "https://to.example.com"),
}) })
@ -912,9 +912,9 @@ func Test_bindConfig(t *testing.T) {
assert.Nil(t, cluster.UpstreamBindConfig) assert.Nil(t, cluster.UpstreamBindConfig)
}) })
t.Run("freebind", func(t *testing.T) { t.Run("freebind", func(t *testing.T) {
cluster, err := b.buildPolicyCluster(ctx, &config.Options{ cluster, err := b.buildPolicyCluster(ctx, config.New(&config.Options{
EnvoyBindConfigFreebind: null.BoolFrom(true), EnvoyBindConfigFreebind: null.BoolFrom(true),
}, &config.Policy{ }), &config.Policy{
From: "https://from.example.com", From: "https://from.example.com",
To: mustParseWeightedURLs(t, "https://to.example.com"), To: mustParseWeightedURLs(t, "https://to.example.com"),
}) })
@ -930,9 +930,9 @@ func Test_bindConfig(t *testing.T) {
`, cluster.UpstreamBindConfig) `, cluster.UpstreamBindConfig)
}) })
t.Run("source address", func(t *testing.T) { t.Run("source address", func(t *testing.T) {
cluster, err := b.buildPolicyCluster(ctx, &config.Options{ cluster, err := b.buildPolicyCluster(ctx, config.New(&config.Options{
EnvoyBindConfigSourceAddress: "192.168.0.1", EnvoyBindConfigSourceAddress: "192.168.0.1",
}, &config.Policy{ }), &config.Policy{
From: "https://from.example.com", From: "https://from.example.com",
To: mustParseWeightedURLs(t, "https://to.example.com"), To: mustParseWeightedURLs(t, "https://to.example.com"),
}) })

View file

@ -76,10 +76,10 @@ func newDefaultEnvoyClusterConfig() *envoy_config_cluster_v3.Cluster {
} }
} }
func buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog { func buildAccessLogs(cfg *config.Config) []*envoy_config_accesslog_v3.AccessLog {
lvl := options.ProxyLogLevel lvl := cfg.Options.ProxyLogLevel
if lvl == "" { if lvl == "" {
lvl = options.LogLevel lvl = cfg.Options.LogLevel
} }
if lvl == "" { if lvl == "" {
lvl = "debug" lvl = "debug"

View file

@ -10,7 +10,7 @@ import (
) )
func (b *Builder) buildVirtualHost( func (b *Builder) buildVirtualHost(
options *config.Options, cfg *config.Config,
name string, name string,
domain string, domain string,
) (*envoy_config_route_v3.VirtualHost, error) { ) (*envoy_config_route_v3.VirtualHost, error) {
@ -20,15 +20,15 @@ func (b *Builder) buildVirtualHost(
} }
// these routes match /.pomerium/... and similar paths // these routes match /.pomerium/... and similar paths
rs, err := b.buildPomeriumHTTPRoutes(options, domain) rs, err := b.buildPomeriumHTTPRoutes(cfg, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
vh.Routes = append(vh.Routes, rs...) vh.Routes = append(vh.Routes, rs...)
// if we're the proxy or authenticate service, add our global headers // if we're the proxy or authenticate service, add our global headers
if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) { if config.IsProxy(cfg.Options.Services) || config.IsAuthenticate(cfg.Options.Services) {
vh.ResponseHeadersToAdd = toEnvoyHeaders(options.GetSetResponseHeaders()) vh.ResponseHeadersToAdd = toEnvoyHeaders(cfg.Options.GetSetResponseHeaders())
} }
return vh, nil return vh, nil
@ -37,13 +37,13 @@ func (b *Builder) buildVirtualHost(
// buildLocalReplyConfig builds the local reply config: the config used to modify "local" replies, that is replies // buildLocalReplyConfig builds the local reply config: the config used to modify "local" replies, that is replies
// coming directly from envoy // coming directly from envoy
func (b *Builder) buildLocalReplyConfig( func (b *Builder) buildLocalReplyConfig(
options *config.Options, cfg *config.Config,
) *envoy_http_connection_manager.LocalReplyConfig { ) *envoy_http_connection_manager.LocalReplyConfig {
// add global headers for HSTS headers (#2110) // add global headers for HSTS headers (#2110)
var headers []*envoy_config_core_v3.HeaderValueOption var headers []*envoy_config_core_v3.HeaderValueOption
// if we're the proxy or authenticate service, add our global headers // if we're the proxy or authenticate service, add our global headers
if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) { if config.IsProxy(cfg.Options.Services) || config.IsAuthenticate(cfg.Options.Services) {
headers = toEnvoyHeaders(options.GetSetResponseHeaders()) headers = toEnvoyHeaders(cfg.Options.GetSetResponseHeaders())
} }
return &envoy_http_connection_manager.LocalReplyConfig{ return &envoy_http_connection_manager.LocalReplyConfig{

View file

@ -53,7 +53,10 @@ func init() {
} }
// BuildListeners builds envoy listeners from the given config. // BuildListeners builds envoy listeners from the given config.
func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*envoy_config_listener_v3.Listener, error) { func (b *Builder) BuildListeners(
ctx context.Context,
cfg *config.Config,
) ([]*envoy_config_listener_v3.Listener, error) {
var listeners []*envoy_config_listener_v3.Listener var listeners []*envoy_config_listener_v3.Listener
if config.IsAuthenticate(cfg.Options.Services) || config.IsProxy(cfg.Options.Services) { if config.IsAuthenticate(cfg.Options.Services) || config.IsProxy(cfg.Options.Services) {
@ -89,19 +92,22 @@ func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*en
return listeners, nil return listeners, nil
} }
func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) { func (b *Builder) buildMainListener(
ctx context.Context,
cfg *config.Config,
) (*envoy_config_listener_v3.Listener, error) {
listenerFilters := []*envoy_config_listener_v3.ListenerFilter{} listenerFilters := []*envoy_config_listener_v3.ListenerFilter{}
if cfg.Options.UseProxyProtocol { if cfg.Options.UseProxyProtocol {
listenerFilters = append(listenerFilters, ProxyProtocolFilter()) listenerFilters = append(listenerFilters, ProxyProtocolFilter())
} }
if cfg.Options.InsecureServer { if cfg.Options.InsecureServer {
allDomains, err := getAllRouteableDomains(cfg.Options, cfg.Options.Addr) allDomains, err := getAllRouteableDomains(cfg, cfg.Options.Addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, allDomains, "") filter, err := b.buildMainHTTPConnectionManagerFilter(cfg, allDomains, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -118,9 +124,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
} }
listenerFilters = append(listenerFilters, TLSInspectorFilter()) listenerFilters = append(listenerFilters, TLSInspectorFilter())
chains, err := b.buildFilterChains(cfg.Options, cfg.Options.Addr, chains, err := b.buildFilterChains(cfg, cfg.Options.Addr,
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain) filter, err := b.buildMainHTTPConnectionManagerFilter(cfg, httpDomains, tlsDomain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -155,7 +161,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
return li, nil return li, nil
} }
func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listener_v3.Listener, error) { func (b *Builder) buildMetricsListener(
cfg *config.Config,
) (*envoy_config_listener_v3.Listener, error) {
filter, err := b.buildMetricsHTTPConnectionManagerFilter() filter, err := b.buildMetricsHTTPConnectionManagerFilter()
if err != nil { if err != nil {
return nil, err return nil, err
@ -235,22 +243,23 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen
} }
func (b *Builder) buildFilterChains( func (b *Builder) buildFilterChains(
options *config.Options, addr string, cfg *config.Config,
addr string,
callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error), callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error),
) ([]*envoy_config_listener_v3.FilterChain, error) { ) ([]*envoy_config_listener_v3.FilterChain, error) {
allDomains, err := getAllRouteableDomains(options, addr) allDomains, err := getAllRouteableDomains(cfg, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tlsDomains, err := getAllTLSDomains(options, addr) tlsDomains, err := getAllTLSDomains(cfg, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
var chains []*envoy_config_listener_v3.FilterChain var chains []*envoy_config_listener_v3.FilterChain
for _, domain := range tlsDomains { for _, domain := range tlsDomains {
routeableDomains, err := getRouteableDomainsForTLSServerName(options, addr, domain) routeableDomains, err := getRouteableDomainsForTLSServerName(cfg, addr, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -273,31 +282,31 @@ func (b *Builder) buildFilterChains(
} }
func (b *Builder) buildMainHTTPConnectionManagerFilter( func (b *Builder) buildMainHTTPConnectionManagerFilter(
options *config.Options, cfg *config.Config,
domains []string, domains []string,
tlsDomain string, tlsDomain string,
) (*envoy_config_listener_v3.Filter, error) { ) (*envoy_config_listener_v3.Filter, error) {
authorizeURLs, err := options.GetInternalAuthorizeURLs() authorizeURLs, err := cfg.Options.GetInternalAuthorizeURLs()
if err != nil { if err != nil {
return nil, err return nil, err
} }
dataBrokerURLs, err := options.GetInternalDataBrokerURLs() dataBrokerURLs, err := cfg.Options.GetInternalDataBrokerURLs()
if err != nil { if err != nil {
return nil, err return nil, err
} }
var virtualHosts []*envoy_config_route_v3.VirtualHost var virtualHosts []*envoy_config_route_v3.VirtualHost
for _, domain := range domains { for _, domain := range domains {
vh, err := b.buildVirtualHost(options, domain, domain) vh, err := b.buildVirtualHost(cfg, domain, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if options.Addr == options.GetGRPCAddr() { if cfg.Options.Addr == cfg.Options.GetGRPCAddr() {
// if this is a gRPC service domain and we're supposed to handle that, add those routes // if this is a gRPC service domain and we're supposed to handle that, add those routes
if (config.IsAuthorize(options.Services) && hostsMatchDomain(authorizeURLs, domain)) || if (config.IsAuthorize(cfg.Options.Services) && hostsMatchDomain(authorizeURLs, domain)) ||
(config.IsDataBroker(options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) { (config.IsDataBroker(cfg.Options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) {
rs, err := b.buildGRPCRoutes() rs, err := b.buildGRPCRoutes()
if err != nil { if err != nil {
return nil, err return nil, err
@ -307,8 +316,8 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
} }
// if we're the proxy, add all the policy routes // if we're the proxy, add all the policy routes
if config.IsProxy(options.Services) { if config.IsProxy(cfg.Options.Services) {
rs, err := b.buildPolicyRoutes(options, domain) rs, err := b.buildPolicyRoutes(cfg, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -320,15 +329,15 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
} }
} }
vh, err := b.buildVirtualHost(options, "catch-all", "*") vh, err := b.buildVirtualHost(cfg, "catch-all", "*")
if err != nil { if err != nil {
return nil, err return nil, err
} }
virtualHosts = append(virtualHosts, vh) virtualHosts = append(virtualHosts, vh)
var grpcClientTimeout *durationpb.Duration var grpcClientTimeout *durationpb.Duration
if options.GRPCClientTimeout != 0 { if cfg.Options.GRPCClientTimeout != 0 {
grpcClientTimeout = durationpb.New(options.GRPCClientTimeout) grpcClientTimeout = durationpb.New(cfg.Options.GRPCClientTimeout)
} else { } else {
grpcClientTimeout = durationpb.New(30 * time.Second) grpcClientTimeout = durationpb.New(30 * time.Second)
} }
@ -346,15 +355,15 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
filters = append(filters, HTTPRouterFilter()) filters = append(filters, HTTPRouterFilter())
var maxStreamDuration *durationpb.Duration var maxStreamDuration *durationpb.Duration
if options.WriteTimeout > 0 { if cfg.Options.WriteTimeout > 0 {
maxStreamDuration = durationpb.New(options.WriteTimeout) maxStreamDuration = durationpb.New(cfg.Options.WriteTimeout)
} }
rc, err := b.buildRouteConfiguration("main", virtualHosts) rc, err := b.buildRouteConfiguration("main", virtualHosts)
if err != nil { if err != nil {
return nil, err return nil, err
} }
tracingProvider, err := buildTracingHTTP(options) tracingProvider, err := buildTracingHTTP(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -362,27 +371,27 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{ return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{
AlwaysSetRequestIdInResponse: true, AlwaysSetRequestIdInResponse: true,
CodecType: options.GetCodecType().ToEnvoy(), CodecType: cfg.Options.GetCodecType().ToEnvoy(),
StatPrefix: "ingress", StatPrefix: "ingress",
RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{ RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{
RouteConfig: rc, RouteConfig: rc,
}, },
HttpFilters: filters, HttpFilters: filters,
AccessLog: buildAccessLogs(options), AccessLog: buildAccessLogs(cfg),
CommonHttpProtocolOptions: &envoy_config_core_v3.HttpProtocolOptions{ CommonHttpProtocolOptions: &envoy_config_core_v3.HttpProtocolOptions{
IdleTimeout: durationpb.New(options.IdleTimeout), IdleTimeout: durationpb.New(cfg.Options.IdleTimeout),
MaxStreamDuration: maxStreamDuration, MaxStreamDuration: maxStreamDuration,
}, },
RequestTimeout: durationpb.New(options.ReadTimeout), RequestTimeout: durationpb.New(cfg.Options.ReadTimeout),
Tracing: &envoy_http_connection_manager.HttpConnectionManager_Tracing{ Tracing: &envoy_http_connection_manager.HttpConnectionManager_Tracing{
RandomSampling: &envoy_type_v3.Percent{Value: options.TracingSampleRate * 100}, RandomSampling: &envoy_type_v3.Percent{Value: cfg.Options.TracingSampleRate * 100},
Provider: tracingProvider, Provider: tracingProvider,
}, },
// See https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_conn_man/headers#x-forwarded-for // See https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_conn_man/headers#x-forwarded-for
UseRemoteAddress: &wrappers.BoolValue{Value: true}, UseRemoteAddress: &wrappers.BoolValue{Value: true},
SkipXffAppend: options.SkipXffAppend, SkipXffAppend: cfg.Options.SkipXffAppend,
XffNumTrustedHops: options.XffNumTrustedHops, XffNumTrustedHops: cfg.Options.XffNumTrustedHops,
LocalReplyConfig: b.buildLocalReplyConfig(options), LocalReplyConfig: b.buildLocalReplyConfig(cfg),
}), nil }), nil
} }
@ -420,7 +429,10 @@ func (b *Builder) buildMetricsHTTPConnectionManagerFilter() (*envoy_config_liste
}), nil }), nil
} }
func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) { func (b *Builder) buildGRPCListener(
ctx context.Context,
cfg *config.Config,
) (*envoy_config_listener_v3.Listener, error) {
filter, err := b.buildGRPCHTTPConnectionManagerFilter() filter, err := b.buildGRPCHTTPConnectionManagerFilter()
if err != nil { if err != nil {
return nil, err return nil, err
@ -437,7 +449,7 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e
return li, nil return li, nil
} }
chains, err := b.buildFilterChains(cfg.Options, cfg.Options.GRPCAddr, chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr,
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) { func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
filterChain := &envoy_config_listener_v3.FilterChain{ filterChain := &envoy_config_listener_v3.FilterChain{
Filters: []*envoy_config_listener_v3.Filter{filter}, Filters: []*envoy_config_listener_v3.Filter{filter},
@ -518,7 +530,10 @@ func (b *Builder) buildGRPCHTTPConnectionManagerFilter() (*envoy_config_listener
}), nil }), nil
} }
func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3.VirtualHost) (*envoy_config_route_v3.RouteConfiguration, error) { func (b *Builder) buildRouteConfiguration(
name string,
virtualHosts []*envoy_config_route_v3.VirtualHost,
) (*envoy_config_route_v3.RouteConfiguration, error) {
return &envoy_config_route_v3.RouteConfiguration{ return &envoy_config_route_v3.RouteConfiguration{
Name: name, Name: name,
VirtualHosts: virtualHosts, VirtualHosts: virtualHosts,
@ -527,7 +542,8 @@ func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_con
}, nil }, nil
} }
func (b *Builder) buildDownstreamTLSContext(ctx context.Context, func (b *Builder) buildDownstreamTLSContext(
ctx context.Context,
cfg *config.Config, cfg *config.Config,
domain string, domain string,
) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { ) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
@ -570,7 +586,8 @@ func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
} }
} }
func (b *Builder) buildDownstreamValidationContext(ctx context.Context, func (b *Builder) buildDownstreamValidationContext(
ctx context.Context,
cfg *config.Config, cfg *config.Config,
domain string, domain string,
) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext { ) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext {
@ -580,7 +597,7 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
needsClientCert = true needsClientCert = true
} }
if !needsClientCert { if !needsClientCert {
for _, p := range getPoliciesForDomain(cfg.Options, domain) { for _, p := range getPoliciesForDomain(cfg, domain) {
if p.TLSDownstreamClientCA != "" { if p.TLSDownstreamClientCA != "" {
needsClientCert = true needsClientCert = true
break break
@ -613,19 +630,23 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
return vc return vc
} }
func getRouteableDomainsForTLSServerName(options *config.Options, addr string, tlsServerName string) ([]string, error) { func getRouteableDomainsForTLSServerName(
cfg *config.Config,
addr string,
tlsServerName string,
) ([]string, error) {
allDomains := sets.NewSorted[string]() allDomains := sets.NewSorted[string]()
if addr == options.Addr { if addr == cfg.Options.Addr {
domains, err := options.GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName) domains, err := cfg.Options.GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
allDomains.Add(domains...) allDomains.Add(domains...)
} }
if addr == options.GetGRPCAddr() { if addr == cfg.Options.GetGRPCAddr() {
domains, err := options.GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName) domains, err := cfg.Options.GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -635,19 +656,22 @@ func getRouteableDomainsForTLSServerName(options *config.Options, addr string, t
return allDomains.ToSlice(), nil return allDomains.ToSlice(), nil
} }
func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) { func getAllRouteableDomains(
cfg *config.Config,
addr string,
) ([]string, error) {
allDomains := sets.NewSorted[string]() allDomains := sets.NewSorted[string]()
if addr == options.Addr { if addr == cfg.Options.Addr {
domains, err := options.GetAllRouteableHTTPDomains() domains, err := cfg.Options.GetAllRouteableHTTPDomains()
if err != nil { if err != nil {
return nil, err return nil, err
} }
allDomains.Add(domains...) allDomains.Add(domains...)
} }
if addr == options.GetGRPCAddr() { if addr == cfg.Options.GetGRPCAddr() {
domains, err := options.GetAllRouteableGRPCDomains() domains, err := cfg.Options.GetAllRouteableGRPCDomains()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -657,8 +681,11 @@ func getAllRouteableDomains(options *config.Options, addr string) ([]string, err
return allDomains.ToSlice(), nil return allDomains.ToSlice(), nil
} }
func getAllTLSDomains(options *config.Options, addr string) ([]string, error) { func getAllTLSDomains(
allDomains, err := getAllRouteableDomains(options, addr) cfg *config.Config,
addr string,
) ([]string, error) {
allDomains, err := getAllRouteableDomains(cfg, addr)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -711,9 +738,12 @@ func hostMatchesDomain(u *url.URL, host string) bool {
return h1 == h2 && p1 == p2 return h1 == h2 && p1 == p2
} }
func getPoliciesForDomain(options *config.Options, domain string) []config.Policy { func getPoliciesForDomain(
cfg *config.Config,
domain string,
) []config.Policy {
var policies []config.Policy var policies []config.Policy
for _, p := range options.GetAllPolicies() { for _, p := range cfg.Options.GetAllPolicies() {
if p.Source != nil && p.Source.URL.Hostname() == domain { if p.Source != nil && p.Source.URL.Hostname() == domain {
policies = append(policies, p) policies = append(policies, p)
} }

View file

@ -26,13 +26,11 @@ func Test_buildMetricsHTTPConnectionManagerFilter(t *testing.T) {
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem") keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
li, err := b.buildMetricsListener(&config.Config{ li, err := b.buildMetricsListener(config.New(&config.Options{
Options: &config.Options{ MetricsAddr: "127.0.0.1:9902",
MetricsAddr: "127.0.0.1:9902", MetricsCertificate: aExampleComCert,
MetricsCertificate: aExampleComCert, MetricsCertificateKey: aExampleComKey,
MetricsCertificateKey: aExampleComKey, }))
},
})
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
{ {
@ -115,7 +113,7 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
options := config.NewDefaultOptions() options := config.NewDefaultOptions()
options.SkipXffAppend = true options.SkipXffAppend = true
options.XffNumTrustedHops = 1 options.XffNumTrustedHops = 1
filter, err := b.buildMainHTTPConnectionManagerFilter(options, []string{"example.com"}, "*") filter, err := b.buildMainHTTPConnectionManagerFilter(config.New(options), []string{"example.com"}, "*")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `{ testutil.AssertProtoJSONEqual(t, `{
"name": "envoy.filters.network.http_connection_manager", "name": "envoy.filters.network.http_connection_manager",
@ -544,10 +542,10 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem") keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")
t.Run("no-validation", func(t *testing.T) { t.Run("no-validation", func(t *testing.T) {
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
Cert: aExampleComCert, Cert: aExampleComCert,
Key: aExampleComKey, Key: aExampleComKey,
}}, "a.example.com") }), "a.example.com")
testutil.AssertProtoJSONEqual(t, `{ testutil.AssertProtoJSONEqual(t, `{
"commonTlsContext": { "commonTlsContext": {
@ -577,11 +575,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
}`, downstreamTLSContext) }`, downstreamTLSContext)
}) })
t.Run("client-ca", func(t *testing.T) { t.Run("client-ca", func(t *testing.T) {
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
Cert: aExampleComCert, Cert: aExampleComCert,
Key: aExampleComKey, Key: aExampleComKey,
ClientCA: "TEST", ClientCA: "TEST",
}}, "a.example.com") }), "a.example.com")
testutil.AssertProtoJSONEqual(t, `{ testutil.AssertProtoJSONEqual(t, `{
"commonTlsContext": { "commonTlsContext": {
@ -614,7 +612,7 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
}`, downstreamTLSContext) }`, downstreamTLSContext)
}) })
t.Run("policy-client-ca", func(t *testing.T) { t.Run("policy-client-ca", func(t *testing.T) {
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
Cert: aExampleComCert, Cert: aExampleComCert,
Key: aExampleComKey, Key: aExampleComKey,
Policies: []config.Policy{ Policies: []config.Policy{
@ -623,7 +621,7 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
TLSDownstreamClientCA: "TEST", TLSDownstreamClientCA: "TEST",
}, },
}, },
}}, "a.example.com") }), "a.example.com")
testutil.AssertProtoJSONEqual(t, `{ testutil.AssertProtoJSONEqual(t, `{
"commonTlsContext": { "commonTlsContext": {
@ -656,11 +654,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
}`, downstreamTLSContext) }`, downstreamTLSContext)
}) })
t.Run("http1", func(t *testing.T) { t.Run("http1", func(t *testing.T) {
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
Cert: aExampleComCert, Cert: aExampleComCert,
Key: aExampleComKey, Key: aExampleComKey,
CodecType: config.CodecTypeHTTP1, CodecType: config.CodecTypeHTTP1,
}}, "a.example.com") }), "a.example.com")
testutil.AssertProtoJSONEqual(t, `{ testutil.AssertProtoJSONEqual(t, `{
"commonTlsContext": { "commonTlsContext": {
@ -690,11 +688,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
}`, downstreamTLSContext) }`, downstreamTLSContext)
}) })
t.Run("http2", func(t *testing.T) { t.Run("http2", func(t *testing.T) {
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
Cert: aExampleComCert, Cert: aExampleComCert,
Key: aExampleComKey, Key: aExampleComKey,
CodecType: config.CodecTypeHTTP2, CodecType: config.CodecTypeHTTP2,
}}, "a.example.com") }), "a.example.com")
testutil.AssertProtoJSONEqual(t, `{ testutil.AssertProtoJSONEqual(t, `{
"commonTlsContext": { "commonTlsContext": {
@ -739,9 +737,10 @@ func Test_getAllDomains(t *testing.T) {
{Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}}, {Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}},
}, },
} }
cfg := config.New(options)
t.Run("routable", func(t *testing.T) { t.Run("routable", func(t *testing.T) {
t.Run("http", func(t *testing.T) { t.Run("http", func(t *testing.T) {
actual, err := getAllRouteableDomains(options, "127.0.0.1:9000") actual, err := getAllRouteableDomains(cfg, "127.0.0.1:9000")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"a.example.com", "a.example.com",
@ -756,7 +755,7 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
}) })
t.Run("grpc", func(t *testing.T) { t.Run("grpc", func(t *testing.T) {
actual, err := getAllRouteableDomains(options, "127.0.0.1:9001") actual, err := getAllRouteableDomains(cfg, "127.0.0.1:9001")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"authorize.example.com:9001", "authorize.example.com:9001",
@ -765,9 +764,9 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
}) })
t.Run("both", func(t *testing.T) { t.Run("both", func(t *testing.T) {
newOptions := *options newCfg := cfg.Clone()
newOptions.GRPCAddr = newOptions.Addr newCfg.Options.GRPCAddr = cfg.Options.Addr
actual, err := getAllRouteableDomains(&newOptions, "127.0.0.1:9000") actual, err := getAllRouteableDomains(newCfg, "127.0.0.1:9000")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"a.example.com", "a.example.com",
@ -786,7 +785,7 @@ func Test_getAllDomains(t *testing.T) {
}) })
t.Run("tls", func(t *testing.T) { t.Run("tls", func(t *testing.T) {
t.Run("http", func(t *testing.T) { t.Run("http", func(t *testing.T) {
actual, err := getAllTLSDomains(options, "127.0.0.1:9000") actual, err := getAllTLSDomains(cfg, "127.0.0.1:9000")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"a.example.com", "a.example.com",
@ -797,7 +796,7 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
}) })
t.Run("grpc", func(t *testing.T) { t.Run("grpc", func(t *testing.T) {
actual, err := getAllTLSDomains(options, "127.0.0.1:9001") actual, err := getAllTLSDomains(cfg, "127.0.0.1:9001")
require.NoError(t, err) require.NoError(t, err)
expect := []string{ expect := []string{
"authorize.example.com", "authorize.example.com",
@ -831,10 +830,10 @@ func Test_buildRouteConfiguration(t *testing.T) {
func Test_requireProxyProtocol(t *testing.T) { func Test_requireProxyProtocol(t *testing.T) {
b := New("local-grpc", "local-http", "local-metrics", nil, nil) b := New("local-grpc", "local-http", "local-metrics", nil, nil)
t.Run("required", func(t *testing.T) { t.Run("required", func(t *testing.T) {
li, err := b.buildMainListener(context.Background(), &config.Config{Options: &config.Options{ li, err := b.buildMainListener(context.Background(), config.New(&config.Options{
UseProxyProtocol: true, UseProxyProtocol: true,
InsecureServer: true, InsecureServer: true,
}}) }))
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[ testutil.AssertProtoJSONEqual(t, `[
{ {
@ -846,10 +845,10 @@ func Test_requireProxyProtocol(t *testing.T) {
]`, li.GetListenerFilters()) ]`, li.GetListenerFilters())
}) })
t.Run("not required", func(t *testing.T) { t.Run("not required", func(t *testing.T) {
li, err := b.buildMainListener(context.Background(), &config.Config{Options: &config.Options{ li, err := b.buildMainListener(context.Background(), config.New(&config.Options{
UseProxyProtocol: false, UseProxyProtocol: false,
InsecureServer: true, InsecureServer: true,
}}) }))
require.NoError(t, err) require.NoError(t, err)
assert.Len(t, li.GetListenerFilters(), 0) assert.Len(t, li.GetListenerFilters(), 0)
}) })

View file

@ -14,7 +14,9 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
) )
func (b *Builder) buildOutboundListener(cfg *config.Config) (*envoy_config_listener_v3.Listener, error) { func (b *Builder) buildOutboundListener(
cfg *config.Config,
) (*envoy_config_listener_v3.Listener, error) {
outboundPort, err := strconv.Atoi(cfg.OutboundPort) outboundPort, err := strconv.Atoi(cfg.OutboundPort)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid outbound port %v: %w", cfg.OutboundPort, err) return nil, fmt.Errorf("invalid outbound port %v: %w", cfg.OutboundPort, err)

View file

@ -47,12 +47,15 @@ func (b *Builder) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) {
}}, nil }}, nil
} }
func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) { func (b *Builder) buildPomeriumHTTPRoutes(
cfg *config.Config,
domain string,
) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route var routes []*envoy_config_route_v3.Route
// if this is the pomerium proxy in front of the the authenticate service, don't add // if this is the pomerium proxy in front of the the authenticate service, don't add
// these routes since they will be handled by authenticate // these routes since they will be handled by authenticate
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain) isFrontingAuthenticate, err := isProxyFrontingAuthenticate(cfg, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -69,27 +72,27 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string
b.buildControlPlanePrefixRoute("/.well-known/pomerium/", false), b.buildControlPlanePrefixRoute("/.well-known/pomerium/", false),
) )
// per #837, only add robots.txt if there are no unauthenticated routes // per #837, only add robots.txt if there are no unauthenticated routes
if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) { if !hasPublicPolicyMatchingURL(cfg, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) {
routes = append(routes, b.buildControlPlanePathRoute("/robots.txt", false)) routes = append(routes, b.buildControlPlanePathRoute("/robots.txt", false))
} }
} }
// if we're handling authentication, add the oauth2 callback url // if we're handling authentication, add the oauth2 callback url
authenticateURL, err := options.GetInternalAuthenticateURL() authenticateURL, err := cfg.Options.GetInternalAuthenticateURL()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) { if config.IsAuthenticate(cfg.Options.Services) && hostMatchesDomain(authenticateURL, domain) {
routes = append(routes, routes = append(routes,
b.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false), b.buildControlPlanePathRoute(cfg.Options.AuthenticateCallbackPath, false),
b.buildControlPlanePathRoute("/", false), b.buildControlPlanePathRoute("/", false),
) )
} }
// if we're the proxy and this is the forward-auth url // if we're the proxy and this is the forward-auth url
forwardAuthURL, err := options.GetForwardAuthURL() forwardAuthURL, err := cfg.Options.GetForwardAuthURL()
if err != nil { if err != nil {
return nil, err return nil, err
} }
if config.IsProxy(options.Services) && hostMatchesDomain(forwardAuthURL, domain) { if config.IsProxy(cfg.Options.Services) && hostMatchesDomain(forwardAuthURL, domain) {
// disable ext_authz and pass request to proxy handlers that enable authN flow // disable ext_authz and pass request to proxy handlers that enable authN flow
r, err := b.buildControlPlanePathAndQueryRoute("/verify", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI}) r, err := b.buildControlPlanePathAndQueryRoute("/verify", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI})
if err != nil { if err != nil {
@ -227,10 +230,13 @@ func getClusterStatsName(policy *config.Policy) string {
return "" return ""
} }
func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) { func (b *Builder) buildPolicyRoutes(
cfg *config.Config,
domain string,
) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route var routes []*envoy_config_route_v3.Route
for i, p := range options.GetAllPolicies() { for i, p := range cfg.Options.GetAllPolicies() {
policy := p policy := p
if !hostMatchesDomain(policy.Source.URL, domain) { if !hostMatchesDomain(policy.Source.URL, domain) {
continue continue
@ -242,7 +248,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
Match: match, Match: match,
Metadata: &envoy_config_core_v3.Metadata{}, Metadata: &envoy_config_core_v3.Metadata{},
RequestHeadersToAdd: toEnvoyHeaders(policy.SetRequestHeaders), RequestHeadersToAdd: toEnvoyHeaders(policy.SetRequestHeaders),
RequestHeadersToRemove: getRequestHeadersToRemove(options, &policy), RequestHeadersToRemove: getRequestHeadersToRemove(cfg, &policy),
ResponseHeadersToAdd: toEnvoyHeaders(policy.SetResponseHeaders), ResponseHeadersToAdd: toEnvoyHeaders(policy.SetResponseHeaders),
} }
if policy.Redirect != nil { if policy.Redirect != nil {
@ -252,7 +258,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
} }
envoyRoute.Action = &envoy_config_route_v3.Route_Redirect{Redirect: action} envoyRoute.Action = &envoy_config_route_v3.Route_Redirect{Redirect: action}
} else { } else {
action, err := b.buildPolicyRouteRouteAction(options, &policy) action, err := b.buildPolicyRouteRouteAction(cfg, &policy)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -264,7 +270,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
} }
// disable authentication entirely when the proxy is fronting authenticate // disable authentication entirely when the proxy is fronting authenticate
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain) isFrontingAuthenticate, err := isProxyFrontingAuthenticate(cfg, domain)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -275,7 +281,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
} else { } else {
luaMetadata["remove_pomerium_cookie"] = &structpb.Value{ luaMetadata["remove_pomerium_cookie"] = &structpb.Value{
Kind: &structpb.Value_StringValue{ Kind: &structpb.Value_StringValue{
StringValue: options.CookieName, StringValue: cfg.Options.CookieName,
}, },
} }
luaMetadata["remove_pomerium_authorization"] = &structpb.Value{ luaMetadata["remove_pomerium_authorization"] = &structpb.Value{
@ -350,13 +356,16 @@ func (b *Builder) buildPolicyRouteRedirectAction(r *config.PolicyRedirect) (*env
return action, nil return action, nil
} }
func (b *Builder) buildPolicyRouteRouteAction(options *config.Options, policy *config.Policy) (*envoy_config_route_v3.RouteAction, error) { func (b *Builder) buildPolicyRouteRouteAction(
cfg *config.Config,
policy *config.Policy,
) (*envoy_config_route_v3.RouteAction, error) {
clusterName := getClusterID(policy) clusterName := getClusterID(policy)
// kubernetes requests are sent to the http control plane to be reproxied // kubernetes requests are sent to the http control plane to be reproxied
if policy.IsForKubernetes() { if policy.IsForKubernetes() {
clusterName = httpCluster clusterName = httpCluster
} }
routeTimeout := getRouteTimeout(options, policy) routeTimeout := getRouteTimeout(cfg, policy)
idleTimeout := getRouteIdleTimeout(policy) idleTimeout := getRouteIdleTimeout(policy)
prefixRewrite, regexRewrite := getRewriteOptions(policy) prefixRewrite, regexRewrite := getRewriteOptions(policy)
upgradeConfigs := []*envoy_config_route_v3.RouteAction_UpgradeConfig{ upgradeConfigs := []*envoy_config_route_v3.RouteAction_UpgradeConfig{
@ -464,13 +473,16 @@ func mkRouteMatch(policy *config.Policy) *envoy_config_route_v3.RouteMatch {
return match return match
} }
func getRequestHeadersToRemove(options *config.Options, policy *config.Policy) []string { func getRequestHeadersToRemove(
cfg *config.Config,
policy *config.Policy,
) []string {
requestHeadersToRemove := policy.RemoveRequestHeaders requestHeadersToRemove := policy.RemoveRequestHeaders
if !policy.PassIdentityHeaders { if !policy.PassIdentityHeaders {
requestHeadersToRemove = append(requestHeadersToRemove, requestHeadersToRemove = append(requestHeadersToRemove,
httputil.HeaderPomeriumJWTAssertion, httputil.HeaderPomeriumJWTAssertion,
httputil.HeaderPomeriumJWTAssertionFor) httputil.HeaderPomeriumJWTAssertionFor)
for headerName := range options.JWTClaimsHeaders { for headerName := range cfg.Options.JWTClaimsHeaders {
requestHeadersToRemove = append(requestHeadersToRemove, headerName) requestHeadersToRemove = append(requestHeadersToRemove, headerName)
} }
} }
@ -482,7 +494,10 @@ func getRequestHeadersToRemove(options *config.Options, policy *config.Policy) [
return requestHeadersToRemove return requestHeadersToRemove
} }
func getRouteTimeout(options *config.Options, policy *config.Policy) *durationpb.Duration { func getRouteTimeout(
cfg *config.Config,
policy *config.Policy,
) *durationpb.Duration {
var routeTimeout *durationpb.Duration var routeTimeout *durationpb.Duration
if policy.UpstreamTimeout != nil { if policy.UpstreamTimeout != nil {
routeTimeout = durationpb.New(*policy.UpstreamTimeout) routeTimeout = durationpb.New(*policy.UpstreamTimeout)
@ -490,7 +505,7 @@ func getRouteTimeout(options *config.Options, policy *config.Policy) *durationpb
// a non-zero value would conflict with idleTimeout and/or websocket / tcp calls // a non-zero value would conflict with idleTimeout and/or websocket / tcp calls
routeTimeout = durationpb.New(0) routeTimeout = durationpb.New(0)
} else { } else {
routeTimeout = durationpb.New(options.DefaultUpstreamTimeout) routeTimeout = durationpb.New(cfg.Options.DefaultUpstreamTimeout)
} }
return routeTimeout return routeTimeout
} }
@ -564,8 +579,11 @@ func setHostRewriteOptions(policy *config.Policy, action *envoy_config_route_v3.
} }
} }
func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) bool { func hasPublicPolicyMatchingURL(
for _, policy := range options.GetAllPolicies() { cfg *config.Config,
requestURL url.URL,
) bool {
for _, policy := range cfg.Options.GetAllPolicies() {
if policy.AllowPublicUnauthenticatedAccess && policy.Matches(requestURL) { if policy.AllowPublicUnauthenticatedAccess && policy.Matches(requestURL) {
return true return true
} }
@ -573,13 +591,16 @@ func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) boo
return false return false
} }
func isProxyFrontingAuthenticate(options *config.Options, domain string) (bool, error) { func isProxyFrontingAuthenticate(
authenticateURL, err := options.GetAuthenticateURL() cfg *config.Config,
domain string,
) (bool, error) {
authenticateURL, err := cfg.Options.GetAuthenticateURL()
if err != nil { if err != nil {
return false, err return false, err
} }
if !config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) { if !config.IsAuthenticate(cfg.Options.Services) && hostMatchesDomain(authenticateURL, domain) {
return true, nil return true, nil
} }

View file

@ -84,7 +84,8 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
AuthenticateCallbackPath: "/oauth2/callback", AuthenticateCallbackPath: "/oauth2/callback",
ForwardAuthURLString: "https://forward-auth.example.com", ForwardAuthURLString: "https://forward-auth.example.com",
} }
routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com") cfg := config.New(options)
routes, err := b.buildPomeriumHTTPRoutes(cfg, "authenticate.example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[ testutil.AssertProtoJSONEqual(t, `[
@ -106,7 +107,8 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
AuthenticateURLString: "https://authenticate.example.com", AuthenticateURLString: "https://authenticate.example.com",
AuthenticateCallbackPath: "/oauth2/callback", AuthenticateCallbackPath: "/oauth2/callback",
} }
routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com") cfg := config.New(options)
routes, err := b.buildPomeriumHTTPRoutes(cfg, "authenticate.example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, "null", routes) testutil.AssertProtoJSONEqual(t, "null", routes)
}) })
@ -122,8 +124,9 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
To: mustParseWeightedURLs(t, "https://to.example.com"), To: mustParseWeightedURLs(t, "https://to.example.com"),
}}, }},
} }
cfg := config.New(options)
_ = options.Policies[0].Validate() _ = options.Policies[0].Validate()
routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com") routes, err := b.buildPomeriumHTTPRoutes(cfg, "from.example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[ testutil.AssertProtoJSONEqual(t, `[
@ -150,8 +153,9 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
AllowPublicUnauthenticatedAccess: true, AllowPublicUnauthenticatedAccess: true,
}}, }},
} }
cfg := config.New(options)
_ = options.Policies[0].Validate() _ = options.Policies[0].Validate()
routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com") routes, err := b.buildPomeriumHTTPRoutes(cfg, "from.example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[ testutil.AssertProtoJSONEqual(t, `[
@ -242,7 +246,7 @@ func TestTimeouts(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
b := &Builder{filemgr: filemgr.NewManager()} b := &Builder{filemgr: filemgr.NewManager()}
routes, err := b.buildPolicyRoutes(&config.Options{ routes, err := b.buildPolicyRoutes(config.New(&config.Options{
CookieName: "pomerium", CookieName: "pomerium",
DefaultUpstreamTimeout: time.Second * 3, DefaultUpstreamTimeout: time.Second * 3,
Policies: []config.Policy{ Policies: []config.Policy{
@ -253,7 +257,7 @@ func TestTimeouts(t *testing.T) {
IdleTimeout: getDuration(tc.idle), IdleTimeout: getDuration(tc.idle),
AllowWebsockets: tc.allowWebsockets, AllowWebsockets: tc.allowWebsockets,
}}, }},
}, "example.com") }), "example.com")
if !assert.NoError(t, err, "%v", tc) || !assert.Len(t, routes, 1, tc) || !assert.NotNil(t, routes[0].GetRoute(), "%v", tc) { if !assert.NoError(t, err, "%v", tc) || !assert.Len(t, routes, 1, tc) || !assert.NotNil(t, routes[0].GetRoute(), "%v", tc) {
continue continue
} }
@ -295,7 +299,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
ten := time.Second * 10 ten := time.Second * 10
b := &Builder{filemgr: filemgr.NewManager()} b := &Builder{filemgr: filemgr.NewManager()}
routes, err := b.buildPolicyRoutes(&config.Options{ routes, err := b.buildPolicyRoutes(config.New(&config.Options{
CookieName: "pomerium", CookieName: "pomerium",
DefaultUpstreamTimeout: time.Second * 3, DefaultUpstreamTimeout: time.Second * 3,
Policies: []config.Policy{ Policies: []config.Policy{
@ -357,7 +361,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
UpstreamTimeout: &ten, UpstreamTimeout: &ten,
}, },
}, },
}, "example.com") }), "example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
@ -724,7 +728,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
`, routes) `, routes)
t.Run("fronting-authenticate", func(t *testing.T) { t.Run("fronting-authenticate", func(t *testing.T) {
routes, err := b.buildPolicyRoutes(&config.Options{ routes, err := b.buildPolicyRoutes(config.New(&config.Options{
AuthenticateURLString: "https://authenticate.example.com", AuthenticateURLString: "https://authenticate.example.com",
Services: "proxy", Services: "proxy",
CookieName: "pomerium", CookieName: "pomerium",
@ -735,7 +739,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
PassIdentityHeaders: true, PassIdentityHeaders: true,
}, },
}, },
}, "authenticate.example.com") }), "authenticate.example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
@ -791,7 +795,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
}) })
t.Run("tcp", func(t *testing.T) { t.Run("tcp", func(t *testing.T) {
routes, err := b.buildPolicyRoutes(&config.Options{ routes, err := b.buildPolicyRoutes(config.New(&config.Options{
CookieName: "pomerium", CookieName: "pomerium",
DefaultUpstreamTimeout: time.Second * 3, DefaultUpstreamTimeout: time.Second * 3,
Policies: []config.Policy{ Policies: []config.Policy{
@ -805,7 +809,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
UpstreamTimeout: &ten, UpstreamTimeout: &ten,
}, },
}, },
}, "example.com:22") }), "example.com:22")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
@ -905,7 +909,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
}) })
t.Run("remove-pomerium-headers", func(t *testing.T) { t.Run("remove-pomerium-headers", func(t *testing.T) {
routes, err := b.buildPolicyRoutes(&config.Options{ routes, err := b.buildPolicyRoutes(config.New(&config.Options{
AuthenticateURLString: "https://authenticate.example.com", AuthenticateURLString: "https://authenticate.example.com",
Services: "proxy", Services: "proxy",
CookieName: "pomerium", CookieName: "pomerium",
@ -918,7 +922,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
Source: &config.StringURL{URL: mustParseURL(t, "https://from.example.com")}, Source: &config.StringURL{URL: mustParseURL(t, "https://from.example.com")},
}, },
}, },
}, "from.example.com") }), "from.example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
@ -980,7 +984,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) {
}(getClusterID) }(getClusterID)
getClusterID = policyNameFunc() getClusterID = policyNameFunc()
b := &Builder{filemgr: filemgr.NewManager()} b := &Builder{filemgr: filemgr.NewManager()}
routes, err := b.buildPolicyRoutes(&config.Options{ routes, err := b.buildPolicyRoutes(config.New(&config.Options{
CookieName: "pomerium", CookieName: "pomerium",
DefaultUpstreamTimeout: time.Second * 3, DefaultUpstreamTimeout: time.Second * 3,
Policies: []config.Policy{ Policies: []config.Policy{
@ -1022,7 +1026,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) {
HostPathRegexRewriteSubstitution: "\\1", HostPathRegexRewriteSubstitution: "\\1",
}, },
}, },
}, "example.com") }), "example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `

View file

@ -14,8 +14,10 @@ import (
"github.com/pomerium/pomerium/pkg/protoutil" "github.com/pomerium/pomerium/pkg/protoutil"
) )
func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Cluster, error) { func buildTracingCluster(
tracingOptions, err := config.NewTracingOptions(options) cfg *config.Config,
) (*envoy_config_cluster_v3.Cluster, error) {
tracingOptions, err := config.NewTracingOptions(cfg)
if err != nil { if err != nil {
return nil, fmt.Errorf("envoyconfig: invalid tracing config: %w", err) return nil, fmt.Errorf("envoyconfig: invalid tracing config: %w", err)
} }
@ -24,8 +26,8 @@ func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Clus
case trace.DatadogTracingProviderName: case trace.DatadogTracingProviderName:
addr, _ := parseAddress("127.0.0.1:8126") addr, _ := parseAddress("127.0.0.1:8126")
if options.TracingDatadogAddress != "" { if cfg.Options.TracingDatadogAddress != "" {
addr, err = parseAddress(options.TracingDatadogAddress) addr, err = parseAddress(cfg.Options.TracingDatadogAddress)
if err != nil { if err != nil {
return nil, fmt.Errorf("envoyconfig: invalid tracing datadog address: %w", err) return nil, fmt.Errorf("envoyconfig: invalid tracing datadog address: %w", err)
} }
@ -94,8 +96,10 @@ func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Clus
} }
} }
func buildTracingHTTP(options *config.Options) (*envoy_config_trace_v3.Tracing_Http, error) { func buildTracingHTTP(
tracingOptions, err := config.NewTracingOptions(options) cfg *config.Config,
) (*envoy_config_trace_v3.Tracing_Http, error) {
tracingOptions, err := config.NewTracingOptions(cfg)
if err != nil { if err != nil {
return nil, fmt.Errorf("invalid tracing config: %w", err) return nil, fmt.Errorf("invalid tracing config: %w", err)
} }

View file

@ -11,9 +11,9 @@ import (
func TestBuildTracingCluster(t *testing.T) { func TestBuildTracingCluster(t *testing.T) {
t.Run("datadog", func(t *testing.T) { t.Run("datadog", func(t *testing.T) {
c, err := buildTracingCluster(&config.Options{ c, err := buildTracingCluster(config.New(&config.Options{
TracingProvider: "datadog", TracingProvider: "datadog",
}) }))
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
{ {
@ -38,10 +38,10 @@ func TestBuildTracingCluster(t *testing.T) {
} }
`, c) `, c)
c, err = buildTracingCluster(&config.Options{ c, err = buildTracingCluster(config.New(&config.Options{
TracingProvider: "datadog", TracingProvider: "datadog",
TracingDatadogAddress: "example.com:8126", TracingDatadogAddress: "example.com:8126",
}) }))
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
{ {
@ -67,10 +67,10 @@ func TestBuildTracingCluster(t *testing.T) {
`, c) `, c)
}) })
t.Run("zipkin", func(t *testing.T) { t.Run("zipkin", func(t *testing.T) {
c, err := buildTracingCluster(&config.Options{ c, err := buildTracingCluster(config.New(&config.Options{
TracingProvider: "zipkin", TracingProvider: "zipkin",
ZipkinEndpoint: "https://example.com/api/v2/spans", ZipkinEndpoint: "https://example.com/api/v2/spans",
}) }))
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
{ {
@ -99,9 +99,9 @@ func TestBuildTracingCluster(t *testing.T) {
func TestBuildTracingHTTP(t *testing.T) { func TestBuildTracingHTTP(t *testing.T) {
t.Run("datadog", func(t *testing.T) { t.Run("datadog", func(t *testing.T) {
h, err := buildTracingHTTP(&config.Options{ h, err := buildTracingHTTP(config.New(&config.Options{
TracingProvider: "datadog", TracingProvider: "datadog",
}) }))
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
{ {
@ -115,10 +115,10 @@ func TestBuildTracingHTTP(t *testing.T) {
`, h) `, h)
}) })
t.Run("zipkin", func(t *testing.T) { t.Run("zipkin", func(t *testing.T) {
h, err := buildTracingHTTP(&config.Options{ h, err := buildTracingHTTP(config.New(&config.Options{
TracingProvider: "zipkin", TracingProvider: "zipkin",
ZipkinEndpoint: "https://example.com/api/v2/spans", ZipkinEndpoint: "https://example.com/api/v2/spans",
}) }))
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
{ {

View file

@ -18,28 +18,28 @@ import (
type TracingOptions = trace.TracingOptions type TracingOptions = trace.TracingOptions
// NewTracingOptions builds a new TracingOptions from core Options // NewTracingOptions builds a new TracingOptions from core Options
func NewTracingOptions(o *Options) (*TracingOptions, error) { func NewTracingOptions(cfg *Config) (*TracingOptions, error) {
tracingOpts := TracingOptions{ tracingOpts := TracingOptions{
Provider: o.TracingProvider, Provider: cfg.Options.TracingProvider,
Service: telemetry.ServiceName(o.Services), Service: telemetry.ServiceName(cfg.Options.Services),
JaegerAgentEndpoint: o.TracingJaegerAgentEndpoint, JaegerAgentEndpoint: cfg.Options.TracingJaegerAgentEndpoint,
SampleRate: o.TracingSampleRate, SampleRate: cfg.Options.TracingSampleRate,
} }
switch o.TracingProvider { switch cfg.Options.TracingProvider {
case trace.DatadogTracingProviderName: case trace.DatadogTracingProviderName:
tracingOpts.DatadogAddress = o.TracingDatadogAddress tracingOpts.DatadogAddress = cfg.Options.TracingDatadogAddress
case trace.JaegerTracingProviderName: case trace.JaegerTracingProviderName:
if o.TracingJaegerCollectorEndpoint != "" { if cfg.Options.TracingJaegerCollectorEndpoint != "" {
jaegerCollectorEndpoint, err := urlutil.ParseAndValidateURL(o.TracingJaegerCollectorEndpoint) jaegerCollectorEndpoint, err := urlutil.ParseAndValidateURL(cfg.Options.TracingJaegerCollectorEndpoint)
if err != nil { if err != nil {
return nil, fmt.Errorf("config: invalid jaeger endpoint url: %w", err) return nil, fmt.Errorf("config: invalid jaeger endpoint url: %w", err)
} }
tracingOpts.JaegerCollectorEndpoint = jaegerCollectorEndpoint tracingOpts.JaegerCollectorEndpoint = jaegerCollectorEndpoint
tracingOpts.JaegerAgentEndpoint = o.TracingJaegerAgentEndpoint tracingOpts.JaegerAgentEndpoint = cfg.Options.TracingJaegerAgentEndpoint
} }
case trace.ZipkinTracingProviderName: case trace.ZipkinTracingProviderName:
zipkinEndpoint, err := urlutil.ParseAndValidateURL(o.ZipkinEndpoint) zipkinEndpoint, err := urlutil.ParseAndValidateURL(cfg.Options.ZipkinEndpoint)
if err != nil { if err != nil {
return nil, fmt.Errorf("config: invalid zipkin endpoint url: %w", err) return nil, fmt.Errorf("config: invalid zipkin endpoint url: %w", err)
} }
@ -47,7 +47,7 @@ func NewTracingOptions(o *Options) (*TracingOptions, error) {
case "": case "":
return &TracingOptions{}, nil return &TracingOptions{}, nil
default: default:
return nil, fmt.Errorf("config: provider %s unknown", o.TracingProvider) return nil, fmt.Errorf("config: provider %s unknown", cfg.Options.TracingProvider)
} }
return &tracingOpts, nil return &tracingOpts, nil
@ -88,7 +88,7 @@ func (mgr *TraceManager) OnConfigChange(ctx context.Context, cfg *Config) {
mgr.mu.Lock() mgr.mu.Lock()
defer mgr.mu.Unlock() defer mgr.mu.Unlock()
traceOpts, err := NewTracingOptions(cfg.Options) traceOpts, err := NewTracingOptions(cfg)
if err != nil { if err != nil {
log.Error(ctx).Err(err).Msg("trace: failed to build tracing options") log.Error(ctx).Err(err).Msg("trace: failed to build tracing options")
return return

View file

@ -68,7 +68,7 @@ func Test_NewTracingOptions(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := NewTracingOptions(tt.opts) got, err := NewTracingOptions(&Config{Options: tt.opts})
assert.NotEqual(t, err == nil, tt.wantErr, "unexpected error value") assert.NotEqual(t, err == nil, tt.wantErr, "unexpected error value")
assert.Empty(t, cmp.Diff(tt.want, got)) assert.Empty(t, cmp.Diff(tt.want, got))
}) })

View file

@ -28,7 +28,7 @@ func TestNew(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
tt.opts.Provider = "google" tt.opts.Provider = "google"
_, err := New(&config.Config{Options: &tt.opts}, events.New()) _, err := New(config.New(&tt.opts), events.New())
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return return

View file

@ -236,8 +236,8 @@ func TestConfig(t *testing.T) {
} }
_ = p1.Validate() _ = p1.Validate()
mgr, err := newManager(ctx, config.NewStaticSource(&config.Config{ mgr, err := newManager(ctx,
Options: &config.Options{ config.NewStaticSource(config.New(&config.Options{
AutocertOptions: config.AutocertOptions{ AutocertOptions: config.AutocertOptions{
Enable: true, Enable: true,
UseStaging: true, UseStaging: true,
@ -247,11 +247,11 @@ func TestConfig(t *testing.T) {
}, },
HTTPRedirectAddr: addr, HTTPRedirectAddr: addr,
Policies: []config.Policy{p1}, Policies: []config.Policy{p1},
}, })),
}), certmagic.ACMEIssuer{ certmagic.ACMEIssuer{
CA: srv.URL + "/acme/directory", CA: srv.URL + "/acme/directory",
TestCA: srv.URL + "/acme/directory", TestCA: srv.URL + "/acme/directory",
}, time.Millisecond*100) }, time.Millisecond*100)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return
} }
@ -304,16 +304,14 @@ func TestRedirect(t *testing.T) {
addr := li.Addr().String() addr := li.Addr().String()
_ = li.Close() _ = li.Close()
src := config.NewStaticSource(&config.Config{ src := config.NewStaticSource(config.New(&config.Options{
Options: &config.Options{ HTTPRedirectAddr: addr,
HTTPRedirectAddr: addr, SetResponseHeaders: map[string]string{
SetResponseHeaders: map[string]string{ "X-Frame-Options": "SAMEORIGIN",
"X-Frame-Options": "SAMEORIGIN", "X-XSS-Protection": "1; mode=block",
"X-XSS-Protection": "1; mode=block", "Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
},
}, },
}) }))
_, err = New(src) _, err = New(src)
if !assert.NoError(t, err) { if !assert.NoError(t, err) {
return return

View file

@ -28,7 +28,7 @@ import (
type Handler struct { type Handler struct {
mu sync.RWMutex mu sync.RWMutex
key []byte key []byte
options *config.Options cfg *config.Config
policies map[uint64]config.Policy policies map[uint64]config.Policy
} }
@ -83,7 +83,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler {
} }
h.mu.RLock() h.mu.RLock()
options := h.options options := h.cfg.Options
policy, ok := h.policies[policyID] policy, ok := h.policies[policyID]
h.mu.RUnlock() h.mu.RUnlock()
@ -132,7 +132,7 @@ func (h *Handler) Update(ctx context.Context, cfg *config.Config) {
defer h.mu.Unlock() defer h.mu.Unlock()
h.key, _ = cfg.Options.GetSharedKey() h.key, _ = cfg.Options.GetSharedKey()
h.options = cfg.Options h.cfg = cfg
h.policies = make(map[uint64]config.Policy) h.policies = make(map[uint64]config.Policy)
for _, p := range cfg.Options.GetAllPolicies() { for _, p := range cfg.Options.GetAllPolicies() {
id, err := p.RouteID() id, err := p.RouteID()

View file

@ -49,15 +49,13 @@ func TestMiddleware(t *testing.T) {
srv2 := httptest.NewServer(h.Middleware(next)) srv2 := httptest.NewServer(h.Middleware(next))
defer srv2.Close() defer srv2.Close()
cfg := &config.Config{ cfg := config.New(&config.Options{
Options: &config.Options{ SharedKey: cryptutil.NewBase64Key(),
SharedKey: cryptutil.NewBase64Key(), Policies: []config.Policy{{
Policies: []config.Policy{{ To: config.WeightedURLs{{URL: *u}},
To: config.WeightedURLs{{URL: *u}}, KubernetesServiceAccountToken: "ABCD",
KubernetesServiceAccountToken: "ABCD", }},
}}, })
},
}
h.Update(context.Background(), cfg) h.Update(context.Background(), cfg)
policyID, _ := cfg.Options.Policies[0].RouteID() policyID, _ := cfg.Options.Policies[0].RouteID()

View file

@ -80,11 +80,11 @@ func TestProxy_ForwardAuth(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options}) p, err := New(config.New(tt.options))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.OnConfigChange(ctx, &config.Config{Options: tt.options}) p.OnConfigChange(ctx, config.New(tt.options))
state := p.state.Load() state := p.state.Load()
state.sessionStore = tt.sessionStore state.sessionStore = tt.sessionStore
signer, err := jws.NewHS256Signer(nil) signer, err := jws.NewHS256Signer(nil)

View file

@ -58,7 +58,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) error {
state := p.state.Load() state := p.state.Load()
var redirectURL *url.URL var redirectURL *url.URL
signOutURL, err := p.currentOptions.Load().GetSignOutRedirectURL() signOutURL, err := p.currentConfig.Load().Options.GetSignOutRedirectURL()
if err != nil { if err != nil {
return httputil.NewError(http.StatusInternalServerError, err) return httputil.NewError(http.StatusInternalServerError, err)
} }

View file

@ -47,7 +47,7 @@ func TestProxy_Signout(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
proxy, err := New(&config.Config{Options: opts}) proxy, err := New(config.New(opts))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -70,7 +70,7 @@ func TestProxy_userInfo(t *testing.T) {
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
proxy, err := New(&config.Config{Options: opts}) proxy, err := New(config.New(opts))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -102,7 +102,7 @@ func TestProxy_SignOut(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
opts := testOptions(t) opts := testOptions(t)
p, err := New(&config.Config{Options: opts}) p, err := New(config.New(opts))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -236,11 +236,11 @@ func TestProxy_Callback(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options}) p, err := New(config.New(tt.options))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options}) p.OnConfigChange(context.Background(), config.New(tt.options))
state := p.state.Load() state := p.state.Load()
state.encoder = tt.cipher state.encoder = tt.cipher
state.sessionStore = tt.sessionStore state.sessionStore = tt.sessionStore
@ -350,7 +350,7 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options}) p, err := New(config.New(tt.options))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -482,11 +482,11 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.options}) p, err := New(config.New(tt.options))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options}) p.OnConfigChange(context.Background(), config.New(tt.options))
state := p.state.Load() state := p.state.Load()
state.encoder = tt.cipher state.encoder = tt.cipher
state.sessionStore = tt.sessionStore state.sessionStore = tt.sessionStore

View file

@ -51,9 +51,9 @@ func ValidateOptions(o *config.Options) error {
// Proxy stores all the information associated with proxying a request. // Proxy stores all the information associated with proxying a request.
type Proxy struct { type Proxy struct {
state *atomicutil.Value[*proxyState] state *atomicutil.Value[*proxyState]
currentOptions *atomicutil.Value[*config.Options] currentConfig *atomicutil.Value[*config.Config]
currentRouter *atomicutil.Value[*mux.Router] currentRouter *atomicutil.Value[*mux.Router]
} }
// New takes a Proxy service from options and a validation function. // New takes a Proxy service from options and a validation function.
@ -65,13 +65,13 @@ func New(cfg *config.Config) (*Proxy, error) {
} }
p := &Proxy{ p := &Proxy{
state: atomicutil.NewValue(state), state: atomicutil.NewValue(state),
currentOptions: config.NewAtomicOptions(), currentConfig: atomicutil.NewValue(cfg),
currentRouter: atomicutil.NewValue(httputil.NewRouter()), currentRouter: atomicutil.NewValue(httputil.NewRouter()),
} }
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 { metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
return int64(len(p.currentOptions.Load().GetAllPolicies())) return int64(len(p.currentConfig.Load().Options.GetAllPolicies()))
}) })
return p, nil return p, nil
@ -88,8 +88,8 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
return return
} }
p.currentOptions.Store(cfg.Options) p.currentConfig.Store(cfg)
if err := p.setHandlers(cfg.Options); err != nil { if err := p.setHandlers(cfg); err != nil {
log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings") log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
} }
if state, err := newProxyStateFromConfig(cfg); err != nil { if state, err := newProxyStateFromConfig(cfg); err != nil {
@ -99,8 +99,8 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
} }
} }
func (p *Proxy) setHandlers(opts *config.Options) error { func (p *Proxy) setHandlers(cfg *config.Config) error {
if len(opts.GetAllPolicies()) == 0 { if len(cfg.Options.GetAllPolicies()) == 0 {
log.Warn(context.TODO()).Msg("proxy: configuration has no policies") log.Warn(context.TODO()).Msg("proxy: configuration has no policies")
} }
r := httputil.NewRouter() r := httputil.NewRouter()
@ -113,7 +113,7 @@ func (p *Proxy) setHandlers(opts *config.Options) error {
// dashboard handlers are registered to all routes // dashboard handlers are registered to all routes
r = p.registerDashboardHandlers(r) r = p.registerDashboardHandlers(r)
forwardAuthURL, err := opts.GetForwardAuthURL() forwardAuthURL, err := cfg.Options.GetForwardAuthURL()
if err != nil { if err != nil {
return err return err
} }

View file

@ -99,7 +99,7 @@ func TestNew(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := New(&config.Config{Options: tt.opts}) got, err := New(config.New(tt.opts))
if (err != nil) != tt.wantErr { if (err != nil) != tt.wantErr {
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
return return
@ -194,12 +194,12 @@ func Test_UpdateOptions(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
p, err := New(&config.Config{Options: tt.originalOptions}) p, err := New(config.New(tt.originalOptions))
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
p.OnConfigChange(context.Background(), &config.Config{Options: tt.updatedOptions}) p.OnConfigChange(context.Background(), config.New(tt.updatedOptions))
r := httptest.NewRequest("GET", tt.host, nil) r := httptest.NewRequest("GET", tt.host, nil)
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.ServeHTTP(w, r) p.ServeHTTP(w, r)
@ -212,5 +212,5 @@ func Test_UpdateOptions(t *testing.T) {
// Test nil // Test nil
var p *Proxy var p *Proxy
p.OnConfigChange(context.Background(), &config.Config{}) p.OnConfigChange(context.Background(), config.New(nil))
} }