mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-23 21:18:09 +02:00
config: use config.Config instead of config.Options everywhere
This commit is contained in:
parent
5f51510e91
commit
1b80e8a6c2
40 changed files with 484 additions and 412 deletions
|
@ -40,7 +40,7 @@ func ValidateOptions(o *config.Options) error {
|
|||
// Authenticate contains data required to run the authenticate service.
|
||||
type Authenticate struct {
|
||||
cfg *authenticateConfig
|
||||
options *atomicutil.Value[*config.Options]
|
||||
currentConfig *atomicutil.Value[*config.Config]
|
||||
state *atomicutil.Value[*authenticateState]
|
||||
webauthn *webauthn.Handler
|
||||
}
|
||||
|
@ -49,7 +49,7 @@ type Authenticate struct {
|
|||
func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
|
||||
a := &Authenticate{
|
||||
cfg: getAuthenticateConfig(options...),
|
||||
options: config.NewAtomicOptions(),
|
||||
currentConfig: atomicutil.NewValue(cfg),
|
||||
state: atomicutil.NewValue(newAuthenticateState()),
|
||||
}
|
||||
a.webauthn = webauthn.New(a.getWebauthnState)
|
||||
|
@ -69,7 +69,7 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
|||
return
|
||||
}
|
||||
|
||||
a.options.Store(cfg.Options)
|
||||
a.currentConfig.Store(cfg)
|
||||
if state, err := newAuthenticateStateFromConfig(cfg); err != nil {
|
||||
log.Error(ctx).Err(err).Msg("authenticate: failed to update state")
|
||||
} else {
|
||||
|
|
|
@ -113,7 +113,7 @@ func TestNew(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
_, err := New(&config.Config{Options: tt.opts})
|
||||
_, err := New(config.New(tt.opts))
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
)
|
||||
|
||||
type authenticateConfig struct {
|
||||
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)
|
||||
getIdentityProvider func(cfg *config.Config, idpID string) (identity.Authenticator, error)
|
||||
}
|
||||
|
||||
// An Option customizes the Authenticate config.
|
||||
|
@ -22,7 +22,7 @@ func getAuthenticateConfig(options ...Option) *authenticateConfig {
|
|||
}
|
||||
|
||||
// WithGetIdentityProvider sets the getIdentityProvider function in the config.
|
||||
func WithGetIdentityProvider(getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)) Option {
|
||||
func WithGetIdentityProvider(getIdentityProvider func(cfg *config.Config, idpID string) (identity.Authenticator, error)) Option {
|
||||
return func(cfg *authenticateConfig) {
|
||||
cfg.getIdentityProvider = getIdentityProvider
|
||||
}
|
||||
|
|
|
@ -47,12 +47,12 @@ func (a *Authenticate) Mount(r *mux.Router) {
|
|||
r.StrictSlash(true)
|
||||
r.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
|
||||
r.Use(func(h http.Handler) http.Handler {
|
||||
options := a.options.Load()
|
||||
cfg := a.currentConfig.Load()
|
||||
state := a.state.Load()
|
||||
csrfKey := fmt.Sprintf("%s_csrf", options.CookieName)
|
||||
csrfKey := fmt.Sprintf("%s_csrf", cfg.Options.CookieName)
|
||||
return csrf.Protect(
|
||||
state.cookieSecret,
|
||||
csrf.Secure(options.CookieSecure),
|
||||
csrf.Secure(cfg.Options.CookieSecure),
|
||||
csrf.Path("/"),
|
||||
csrf.UnsafePaths(
|
||||
[]string{
|
||||
|
@ -256,7 +256,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
|||
// check for an HMAC'd URL. If none is found, show a confirmation page.
|
||||
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
|
||||
if err != nil {
|
||||
authenticateURL, err := a.options.Load().GetAuthenticateURL()
|
||||
authenticateURL, err := a.currentConfig.Load().Options.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -275,9 +275,9 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
|
|||
ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
|
||||
defer span.End()
|
||||
|
||||
options := a.options.Load()
|
||||
cfg := a.currentConfig.Load()
|
||||
|
||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -285,7 +285,7 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
|
|||
rawIDToken := a.revokeSession(ctx, w, r)
|
||||
|
||||
redirectString := ""
|
||||
signOutURL, err := a.options.Load().GetSignOutRedirectURL()
|
||||
signOutURL, err := cfg.Options.GetSignOutRedirectURL()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -330,10 +330,10 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
|
|||
return httputil.NewError(http.StatusUnauthorized, err)
|
||||
}
|
||||
|
||||
options := a.options.Load()
|
||||
cfg := a.currentConfig.Load()
|
||||
state := a.state.Load()
|
||||
|
||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -381,7 +381,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
|
||||
defer span.End()
|
||||
|
||||
options := a.options.Load()
|
||||
cfg := a.currentConfig.Load()
|
||||
state := a.state.Load()
|
||||
|
||||
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
|
||||
|
@ -430,7 +430,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
|||
}
|
||||
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)
|
||||
|
||||
idp, err := a.cfg.getIdentityProvider(options, idpID)
|
||||
idp, err := a.cfg.getIdentityProvider(cfg, idpID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -513,7 +513,7 @@ func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error {
|
|||
func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData, error) {
|
||||
state := a.state.Load()
|
||||
|
||||
authenticateURL, err := a.options.Load().GetAuthenticateURL()
|
||||
authenticateURL, err := a.currentConfig.Load().Options.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return handlers.UserInfoData{}, err
|
||||
}
|
||||
|
@ -569,7 +569,7 @@ func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData,
|
|||
WebAuthnRequestOptions: requestOptions,
|
||||
WebAuthnURL: urlutil.WebAuthnURL(r, authenticateURL, state.sharedKey, r.URL.Query()),
|
||||
|
||||
BrandingOptions: a.options.Load().BrandingOptions,
|
||||
BrandingOptions: a.currentConfig.Load().Options.BrandingOptions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
@ -581,14 +581,14 @@ func (a *Authenticate) saveSessionToDataBroker(
|
|||
accessToken *oauth2.Token,
|
||||
) error {
|
||||
state := a.state.Load()
|
||||
options := a.options.Load()
|
||||
cfg := a.currentConfig.Load()
|
||||
|
||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
sessionExpiry := timestamppb.New(time.Now().Add(options.CookieExpire))
|
||||
sessionExpiry := timestamppb.New(time.Now().Add(cfg.Options.CookieExpire))
|
||||
idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time())
|
||||
|
||||
s := &session.Session{
|
||||
|
@ -648,13 +648,13 @@ func (a *Authenticate) saveSessionToDataBroker(
|
|||
// databroker. If successful, it returns the original `id_token` of the session, if failed, returns
|
||||
// and empty string.
|
||||
func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string {
|
||||
options := a.options.Load()
|
||||
cfg := a.currentConfig.Load()
|
||||
state := a.state.Load()
|
||||
|
||||
// clear the user's local session no matter what
|
||||
defer state.sessionStore.ClearSession(w, r)
|
||||
|
||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
|
@ -708,6 +708,7 @@ func (a *Authenticate) getDirectoryUser(ctx context.Context, userID string) (*di
|
|||
|
||||
func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, error) {
|
||||
state := a.state.Load()
|
||||
cfg := a.currentConfig.Load()
|
||||
|
||||
s, _, err := a.getCurrentSession(ctx)
|
||||
if err != nil {
|
||||
|
@ -719,17 +720,17 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
|
|||
return nil, err
|
||||
}
|
||||
|
||||
authenticateURL, err := a.options.Load().GetAuthenticateURL()
|
||||
authenticateURL, err := cfg.Options.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
internalAuthenticateURL, err := a.options.Load().GetInternalAuthenticateURL()
|
||||
internalAuthenticateURL, err := cfg.Options.GetInternalAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pomeriumDomains, err := a.options.Load().GetAllRouteableHTTPDomains()
|
||||
pomeriumDomains, err := cfg.Options.GetAllRouteableHTTPDomains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -744,7 +745,7 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
|
|||
SessionState: ss,
|
||||
SessionStore: state.sessionStore,
|
||||
RelyingParty: state.webauthnRelyingParty,
|
||||
BrandingOptions: a.options.Load().BrandingOptions,
|
||||
BrandingOptions: cfg.Options.BrandingOptions,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -49,10 +49,9 @@ func testAuthenticate() *Authenticate {
|
|||
redirectURL: redirectURL,
|
||||
cookieSecret: cryptutil.NewKey(),
|
||||
})
|
||||
auth.options = config.NewAtomicOptions()
|
||||
auth.options.Store(&config.Options{
|
||||
auth.currentConfig = atomicutil.NewValue(config.New(&config.Options{
|
||||
SharedKey: cryptutil.NewBase64Key(),
|
||||
})
|
||||
}))
|
||||
return &auth
|
||||
}
|
||||
|
||||
|
@ -148,7 +147,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||
|
||||
a := &Authenticate{
|
||||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
|
||||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
|
||||
return tt.provider, nil
|
||||
})),
|
||||
state: atomicutil.NewValue(&authenticateState{
|
||||
|
@ -168,10 +167,10 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
|||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
}),
|
||||
|
||||
options: config.NewAtomicOptions(),
|
||||
currentConfig: atomicutil.NewValue(config.New(&config.Options{
|
||||
SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey()),
|
||||
})),
|
||||
}
|
||||
a.options.Store(&config.Options{SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey())})
|
||||
uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
|
||||
|
||||
queryString := uri.Query()
|
||||
|
@ -304,7 +303,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
ctrl := gomock.NewController(t)
|
||||
defer ctrl.Finish()
|
||||
a := &Authenticate{
|
||||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
|
||||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
|
||||
return tt.provider, nil
|
||||
})),
|
||||
state: atomicutil.NewValue(&authenticateState{
|
||||
|
@ -325,12 +324,9 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
|||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
}),
|
||||
options: config.NewAtomicOptions(),
|
||||
}
|
||||
if tt.signoutRedirectURL != "" {
|
||||
opts := a.options.Load()
|
||||
opts.SignOutRedirectURLString = tt.signoutRedirectURL
|
||||
a.options.Store(opts)
|
||||
currentConfig: atomicutil.NewValue(config.New(&config.Options{
|
||||
SignOutRedirectURLString: tt.signoutRedirectURL,
|
||||
})),
|
||||
}
|
||||
u, _ := url.Parse("/sign_out")
|
||||
params, _ := url.ParseQuery(u.RawQuery)
|
||||
|
@ -417,7 +413,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
}
|
||||
authURL, _ := url.Parse(tt.authenticateURL)
|
||||
a := &Authenticate{
|
||||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
|
||||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
|
||||
return tt.provider, nil
|
||||
})),
|
||||
state: atomicutil.NewValue(&authenticateState{
|
||||
|
@ -435,7 +431,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
|||
cookieCipher: aead,
|
||||
encryptedEncoder: signer,
|
||||
}),
|
||||
options: config.NewAtomicOptions(),
|
||||
currentConfig: atomicutil.NewValue(config.New(nil)),
|
||||
}
|
||||
u, _ := url.Parse("/oauthGet")
|
||||
params, _ := url.ParseQuery(u.RawQuery)
|
||||
|
@ -552,7 +548,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
a := &Authenticate{
|
||||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
|
||||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
|
||||
return tt.provider, nil
|
||||
})),
|
||||
state: atomicutil.NewValue(&authenticateState{
|
||||
|
@ -573,7 +569,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
}),
|
||||
options: config.NewAtomicOptions(),
|
||||
currentConfig: atomicutil.NewValue(config.New(nil)),
|
||||
}
|
||||
r := httptest.NewRequest("GET", "/", nil)
|
||||
state, err := tt.session.LoadSession(r)
|
||||
|
@ -604,7 +600,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
|||
func TestJwksEndpoint(t *testing.T) {
|
||||
o := newTestOptions(t)
|
||||
o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
|
||||
auth, err := New(&config.Config{Options: o})
|
||||
auth, err := New(config.New(o))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
return
|
||||
|
@ -632,12 +628,11 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
a.state = atomicutil.NewValue(&authenticateState{
|
||||
cookieSecret: cryptutil.NewKey(),
|
||||
})
|
||||
a.options = config.NewAtomicOptions()
|
||||
a.options.Store(&config.Options{
|
||||
a.currentConfig = atomicutil.NewValue(config.New(&config.Options{
|
||||
SharedKey: cryptutil.NewBase64Key(),
|
||||
AuthenticateURLString: "https://authenticate.example.com",
|
||||
AuthenticateInternalURLString: "https://authenticate.service.cluster.local",
|
||||
})
|
||||
}))
|
||||
err := a.userInfo(w, r)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, http.StatusFound, w.Code)
|
||||
|
@ -687,13 +682,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
o := config.NewAtomicOptions()
|
||||
o.Store(&config.Options{
|
||||
AuthenticateURLString: "https://authenticate.localhost.pomerium.io",
|
||||
SharedKey: "SHARED KEY",
|
||||
})
|
||||
a := &Authenticate{
|
||||
options: o,
|
||||
state: atomicutil.NewValue(&authenticateState{
|
||||
sessionStore: tt.sessionStore,
|
||||
encryptedEncoder: signer,
|
||||
|
@ -711,6 +700,10 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
|||
},
|
||||
directoryClient: new(mockDirectoryServiceClient),
|
||||
}),
|
||||
currentConfig: atomicutil.NewValue(config.New(&config.Options{
|
||||
AuthenticateURLString: "https://authenticate.localhost.pomerium.io",
|
||||
SharedKey: "SHARED KEY",
|
||||
})),
|
||||
}
|
||||
a.webauthn = webauthn.New(a.getWebauthnState)
|
||||
r := httptest.NewRequest(tt.method, tt.url.String(), nil)
|
||||
|
|
|
@ -7,8 +7,8 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
func defaultGetIdentityProvider(options *config.Options, idpID string) (identity.Authenticator, error) {
|
||||
authenticateURL, err := options.GetAuthenticateURL()
|
||||
func defaultGetIdentityProvider(cfg *config.Config, idpID string) (identity.Authenticator, error) {
|
||||
authenticateURL, err := cfg.Options.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -17,9 +17,9 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
redirectURL.Path = options.AuthenticateCallbackPath
|
||||
redirectURL.Path = cfg.Options.AuthenticateCallbackPath
|
||||
|
||||
idp, err := options.GetIdentityProviderForID(idpID)
|
||||
idp, err := cfg.Options.GetIdentityProviderForID(idpID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -34,14 +34,14 @@ func (a *Authenticate) requireValidSignature(next httputil.HandlerFunc) http.Han
|
|||
}
|
||||
|
||||
func (a *Authenticate) getExternalRequest(r *http.Request) *http.Request {
|
||||
options := a.options.Load()
|
||||
cfg := a.currentConfig.Load()
|
||||
|
||||
externalURL, err := options.GetAuthenticateURL()
|
||||
externalURL, err := cfg.Options.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return r
|
||||
}
|
||||
|
||||
internalURL, err := options.GetInternalAuthenticateURL()
|
||||
internalURL, err := cfg.Options.GetInternalAuthenticateURL()
|
||||
if err != nil {
|
||||
return r
|
||||
}
|
||||
|
|
|
@ -27,7 +27,7 @@ import (
|
|||
type Authorize struct {
|
||||
state *atomicutil.Value[*authorizeState]
|
||||
store *store.Store
|
||||
currentOptions *atomicutil.Value[*config.Options]
|
||||
currentConfig *atomicutil.Value[*config.Config]
|
||||
accessTracker *AccessTracker
|
||||
globalCache storage.Cache
|
||||
|
||||
|
@ -40,7 +40,7 @@ type Authorize struct {
|
|||
// New validates and creates a new Authorize service from a set of config options.
|
||||
func New(cfg *config.Config) (*Authorize, error) {
|
||||
a := &Authorize{
|
||||
currentOptions: config.NewAtomicOptions(),
|
||||
currentConfig: atomicutil.NewValue(cfg),
|
||||
store: store.New(),
|
||||
globalCache: storage.NewGlobalCache(time.Minute),
|
||||
}
|
||||
|
@ -86,42 +86,42 @@ func validateOptions(o *config.Options) error {
|
|||
}
|
||||
|
||||
// newPolicyEvaluator returns an policy evaluator.
|
||||
func newPolicyEvaluator(opts *config.Options, store *store.Store) (*evaluator.Evaluator, error) {
|
||||
func newPolicyEvaluator(cfg *config.Config, store *store.Store) (*evaluator.Evaluator, error) {
|
||||
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
|
||||
return int64(len(opts.GetAllPolicies()))
|
||||
return int64(len(cfg.Options.GetAllPolicies()))
|
||||
})
|
||||
ctx := context.Background()
|
||||
_, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator")
|
||||
defer span.End()
|
||||
|
||||
clientCA, err := opts.GetClientCA()
|
||||
clientCA, err := cfg.Options.GetClientCA()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: invalid client CA: %w", err)
|
||||
}
|
||||
|
||||
authenticateURL, err := opts.GetInternalAuthenticateURL()
|
||||
authenticateURL, err := cfg.Options.GetInternalAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: invalid authenticate url: %w", err)
|
||||
}
|
||||
|
||||
signingKey, err := opts.GetSigningKey()
|
||||
signingKey, err := cfg.Options.GetSigningKey()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: invalid signing key: %w", err)
|
||||
}
|
||||
|
||||
return evaluator.New(ctx, store,
|
||||
evaluator.WithPolicies(opts.GetAllPolicies()),
|
||||
evaluator.WithPolicies(cfg.Options.GetAllPolicies()),
|
||||
evaluator.WithClientCA(clientCA),
|
||||
evaluator.WithSigningKey(signingKey),
|
||||
evaluator.WithAuthenticateURL(authenticateURL.String()),
|
||||
evaluator.WithGoogleCloudServerlessAuthenticationServiceAccount(opts.GetGoogleCloudServerlessAuthenticationServiceAccount()),
|
||||
evaluator.WithJWTClaimsHeaders(opts.JWTClaimsHeaders),
|
||||
evaluator.WithGoogleCloudServerlessAuthenticationServiceAccount(cfg.Options.GetGoogleCloudServerlessAuthenticationServiceAccount()),
|
||||
evaluator.WithJWTClaimsHeaders(cfg.Options.JWTClaimsHeaders),
|
||||
)
|
||||
}
|
||||
|
||||
// OnConfigChange updates internal structures based on config.Options
|
||||
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||
a.currentOptions.Store(cfg.Options)
|
||||
a.currentConfig.Store(cfg)
|
||||
if state, err := newAuthorizeStateFromConfig(cfg, a.store); err != nil {
|
||||
log.Error(ctx).Err(err).Msg("authorize: error updating state")
|
||||
} else {
|
||||
|
|
|
@ -74,7 +74,7 @@ func TestNew(t *testing.T) {
|
|||
tt := tt
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
_, err := New(&config.Config{Options: &tt.config})
|
||||
_, err := New(config.New(&tt.config))
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -105,12 +105,12 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
|
|||
SharedKey: tc.SharedKey,
|
||||
Policies: tc.Policies,
|
||||
}
|
||||
a, err := New(&config.Config{Options: o})
|
||||
a, err := New(config.New(o))
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, a)
|
||||
|
||||
oldPe := a.state.Load().evaluator
|
||||
cfg := &config.Config{Options: o}
|
||||
cfg := config.New(o)
|
||||
assertFunc := assert.True
|
||||
o.SigningKey = "bad-share-key"
|
||||
if tc.expectedChange {
|
||||
|
|
|
@ -125,7 +125,7 @@ func (a *Authorize) deniedResponse(
|
|||
respBody := []byte(reason)
|
||||
respHeader := []*envoy_config_core_v3.HeaderValueOption{}
|
||||
|
||||
forwardAuthURL, _ := a.currentOptions.Load().GetForwardAuthURL()
|
||||
forwardAuthURL, _ := a.currentConfig.Load().Options.GetForwardAuthURL()
|
||||
if forwardAuthURL == nil {
|
||||
// create a http response writer recorder
|
||||
w := httptest.NewRecorder()
|
||||
|
@ -140,7 +140,7 @@ func (a *Authorize) deniedResponse(
|
|||
Err: errors.New(reason),
|
||||
DebugURL: debugEndpoint,
|
||||
RequestID: requestid.FromContext(ctx),
|
||||
BrandingOptions: a.currentOptions.Load().BrandingOptions,
|
||||
BrandingOptions: a.currentConfig.Load().Options.BrandingOptions,
|
||||
}
|
||||
httpErr.ErrorResponse(ctx, w, r)
|
||||
|
||||
|
@ -184,7 +184,7 @@ func (a *Authorize) requireLoginResponse(
|
|||
request *evaluator.Request,
|
||||
isForwardAuthVerify bool,
|
||||
) (*envoy_service_auth_v3.CheckResponse, error) {
|
||||
opts := a.currentOptions.Load()
|
||||
opts := a.currentConfig.Load().Options
|
||||
state := a.state.Load()
|
||||
authenticateURL, err := opts.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
|
@ -225,7 +225,7 @@ func (a *Authorize) requireWebAuthnResponse(
|
|||
result *evaluator.Result,
|
||||
isForwardAuthVerify bool,
|
||||
) (*envoy_service_auth_v3.CheckResponse, error) {
|
||||
opts := a.currentOptions.Load()
|
||||
opts := a.currentConfig.Load().Options
|
||||
state := a.state.Load()
|
||||
authenticateURL, err := opts.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
|
@ -295,7 +295,7 @@ func toEnvoyHeaders(headers http.Header) []*envoy_config_core_v3.HeaderValueOpti
|
|||
// userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's
|
||||
// session that lives on the authenticate service.
|
||||
func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) {
|
||||
opts := a.currentOptions.Load()
|
||||
opts := a.currentConfig.Load().Options
|
||||
authenticateURL, err := opts.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -29,7 +29,7 @@ func TestAuthorize_handleResult(t *testing.T) {
|
|||
opt.AuthenticateURLString = "https://authenticate.example.com"
|
||||
opt.DataBrokerURLString = "https://databroker.example.com"
|
||||
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
|
||||
a, err := New(&config.Config{Options: opt})
|
||||
a, err := New(config.New(opt))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("user-unauthenticated", func(t *testing.T) {
|
||||
|
@ -67,12 +67,12 @@ func TestAuthorize_okResponse(t *testing.T) {
|
|||
}},
|
||||
JWTClaimsHeaders: config.NewJWTClaimHeaders("email"),
|
||||
}
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
cfg := config.New(opt)
|
||||
a := &Authorize{currentConfig: atomicutil.NewValue(cfg), state: atomicutil.NewValue(new(authorizeState))}
|
||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||
a.state.Load().encoder = encoder
|
||||
a.currentOptions.Store(opt)
|
||||
a.store = store.New()
|
||||
pe, err := newPolicyEvaluator(opt, a.store)
|
||||
pe, err := newPolicyEvaluator(cfg, a.store)
|
||||
require.NoError(t, err)
|
||||
a.state.Load().evaluator = pe
|
||||
|
||||
|
@ -123,17 +123,17 @@ func TestAuthorize_okResponse(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestAuthorize_deniedResponse(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
|
||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||
a.state.Load().encoder = encoder
|
||||
a.currentOptions.Store(&config.Options{
|
||||
a.currentConfig.Store(config.New(&config.Options{
|
||||
Policies: []config.Policy{{
|
||||
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
}))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
@ -190,7 +190,7 @@ func TestRequireLogin(t *testing.T) {
|
|||
opt.AuthenticateURLString = "https://authenticate.example.com"
|
||||
opt.DataBrokerURLString = "https://databroker.example.com"
|
||||
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
|
||||
a, err := New(&config.Config{Options: opt})
|
||||
a, err := New(config.New(opt))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("accept empty", func(t *testing.T) {
|
||||
|
|
|
@ -55,7 +55,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
}
|
||||
}
|
||||
|
||||
rawJWT, _ := loadRawSession(hreq, a.currentOptions.Load(), state.encoder)
|
||||
rawJWT, _ := loadRawSession(hreq, a.currentConfig.Load(), state.encoder)
|
||||
sessionState, _ := loadSession(state.encoder, rawJWT)
|
||||
|
||||
var s sessionOrServiceAccount
|
||||
|
@ -100,9 +100,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
|||
|
||||
// isForwardAuth returns if the current request is a forward auth route.
|
||||
func (a *Authorize) isForwardAuth(req *envoy_service_auth_v3.CheckRequest) bool {
|
||||
opts := a.currentOptions.Load()
|
||||
cfg := a.currentConfig.Load()
|
||||
|
||||
forwardAuthURL, err := opts.GetForwardAuthURL()
|
||||
forwardAuthURL, err := cfg.Options.GetForwardAuthURL()
|
||||
if err != nil || forwardAuthURL == nil {
|
||||
return false
|
||||
}
|
||||
|
@ -136,9 +136,9 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
|
|||
}
|
||||
|
||||
func (a *Authorize) getMatchingPolicy(requestURL url.URL) *config.Policy {
|
||||
options := a.currentOptions.Load()
|
||||
cfg := a.currentConfig.Load()
|
||||
|
||||
for _, p := range options.GetAllPolicies() {
|
||||
for _, p := range cfg.Options.GetAllPolicies() {
|
||||
if p.Matches(requestURL) {
|
||||
return &p
|
||||
}
|
||||
|
|
|
@ -47,17 +47,17 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
|
|||
-----END CERTIFICATE-----`
|
||||
|
||||
func Test_getEvaluatorRequest(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
|
||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||
a.state.Load().encoder = encoder
|
||||
a.currentOptions.Store(&config.Options{
|
||||
a.currentConfig.Store(config.New(&config.Options{
|
||||
Policies: []config.Policy{{
|
||||
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
}))
|
||||
|
||||
actual, err := a.getEvaluatorRequestFromCheckRequest(
|
||||
&envoy_service_auth_v3.CheckRequest{
|
||||
|
@ -87,7 +87,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
|
|||
)
|
||||
require.NoError(t, err)
|
||||
expect := &evaluator.Request{
|
||||
Policy: &a.currentOptions.Load().Policies[0],
|
||||
Policy: &a.currentConfig.Load().Options.Policies[0],
|
||||
Session: evaluator.RequestSession{
|
||||
ID: "SESSION_ID",
|
||||
},
|
||||
|
@ -248,8 +248,10 @@ func Test_handleForwardAuth(t *testing.T) {
|
|||
for _, tc := range tests {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentOptions.Store(&config.Options{ForwardAuthURLString: tc.forwardAuthURL})
|
||||
a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a.currentConfig.Store(config.New(&config.Options{
|
||||
ForwardAuthURLString: tc.forwardAuthURL,
|
||||
}))
|
||||
|
||||
got := a.isForwardAuth(tc.checkReq)
|
||||
|
||||
|
@ -261,17 +263,17 @@ func Test_handleForwardAuth(t *testing.T) {
|
|||
}
|
||||
|
||||
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
||||
a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
|
||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||
a.state.Load().encoder = encoder
|
||||
a.currentOptions.Store(&config.Options{
|
||||
a.currentConfig.Store(config.New(&config.Options{
|
||||
Policies: []config.Policy{{
|
||||
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
|
||||
SubPolicies: []config.SubPolicy{{
|
||||
Rego: []string{"allow = true"},
|
||||
}},
|
||||
}},
|
||||
})
|
||||
}))
|
||||
|
||||
actual, err := a.getEvaluatorRequestFromCheckRequest(&envoy_service_auth_v3.CheckRequest{
|
||||
Attributes: &envoy_service_auth_v3.AttributeContext{
|
||||
|
@ -296,7 +298,7 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
|||
}, nil)
|
||||
require.NoError(t, err)
|
||||
expect := &evaluator.Request{
|
||||
Policy: &a.currentOptions.Load().Policies[0],
|
||||
Policy: &a.currentConfig.Load().Options.Policies[0],
|
||||
Session: evaluator.RequestSession{},
|
||||
HTTP: evaluator.NewRequestHTTP(
|
||||
"GET",
|
||||
|
@ -332,11 +334,13 @@ func TestAuthorize_Check(t *testing.T) {
|
|||
opt.AuthenticateURLString = "https://authenticate.example.com"
|
||||
opt.DataBrokerURLString = "https://databroker.example.com"
|
||||
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
|
||||
a, err := New(&config.Config{Options: opt})
|
||||
a, err := New(config.New(opt))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
a.currentOptions.Store(&config.Options{ForwardAuthURLString: "https://forward-auth.example.com"})
|
||||
a.currentConfig.Store(config.New(&config.Options{
|
||||
ForwardAuthURLString: "https://forward-auth.example.com",
|
||||
}))
|
||||
|
||||
cmpOpts := []cmp.Option{
|
||||
cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}),
|
||||
|
|
|
@ -13,9 +13,13 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
||||
func loadRawSession(req *http.Request, options *config.Options, encoder encoding.MarshalUnmarshaler) ([]byte, error) {
|
||||
func loadRawSession(
|
||||
req *http.Request,
|
||||
cfg *config.Config,
|
||||
encoder encoding.MarshalUnmarshaler,
|
||||
) ([]byte, error) {
|
||||
var loaders []sessions.SessionLoader
|
||||
cookieStore, err := getCookieStore(options, encoder)
|
||||
cookieStore, err := getCookieStore(cfg, encoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -37,7 +41,10 @@ func loadRawSession(req *http.Request, options *config.Options, encoder encoding
|
|||
return nil, sessions.ErrNoSessionFound
|
||||
}
|
||||
|
||||
func loadSession(encoder encoding.MarshalUnmarshaler, rawJWT []byte) (*sessions.State, error) {
|
||||
func loadSession(
|
||||
encoder encoding.MarshalUnmarshaler,
|
||||
rawJWT []byte,
|
||||
) (*sessions.State, error) {
|
||||
var s sessions.State
|
||||
err := encoder.Unmarshal(rawJWT, &s)
|
||||
if err != nil {
|
||||
|
@ -46,14 +53,17 @@ func loadSession(encoder encoding.MarshalUnmarshaler, rawJWT []byte) (*sessions.
|
|||
return &s, nil
|
||||
}
|
||||
|
||||
func getCookieStore(options *config.Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) {
|
||||
func getCookieStore(
|
||||
cfg *config.Config,
|
||||
encoder encoding.MarshalUnmarshaler,
|
||||
) (sessions.SessionStore, error) {
|
||||
cookieStore, err := cookie.NewStore(func() cookie.Options {
|
||||
return cookie.Options{
|
||||
Name: options.CookieName,
|
||||
Domain: options.CookieDomain,
|
||||
Secure: options.CookieSecure,
|
||||
HTTPOnly: options.CookieHTTPOnly,
|
||||
Expire: options.CookieExpire,
|
||||
Name: cfg.Options.CookieName,
|
||||
Domain: cfg.Options.CookieDomain,
|
||||
Secure: cfg.Options.CookieSecure,
|
||||
HTTPOnly: cfg.Options.CookieHTTPOnly,
|
||||
Expire: cfg.Options.CookieExpire,
|
||||
}
|
||||
}, encoder)
|
||||
if err != nil {
|
||||
|
|
|
@ -32,7 +32,7 @@ func TestLoadSession(t *testing.T) {
|
|||
},
|
||||
},
|
||||
})
|
||||
raw, err := loadRawSession(req, opts, encoder)
|
||||
raw, err := loadRawSession(req, config.New(opts), encoder)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -36,7 +36,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*autho
|
|||
|
||||
var err error
|
||||
|
||||
state.evaluator, err = newPolicyEvaluator(cfg.Options, store)
|
||||
state.evaluator, err = newPolicyEvaluator(cfg, store)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
|
||||
}
|
||||
|
|
|
@ -31,6 +31,16 @@ type Config struct {
|
|||
MetricsScrapeEndpoints []MetricsScrapeEndpoint
|
||||
}
|
||||
|
||||
// New creates a new Config.
|
||||
func New(options *Options) *Config {
|
||||
if options == nil {
|
||||
options = NewDefaultOptions()
|
||||
}
|
||||
return &Config{
|
||||
Options: options,
|
||||
}
|
||||
}
|
||||
|
||||
// Clone creates a clone of the config.
|
||||
func (cfg *Config) Clone() *Config {
|
||||
newOptions := new(Options)
|
||||
|
|
|
@ -13,11 +13,9 @@ import (
|
|||
func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
|
||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
adminCfg, err := b.BuildBootstrapAdmin(&config.Config{
|
||||
Options: &config.Options{
|
||||
adminCfg, err := b.BuildBootstrapAdmin(config.New(&config.Options{
|
||||
EnvoyAdminAddress: "localhost:9901",
|
||||
},
|
||||
})
|
||||
}))
|
||||
assert.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
{
|
||||
|
@ -31,11 +29,9 @@ func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
|
|||
`, adminCfg)
|
||||
})
|
||||
t.Run("bad address", func(t *testing.T) {
|
||||
_, err := b.BuildBootstrapAdmin(&config.Config{
|
||||
Options: &config.Options{
|
||||
_, err := b.BuildBootstrapAdmin(config.New(&config.Options{
|
||||
EnvoyAdminAddress: "xyz1234:zyx4321",
|
||||
},
|
||||
})
|
||||
}))
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
@ -111,11 +107,9 @@ func TestBuilder_BuildBootstrapStaticResources(t *testing.T) {
|
|||
func TestBuilder_BuildBootstrapStatsConfig(t *testing.T) {
|
||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||
t.Run("valid", func(t *testing.T) {
|
||||
statsCfg, err := b.BuildBootstrapStatsConfig(&config.Config{
|
||||
Options: &config.Options{
|
||||
statsCfg, err := b.BuildBootstrapStatsConfig(config.New(&config.Options{
|
||||
Services: "all",
|
||||
},
|
||||
})
|
||||
}))
|
||||
assert.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
{
|
||||
|
|
|
@ -46,22 +46,22 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
|||
return nil, err
|
||||
}
|
||||
|
||||
controlGRPC, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, upstreamProtocolHTTP2)
|
||||
controlGRPC, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, upstreamProtocolHTTP2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
|
||||
controlHTTP, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
controlMetrics, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto)
|
||||
controlMetrics, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authorizeCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
|
||||
authorizeCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
|||
authorizeCluster.OutlierDetection = grpcAuthorizeOutlierDetection()
|
||||
}
|
||||
|
||||
databrokerCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
|
||||
databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -87,7 +87,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
|||
databrokerCluster,
|
||||
}
|
||||
|
||||
tracingCluster, err := buildTracingCluster(cfg.Options)
|
||||
tracingCluster, err := buildTracingCluster(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
} else if tracingCluster != nil {
|
||||
|
@ -101,7 +101,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
|||
policy.EnvoyOpts = newDefaultEnvoyClusterConfig()
|
||||
}
|
||||
if len(policy.To) > 0 {
|
||||
cluster, err := b.buildPolicyCluster(ctx, cfg.Options, &policy)
|
||||
cluster, err := b.buildPolicyCluster(ctx, cfg, &policy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("policy #%d: %w", i, err)
|
||||
}
|
||||
|
@ -119,16 +119,16 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
|||
|
||||
func (b *Builder) buildInternalCluster(
|
||||
ctx context.Context,
|
||||
options *config.Options,
|
||||
cfg *config.Config,
|
||||
name string,
|
||||
dsts []*url.URL,
|
||||
upstreamProtocol upstreamProtocolConfig,
|
||||
) (*envoy_config_cluster_v3.Cluster, error) {
|
||||
cluster := newDefaultEnvoyClusterConfig()
|
||||
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
|
||||
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily)
|
||||
var endpoints []Endpoint
|
||||
for _, dst := range dsts {
|
||||
ts, err := b.buildInternalTransportSocket(ctx, options, dst)
|
||||
ts, err := b.buildInternalTransportSocket(ctx, cfg, dst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -141,18 +141,22 @@ func (b *Builder) buildInternalCluster(
|
|||
return cluster, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
|
||||
func (b *Builder) buildPolicyCluster(
|
||||
ctx context.Context,
|
||||
cfg *config.Config,
|
||||
policy *config.Policy,
|
||||
) (*envoy_config_cluster_v3.Cluster, error) {
|
||||
cluster := new(envoy_config_cluster_v3.Cluster)
|
||||
proto.Merge(cluster, policy.EnvoyOpts)
|
||||
|
||||
if options.EnvoyBindConfigFreebind.IsSet() || options.EnvoyBindConfigSourceAddress != "" {
|
||||
if cfg.Options.EnvoyBindConfigFreebind.IsSet() || cfg.Options.EnvoyBindConfigSourceAddress != "" {
|
||||
cluster.UpstreamBindConfig = new(envoy_config_core_v3.BindConfig)
|
||||
if options.EnvoyBindConfigFreebind.IsSet() {
|
||||
cluster.UpstreamBindConfig.Freebind = wrapperspb.Bool(options.EnvoyBindConfigFreebind.Bool)
|
||||
if cfg.Options.EnvoyBindConfigFreebind.IsSet() {
|
||||
cluster.UpstreamBindConfig.Freebind = wrapperspb.Bool(cfg.Options.EnvoyBindConfigFreebind.Bool)
|
||||
}
|
||||
if options.EnvoyBindConfigSourceAddress != "" {
|
||||
if cfg.Options.EnvoyBindConfigSourceAddress != "" {
|
||||
cluster.UpstreamBindConfig.SourceAddress = &envoy_config_core_v3.SocketAddress{
|
||||
Address: options.EnvoyBindConfigSourceAddress,
|
||||
Address: cfg.Options.EnvoyBindConfigSourceAddress,
|
||||
PortSpecifier: &envoy_config_core_v3.SocketAddress_PortValue{
|
||||
PortValue: 0,
|
||||
},
|
||||
|
@ -171,13 +175,13 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
|
|||
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
|
||||
|
||||
name := getClusterID(policy)
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, options, policy)
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, cfg, policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if cluster.DnsLookupFamily == envoy_config_cluster_v3.Cluster_AUTO {
|
||||
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
|
||||
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily)
|
||||
}
|
||||
|
||||
if policy.EnableGoogleCloudServerlessAuthentication {
|
||||
|
@ -193,12 +197,12 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
|
|||
|
||||
func (b *Builder) buildPolicyEndpoints(
|
||||
ctx context.Context,
|
||||
options *config.Options,
|
||||
cfg *config.Config,
|
||||
policy *config.Policy,
|
||||
) ([]Endpoint, error) {
|
||||
var endpoints []Endpoint
|
||||
for _, dst := range policy.To {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, options, policy, dst.URL)
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, cfg, policy, dst.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -209,7 +213,7 @@ func (b *Builder) buildPolicyEndpoints(
|
|||
|
||||
func (b *Builder) buildInternalTransportSocket(
|
||||
ctx context.Context,
|
||||
options *config.Options,
|
||||
cfg *config.Config,
|
||||
endpoint *url.URL,
|
||||
) (*envoy_config_core_v3.TransportSocket, error) {
|
||||
if endpoint.Scheme != "https" {
|
||||
|
@ -218,10 +222,10 @@ func (b *Builder) buildInternalTransportSocket(
|
|||
|
||||
validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{
|
||||
MatchTypedSubjectAltNames: []*envoy_extensions_transport_sockets_tls_v3.SubjectAltNameMatcher{
|
||||
b.buildSubjectAltNameMatcher(endpoint, options.OverrideCertificateName),
|
||||
b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName),
|
||||
},
|
||||
}
|
||||
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile)
|
||||
bs, err := getCombinedCertificateAuthority(cfg.Options.CA, cfg.Options.CAFile)
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||
} else {
|
||||
|
@ -234,7 +238,7 @@ func (b *Builder) buildInternalTransportSocket(
|
|||
ValidationContext: validationContext,
|
||||
},
|
||||
},
|
||||
Sni: b.buildSubjectNameIndication(endpoint, options.OverrideCertificateName),
|
||||
Sni: b.buildSubjectNameIndication(endpoint, cfg.Options.OverrideCertificateName),
|
||||
}
|
||||
tlsConfig := marshalAny(tlsContext)
|
||||
return &envoy_config_core_v3.TransportSocket{
|
||||
|
@ -247,7 +251,7 @@ func (b *Builder) buildInternalTransportSocket(
|
|||
|
||||
func (b *Builder) buildPolicyTransportSocket(
|
||||
ctx context.Context,
|
||||
options *config.Options,
|
||||
cfg *config.Config,
|
||||
policy *config.Policy,
|
||||
dst url.URL,
|
||||
) (*envoy_config_core_v3.TransportSocket, error) {
|
||||
|
@ -257,7 +261,7 @@ func (b *Builder) buildPolicyTransportSocket(
|
|||
|
||||
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
|
||||
|
||||
vc, err := b.buildPolicyValidationContext(ctx, options, policy, dst)
|
||||
vc, err := b.buildPolicyValidationContext(ctx, cfg, policy, dst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -318,7 +322,7 @@ func (b *Builder) buildPolicyTransportSocket(
|
|||
|
||||
func (b *Builder) buildPolicyValidationContext(
|
||||
ctx context.Context,
|
||||
options *config.Options,
|
||||
cfg *config.Config,
|
||||
policy *config.Policy,
|
||||
dst url.URL,
|
||||
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
|
||||
|
@ -343,7 +347,7 @@ func (b *Builder) buildPolicyValidationContext(
|
|||
}
|
||||
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
|
||||
} else {
|
||||
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile)
|
||||
bs, err := getCombinedCertificateAuthority(cfg.Options.CA, cfg.Options.CAFile)
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||
} else {
|
||||
|
|
|
@ -37,14 +37,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename()
|
||||
|
||||
t.Run("insecure", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://example.com"),
|
||||
}, *mustParseURL(t, "http://example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, ts)
|
||||
})
|
||||
t.Run("host as sni", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
require.NoError(t, err)
|
||||
|
@ -97,7 +97,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("tls_server_name as sni", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
TLSServerName: "use-this-name.example.com",
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -151,7 +151,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("tls_upstream_server_name as sni", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
TLSUpstreamServerName: "use-this-name.example.com",
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -205,7 +205,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("tls_skip_verify", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
TLSSkipVerify: true,
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -260,7 +260,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("custom ca", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -314,7 +314,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("options custom ca", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o2, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o2), &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
require.NoError(t, err)
|
||||
|
@ -368,7 +368,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
})
|
||||
t.Run("client certificate", func(t *testing.T) {
|
||||
clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey)
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
ClientCertificate: clientCert,
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -438,7 +438,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
||||
o1 := config.NewDefaultOptions()
|
||||
t.Run("insecure", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -495,7 +495,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("secure", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
|
||||
To: mustParseWeightedURLs(t,
|
||||
"https://example.com",
|
||||
"https://example.com",
|
||||
|
@ -663,7 +663,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("ip addresses", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -718,7 +718,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("weights", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -775,7 +775,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("localhost", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://localhost"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -821,7 +821,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("outlier", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://example.com"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -904,7 +904,7 @@ func Test_bindConfig(t *testing.T) {
|
|||
|
||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||
t.Run("no bind config", func(t *testing.T) {
|
||||
cluster, err := b.buildPolicyCluster(ctx, &config.Options{}, &config.Policy{
|
||||
cluster, err := b.buildPolicyCluster(ctx, config.New(nil), &config.Policy{
|
||||
From: "https://from.example.com",
|
||||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||
})
|
||||
|
@ -912,9 +912,9 @@ func Test_bindConfig(t *testing.T) {
|
|||
assert.Nil(t, cluster.UpstreamBindConfig)
|
||||
})
|
||||
t.Run("freebind", func(t *testing.T) {
|
||||
cluster, err := b.buildPolicyCluster(ctx, &config.Options{
|
||||
cluster, err := b.buildPolicyCluster(ctx, config.New(&config.Options{
|
||||
EnvoyBindConfigFreebind: null.BoolFrom(true),
|
||||
}, &config.Policy{
|
||||
}), &config.Policy{
|
||||
From: "https://from.example.com",
|
||||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||
})
|
||||
|
@ -930,9 +930,9 @@ func Test_bindConfig(t *testing.T) {
|
|||
`, cluster.UpstreamBindConfig)
|
||||
})
|
||||
t.Run("source address", func(t *testing.T) {
|
||||
cluster, err := b.buildPolicyCluster(ctx, &config.Options{
|
||||
cluster, err := b.buildPolicyCluster(ctx, config.New(&config.Options{
|
||||
EnvoyBindConfigSourceAddress: "192.168.0.1",
|
||||
}, &config.Policy{
|
||||
}), &config.Policy{
|
||||
From: "https://from.example.com",
|
||||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||
})
|
||||
|
|
|
@ -76,10 +76,10 @@ func newDefaultEnvoyClusterConfig() *envoy_config_cluster_v3.Cluster {
|
|||
}
|
||||
}
|
||||
|
||||
func buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog {
|
||||
lvl := options.ProxyLogLevel
|
||||
func buildAccessLogs(cfg *config.Config) []*envoy_config_accesslog_v3.AccessLog {
|
||||
lvl := cfg.Options.ProxyLogLevel
|
||||
if lvl == "" {
|
||||
lvl = options.LogLevel
|
||||
lvl = cfg.Options.LogLevel
|
||||
}
|
||||
if lvl == "" {
|
||||
lvl = "debug"
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
)
|
||||
|
||||
func (b *Builder) buildVirtualHost(
|
||||
options *config.Options,
|
||||
cfg *config.Config,
|
||||
name string,
|
||||
domain string,
|
||||
) (*envoy_config_route_v3.VirtualHost, error) {
|
||||
|
@ -20,15 +20,15 @@ func (b *Builder) buildVirtualHost(
|
|||
}
|
||||
|
||||
// these routes match /.pomerium/... and similar paths
|
||||
rs, err := b.buildPomeriumHTTPRoutes(options, domain)
|
||||
rs, err := b.buildPomeriumHTTPRoutes(cfg, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
vh.Routes = append(vh.Routes, rs...)
|
||||
|
||||
// if we're the proxy or authenticate service, add our global headers
|
||||
if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) {
|
||||
vh.ResponseHeadersToAdd = toEnvoyHeaders(options.GetSetResponseHeaders())
|
||||
if config.IsProxy(cfg.Options.Services) || config.IsAuthenticate(cfg.Options.Services) {
|
||||
vh.ResponseHeadersToAdd = toEnvoyHeaders(cfg.Options.GetSetResponseHeaders())
|
||||
}
|
||||
|
||||
return vh, nil
|
||||
|
@ -37,13 +37,13 @@ func (b *Builder) buildVirtualHost(
|
|||
// buildLocalReplyConfig builds the local reply config: the config used to modify "local" replies, that is replies
|
||||
// coming directly from envoy
|
||||
func (b *Builder) buildLocalReplyConfig(
|
||||
options *config.Options,
|
||||
cfg *config.Config,
|
||||
) *envoy_http_connection_manager.LocalReplyConfig {
|
||||
// add global headers for HSTS headers (#2110)
|
||||
var headers []*envoy_config_core_v3.HeaderValueOption
|
||||
// if we're the proxy or authenticate service, add our global headers
|
||||
if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) {
|
||||
headers = toEnvoyHeaders(options.GetSetResponseHeaders())
|
||||
if config.IsProxy(cfg.Options.Services) || config.IsAuthenticate(cfg.Options.Services) {
|
||||
headers = toEnvoyHeaders(cfg.Options.GetSetResponseHeaders())
|
||||
}
|
||||
|
||||
return &envoy_http_connection_manager.LocalReplyConfig{
|
||||
|
|
|
@ -53,7 +53,10 @@ func init() {
|
|||
}
|
||||
|
||||
// BuildListeners builds envoy listeners from the given config.
|
||||
func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*envoy_config_listener_v3.Listener, error) {
|
||||
func (b *Builder) BuildListeners(
|
||||
ctx context.Context,
|
||||
cfg *config.Config,
|
||||
) ([]*envoy_config_listener_v3.Listener, error) {
|
||||
var listeners []*envoy_config_listener_v3.Listener
|
||||
|
||||
if config.IsAuthenticate(cfg.Options.Services) || config.IsProxy(cfg.Options.Services) {
|
||||
|
@ -89,19 +92,22 @@ func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*en
|
|||
return listeners, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
|
||||
func (b *Builder) buildMainListener(
|
||||
ctx context.Context,
|
||||
cfg *config.Config,
|
||||
) (*envoy_config_listener_v3.Listener, error) {
|
||||
listenerFilters := []*envoy_config_listener_v3.ListenerFilter{}
|
||||
if cfg.Options.UseProxyProtocol {
|
||||
listenerFilters = append(listenerFilters, ProxyProtocolFilter())
|
||||
}
|
||||
|
||||
if cfg.Options.InsecureServer {
|
||||
allDomains, err := getAllRouteableDomains(cfg.Options, cfg.Options.Addr)
|
||||
allDomains, err := getAllRouteableDomains(cfg, cfg.Options.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, allDomains, "")
|
||||
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg, allDomains, "")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -118,9 +124,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
|
|||
}
|
||||
listenerFilters = append(listenerFilters, TLSInspectorFilter())
|
||||
|
||||
chains, err := b.buildFilterChains(cfg.Options, cfg.Options.Addr,
|
||||
chains, err := b.buildFilterChains(cfg, cfg.Options.Addr,
|
||||
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
|
||||
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain)
|
||||
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg, httpDomains, tlsDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -155,7 +161,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
|
|||
return li, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
|
||||
func (b *Builder) buildMetricsListener(
|
||||
cfg *config.Config,
|
||||
) (*envoy_config_listener_v3.Listener, error) {
|
||||
filter, err := b.buildMetricsHTTPConnectionManagerFilter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -235,22 +243,23 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen
|
|||
}
|
||||
|
||||
func (b *Builder) buildFilterChains(
|
||||
options *config.Options, addr string,
|
||||
cfg *config.Config,
|
||||
addr string,
|
||||
callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error),
|
||||
) ([]*envoy_config_listener_v3.FilterChain, error) {
|
||||
allDomains, err := getAllRouteableDomains(options, addr)
|
||||
allDomains, err := getAllRouteableDomains(cfg, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsDomains, err := getAllTLSDomains(options, addr)
|
||||
tlsDomains, err := getAllTLSDomains(cfg, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var chains []*envoy_config_listener_v3.FilterChain
|
||||
for _, domain := range tlsDomains {
|
||||
routeableDomains, err := getRouteableDomainsForTLSServerName(options, addr, domain)
|
||||
routeableDomains, err := getRouteableDomainsForTLSServerName(cfg, addr, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -273,31 +282,31 @@ func (b *Builder) buildFilterChains(
|
|||
}
|
||||
|
||||
func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
||||
options *config.Options,
|
||||
cfg *config.Config,
|
||||
domains []string,
|
||||
tlsDomain string,
|
||||
) (*envoy_config_listener_v3.Filter, error) {
|
||||
authorizeURLs, err := options.GetInternalAuthorizeURLs()
|
||||
authorizeURLs, err := cfg.Options.GetInternalAuthorizeURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
dataBrokerURLs, err := options.GetInternalDataBrokerURLs()
|
||||
dataBrokerURLs, err := cfg.Options.GetInternalDataBrokerURLs()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var virtualHosts []*envoy_config_route_v3.VirtualHost
|
||||
for _, domain := range domains {
|
||||
vh, err := b.buildVirtualHost(options, domain, domain)
|
||||
vh, err := b.buildVirtualHost(cfg, domain, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if options.Addr == options.GetGRPCAddr() {
|
||||
if cfg.Options.Addr == cfg.Options.GetGRPCAddr() {
|
||||
// if this is a gRPC service domain and we're supposed to handle that, add those routes
|
||||
if (config.IsAuthorize(options.Services) && hostsMatchDomain(authorizeURLs, domain)) ||
|
||||
(config.IsDataBroker(options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) {
|
||||
if (config.IsAuthorize(cfg.Options.Services) && hostsMatchDomain(authorizeURLs, domain)) ||
|
||||
(config.IsDataBroker(cfg.Options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) {
|
||||
rs, err := b.buildGRPCRoutes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -307,8 +316,8 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
|||
}
|
||||
|
||||
// if we're the proxy, add all the policy routes
|
||||
if config.IsProxy(options.Services) {
|
||||
rs, err := b.buildPolicyRoutes(options, domain)
|
||||
if config.IsProxy(cfg.Options.Services) {
|
||||
rs, err := b.buildPolicyRoutes(cfg, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -320,15 +329,15 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
|||
}
|
||||
}
|
||||
|
||||
vh, err := b.buildVirtualHost(options, "catch-all", "*")
|
||||
vh, err := b.buildVirtualHost(cfg, "catch-all", "*")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
virtualHosts = append(virtualHosts, vh)
|
||||
|
||||
var grpcClientTimeout *durationpb.Duration
|
||||
if options.GRPCClientTimeout != 0 {
|
||||
grpcClientTimeout = durationpb.New(options.GRPCClientTimeout)
|
||||
if cfg.Options.GRPCClientTimeout != 0 {
|
||||
grpcClientTimeout = durationpb.New(cfg.Options.GRPCClientTimeout)
|
||||
} else {
|
||||
grpcClientTimeout = durationpb.New(30 * time.Second)
|
||||
}
|
||||
|
@ -346,15 +355,15 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
|||
filters = append(filters, HTTPRouterFilter())
|
||||
|
||||
var maxStreamDuration *durationpb.Duration
|
||||
if options.WriteTimeout > 0 {
|
||||
maxStreamDuration = durationpb.New(options.WriteTimeout)
|
||||
if cfg.Options.WriteTimeout > 0 {
|
||||
maxStreamDuration = durationpb.New(cfg.Options.WriteTimeout)
|
||||
}
|
||||
|
||||
rc, err := b.buildRouteConfiguration("main", virtualHosts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tracingProvider, err := buildTracingHTTP(options)
|
||||
tracingProvider, err := buildTracingHTTP(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -362,27 +371,27 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
|||
return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{
|
||||
AlwaysSetRequestIdInResponse: true,
|
||||
|
||||
CodecType: options.GetCodecType().ToEnvoy(),
|
||||
CodecType: cfg.Options.GetCodecType().ToEnvoy(),
|
||||
StatPrefix: "ingress",
|
||||
RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{
|
||||
RouteConfig: rc,
|
||||
},
|
||||
HttpFilters: filters,
|
||||
AccessLog: buildAccessLogs(options),
|
||||
AccessLog: buildAccessLogs(cfg),
|
||||
CommonHttpProtocolOptions: &envoy_config_core_v3.HttpProtocolOptions{
|
||||
IdleTimeout: durationpb.New(options.IdleTimeout),
|
||||
IdleTimeout: durationpb.New(cfg.Options.IdleTimeout),
|
||||
MaxStreamDuration: maxStreamDuration,
|
||||
},
|
||||
RequestTimeout: durationpb.New(options.ReadTimeout),
|
||||
RequestTimeout: durationpb.New(cfg.Options.ReadTimeout),
|
||||
Tracing: &envoy_http_connection_manager.HttpConnectionManager_Tracing{
|
||||
RandomSampling: &envoy_type_v3.Percent{Value: options.TracingSampleRate * 100},
|
||||
RandomSampling: &envoy_type_v3.Percent{Value: cfg.Options.TracingSampleRate * 100},
|
||||
Provider: tracingProvider,
|
||||
},
|
||||
// See https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_conn_man/headers#x-forwarded-for
|
||||
UseRemoteAddress: &wrappers.BoolValue{Value: true},
|
||||
SkipXffAppend: options.SkipXffAppend,
|
||||
XffNumTrustedHops: options.XffNumTrustedHops,
|
||||
LocalReplyConfig: b.buildLocalReplyConfig(options),
|
||||
SkipXffAppend: cfg.Options.SkipXffAppend,
|
||||
XffNumTrustedHops: cfg.Options.XffNumTrustedHops,
|
||||
LocalReplyConfig: b.buildLocalReplyConfig(cfg),
|
||||
}), nil
|
||||
}
|
||||
|
||||
|
@ -420,7 +429,10 @@ func (b *Builder) buildMetricsHTTPConnectionManagerFilter() (*envoy_config_liste
|
|||
}), nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
|
||||
func (b *Builder) buildGRPCListener(
|
||||
ctx context.Context,
|
||||
cfg *config.Config,
|
||||
) (*envoy_config_listener_v3.Listener, error) {
|
||||
filter, err := b.buildGRPCHTTPConnectionManagerFilter()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -437,7 +449,7 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e
|
|||
return li, nil
|
||||
}
|
||||
|
||||
chains, err := b.buildFilterChains(cfg.Options, cfg.Options.GRPCAddr,
|
||||
chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr,
|
||||
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
|
||||
filterChain := &envoy_config_listener_v3.FilterChain{
|
||||
Filters: []*envoy_config_listener_v3.Filter{filter},
|
||||
|
@ -518,7 +530,10 @@ func (b *Builder) buildGRPCHTTPConnectionManagerFilter() (*envoy_config_listener
|
|||
}), nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3.VirtualHost) (*envoy_config_route_v3.RouteConfiguration, error) {
|
||||
func (b *Builder) buildRouteConfiguration(
|
||||
name string,
|
||||
virtualHosts []*envoy_config_route_v3.VirtualHost,
|
||||
) (*envoy_config_route_v3.RouteConfiguration, error) {
|
||||
return &envoy_config_route_v3.RouteConfiguration{
|
||||
Name: name,
|
||||
VirtualHosts: virtualHosts,
|
||||
|
@ -527,7 +542,8 @@ func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_con
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
|
||||
func (b *Builder) buildDownstreamTLSContext(
|
||||
ctx context.Context,
|
||||
cfg *config.Config,
|
||||
domain string,
|
||||
) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
|
||||
|
@ -570,7 +586,8 @@ func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
|
|||
}
|
||||
}
|
||||
|
||||
func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
|
||||
func (b *Builder) buildDownstreamValidationContext(
|
||||
ctx context.Context,
|
||||
cfg *config.Config,
|
||||
domain string,
|
||||
) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext {
|
||||
|
@ -580,7 +597,7 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
|
|||
needsClientCert = true
|
||||
}
|
||||
if !needsClientCert {
|
||||
for _, p := range getPoliciesForDomain(cfg.Options, domain) {
|
||||
for _, p := range getPoliciesForDomain(cfg, domain) {
|
||||
if p.TLSDownstreamClientCA != "" {
|
||||
needsClientCert = true
|
||||
break
|
||||
|
@ -613,19 +630,23 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
|
|||
return vc
|
||||
}
|
||||
|
||||
func getRouteableDomainsForTLSServerName(options *config.Options, addr string, tlsServerName string) ([]string, error) {
|
||||
func getRouteableDomainsForTLSServerName(
|
||||
cfg *config.Config,
|
||||
addr string,
|
||||
tlsServerName string,
|
||||
) ([]string, error) {
|
||||
allDomains := sets.NewSorted[string]()
|
||||
|
||||
if addr == options.Addr {
|
||||
domains, err := options.GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName)
|
||||
if addr == cfg.Options.Addr {
|
||||
domains, err := cfg.Options.GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allDomains.Add(domains...)
|
||||
}
|
||||
|
||||
if addr == options.GetGRPCAddr() {
|
||||
domains, err := options.GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName)
|
||||
if addr == cfg.Options.GetGRPCAddr() {
|
||||
domains, err := cfg.Options.GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -635,19 +656,22 @@ func getRouteableDomainsForTLSServerName(options *config.Options, addr string, t
|
|||
return allDomains.ToSlice(), nil
|
||||
}
|
||||
|
||||
func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) {
|
||||
func getAllRouteableDomains(
|
||||
cfg *config.Config,
|
||||
addr string,
|
||||
) ([]string, error) {
|
||||
allDomains := sets.NewSorted[string]()
|
||||
|
||||
if addr == options.Addr {
|
||||
domains, err := options.GetAllRouteableHTTPDomains()
|
||||
if addr == cfg.Options.Addr {
|
||||
domains, err := cfg.Options.GetAllRouteableHTTPDomains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allDomains.Add(domains...)
|
||||
}
|
||||
|
||||
if addr == options.GetGRPCAddr() {
|
||||
domains, err := options.GetAllRouteableGRPCDomains()
|
||||
if addr == cfg.Options.GetGRPCAddr() {
|
||||
domains, err := cfg.Options.GetAllRouteableGRPCDomains()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -657,8 +681,11 @@ func getAllRouteableDomains(options *config.Options, addr string) ([]string, err
|
|||
return allDomains.ToSlice(), nil
|
||||
}
|
||||
|
||||
func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {
|
||||
allDomains, err := getAllRouteableDomains(options, addr)
|
||||
func getAllTLSDomains(
|
||||
cfg *config.Config,
|
||||
addr string,
|
||||
) ([]string, error) {
|
||||
allDomains, err := getAllRouteableDomains(cfg, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -711,9 +738,12 @@ func hostMatchesDomain(u *url.URL, host string) bool {
|
|||
return h1 == h2 && p1 == p2
|
||||
}
|
||||
|
||||
func getPoliciesForDomain(options *config.Options, domain string) []config.Policy {
|
||||
func getPoliciesForDomain(
|
||||
cfg *config.Config,
|
||||
domain string,
|
||||
) []config.Policy {
|
||||
var policies []config.Policy
|
||||
for _, p := range options.GetAllPolicies() {
|
||||
for _, p := range cfg.Options.GetAllPolicies() {
|
||||
if p.Source != nil && p.Source.URL.Hostname() == domain {
|
||||
policies = append(policies, p)
|
||||
}
|
||||
|
|
|
@ -26,13 +26,11 @@ func Test_buildMetricsHTTPConnectionManagerFilter(t *testing.T) {
|
|||
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")
|
||||
|
||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||
li, err := b.buildMetricsListener(&config.Config{
|
||||
Options: &config.Options{
|
||||
li, err := b.buildMetricsListener(config.New(&config.Options{
|
||||
MetricsAddr: "127.0.0.1:9902",
|
||||
MetricsCertificate: aExampleComCert,
|
||||
MetricsCertificateKey: aExampleComKey,
|
||||
},
|
||||
})
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
{
|
||||
|
@ -115,7 +113,7 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
|
|||
options := config.NewDefaultOptions()
|
||||
options.SkipXffAppend = true
|
||||
options.XffNumTrustedHops = 1
|
||||
filter, err := b.buildMainHTTPConnectionManagerFilter(options, []string{"example.com"}, "*")
|
||||
filter, err := b.buildMainHTTPConnectionManagerFilter(config.New(options), []string{"example.com"}, "*")
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"name": "envoy.filters.network.http_connection_manager",
|
||||
|
@ -544,10 +542,10 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
|||
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")
|
||||
|
||||
t.Run("no-validation", func(t *testing.T) {
|
||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
|
||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
|
||||
Cert: aExampleComCert,
|
||||
Key: aExampleComKey,
|
||||
}}, "a.example.com")
|
||||
}), "a.example.com")
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"commonTlsContext": {
|
||||
|
@ -577,11 +575,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
|||
}`, downstreamTLSContext)
|
||||
})
|
||||
t.Run("client-ca", func(t *testing.T) {
|
||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
|
||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
|
||||
Cert: aExampleComCert,
|
||||
Key: aExampleComKey,
|
||||
ClientCA: "TEST",
|
||||
}}, "a.example.com")
|
||||
}), "a.example.com")
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"commonTlsContext": {
|
||||
|
@ -614,7 +612,7 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
|||
}`, downstreamTLSContext)
|
||||
})
|
||||
t.Run("policy-client-ca", func(t *testing.T) {
|
||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
|
||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
|
||||
Cert: aExampleComCert,
|
||||
Key: aExampleComKey,
|
||||
Policies: []config.Policy{
|
||||
|
@ -623,7 +621,7 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
|||
TLSDownstreamClientCA: "TEST",
|
||||
},
|
||||
},
|
||||
}}, "a.example.com")
|
||||
}), "a.example.com")
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"commonTlsContext": {
|
||||
|
@ -656,11 +654,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
|||
}`, downstreamTLSContext)
|
||||
})
|
||||
t.Run("http1", func(t *testing.T) {
|
||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
|
||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
|
||||
Cert: aExampleComCert,
|
||||
Key: aExampleComKey,
|
||||
CodecType: config.CodecTypeHTTP1,
|
||||
}}, "a.example.com")
|
||||
}), "a.example.com")
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"commonTlsContext": {
|
||||
|
@ -690,11 +688,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
|||
}`, downstreamTLSContext)
|
||||
})
|
||||
t.Run("http2", func(t *testing.T) {
|
||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
|
||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
|
||||
Cert: aExampleComCert,
|
||||
Key: aExampleComKey,
|
||||
CodecType: config.CodecTypeHTTP2,
|
||||
}}, "a.example.com")
|
||||
}), "a.example.com")
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"commonTlsContext": {
|
||||
|
@ -739,9 +737,10 @@ func Test_getAllDomains(t *testing.T) {
|
|||
{Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}},
|
||||
},
|
||||
}
|
||||
cfg := config.New(options)
|
||||
t.Run("routable", func(t *testing.T) {
|
||||
t.Run("http", func(t *testing.T) {
|
||||
actual, err := getAllRouteableDomains(options, "127.0.0.1:9000")
|
||||
actual, err := getAllRouteableDomains(cfg, "127.0.0.1:9000")
|
||||
require.NoError(t, err)
|
||||
expect := []string{
|
||||
"a.example.com",
|
||||
|
@ -756,7 +755,7 @@ func Test_getAllDomains(t *testing.T) {
|
|||
assert.Equal(t, expect, actual)
|
||||
})
|
||||
t.Run("grpc", func(t *testing.T) {
|
||||
actual, err := getAllRouteableDomains(options, "127.0.0.1:9001")
|
||||
actual, err := getAllRouteableDomains(cfg, "127.0.0.1:9001")
|
||||
require.NoError(t, err)
|
||||
expect := []string{
|
||||
"authorize.example.com:9001",
|
||||
|
@ -765,9 +764,9 @@ func Test_getAllDomains(t *testing.T) {
|
|||
assert.Equal(t, expect, actual)
|
||||
})
|
||||
t.Run("both", func(t *testing.T) {
|
||||
newOptions := *options
|
||||
newOptions.GRPCAddr = newOptions.Addr
|
||||
actual, err := getAllRouteableDomains(&newOptions, "127.0.0.1:9000")
|
||||
newCfg := cfg.Clone()
|
||||
newCfg.Options.GRPCAddr = cfg.Options.Addr
|
||||
actual, err := getAllRouteableDomains(newCfg, "127.0.0.1:9000")
|
||||
require.NoError(t, err)
|
||||
expect := []string{
|
||||
"a.example.com",
|
||||
|
@ -786,7 +785,7 @@ func Test_getAllDomains(t *testing.T) {
|
|||
})
|
||||
t.Run("tls", func(t *testing.T) {
|
||||
t.Run("http", func(t *testing.T) {
|
||||
actual, err := getAllTLSDomains(options, "127.0.0.1:9000")
|
||||
actual, err := getAllTLSDomains(cfg, "127.0.0.1:9000")
|
||||
require.NoError(t, err)
|
||||
expect := []string{
|
||||
"a.example.com",
|
||||
|
@ -797,7 +796,7 @@ func Test_getAllDomains(t *testing.T) {
|
|||
assert.Equal(t, expect, actual)
|
||||
})
|
||||
t.Run("grpc", func(t *testing.T) {
|
||||
actual, err := getAllTLSDomains(options, "127.0.0.1:9001")
|
||||
actual, err := getAllTLSDomains(cfg, "127.0.0.1:9001")
|
||||
require.NoError(t, err)
|
||||
expect := []string{
|
||||
"authorize.example.com",
|
||||
|
@ -831,10 +830,10 @@ func Test_buildRouteConfiguration(t *testing.T) {
|
|||
func Test_requireProxyProtocol(t *testing.T) {
|
||||
b := New("local-grpc", "local-http", "local-metrics", nil, nil)
|
||||
t.Run("required", func(t *testing.T) {
|
||||
li, err := b.buildMainListener(context.Background(), &config.Config{Options: &config.Options{
|
||||
li, err := b.buildMainListener(context.Background(), config.New(&config.Options{
|
||||
UseProxyProtocol: true,
|
||||
InsecureServer: true,
|
||||
}})
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
{
|
||||
|
@ -846,10 +845,10 @@ func Test_requireProxyProtocol(t *testing.T) {
|
|||
]`, li.GetListenerFilters())
|
||||
})
|
||||
t.Run("not required", func(t *testing.T) {
|
||||
li, err := b.buildMainListener(context.Background(), &config.Config{Options: &config.Options{
|
||||
li, err := b.buildMainListener(context.Background(), config.New(&config.Options{
|
||||
UseProxyProtocol: false,
|
||||
InsecureServer: true,
|
||||
}})
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, li.GetListenerFilters(), 0)
|
||||
})
|
||||
|
|
|
@ -14,7 +14,9 @@ import (
|
|||
"github.com/pomerium/pomerium/config"
|
||||
)
|
||||
|
||||
func (b *Builder) buildOutboundListener(cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
|
||||
func (b *Builder) buildOutboundListener(
|
||||
cfg *config.Config,
|
||||
) (*envoy_config_listener_v3.Listener, error) {
|
||||
outboundPort, err := strconv.Atoi(cfg.OutboundPort)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid outbound port %v: %w", cfg.OutboundPort, err)
|
||||
|
|
|
@ -47,12 +47,15 @@ func (b *Builder) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) {
|
|||
}}, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) {
|
||||
func (b *Builder) buildPomeriumHTTPRoutes(
|
||||
cfg *config.Config,
|
||||
domain string,
|
||||
) ([]*envoy_config_route_v3.Route, error) {
|
||||
var routes []*envoy_config_route_v3.Route
|
||||
|
||||
// if this is the pomerium proxy in front of the the authenticate service, don't add
|
||||
// these routes since they will be handled by authenticate
|
||||
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain)
|
||||
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(cfg, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -69,27 +72,27 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string
|
|||
b.buildControlPlanePrefixRoute("/.well-known/pomerium/", false),
|
||||
)
|
||||
// per #837, only add robots.txt if there are no unauthenticated routes
|
||||
if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) {
|
||||
if !hasPublicPolicyMatchingURL(cfg, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) {
|
||||
routes = append(routes, b.buildControlPlanePathRoute("/robots.txt", false))
|
||||
}
|
||||
}
|
||||
// if we're handling authentication, add the oauth2 callback url
|
||||
authenticateURL, err := options.GetInternalAuthenticateURL()
|
||||
authenticateURL, err := cfg.Options.GetInternalAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) {
|
||||
if config.IsAuthenticate(cfg.Options.Services) && hostMatchesDomain(authenticateURL, domain) {
|
||||
routes = append(routes,
|
||||
b.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false),
|
||||
b.buildControlPlanePathRoute(cfg.Options.AuthenticateCallbackPath, false),
|
||||
b.buildControlPlanePathRoute("/", false),
|
||||
)
|
||||
}
|
||||
// if we're the proxy and this is the forward-auth url
|
||||
forwardAuthURL, err := options.GetForwardAuthURL()
|
||||
forwardAuthURL, err := cfg.Options.GetForwardAuthURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if config.IsProxy(options.Services) && hostMatchesDomain(forwardAuthURL, domain) {
|
||||
if config.IsProxy(cfg.Options.Services) && hostMatchesDomain(forwardAuthURL, domain) {
|
||||
// disable ext_authz and pass request to proxy handlers that enable authN flow
|
||||
r, err := b.buildControlPlanePathAndQueryRoute("/verify", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI})
|
||||
if err != nil {
|
||||
|
@ -227,10 +230,13 @@ func getClusterStatsName(policy *config.Policy) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) {
|
||||
func (b *Builder) buildPolicyRoutes(
|
||||
cfg *config.Config,
|
||||
domain string,
|
||||
) ([]*envoy_config_route_v3.Route, error) {
|
||||
var routes []*envoy_config_route_v3.Route
|
||||
|
||||
for i, p := range options.GetAllPolicies() {
|
||||
for i, p := range cfg.Options.GetAllPolicies() {
|
||||
policy := p
|
||||
if !hostMatchesDomain(policy.Source.URL, domain) {
|
||||
continue
|
||||
|
@ -242,7 +248,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
|
|||
Match: match,
|
||||
Metadata: &envoy_config_core_v3.Metadata{},
|
||||
RequestHeadersToAdd: toEnvoyHeaders(policy.SetRequestHeaders),
|
||||
RequestHeadersToRemove: getRequestHeadersToRemove(options, &policy),
|
||||
RequestHeadersToRemove: getRequestHeadersToRemove(cfg, &policy),
|
||||
ResponseHeadersToAdd: toEnvoyHeaders(policy.SetResponseHeaders),
|
||||
}
|
||||
if policy.Redirect != nil {
|
||||
|
@ -252,7 +258,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
|
|||
}
|
||||
envoyRoute.Action = &envoy_config_route_v3.Route_Redirect{Redirect: action}
|
||||
} else {
|
||||
action, err := b.buildPolicyRouteRouteAction(options, &policy)
|
||||
action, err := b.buildPolicyRouteRouteAction(cfg, &policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -264,7 +270,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
|
|||
}
|
||||
|
||||
// disable authentication entirely when the proxy is fronting authenticate
|
||||
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain)
|
||||
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(cfg, domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -275,7 +281,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
|
|||
} else {
|
||||
luaMetadata["remove_pomerium_cookie"] = &structpb.Value{
|
||||
Kind: &structpb.Value_StringValue{
|
||||
StringValue: options.CookieName,
|
||||
StringValue: cfg.Options.CookieName,
|
||||
},
|
||||
}
|
||||
luaMetadata["remove_pomerium_authorization"] = &structpb.Value{
|
||||
|
@ -350,13 +356,16 @@ func (b *Builder) buildPolicyRouteRedirectAction(r *config.PolicyRedirect) (*env
|
|||
return action, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildPolicyRouteRouteAction(options *config.Options, policy *config.Policy) (*envoy_config_route_v3.RouteAction, error) {
|
||||
func (b *Builder) buildPolicyRouteRouteAction(
|
||||
cfg *config.Config,
|
||||
policy *config.Policy,
|
||||
) (*envoy_config_route_v3.RouteAction, error) {
|
||||
clusterName := getClusterID(policy)
|
||||
// kubernetes requests are sent to the http control plane to be reproxied
|
||||
if policy.IsForKubernetes() {
|
||||
clusterName = httpCluster
|
||||
}
|
||||
routeTimeout := getRouteTimeout(options, policy)
|
||||
routeTimeout := getRouteTimeout(cfg, policy)
|
||||
idleTimeout := getRouteIdleTimeout(policy)
|
||||
prefixRewrite, regexRewrite := getRewriteOptions(policy)
|
||||
upgradeConfigs := []*envoy_config_route_v3.RouteAction_UpgradeConfig{
|
||||
|
@ -464,13 +473,16 @@ func mkRouteMatch(policy *config.Policy) *envoy_config_route_v3.RouteMatch {
|
|||
return match
|
||||
}
|
||||
|
||||
func getRequestHeadersToRemove(options *config.Options, policy *config.Policy) []string {
|
||||
func getRequestHeadersToRemove(
|
||||
cfg *config.Config,
|
||||
policy *config.Policy,
|
||||
) []string {
|
||||
requestHeadersToRemove := policy.RemoveRequestHeaders
|
||||
if !policy.PassIdentityHeaders {
|
||||
requestHeadersToRemove = append(requestHeadersToRemove,
|
||||
httputil.HeaderPomeriumJWTAssertion,
|
||||
httputil.HeaderPomeriumJWTAssertionFor)
|
||||
for headerName := range options.JWTClaimsHeaders {
|
||||
for headerName := range cfg.Options.JWTClaimsHeaders {
|
||||
requestHeadersToRemove = append(requestHeadersToRemove, headerName)
|
||||
}
|
||||
}
|
||||
|
@ -482,7 +494,10 @@ func getRequestHeadersToRemove(options *config.Options, policy *config.Policy) [
|
|||
return requestHeadersToRemove
|
||||
}
|
||||
|
||||
func getRouteTimeout(options *config.Options, policy *config.Policy) *durationpb.Duration {
|
||||
func getRouteTimeout(
|
||||
cfg *config.Config,
|
||||
policy *config.Policy,
|
||||
) *durationpb.Duration {
|
||||
var routeTimeout *durationpb.Duration
|
||||
if policy.UpstreamTimeout != nil {
|
||||
routeTimeout = durationpb.New(*policy.UpstreamTimeout)
|
||||
|
@ -490,7 +505,7 @@ func getRouteTimeout(options *config.Options, policy *config.Policy) *durationpb
|
|||
// a non-zero value would conflict with idleTimeout and/or websocket / tcp calls
|
||||
routeTimeout = durationpb.New(0)
|
||||
} else {
|
||||
routeTimeout = durationpb.New(options.DefaultUpstreamTimeout)
|
||||
routeTimeout = durationpb.New(cfg.Options.DefaultUpstreamTimeout)
|
||||
}
|
||||
return routeTimeout
|
||||
}
|
||||
|
@ -564,8 +579,11 @@ func setHostRewriteOptions(policy *config.Policy, action *envoy_config_route_v3.
|
|||
}
|
||||
}
|
||||
|
||||
func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) bool {
|
||||
for _, policy := range options.GetAllPolicies() {
|
||||
func hasPublicPolicyMatchingURL(
|
||||
cfg *config.Config,
|
||||
requestURL url.URL,
|
||||
) bool {
|
||||
for _, policy := range cfg.Options.GetAllPolicies() {
|
||||
if policy.AllowPublicUnauthenticatedAccess && policy.Matches(requestURL) {
|
||||
return true
|
||||
}
|
||||
|
@ -573,13 +591,16 @@ func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) boo
|
|||
return false
|
||||
}
|
||||
|
||||
func isProxyFrontingAuthenticate(options *config.Options, domain string) (bool, error) {
|
||||
authenticateURL, err := options.GetAuthenticateURL()
|
||||
func isProxyFrontingAuthenticate(
|
||||
cfg *config.Config,
|
||||
domain string,
|
||||
) (bool, error) {
|
||||
authenticateURL, err := cfg.Options.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) {
|
||||
if !config.IsAuthenticate(cfg.Options.Services) && hostMatchesDomain(authenticateURL, domain) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -84,7 +84,8 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
|
|||
AuthenticateCallbackPath: "/oauth2/callback",
|
||||
ForwardAuthURLString: "https://forward-auth.example.com",
|
||||
}
|
||||
routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com")
|
||||
cfg := config.New(options)
|
||||
routes, err := b.buildPomeriumHTTPRoutes(cfg, "authenticate.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
|
@ -106,7 +107,8 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
|
|||
AuthenticateURLString: "https://authenticate.example.com",
|
||||
AuthenticateCallbackPath: "/oauth2/callback",
|
||||
}
|
||||
routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com")
|
||||
cfg := config.New(options)
|
||||
routes, err := b.buildPomeriumHTTPRoutes(cfg, "authenticate.example.com")
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, "null", routes)
|
||||
})
|
||||
|
@ -122,8 +124,9 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
|
|||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||
}},
|
||||
}
|
||||
cfg := config.New(options)
|
||||
_ = options.Policies[0].Validate()
|
||||
routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com")
|
||||
routes, err := b.buildPomeriumHTTPRoutes(cfg, "from.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
|
@ -150,8 +153,9 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
|
|||
AllowPublicUnauthenticatedAccess: true,
|
||||
}},
|
||||
}
|
||||
cfg := config.New(options)
|
||||
_ = options.Policies[0].Validate()
|
||||
routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com")
|
||||
routes, err := b.buildPomeriumHTTPRoutes(cfg, "from.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `[
|
||||
|
@ -242,7 +246,7 @@ func TestTimeouts(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
b := &Builder{filemgr: filemgr.NewManager()}
|
||||
routes, err := b.buildPolicyRoutes(&config.Options{
|
||||
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
|
||||
CookieName: "pomerium",
|
||||
DefaultUpstreamTimeout: time.Second * 3,
|
||||
Policies: []config.Policy{
|
||||
|
@ -253,7 +257,7 @@ func TestTimeouts(t *testing.T) {
|
|||
IdleTimeout: getDuration(tc.idle),
|
||||
AllowWebsockets: tc.allowWebsockets,
|
||||
}},
|
||||
}, "example.com")
|
||||
}), "example.com")
|
||||
if !assert.NoError(t, err, "%v", tc) || !assert.Len(t, routes, 1, tc) || !assert.NotNil(t, routes[0].GetRoute(), "%v", tc) {
|
||||
continue
|
||||
}
|
||||
|
@ -295,7 +299,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
|||
ten := time.Second * 10
|
||||
|
||||
b := &Builder{filemgr: filemgr.NewManager()}
|
||||
routes, err := b.buildPolicyRoutes(&config.Options{
|
||||
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
|
||||
CookieName: "pomerium",
|
||||
DefaultUpstreamTimeout: time.Second * 3,
|
||||
Policies: []config.Policy{
|
||||
|
@ -357,7 +361,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
|||
UpstreamTimeout: &ten,
|
||||
},
|
||||
},
|
||||
}, "example.com")
|
||||
}), "example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
|
@ -724,7 +728,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
|||
`, routes)
|
||||
|
||||
t.Run("fronting-authenticate", func(t *testing.T) {
|
||||
routes, err := b.buildPolicyRoutes(&config.Options{
|
||||
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
|
||||
AuthenticateURLString: "https://authenticate.example.com",
|
||||
Services: "proxy",
|
||||
CookieName: "pomerium",
|
||||
|
@ -735,7 +739,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
|||
PassIdentityHeaders: true,
|
||||
},
|
||||
},
|
||||
}, "authenticate.example.com")
|
||||
}), "authenticate.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
|
@ -791,7 +795,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("tcp", func(t *testing.T) {
|
||||
routes, err := b.buildPolicyRoutes(&config.Options{
|
||||
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
|
||||
CookieName: "pomerium",
|
||||
DefaultUpstreamTimeout: time.Second * 3,
|
||||
Policies: []config.Policy{
|
||||
|
@ -805,7 +809,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
|||
UpstreamTimeout: &ten,
|
||||
},
|
||||
},
|
||||
}, "example.com:22")
|
||||
}), "example.com:22")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
|
@ -905,7 +909,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
|||
})
|
||||
|
||||
t.Run("remove-pomerium-headers", func(t *testing.T) {
|
||||
routes, err := b.buildPolicyRoutes(&config.Options{
|
||||
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
|
||||
AuthenticateURLString: "https://authenticate.example.com",
|
||||
Services: "proxy",
|
||||
CookieName: "pomerium",
|
||||
|
@ -918,7 +922,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
|||
Source: &config.StringURL{URL: mustParseURL(t, "https://from.example.com")},
|
||||
},
|
||||
},
|
||||
}, "from.example.com")
|
||||
}), "from.example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
|
@ -980,7 +984,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) {
|
|||
}(getClusterID)
|
||||
getClusterID = policyNameFunc()
|
||||
b := &Builder{filemgr: filemgr.NewManager()}
|
||||
routes, err := b.buildPolicyRoutes(&config.Options{
|
||||
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
|
||||
CookieName: "pomerium",
|
||||
DefaultUpstreamTimeout: time.Second * 3,
|
||||
Policies: []config.Policy{
|
||||
|
@ -1022,7 +1026,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) {
|
|||
HostPathRegexRewriteSubstitution: "\\1",
|
||||
},
|
||||
},
|
||||
}, "example.com")
|
||||
}), "example.com")
|
||||
require.NoError(t, err)
|
||||
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
|
|
|
@ -14,8 +14,10 @@ import (
|
|||
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||
)
|
||||
|
||||
func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Cluster, error) {
|
||||
tracingOptions, err := config.NewTracingOptions(options)
|
||||
func buildTracingCluster(
|
||||
cfg *config.Config,
|
||||
) (*envoy_config_cluster_v3.Cluster, error) {
|
||||
tracingOptions, err := config.NewTracingOptions(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("envoyconfig: invalid tracing config: %w", err)
|
||||
}
|
||||
|
@ -24,8 +26,8 @@ func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Clus
|
|||
case trace.DatadogTracingProviderName:
|
||||
addr, _ := parseAddress("127.0.0.1:8126")
|
||||
|
||||
if options.TracingDatadogAddress != "" {
|
||||
addr, err = parseAddress(options.TracingDatadogAddress)
|
||||
if cfg.Options.TracingDatadogAddress != "" {
|
||||
addr, err = parseAddress(cfg.Options.TracingDatadogAddress)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("envoyconfig: invalid tracing datadog address: %w", err)
|
||||
}
|
||||
|
@ -94,8 +96,10 @@ func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Clus
|
|||
}
|
||||
}
|
||||
|
||||
func buildTracingHTTP(options *config.Options) (*envoy_config_trace_v3.Tracing_Http, error) {
|
||||
tracingOptions, err := config.NewTracingOptions(options)
|
||||
func buildTracingHTTP(
|
||||
cfg *config.Config,
|
||||
) (*envoy_config_trace_v3.Tracing_Http, error) {
|
||||
tracingOptions, err := config.NewTracingOptions(cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid tracing config: %w", err)
|
||||
}
|
||||
|
|
|
@ -11,9 +11,9 @@ import (
|
|||
|
||||
func TestBuildTracingCluster(t *testing.T) {
|
||||
t.Run("datadog", func(t *testing.T) {
|
||||
c, err := buildTracingCluster(&config.Options{
|
||||
c, err := buildTracingCluster(config.New(&config.Options{
|
||||
TracingProvider: "datadog",
|
||||
})
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
{
|
||||
|
@ -38,10 +38,10 @@ func TestBuildTracingCluster(t *testing.T) {
|
|||
}
|
||||
`, c)
|
||||
|
||||
c, err = buildTracingCluster(&config.Options{
|
||||
c, err = buildTracingCluster(config.New(&config.Options{
|
||||
TracingProvider: "datadog",
|
||||
TracingDatadogAddress: "example.com:8126",
|
||||
})
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
{
|
||||
|
@ -67,10 +67,10 @@ func TestBuildTracingCluster(t *testing.T) {
|
|||
`, c)
|
||||
})
|
||||
t.Run("zipkin", func(t *testing.T) {
|
||||
c, err := buildTracingCluster(&config.Options{
|
||||
c, err := buildTracingCluster(config.New(&config.Options{
|
||||
TracingProvider: "zipkin",
|
||||
ZipkinEndpoint: "https://example.com/api/v2/spans",
|
||||
})
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
{
|
||||
|
@ -99,9 +99,9 @@ func TestBuildTracingCluster(t *testing.T) {
|
|||
|
||||
func TestBuildTracingHTTP(t *testing.T) {
|
||||
t.Run("datadog", func(t *testing.T) {
|
||||
h, err := buildTracingHTTP(&config.Options{
|
||||
h, err := buildTracingHTTP(config.New(&config.Options{
|
||||
TracingProvider: "datadog",
|
||||
})
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
{
|
||||
|
@ -115,10 +115,10 @@ func TestBuildTracingHTTP(t *testing.T) {
|
|||
`, h)
|
||||
})
|
||||
t.Run("zipkin", func(t *testing.T) {
|
||||
h, err := buildTracingHTTP(&config.Options{
|
||||
h, err := buildTracingHTTP(config.New(&config.Options{
|
||||
TracingProvider: "zipkin",
|
||||
ZipkinEndpoint: "https://example.com/api/v2/spans",
|
||||
})
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `
|
||||
{
|
||||
|
|
|
@ -18,28 +18,28 @@ import (
|
|||
type TracingOptions = trace.TracingOptions
|
||||
|
||||
// NewTracingOptions builds a new TracingOptions from core Options
|
||||
func NewTracingOptions(o *Options) (*TracingOptions, error) {
|
||||
func NewTracingOptions(cfg *Config) (*TracingOptions, error) {
|
||||
tracingOpts := TracingOptions{
|
||||
Provider: o.TracingProvider,
|
||||
Service: telemetry.ServiceName(o.Services),
|
||||
JaegerAgentEndpoint: o.TracingJaegerAgentEndpoint,
|
||||
SampleRate: o.TracingSampleRate,
|
||||
Provider: cfg.Options.TracingProvider,
|
||||
Service: telemetry.ServiceName(cfg.Options.Services),
|
||||
JaegerAgentEndpoint: cfg.Options.TracingJaegerAgentEndpoint,
|
||||
SampleRate: cfg.Options.TracingSampleRate,
|
||||
}
|
||||
|
||||
switch o.TracingProvider {
|
||||
switch cfg.Options.TracingProvider {
|
||||
case trace.DatadogTracingProviderName:
|
||||
tracingOpts.DatadogAddress = o.TracingDatadogAddress
|
||||
tracingOpts.DatadogAddress = cfg.Options.TracingDatadogAddress
|
||||
case trace.JaegerTracingProviderName:
|
||||
if o.TracingJaegerCollectorEndpoint != "" {
|
||||
jaegerCollectorEndpoint, err := urlutil.ParseAndValidateURL(o.TracingJaegerCollectorEndpoint)
|
||||
if cfg.Options.TracingJaegerCollectorEndpoint != "" {
|
||||
jaegerCollectorEndpoint, err := urlutil.ParseAndValidateURL(cfg.Options.TracingJaegerCollectorEndpoint)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config: invalid jaeger endpoint url: %w", err)
|
||||
}
|
||||
tracingOpts.JaegerCollectorEndpoint = jaegerCollectorEndpoint
|
||||
tracingOpts.JaegerAgentEndpoint = o.TracingJaegerAgentEndpoint
|
||||
tracingOpts.JaegerAgentEndpoint = cfg.Options.TracingJaegerAgentEndpoint
|
||||
}
|
||||
case trace.ZipkinTracingProviderName:
|
||||
zipkinEndpoint, err := urlutil.ParseAndValidateURL(o.ZipkinEndpoint)
|
||||
zipkinEndpoint, err := urlutil.ParseAndValidateURL(cfg.Options.ZipkinEndpoint)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("config: invalid zipkin endpoint url: %w", err)
|
||||
}
|
||||
|
@ -47,7 +47,7 @@ func NewTracingOptions(o *Options) (*TracingOptions, error) {
|
|||
case "":
|
||||
return &TracingOptions{}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("config: provider %s unknown", o.TracingProvider)
|
||||
return nil, fmt.Errorf("config: provider %s unknown", cfg.Options.TracingProvider)
|
||||
}
|
||||
|
||||
return &tracingOpts, nil
|
||||
|
@ -88,7 +88,7 @@ func (mgr *TraceManager) OnConfigChange(ctx context.Context, cfg *Config) {
|
|||
mgr.mu.Lock()
|
||||
defer mgr.mu.Unlock()
|
||||
|
||||
traceOpts, err := NewTracingOptions(cfg.Options)
|
||||
traceOpts, err := NewTracingOptions(cfg)
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("trace: failed to build tracing options")
|
||||
return
|
||||
|
|
|
@ -68,7 +68,7 @@ func Test_NewTracingOptions(t *testing.T) {
|
|||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := NewTracingOptions(tt.opts)
|
||||
got, err := NewTracingOptions(&Config{Options: tt.opts})
|
||||
assert.NotEqual(t, err == nil, tt.wantErr, "unexpected error value")
|
||||
assert.Empty(t, cmp.Diff(tt.want, got))
|
||||
})
|
||||
|
|
|
@ -28,7 +28,7 @@ func TestNew(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.opts.Provider = "google"
|
||||
_, err := New(&config.Config{Options: &tt.opts}, events.New())
|
||||
_, err := New(config.New(&tt.opts), events.New())
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
|
|
@ -236,8 +236,8 @@ func TestConfig(t *testing.T) {
|
|||
}
|
||||
_ = p1.Validate()
|
||||
|
||||
mgr, err := newManager(ctx, config.NewStaticSource(&config.Config{
|
||||
Options: &config.Options{
|
||||
mgr, err := newManager(ctx,
|
||||
config.NewStaticSource(config.New(&config.Options{
|
||||
AutocertOptions: config.AutocertOptions{
|
||||
Enable: true,
|
||||
UseStaging: true,
|
||||
|
@ -247,8 +247,8 @@ func TestConfig(t *testing.T) {
|
|||
},
|
||||
HTTPRedirectAddr: addr,
|
||||
Policies: []config.Policy{p1},
|
||||
},
|
||||
}), certmagic.ACMEIssuer{
|
||||
})),
|
||||
certmagic.ACMEIssuer{
|
||||
CA: srv.URL + "/acme/directory",
|
||||
TestCA: srv.URL + "/acme/directory",
|
||||
}, time.Millisecond*100)
|
||||
|
@ -304,16 +304,14 @@ func TestRedirect(t *testing.T) {
|
|||
addr := li.Addr().String()
|
||||
_ = li.Close()
|
||||
|
||||
src := config.NewStaticSource(&config.Config{
|
||||
Options: &config.Options{
|
||||
src := config.NewStaticSource(config.New(&config.Options{
|
||||
HTTPRedirectAddr: addr,
|
||||
SetResponseHeaders: map[string]string{
|
||||
"X-Frame-Options": "SAMEORIGIN",
|
||||
"X-XSS-Protection": "1; mode=block",
|
||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
||||
},
|
||||
},
|
||||
})
|
||||
}))
|
||||
_, err = New(src)
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
|
|
|
@ -28,7 +28,7 @@ import (
|
|||
type Handler struct {
|
||||
mu sync.RWMutex
|
||||
key []byte
|
||||
options *config.Options
|
||||
cfg *config.Config
|
||||
policies map[uint64]config.Policy
|
||||
}
|
||||
|
||||
|
@ -83,7 +83,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler {
|
|||
}
|
||||
|
||||
h.mu.RLock()
|
||||
options := h.options
|
||||
options := h.cfg.Options
|
||||
policy, ok := h.policies[policyID]
|
||||
h.mu.RUnlock()
|
||||
|
||||
|
@ -132,7 +132,7 @@ func (h *Handler) Update(ctx context.Context, cfg *config.Config) {
|
|||
defer h.mu.Unlock()
|
||||
|
||||
h.key, _ = cfg.Options.GetSharedKey()
|
||||
h.options = cfg.Options
|
||||
h.cfg = cfg
|
||||
h.policies = make(map[uint64]config.Policy)
|
||||
for _, p := range cfg.Options.GetAllPolicies() {
|
||||
id, err := p.RouteID()
|
||||
|
|
|
@ -49,15 +49,13 @@ func TestMiddleware(t *testing.T) {
|
|||
srv2 := httptest.NewServer(h.Middleware(next))
|
||||
defer srv2.Close()
|
||||
|
||||
cfg := &config.Config{
|
||||
Options: &config.Options{
|
||||
cfg := config.New(&config.Options{
|
||||
SharedKey: cryptutil.NewBase64Key(),
|
||||
Policies: []config.Policy{{
|
||||
To: config.WeightedURLs{{URL: *u}},
|
||||
KubernetesServiceAccountToken: "ABCD",
|
||||
}},
|
||||
},
|
||||
}
|
||||
})
|
||||
h.Update(context.Background(), cfg)
|
||||
|
||||
policyID, _ := cfg.Options.Policies[0].RouteID()
|
||||
|
|
|
@ -80,11 +80,11 @@ func TestProxy_ForwardAuth(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, err := New(&config.Config{Options: tt.options})
|
||||
p, err := New(config.New(tt.options))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.OnConfigChange(ctx, &config.Config{Options: tt.options})
|
||||
p.OnConfigChange(ctx, config.New(tt.options))
|
||||
state := p.state.Load()
|
||||
state.sessionStore = tt.sessionStore
|
||||
signer, err := jws.NewHS256Signer(nil)
|
||||
|
|
|
@ -58,7 +58,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) error {
|
|||
state := p.state.Load()
|
||||
|
||||
var redirectURL *url.URL
|
||||
signOutURL, err := p.currentOptions.Load().GetSignOutRedirectURL()
|
||||
signOutURL, err := p.currentConfig.Load().Options.GetSignOutRedirectURL()
|
||||
if err != nil {
|
||||
return httputil.NewError(http.StatusInternalServerError, err)
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ func TestProxy_Signout(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
proxy, err := New(&config.Config{Options: opts})
|
||||
proxy, err := New(config.New(opts))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ func TestProxy_userInfo(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
proxy, err := New(&config.Config{Options: opts})
|
||||
proxy, err := New(config.New(opts))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -102,7 +102,7 @@ func TestProxy_SignOut(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
opts := testOptions(t)
|
||||
p, err := New(&config.Config{Options: opts})
|
||||
p, err := New(config.New(opts))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -236,11 +236,11 @@ func TestProxy_Callback(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, err := New(&config.Config{Options: tt.options})
|
||||
p, err := New(config.New(tt.options))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
|
||||
p.OnConfigChange(context.Background(), config.New(tt.options))
|
||||
state := p.state.Load()
|
||||
state.encoder = tt.cipher
|
||||
state.sessionStore = tt.sessionStore
|
||||
|
@ -350,7 +350,7 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, err := New(&config.Config{Options: tt.options})
|
||||
p, err := New(config.New(tt.options))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -482,11 +482,11 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, err := New(&config.Config{Options: tt.options})
|
||||
p, err := New(config.New(tt.options))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
|
||||
p.OnConfigChange(context.Background(), config.New(tt.options))
|
||||
state := p.state.Load()
|
||||
state.encoder = tt.cipher
|
||||
state.sessionStore = tt.sessionStore
|
||||
|
|
|
@ -52,7 +52,7 @@ func ValidateOptions(o *config.Options) error {
|
|||
// Proxy stores all the information associated with proxying a request.
|
||||
type Proxy struct {
|
||||
state *atomicutil.Value[*proxyState]
|
||||
currentOptions *atomicutil.Value[*config.Options]
|
||||
currentConfig *atomicutil.Value[*config.Config]
|
||||
currentRouter *atomicutil.Value[*mux.Router]
|
||||
}
|
||||
|
||||
|
@ -66,12 +66,12 @@ func New(cfg *config.Config) (*Proxy, error) {
|
|||
|
||||
p := &Proxy{
|
||||
state: atomicutil.NewValue(state),
|
||||
currentOptions: config.NewAtomicOptions(),
|
||||
currentConfig: atomicutil.NewValue(cfg),
|
||||
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
|
||||
}
|
||||
|
||||
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
|
||||
return int64(len(p.currentOptions.Load().GetAllPolicies()))
|
||||
return int64(len(p.currentConfig.Load().Options.GetAllPolicies()))
|
||||
})
|
||||
|
||||
return p, nil
|
||||
|
@ -88,8 +88,8 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
|||
return
|
||||
}
|
||||
|
||||
p.currentOptions.Store(cfg.Options)
|
||||
if err := p.setHandlers(cfg.Options); err != nil {
|
||||
p.currentConfig.Store(cfg)
|
||||
if err := p.setHandlers(cfg); err != nil {
|
||||
log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
|
||||
}
|
||||
if state, err := newProxyStateFromConfig(cfg); err != nil {
|
||||
|
@ -99,8 +99,8 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *Proxy) setHandlers(opts *config.Options) error {
|
||||
if len(opts.GetAllPolicies()) == 0 {
|
||||
func (p *Proxy) setHandlers(cfg *config.Config) error {
|
||||
if len(cfg.Options.GetAllPolicies()) == 0 {
|
||||
log.Warn(context.TODO()).Msg("proxy: configuration has no policies")
|
||||
}
|
||||
r := httputil.NewRouter()
|
||||
|
@ -113,7 +113,7 @@ func (p *Proxy) setHandlers(opts *config.Options) error {
|
|||
// dashboard handlers are registered to all routes
|
||||
r = p.registerDashboardHandlers(r)
|
||||
|
||||
forwardAuthURL, err := opts.GetForwardAuthURL()
|
||||
forwardAuthURL, err := cfg.Options.GetForwardAuthURL()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -99,7 +99,7 @@ func TestNew(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := New(&config.Config{Options: tt.opts})
|
||||
got, err := New(config.New(tt.opts))
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
|
@ -194,12 +194,12 @@ func Test_UpdateOptions(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
p, err := New(&config.Config{Options: tt.originalOptions})
|
||||
p, err := New(config.New(tt.originalOptions))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p.OnConfigChange(context.Background(), &config.Config{Options: tt.updatedOptions})
|
||||
p.OnConfigChange(context.Background(), config.New(tt.updatedOptions))
|
||||
r := httptest.NewRequest("GET", tt.host, nil)
|
||||
w := httptest.NewRecorder()
|
||||
p.ServeHTTP(w, r)
|
||||
|
@ -212,5 +212,5 @@ func Test_UpdateOptions(t *testing.T) {
|
|||
|
||||
// Test nil
|
||||
var p *Proxy
|
||||
p.OnConfigChange(context.Background(), &config.Config{})
|
||||
p.OnConfigChange(context.Background(), config.New(nil))
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue