mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-23 21:18:09 +02:00
config: use config.Config instead of config.Options everywhere
This commit is contained in:
parent
5f51510e91
commit
1b80e8a6c2
40 changed files with 484 additions and 412 deletions
|
@ -39,18 +39,18 @@ func ValidateOptions(o *config.Options) error {
|
||||||
|
|
||||||
// Authenticate contains data required to run the authenticate service.
|
// Authenticate contains data required to run the authenticate service.
|
||||||
type Authenticate struct {
|
type Authenticate struct {
|
||||||
cfg *authenticateConfig
|
cfg *authenticateConfig
|
||||||
options *atomicutil.Value[*config.Options]
|
currentConfig *atomicutil.Value[*config.Config]
|
||||||
state *atomicutil.Value[*authenticateState]
|
state *atomicutil.Value[*authenticateState]
|
||||||
webauthn *webauthn.Handler
|
webauthn *webauthn.Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
// New validates and creates a new authenticate service from a set of Options.
|
// New validates and creates a new authenticate service from a set of Options.
|
||||||
func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
|
func New(cfg *config.Config, options ...Option) (*Authenticate, error) {
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
cfg: getAuthenticateConfig(options...),
|
cfg: getAuthenticateConfig(options...),
|
||||||
options: config.NewAtomicOptions(),
|
currentConfig: atomicutil.NewValue(cfg),
|
||||||
state: atomicutil.NewValue(newAuthenticateState()),
|
state: atomicutil.NewValue(newAuthenticateState()),
|
||||||
}
|
}
|
||||||
a.webauthn = webauthn.New(a.getWebauthnState)
|
a.webauthn = webauthn.New(a.getWebauthnState)
|
||||||
|
|
||||||
|
@ -69,7 +69,7 @@ func (a *Authenticate) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
a.options.Store(cfg.Options)
|
a.currentConfig.Store(cfg)
|
||||||
if state, err := newAuthenticateStateFromConfig(cfg); err != nil {
|
if state, err := newAuthenticateStateFromConfig(cfg); err != nil {
|
||||||
log.Error(ctx).Err(err).Msg("authenticate: failed to update state")
|
log.Error(ctx).Err(err).Msg("authenticate: failed to update state")
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -113,7 +113,7 @@ func TestNew(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
tt := tt
|
tt := tt
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
_, err := New(&config.Config{Options: tt.opts})
|
_, err := New(config.New(tt.opts))
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
type authenticateConfig struct {
|
type authenticateConfig struct {
|
||||||
getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)
|
getIdentityProvider func(cfg *config.Config, idpID string) (identity.Authenticator, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
// An Option customizes the Authenticate config.
|
// An Option customizes the Authenticate config.
|
||||||
|
@ -22,7 +22,7 @@ func getAuthenticateConfig(options ...Option) *authenticateConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithGetIdentityProvider sets the getIdentityProvider function in the config.
|
// WithGetIdentityProvider sets the getIdentityProvider function in the config.
|
||||||
func WithGetIdentityProvider(getIdentityProvider func(options *config.Options, idpID string) (identity.Authenticator, error)) Option {
|
func WithGetIdentityProvider(getIdentityProvider func(cfg *config.Config, idpID string) (identity.Authenticator, error)) Option {
|
||||||
return func(cfg *authenticateConfig) {
|
return func(cfg *authenticateConfig) {
|
||||||
cfg.getIdentityProvider = getIdentityProvider
|
cfg.getIdentityProvider = getIdentityProvider
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,12 +47,12 @@ func (a *Authenticate) Mount(r *mux.Router) {
|
||||||
r.StrictSlash(true)
|
r.StrictSlash(true)
|
||||||
r.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
|
r.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy))
|
||||||
r.Use(func(h http.Handler) http.Handler {
|
r.Use(func(h http.Handler) http.Handler {
|
||||||
options := a.options.Load()
|
cfg := a.currentConfig.Load()
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
csrfKey := fmt.Sprintf("%s_csrf", options.CookieName)
|
csrfKey := fmt.Sprintf("%s_csrf", cfg.Options.CookieName)
|
||||||
return csrf.Protect(
|
return csrf.Protect(
|
||||||
state.cookieSecret,
|
state.cookieSecret,
|
||||||
csrf.Secure(options.CookieSecure),
|
csrf.Secure(cfg.Options.CookieSecure),
|
||||||
csrf.Path("/"),
|
csrf.Path("/"),
|
||||||
csrf.UnsafePaths(
|
csrf.UnsafePaths(
|
||||||
[]string{
|
[]string{
|
||||||
|
@ -256,7 +256,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) error {
|
||||||
// check for an HMAC'd URL. If none is found, show a confirmation page.
|
// check for an HMAC'd URL. If none is found, show a confirmation page.
|
||||||
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
|
err := middleware.ValidateRequestURL(a.getExternalRequest(r), a.state.Load().sharedKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
authenticateURL, err := a.options.Load().GetAuthenticateURL()
|
authenticateURL, err := a.currentConfig.Load().Options.GetAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -275,9 +275,9 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
|
||||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
|
ctx, span := trace.StartSpan(r.Context(), "authenticate.SignOut")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
options := a.options.Load()
|
cfg := a.currentConfig.Load()
|
||||||
|
|
||||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -285,7 +285,7 @@ func (a *Authenticate) signOutRedirect(w http.ResponseWriter, r *http.Request) e
|
||||||
rawIDToken := a.revokeSession(ctx, w, r)
|
rawIDToken := a.revokeSession(ctx, w, r)
|
||||||
|
|
||||||
redirectString := ""
|
redirectString := ""
|
||||||
signOutURL, err := a.options.Load().GetSignOutRedirectURL()
|
signOutURL, err := cfg.Options.GetSignOutRedirectURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -330,10 +330,10 @@ func (a *Authenticate) reauthenticateOrFail(w http.ResponseWriter, r *http.Reque
|
||||||
return httputil.NewError(http.StatusUnauthorized, err)
|
return httputil.NewError(http.StatusUnauthorized, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
options := a.options.Load()
|
cfg := a.currentConfig.Load()
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
|
||||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -381,7 +381,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
|
ctx, span := trace.StartSpan(r.Context(), "authenticate.getOAuthCallback")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
options := a.options.Load()
|
cfg := a.currentConfig.Load()
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
|
||||||
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
|
// Error Authentication Response: rfc6749#section-4.1.2.1 & OIDC#3.1.2.6
|
||||||
|
@ -430,7 +430,7 @@ func (a *Authenticate) getOAuthCallback(w http.ResponseWriter, r *http.Request)
|
||||||
}
|
}
|
||||||
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)
|
idpID := redirectURL.Query().Get(urlutil.QueryIdentityProviderID)
|
||||||
|
|
||||||
idp, err := a.cfg.getIdentityProvider(options, idpID)
|
idp, err := a.cfg.getIdentityProvider(cfg, idpID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -513,7 +513,7 @@ func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error {
|
||||||
func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData, error) {
|
func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData, error) {
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
|
||||||
authenticateURL, err := a.options.Load().GetAuthenticateURL()
|
authenticateURL, err := a.currentConfig.Load().Options.GetAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return handlers.UserInfoData{}, err
|
return handlers.UserInfoData{}, err
|
||||||
}
|
}
|
||||||
|
@ -569,7 +569,7 @@ func (a *Authenticate) getUserInfoData(r *http.Request) (handlers.UserInfoData,
|
||||||
WebAuthnRequestOptions: requestOptions,
|
WebAuthnRequestOptions: requestOptions,
|
||||||
WebAuthnURL: urlutil.WebAuthnURL(r, authenticateURL, state.sharedKey, r.URL.Query()),
|
WebAuthnURL: urlutil.WebAuthnURL(r, authenticateURL, state.sharedKey, r.URL.Query()),
|
||||||
|
|
||||||
BrandingOptions: a.options.Load().BrandingOptions,
|
BrandingOptions: a.currentConfig.Load().Options.BrandingOptions,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -581,14 +581,14 @@ func (a *Authenticate) saveSessionToDataBroker(
|
||||||
accessToken *oauth2.Token,
|
accessToken *oauth2.Token,
|
||||||
) error {
|
) error {
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
options := a.options.Load()
|
cfg := a.currentConfig.Load()
|
||||||
|
|
||||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
sessionExpiry := timestamppb.New(time.Now().Add(options.CookieExpire))
|
sessionExpiry := timestamppb.New(time.Now().Add(cfg.Options.CookieExpire))
|
||||||
idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time())
|
idTokenIssuedAt := timestamppb.New(sessionState.IssuedAt.Time())
|
||||||
|
|
||||||
s := &session.Session{
|
s := &session.Session{
|
||||||
|
@ -648,13 +648,13 @@ func (a *Authenticate) saveSessionToDataBroker(
|
||||||
// databroker. If successful, it returns the original `id_token` of the session, if failed, returns
|
// databroker. If successful, it returns the original `id_token` of the session, if failed, returns
|
||||||
// and empty string.
|
// and empty string.
|
||||||
func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string {
|
func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, r *http.Request) string {
|
||||||
options := a.options.Load()
|
cfg := a.currentConfig.Load()
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
|
||||||
// clear the user's local session no matter what
|
// clear the user's local session no matter what
|
||||||
defer state.sessionStore.ClearSession(w, r)
|
defer state.sessionStore.ClearSession(w, r)
|
||||||
|
|
||||||
idp, err := a.cfg.getIdentityProvider(options, r.FormValue(urlutil.QueryIdentityProviderID))
|
idp, err := a.cfg.getIdentityProvider(cfg, r.FormValue(urlutil.QueryIdentityProviderID))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
@ -708,6 +708,7 @@ func (a *Authenticate) getDirectoryUser(ctx context.Context, userID string) (*di
|
||||||
|
|
||||||
func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, error) {
|
func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, error) {
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
|
cfg := a.currentConfig.Load()
|
||||||
|
|
||||||
s, _, err := a.getCurrentSession(ctx)
|
s, _, err := a.getCurrentSession(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -719,17 +720,17 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
authenticateURL, err := a.options.Load().GetAuthenticateURL()
|
authenticateURL, err := cfg.Options.GetAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
internalAuthenticateURL, err := a.options.Load().GetInternalAuthenticateURL()
|
internalAuthenticateURL, err := cfg.Options.GetInternalAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pomeriumDomains, err := a.options.Load().GetAllRouteableHTTPDomains()
|
pomeriumDomains, err := cfg.Options.GetAllRouteableHTTPDomains()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -744,7 +745,7 @@ func (a *Authenticate) getWebauthnState(ctx context.Context) (*webauthn.State, e
|
||||||
SessionState: ss,
|
SessionState: ss,
|
||||||
SessionStore: state.sessionStore,
|
SessionStore: state.sessionStore,
|
||||||
RelyingParty: state.webauthnRelyingParty,
|
RelyingParty: state.webauthnRelyingParty,
|
||||||
BrandingOptions: a.options.Load().BrandingOptions,
|
BrandingOptions: cfg.Options.BrandingOptions,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -49,10 +49,9 @@ func testAuthenticate() *Authenticate {
|
||||||
redirectURL: redirectURL,
|
redirectURL: redirectURL,
|
||||||
cookieSecret: cryptutil.NewKey(),
|
cookieSecret: cryptutil.NewKey(),
|
||||||
})
|
})
|
||||||
auth.options = config.NewAtomicOptions()
|
auth.currentConfig = atomicutil.NewValue(config.New(&config.Options{
|
||||||
auth.options.Store(&config.Options{
|
|
||||||
SharedKey: cryptutil.NewBase64Key(),
|
SharedKey: cryptutil.NewBase64Key(),
|
||||||
})
|
}))
|
||||||
return &auth
|
return &auth
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -148,7 +147,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
sharedCipher, _ := cryptutil.NewAEADCipherFromBase64(cryptutil.NewBase64Key())
|
||||||
|
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
|
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
|
||||||
return tt.provider, nil
|
return tt.provider, nil
|
||||||
})),
|
})),
|
||||||
state: atomicutil.NewValue(&authenticateState{
|
state: atomicutil.NewValue(&authenticateState{
|
||||||
|
@ -168,10 +167,10 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
},
|
},
|
||||||
directoryClient: new(mockDirectoryServiceClient),
|
directoryClient: new(mockDirectoryServiceClient),
|
||||||
}),
|
}),
|
||||||
|
currentConfig: atomicutil.NewValue(config.New(&config.Options{
|
||||||
options: config.NewAtomicOptions(),
|
SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey()),
|
||||||
|
})),
|
||||||
}
|
}
|
||||||
a.options.Store(&config.Options{SharedKey: base64.StdEncoding.EncodeToString(cryptutil.NewKey())})
|
|
||||||
uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
|
uri := &url.URL{Scheme: tt.scheme, Host: tt.host}
|
||||||
|
|
||||||
queryString := uri.Query()
|
queryString := uri.Query()
|
||||||
|
@ -304,7 +303,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
defer ctrl.Finish()
|
defer ctrl.Finish()
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
|
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
|
||||||
return tt.provider, nil
|
return tt.provider, nil
|
||||||
})),
|
})),
|
||||||
state: atomicutil.NewValue(&authenticateState{
|
state: atomicutil.NewValue(&authenticateState{
|
||||||
|
@ -325,12 +324,9 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
},
|
},
|
||||||
directoryClient: new(mockDirectoryServiceClient),
|
directoryClient: new(mockDirectoryServiceClient),
|
||||||
}),
|
}),
|
||||||
options: config.NewAtomicOptions(),
|
currentConfig: atomicutil.NewValue(config.New(&config.Options{
|
||||||
}
|
SignOutRedirectURLString: tt.signoutRedirectURL,
|
||||||
if tt.signoutRedirectURL != "" {
|
})),
|
||||||
opts := a.options.Load()
|
|
||||||
opts.SignOutRedirectURLString = tt.signoutRedirectURL
|
|
||||||
a.options.Store(opts)
|
|
||||||
}
|
}
|
||||||
u, _ := url.Parse("/sign_out")
|
u, _ := url.Parse("/sign_out")
|
||||||
params, _ := url.ParseQuery(u.RawQuery)
|
params, _ := url.ParseQuery(u.RawQuery)
|
||||||
|
@ -417,7 +413,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
authURL, _ := url.Parse(tt.authenticateURL)
|
authURL, _ := url.Parse(tt.authenticateURL)
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
|
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
|
||||||
return tt.provider, nil
|
return tt.provider, nil
|
||||||
})),
|
})),
|
||||||
state: atomicutil.NewValue(&authenticateState{
|
state: atomicutil.NewValue(&authenticateState{
|
||||||
|
@ -435,7 +431,7 @@ func TestAuthenticate_OAuthCallback(t *testing.T) {
|
||||||
cookieCipher: aead,
|
cookieCipher: aead,
|
||||||
encryptedEncoder: signer,
|
encryptedEncoder: signer,
|
||||||
}),
|
}),
|
||||||
options: config.NewAtomicOptions(),
|
currentConfig: atomicutil.NewValue(config.New(nil)),
|
||||||
}
|
}
|
||||||
u, _ := url.Parse("/oauthGet")
|
u, _ := url.Parse("/oauthGet")
|
||||||
params, _ := url.ParseQuery(u.RawQuery)
|
params, _ := url.ParseQuery(u.RawQuery)
|
||||||
|
@ -552,7 +548,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(options *config.Options, idpID string) (identity.Authenticator, error) {
|
cfg: getAuthenticateConfig(WithGetIdentityProvider(func(cfg *config.Config, idpID string) (identity.Authenticator, error) {
|
||||||
return tt.provider, nil
|
return tt.provider, nil
|
||||||
})),
|
})),
|
||||||
state: atomicutil.NewValue(&authenticateState{
|
state: atomicutil.NewValue(&authenticateState{
|
||||||
|
@ -573,7 +569,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
},
|
},
|
||||||
directoryClient: new(mockDirectoryServiceClient),
|
directoryClient: new(mockDirectoryServiceClient),
|
||||||
}),
|
}),
|
||||||
options: config.NewAtomicOptions(),
|
currentConfig: atomicutil.NewValue(config.New(nil)),
|
||||||
}
|
}
|
||||||
r := httptest.NewRequest("GET", "/", nil)
|
r := httptest.NewRequest("GET", "/", nil)
|
||||||
state, err := tt.session.LoadSession(r)
|
state, err := tt.session.LoadSession(r)
|
||||||
|
@ -604,7 +600,7 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
func TestJwksEndpoint(t *testing.T) {
|
func TestJwksEndpoint(t *testing.T) {
|
||||||
o := newTestOptions(t)
|
o := newTestOptions(t)
|
||||||
o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
|
o.SigningKey = "LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IY0NBUUVFSUpCMFZkbko1VjEvbVlpYUlIWHhnd2Q0Yzd5YWRTeXMxb3Y0bzA1b0F3ekdvQW9HQ0NxR1NNNDkKQXdFSG9VUURRZ0FFVUc1eENQMEpUVDFINklvbDhqS3VUSVBWTE0wNENnVzlQbEV5cE5SbVdsb29LRVhSOUhUMwpPYnp6aktZaWN6YjArMUt3VjJmTVRFMTh1dy82MXJVQ0JBPT0KLS0tLS1FTkQgRUMgUFJJVkFURSBLRVktLS0tLQo="
|
||||||
auth, err := New(&config.Config{Options: o})
|
auth, err := New(config.New(o))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
return
|
return
|
||||||
|
@ -632,12 +628,11 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
||||||
a.state = atomicutil.NewValue(&authenticateState{
|
a.state = atomicutil.NewValue(&authenticateState{
|
||||||
cookieSecret: cryptutil.NewKey(),
|
cookieSecret: cryptutil.NewKey(),
|
||||||
})
|
})
|
||||||
a.options = config.NewAtomicOptions()
|
a.currentConfig = atomicutil.NewValue(config.New(&config.Options{
|
||||||
a.options.Store(&config.Options{
|
|
||||||
SharedKey: cryptutil.NewBase64Key(),
|
SharedKey: cryptutil.NewBase64Key(),
|
||||||
AuthenticateURLString: "https://authenticate.example.com",
|
AuthenticateURLString: "https://authenticate.example.com",
|
||||||
AuthenticateInternalURLString: "https://authenticate.service.cluster.local",
|
AuthenticateInternalURLString: "https://authenticate.service.cluster.local",
|
||||||
})
|
}))
|
||||||
err := a.userInfo(w, r)
|
err := a.userInfo(w, r)
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
assert.Equal(t, http.StatusFound, w.Code)
|
assert.Equal(t, http.StatusFound, w.Code)
|
||||||
|
@ -687,13 +682,7 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
o := config.NewAtomicOptions()
|
|
||||||
o.Store(&config.Options{
|
|
||||||
AuthenticateURLString: "https://authenticate.localhost.pomerium.io",
|
|
||||||
SharedKey: "SHARED KEY",
|
|
||||||
})
|
|
||||||
a := &Authenticate{
|
a := &Authenticate{
|
||||||
options: o,
|
|
||||||
state: atomicutil.NewValue(&authenticateState{
|
state: atomicutil.NewValue(&authenticateState{
|
||||||
sessionStore: tt.sessionStore,
|
sessionStore: tt.sessionStore,
|
||||||
encryptedEncoder: signer,
|
encryptedEncoder: signer,
|
||||||
|
@ -711,6 +700,10 @@ func TestAuthenticate_userInfo(t *testing.T) {
|
||||||
},
|
},
|
||||||
directoryClient: new(mockDirectoryServiceClient),
|
directoryClient: new(mockDirectoryServiceClient),
|
||||||
}),
|
}),
|
||||||
|
currentConfig: atomicutil.NewValue(config.New(&config.Options{
|
||||||
|
AuthenticateURLString: "https://authenticate.localhost.pomerium.io",
|
||||||
|
SharedKey: "SHARED KEY",
|
||||||
|
})),
|
||||||
}
|
}
|
||||||
a.webauthn = webauthn.New(a.getWebauthnState)
|
a.webauthn = webauthn.New(a.getWebauthnState)
|
||||||
r := httptest.NewRequest(tt.method, tt.url.String(), nil)
|
r := httptest.NewRequest(tt.method, tt.url.String(), nil)
|
||||||
|
|
|
@ -7,8 +7,8 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func defaultGetIdentityProvider(options *config.Options, idpID string) (identity.Authenticator, error) {
|
func defaultGetIdentityProvider(cfg *config.Config, idpID string) (identity.Authenticator, error) {
|
||||||
authenticateURL, err := options.GetAuthenticateURL()
|
authenticateURL, err := cfg.Options.GetAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -17,9 +17,9 @@ func defaultGetIdentityProvider(options *config.Options, idpID string) (identity
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
redirectURL.Path = options.AuthenticateCallbackPath
|
redirectURL.Path = cfg.Options.AuthenticateCallbackPath
|
||||||
|
|
||||||
idp, err := options.GetIdentityProviderForID(idpID)
|
idp, err := cfg.Options.GetIdentityProviderForID(idpID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -34,14 +34,14 @@ func (a *Authenticate) requireValidSignature(next httputil.HandlerFunc) http.Han
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authenticate) getExternalRequest(r *http.Request) *http.Request {
|
func (a *Authenticate) getExternalRequest(r *http.Request) *http.Request {
|
||||||
options := a.options.Load()
|
cfg := a.currentConfig.Load()
|
||||||
|
|
||||||
externalURL, err := options.GetAuthenticateURL()
|
externalURL, err := cfg.Options.GetAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
internalURL, err := options.GetInternalAuthenticateURL()
|
internalURL, err := cfg.Options.GetInternalAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
|
@ -25,11 +25,11 @@ import (
|
||||||
|
|
||||||
// Authorize struct holds
|
// Authorize struct holds
|
||||||
type Authorize struct {
|
type Authorize struct {
|
||||||
state *atomicutil.Value[*authorizeState]
|
state *atomicutil.Value[*authorizeState]
|
||||||
store *store.Store
|
store *store.Store
|
||||||
currentOptions *atomicutil.Value[*config.Options]
|
currentConfig *atomicutil.Value[*config.Config]
|
||||||
accessTracker *AccessTracker
|
accessTracker *AccessTracker
|
||||||
globalCache storage.Cache
|
globalCache storage.Cache
|
||||||
|
|
||||||
// The stateLock prevents updating the evaluator store simultaneously with an evaluation.
|
// The stateLock prevents updating the evaluator store simultaneously with an evaluation.
|
||||||
// This should provide a consistent view of the data at a given server/record version and
|
// This should provide a consistent view of the data at a given server/record version and
|
||||||
|
@ -40,9 +40,9 @@ type Authorize struct {
|
||||||
// New validates and creates a new Authorize service from a set of config options.
|
// New validates and creates a new Authorize service from a set of config options.
|
||||||
func New(cfg *config.Config) (*Authorize, error) {
|
func New(cfg *config.Config) (*Authorize, error) {
|
||||||
a := &Authorize{
|
a := &Authorize{
|
||||||
currentOptions: config.NewAtomicOptions(),
|
currentConfig: atomicutil.NewValue(cfg),
|
||||||
store: store.New(),
|
store: store.New(),
|
||||||
globalCache: storage.NewGlobalCache(time.Minute),
|
globalCache: storage.NewGlobalCache(time.Minute),
|
||||||
}
|
}
|
||||||
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
a.accessTracker = NewAccessTracker(a, accessTrackerMaxSize, accessTrackerDebouncePeriod)
|
||||||
|
|
||||||
|
@ -86,42 +86,42 @@ func validateOptions(o *config.Options) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// newPolicyEvaluator returns an policy evaluator.
|
// newPolicyEvaluator returns an policy evaluator.
|
||||||
func newPolicyEvaluator(opts *config.Options, store *store.Store) (*evaluator.Evaluator, error) {
|
func newPolicyEvaluator(cfg *config.Config, store *store.Store) (*evaluator.Evaluator, error) {
|
||||||
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
|
metrics.AddPolicyCountCallback("pomerium-authorize", func() int64 {
|
||||||
return int64(len(opts.GetAllPolicies()))
|
return int64(len(cfg.Options.GetAllPolicies()))
|
||||||
})
|
})
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
_, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator")
|
_, span := trace.StartSpan(ctx, "authorize.newPolicyEvaluator")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
|
|
||||||
clientCA, err := opts.GetClientCA()
|
clientCA, err := cfg.Options.GetClientCA()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("authorize: invalid client CA: %w", err)
|
return nil, fmt.Errorf("authorize: invalid client CA: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
authenticateURL, err := opts.GetInternalAuthenticateURL()
|
authenticateURL, err := cfg.Options.GetInternalAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("authorize: invalid authenticate url: %w", err)
|
return nil, fmt.Errorf("authorize: invalid authenticate url: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
signingKey, err := opts.GetSigningKey()
|
signingKey, err := cfg.Options.GetSigningKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("authorize: invalid signing key: %w", err)
|
return nil, fmt.Errorf("authorize: invalid signing key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return evaluator.New(ctx, store,
|
return evaluator.New(ctx, store,
|
||||||
evaluator.WithPolicies(opts.GetAllPolicies()),
|
evaluator.WithPolicies(cfg.Options.GetAllPolicies()),
|
||||||
evaluator.WithClientCA(clientCA),
|
evaluator.WithClientCA(clientCA),
|
||||||
evaluator.WithSigningKey(signingKey),
|
evaluator.WithSigningKey(signingKey),
|
||||||
evaluator.WithAuthenticateURL(authenticateURL.String()),
|
evaluator.WithAuthenticateURL(authenticateURL.String()),
|
||||||
evaluator.WithGoogleCloudServerlessAuthenticationServiceAccount(opts.GetGoogleCloudServerlessAuthenticationServiceAccount()),
|
evaluator.WithGoogleCloudServerlessAuthenticationServiceAccount(cfg.Options.GetGoogleCloudServerlessAuthenticationServiceAccount()),
|
||||||
evaluator.WithJWTClaimsHeaders(opts.JWTClaimsHeaders),
|
evaluator.WithJWTClaimsHeaders(cfg.Options.JWTClaimsHeaders),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnConfigChange updates internal structures based on config.Options
|
// OnConfigChange updates internal structures based on config.Options
|
||||||
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
func (a *Authorize) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
a.currentOptions.Store(cfg.Options)
|
a.currentConfig.Store(cfg)
|
||||||
if state, err := newAuthorizeStateFromConfig(cfg, a.store); err != nil {
|
if state, err := newAuthorizeStateFromConfig(cfg, a.store); err != nil {
|
||||||
log.Error(ctx).Err(err).Msg("authorize: error updating state")
|
log.Error(ctx).Err(err).Msg("authorize: error updating state")
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -74,7 +74,7 @@ func TestNew(t *testing.T) {
|
||||||
tt := tt
|
tt := tt
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
_, err := New(&config.Config{Options: &tt.config})
|
_, err := New(config.New(&tt.config))
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -105,12 +105,12 @@ func TestAuthorize_OnConfigChange(t *testing.T) {
|
||||||
SharedKey: tc.SharedKey,
|
SharedKey: tc.SharedKey,
|
||||||
Policies: tc.Policies,
|
Policies: tc.Policies,
|
||||||
}
|
}
|
||||||
a, err := New(&config.Config{Options: o})
|
a, err := New(config.New(o))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, a)
|
require.NotNil(t, a)
|
||||||
|
|
||||||
oldPe := a.state.Load().evaluator
|
oldPe := a.state.Load().evaluator
|
||||||
cfg := &config.Config{Options: o}
|
cfg := config.New(o)
|
||||||
assertFunc := assert.True
|
assertFunc := assert.True
|
||||||
o.SigningKey = "bad-share-key"
|
o.SigningKey = "bad-share-key"
|
||||||
if tc.expectedChange {
|
if tc.expectedChange {
|
||||||
|
|
|
@ -125,7 +125,7 @@ func (a *Authorize) deniedResponse(
|
||||||
respBody := []byte(reason)
|
respBody := []byte(reason)
|
||||||
respHeader := []*envoy_config_core_v3.HeaderValueOption{}
|
respHeader := []*envoy_config_core_v3.HeaderValueOption{}
|
||||||
|
|
||||||
forwardAuthURL, _ := a.currentOptions.Load().GetForwardAuthURL()
|
forwardAuthURL, _ := a.currentConfig.Load().Options.GetForwardAuthURL()
|
||||||
if forwardAuthURL == nil {
|
if forwardAuthURL == nil {
|
||||||
// create a http response writer recorder
|
// create a http response writer recorder
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
@ -140,7 +140,7 @@ func (a *Authorize) deniedResponse(
|
||||||
Err: errors.New(reason),
|
Err: errors.New(reason),
|
||||||
DebugURL: debugEndpoint,
|
DebugURL: debugEndpoint,
|
||||||
RequestID: requestid.FromContext(ctx),
|
RequestID: requestid.FromContext(ctx),
|
||||||
BrandingOptions: a.currentOptions.Load().BrandingOptions,
|
BrandingOptions: a.currentConfig.Load().Options.BrandingOptions,
|
||||||
}
|
}
|
||||||
httpErr.ErrorResponse(ctx, w, r)
|
httpErr.ErrorResponse(ctx, w, r)
|
||||||
|
|
||||||
|
@ -184,7 +184,7 @@ func (a *Authorize) requireLoginResponse(
|
||||||
request *evaluator.Request,
|
request *evaluator.Request,
|
||||||
isForwardAuthVerify bool,
|
isForwardAuthVerify bool,
|
||||||
) (*envoy_service_auth_v3.CheckResponse, error) {
|
) (*envoy_service_auth_v3.CheckResponse, error) {
|
||||||
opts := a.currentOptions.Load()
|
opts := a.currentConfig.Load().Options
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
authenticateURL, err := opts.GetAuthenticateURL()
|
authenticateURL, err := opts.GetAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -225,7 +225,7 @@ func (a *Authorize) requireWebAuthnResponse(
|
||||||
result *evaluator.Result,
|
result *evaluator.Result,
|
||||||
isForwardAuthVerify bool,
|
isForwardAuthVerify bool,
|
||||||
) (*envoy_service_auth_v3.CheckResponse, error) {
|
) (*envoy_service_auth_v3.CheckResponse, error) {
|
||||||
opts := a.currentOptions.Load()
|
opts := a.currentConfig.Load().Options
|
||||||
state := a.state.Load()
|
state := a.state.Load()
|
||||||
authenticateURL, err := opts.GetAuthenticateURL()
|
authenticateURL, err := opts.GetAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -295,7 +295,7 @@ func toEnvoyHeaders(headers http.Header) []*envoy_config_core_v3.HeaderValueOpti
|
||||||
// userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's
|
// userInfoEndpointURL returns the user info endpoint url which can be used to debug the user's
|
||||||
// session that lives on the authenticate service.
|
// session that lives on the authenticate service.
|
||||||
func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) {
|
func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest) (*url.URL, error) {
|
||||||
opts := a.currentOptions.Load()
|
opts := a.currentConfig.Load().Options
|
||||||
authenticateURL, err := opts.GetAuthenticateURL()
|
authenticateURL, err := opts.GetAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -29,7 +29,7 @@ func TestAuthorize_handleResult(t *testing.T) {
|
||||||
opt.AuthenticateURLString = "https://authenticate.example.com"
|
opt.AuthenticateURLString = "https://authenticate.example.com"
|
||||||
opt.DataBrokerURLString = "https://databroker.example.com"
|
opt.DataBrokerURLString = "https://databroker.example.com"
|
||||||
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
|
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
|
||||||
a, err := New(&config.Config{Options: opt})
|
a, err := New(config.New(opt))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Run("user-unauthenticated", func(t *testing.T) {
|
t.Run("user-unauthenticated", func(t *testing.T) {
|
||||||
|
@ -67,12 +67,12 @@ func TestAuthorize_okResponse(t *testing.T) {
|
||||||
}},
|
}},
|
||||||
JWTClaimsHeaders: config.NewJWTClaimHeaders("email"),
|
JWTClaimsHeaders: config.NewJWTClaimHeaders("email"),
|
||||||
}
|
}
|
||||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
cfg := config.New(opt)
|
||||||
|
a := &Authorize{currentConfig: atomicutil.NewValue(cfg), state: atomicutil.NewValue(new(authorizeState))}
|
||||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||||
a.state.Load().encoder = encoder
|
a.state.Load().encoder = encoder
|
||||||
a.currentOptions.Store(opt)
|
|
||||||
a.store = store.New()
|
a.store = store.New()
|
||||||
pe, err := newPolicyEvaluator(opt, a.store)
|
pe, err := newPolicyEvaluator(cfg, a.store)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
a.state.Load().evaluator = pe
|
a.state.Load().evaluator = pe
|
||||||
|
|
||||||
|
@ -123,17 +123,17 @@ func TestAuthorize_okResponse(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAuthorize_deniedResponse(t *testing.T) {
|
func TestAuthorize_deniedResponse(t *testing.T) {
|
||||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
|
||||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||||
a.state.Load().encoder = encoder
|
a.state.Load().encoder = encoder
|
||||||
a.currentOptions.Store(&config.Options{
|
a.currentConfig.Store(config.New(&config.Options{
|
||||||
Policies: []config.Policy{{
|
Policies: []config.Policy{{
|
||||||
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
|
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
|
||||||
SubPolicies: []config.SubPolicy{{
|
SubPolicies: []config.SubPolicy{{
|
||||||
Rego: []string{"allow = true"},
|
Rego: []string{"allow = true"},
|
||||||
}},
|
}},
|
||||||
}},
|
}},
|
||||||
})
|
}))
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -190,7 +190,7 @@ func TestRequireLogin(t *testing.T) {
|
||||||
opt.AuthenticateURLString = "https://authenticate.example.com"
|
opt.AuthenticateURLString = "https://authenticate.example.com"
|
||||||
opt.DataBrokerURLString = "https://databroker.example.com"
|
opt.DataBrokerURLString = "https://databroker.example.com"
|
||||||
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
|
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
|
||||||
a, err := New(&config.Config{Options: opt})
|
a, err := New(config.New(opt))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
t.Run("accept empty", func(t *testing.T) {
|
t.Run("accept empty", func(t *testing.T) {
|
||||||
|
|
|
@ -55,7 +55,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
rawJWT, _ := loadRawSession(hreq, a.currentOptions.Load(), state.encoder)
|
rawJWT, _ := loadRawSession(hreq, a.currentConfig.Load(), state.encoder)
|
||||||
sessionState, _ := loadSession(state.encoder, rawJWT)
|
sessionState, _ := loadSession(state.encoder, rawJWT)
|
||||||
|
|
||||||
var s sessionOrServiceAccount
|
var s sessionOrServiceAccount
|
||||||
|
@ -100,9 +100,9 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
|
||||||
|
|
||||||
// isForwardAuth returns if the current request is a forward auth route.
|
// isForwardAuth returns if the current request is a forward auth route.
|
||||||
func (a *Authorize) isForwardAuth(req *envoy_service_auth_v3.CheckRequest) bool {
|
func (a *Authorize) isForwardAuth(req *envoy_service_auth_v3.CheckRequest) bool {
|
||||||
opts := a.currentOptions.Load()
|
cfg := a.currentConfig.Load()
|
||||||
|
|
||||||
forwardAuthURL, err := opts.GetForwardAuthURL()
|
forwardAuthURL, err := cfg.Options.GetForwardAuthURL()
|
||||||
if err != nil || forwardAuthURL == nil {
|
if err != nil || forwardAuthURL == nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
@ -136,9 +136,9 @@ func (a *Authorize) getEvaluatorRequestFromCheckRequest(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Authorize) getMatchingPolicy(requestURL url.URL) *config.Policy {
|
func (a *Authorize) getMatchingPolicy(requestURL url.URL) *config.Policy {
|
||||||
options := a.currentOptions.Load()
|
cfg := a.currentConfig.Load()
|
||||||
|
|
||||||
for _, p := range options.GetAllPolicies() {
|
for _, p := range cfg.Options.GetAllPolicies() {
|
||||||
if p.Matches(requestURL) {
|
if p.Matches(requestURL) {
|
||||||
return &p
|
return &p
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,17 +47,17 @@ yE+vPxsiUkvQHdO2fojCkY8jg70jxM+gu59tPDNbw3Uh/2Ij310FgTHsnGQMyA==
|
||||||
-----END CERTIFICATE-----`
|
-----END CERTIFICATE-----`
|
||||||
|
|
||||||
func Test_getEvaluatorRequest(t *testing.T) {
|
func Test_getEvaluatorRequest(t *testing.T) {
|
||||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
|
||||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||||
a.state.Load().encoder = encoder
|
a.state.Load().encoder = encoder
|
||||||
a.currentOptions.Store(&config.Options{
|
a.currentConfig.Store(config.New(&config.Options{
|
||||||
Policies: []config.Policy{{
|
Policies: []config.Policy{{
|
||||||
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
|
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
|
||||||
SubPolicies: []config.SubPolicy{{
|
SubPolicies: []config.SubPolicy{{
|
||||||
Rego: []string{"allow = true"},
|
Rego: []string{"allow = true"},
|
||||||
}},
|
}},
|
||||||
}},
|
}},
|
||||||
})
|
}))
|
||||||
|
|
||||||
actual, err := a.getEvaluatorRequestFromCheckRequest(
|
actual, err := a.getEvaluatorRequestFromCheckRequest(
|
||||||
&envoy_service_auth_v3.CheckRequest{
|
&envoy_service_auth_v3.CheckRequest{
|
||||||
|
@ -87,7 +87,7 @@ func Test_getEvaluatorRequest(t *testing.T) {
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expect := &evaluator.Request{
|
expect := &evaluator.Request{
|
||||||
Policy: &a.currentOptions.Load().Policies[0],
|
Policy: &a.currentConfig.Load().Options.Policies[0],
|
||||||
Session: evaluator.RequestSession{
|
Session: evaluator.RequestSession{
|
||||||
ID: "SESSION_ID",
|
ID: "SESSION_ID",
|
||||||
},
|
},
|
||||||
|
@ -248,8 +248,10 @@ func Test_handleForwardAuth(t *testing.T) {
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
tc := tc
|
tc := tc
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
|
||||||
a.currentOptions.Store(&config.Options{ForwardAuthURLString: tc.forwardAuthURL})
|
a.currentConfig.Store(config.New(&config.Options{
|
||||||
|
ForwardAuthURLString: tc.forwardAuthURL,
|
||||||
|
}))
|
||||||
|
|
||||||
got := a.isForwardAuth(tc.checkReq)
|
got := a.isForwardAuth(tc.checkReq)
|
||||||
|
|
||||||
|
@ -261,17 +263,17 @@ func Test_handleForwardAuth(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
||||||
a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))}
|
a := &Authorize{currentConfig: atomicutil.NewValue(config.New(nil)), state: atomicutil.NewValue(new(authorizeState))}
|
||||||
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
encoder, _ := jws.NewHS256Signer([]byte{0, 0, 0, 0})
|
||||||
a.state.Load().encoder = encoder
|
a.state.Load().encoder = encoder
|
||||||
a.currentOptions.Store(&config.Options{
|
a.currentConfig.Store(config.New(&config.Options{
|
||||||
Policies: []config.Policy{{
|
Policies: []config.Policy{{
|
||||||
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
|
Source: &config.StringURL{URL: &url.URL{Host: "example.com"}},
|
||||||
SubPolicies: []config.SubPolicy{{
|
SubPolicies: []config.SubPolicy{{
|
||||||
Rego: []string{"allow = true"},
|
Rego: []string{"allow = true"},
|
||||||
}},
|
}},
|
||||||
}},
|
}},
|
||||||
})
|
}))
|
||||||
|
|
||||||
actual, err := a.getEvaluatorRequestFromCheckRequest(&envoy_service_auth_v3.CheckRequest{
|
actual, err := a.getEvaluatorRequestFromCheckRequest(&envoy_service_auth_v3.CheckRequest{
|
||||||
Attributes: &envoy_service_auth_v3.AttributeContext{
|
Attributes: &envoy_service_auth_v3.AttributeContext{
|
||||||
|
@ -296,7 +298,7 @@ func Test_getEvaluatorRequestWithPortInHostHeader(t *testing.T) {
|
||||||
}, nil)
|
}, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expect := &evaluator.Request{
|
expect := &evaluator.Request{
|
||||||
Policy: &a.currentOptions.Load().Policies[0],
|
Policy: &a.currentConfig.Load().Options.Policies[0],
|
||||||
Session: evaluator.RequestSession{},
|
Session: evaluator.RequestSession{},
|
||||||
HTTP: evaluator.NewRequestHTTP(
|
HTTP: evaluator.NewRequestHTTP(
|
||||||
"GET",
|
"GET",
|
||||||
|
@ -332,11 +334,13 @@ func TestAuthorize_Check(t *testing.T) {
|
||||||
opt.AuthenticateURLString = "https://authenticate.example.com"
|
opt.AuthenticateURLString = "https://authenticate.example.com"
|
||||||
opt.DataBrokerURLString = "https://databroker.example.com"
|
opt.DataBrokerURLString = "https://databroker.example.com"
|
||||||
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
|
opt.SharedKey = "E8wWIMnihUx+AUfRegAQDNs8eRb3UrB5G3zlJW9XJDM="
|
||||||
a, err := New(&config.Config{Options: opt})
|
a, err := New(config.New(opt))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
a.currentOptions.Store(&config.Options{ForwardAuthURLString: "https://forward-auth.example.com"})
|
a.currentConfig.Store(config.New(&config.Options{
|
||||||
|
ForwardAuthURLString: "https://forward-auth.example.com",
|
||||||
|
}))
|
||||||
|
|
||||||
cmpOpts := []cmp.Option{
|
cmpOpts := []cmp.Option{
|
||||||
cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}),
|
cmpopts.IgnoreUnexported(envoy_service_auth_v3.CheckResponse{}),
|
||||||
|
|
|
@ -13,9 +13,13 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func loadRawSession(req *http.Request, options *config.Options, encoder encoding.MarshalUnmarshaler) ([]byte, error) {
|
func loadRawSession(
|
||||||
|
req *http.Request,
|
||||||
|
cfg *config.Config,
|
||||||
|
encoder encoding.MarshalUnmarshaler,
|
||||||
|
) ([]byte, error) {
|
||||||
var loaders []sessions.SessionLoader
|
var loaders []sessions.SessionLoader
|
||||||
cookieStore, err := getCookieStore(options, encoder)
|
cookieStore, err := getCookieStore(cfg, encoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -37,7 +41,10 @@ func loadRawSession(req *http.Request, options *config.Options, encoder encoding
|
||||||
return nil, sessions.ErrNoSessionFound
|
return nil, sessions.ErrNoSessionFound
|
||||||
}
|
}
|
||||||
|
|
||||||
func loadSession(encoder encoding.MarshalUnmarshaler, rawJWT []byte) (*sessions.State, error) {
|
func loadSession(
|
||||||
|
encoder encoding.MarshalUnmarshaler,
|
||||||
|
rawJWT []byte,
|
||||||
|
) (*sessions.State, error) {
|
||||||
var s sessions.State
|
var s sessions.State
|
||||||
err := encoder.Unmarshal(rawJWT, &s)
|
err := encoder.Unmarshal(rawJWT, &s)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -46,14 +53,17 @@ func loadSession(encoder encoding.MarshalUnmarshaler, rawJWT []byte) (*sessions.
|
||||||
return &s, nil
|
return &s, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCookieStore(options *config.Options, encoder encoding.MarshalUnmarshaler) (sessions.SessionStore, error) {
|
func getCookieStore(
|
||||||
|
cfg *config.Config,
|
||||||
|
encoder encoding.MarshalUnmarshaler,
|
||||||
|
) (sessions.SessionStore, error) {
|
||||||
cookieStore, err := cookie.NewStore(func() cookie.Options {
|
cookieStore, err := cookie.NewStore(func() cookie.Options {
|
||||||
return cookie.Options{
|
return cookie.Options{
|
||||||
Name: options.CookieName,
|
Name: cfg.Options.CookieName,
|
||||||
Domain: options.CookieDomain,
|
Domain: cfg.Options.CookieDomain,
|
||||||
Secure: options.CookieSecure,
|
Secure: cfg.Options.CookieSecure,
|
||||||
HTTPOnly: options.CookieHTTPOnly,
|
HTTPOnly: cfg.Options.CookieHTTPOnly,
|
||||||
Expire: options.CookieExpire,
|
Expire: cfg.Options.CookieExpire,
|
||||||
}
|
}
|
||||||
}, encoder)
|
}, encoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
|
@ -32,7 +32,7 @@ func TestLoadSession(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
raw, err := loadRawSession(req, opts, encoder)
|
raw, err := loadRawSession(req, config.New(opts), encoder)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -36,7 +36,7 @@ func newAuthorizeStateFromConfig(cfg *config.Config, store *store.Store) (*autho
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
state.evaluator, err = newPolicyEvaluator(cfg.Options, store)
|
state.evaluator, err = newPolicyEvaluator(cfg, store)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
|
return nil, fmt.Errorf("authorize: failed to update policy with options: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -31,6 +31,16 @@ type Config struct {
|
||||||
MetricsScrapeEndpoints []MetricsScrapeEndpoint
|
MetricsScrapeEndpoints []MetricsScrapeEndpoint
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// New creates a new Config.
|
||||||
|
func New(options *Options) *Config {
|
||||||
|
if options == nil {
|
||||||
|
options = NewDefaultOptions()
|
||||||
|
}
|
||||||
|
return &Config{
|
||||||
|
Options: options,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Clone creates a clone of the config.
|
// Clone creates a clone of the config.
|
||||||
func (cfg *Config) Clone() *Config {
|
func (cfg *Config) Clone() *Config {
|
||||||
newOptions := new(Options)
|
newOptions := new(Options)
|
||||||
|
|
|
@ -13,11 +13,9 @@ import (
|
||||||
func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
|
func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
adminCfg, err := b.BuildBootstrapAdmin(&config.Config{
|
adminCfg, err := b.BuildBootstrapAdmin(config.New(&config.Options{
|
||||||
Options: &config.Options{
|
EnvoyAdminAddress: "localhost:9901",
|
||||||
EnvoyAdminAddress: "localhost:9901",
|
}))
|
||||||
},
|
|
||||||
})
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
{
|
{
|
||||||
|
@ -31,11 +29,9 @@ func TestBuilder_BuildBootstrapAdmin(t *testing.T) {
|
||||||
`, adminCfg)
|
`, adminCfg)
|
||||||
})
|
})
|
||||||
t.Run("bad address", func(t *testing.T) {
|
t.Run("bad address", func(t *testing.T) {
|
||||||
_, err := b.BuildBootstrapAdmin(&config.Config{
|
_, err := b.BuildBootstrapAdmin(config.New(&config.Options{
|
||||||
Options: &config.Options{
|
EnvoyAdminAddress: "xyz1234:zyx4321",
|
||||||
EnvoyAdminAddress: "xyz1234:zyx4321",
|
}))
|
||||||
},
|
|
||||||
})
|
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -111,11 +107,9 @@ func TestBuilder_BuildBootstrapStaticResources(t *testing.T) {
|
||||||
func TestBuilder_BuildBootstrapStatsConfig(t *testing.T) {
|
func TestBuilder_BuildBootstrapStatsConfig(t *testing.T) {
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||||
t.Run("valid", func(t *testing.T) {
|
t.Run("valid", func(t *testing.T) {
|
||||||
statsCfg, err := b.BuildBootstrapStatsConfig(&config.Config{
|
statsCfg, err := b.BuildBootstrapStatsConfig(config.New(&config.Options{
|
||||||
Options: &config.Options{
|
Services: "all",
|
||||||
Services: "all",
|
}))
|
||||||
},
|
|
||||||
})
|
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
{
|
{
|
||||||
|
|
|
@ -46,22 +46,22 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
controlGRPC, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, upstreamProtocolHTTP2)
|
controlGRPC, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-grpc", []*url.URL{grpcURL}, upstreamProtocolHTTP2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
|
controlHTTP, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
controlMetrics, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto)
|
controlMetrics, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
authorizeCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
|
authorizeCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -70,7 +70,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
||||||
authorizeCluster.OutlierDetection = grpcAuthorizeOutlierDetection()
|
authorizeCluster.OutlierDetection = grpcAuthorizeOutlierDetection()
|
||||||
}
|
}
|
||||||
|
|
||||||
databrokerCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
|
databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -87,7 +87,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
||||||
databrokerCluster,
|
databrokerCluster,
|
||||||
}
|
}
|
||||||
|
|
||||||
tracingCluster, err := buildTracingCluster(cfg.Options)
|
tracingCluster, err := buildTracingCluster(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else if tracingCluster != nil {
|
} else if tracingCluster != nil {
|
||||||
|
@ -101,7 +101,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
||||||
policy.EnvoyOpts = newDefaultEnvoyClusterConfig()
|
policy.EnvoyOpts = newDefaultEnvoyClusterConfig()
|
||||||
}
|
}
|
||||||
if len(policy.To) > 0 {
|
if len(policy.To) > 0 {
|
||||||
cluster, err := b.buildPolicyCluster(ctx, cfg.Options, &policy)
|
cluster, err := b.buildPolicyCluster(ctx, cfg, &policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("policy #%d: %w", i, err)
|
return nil, fmt.Errorf("policy #%d: %w", i, err)
|
||||||
}
|
}
|
||||||
|
@ -119,16 +119,16 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
||||||
|
|
||||||
func (b *Builder) buildInternalCluster(
|
func (b *Builder) buildInternalCluster(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
options *config.Options,
|
cfg *config.Config,
|
||||||
name string,
|
name string,
|
||||||
dsts []*url.URL,
|
dsts []*url.URL,
|
||||||
upstreamProtocol upstreamProtocolConfig,
|
upstreamProtocol upstreamProtocolConfig,
|
||||||
) (*envoy_config_cluster_v3.Cluster, error) {
|
) (*envoy_config_cluster_v3.Cluster, error) {
|
||||||
cluster := newDefaultEnvoyClusterConfig()
|
cluster := newDefaultEnvoyClusterConfig()
|
||||||
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
|
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily)
|
||||||
var endpoints []Endpoint
|
var endpoints []Endpoint
|
||||||
for _, dst := range dsts {
|
for _, dst := range dsts {
|
||||||
ts, err := b.buildInternalTransportSocket(ctx, options, dst)
|
ts, err := b.buildInternalTransportSocket(ctx, cfg, dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -141,18 +141,22 @@ func (b *Builder) buildInternalCluster(
|
||||||
return cluster, nil
|
return cluster, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
|
func (b *Builder) buildPolicyCluster(
|
||||||
|
ctx context.Context,
|
||||||
|
cfg *config.Config,
|
||||||
|
policy *config.Policy,
|
||||||
|
) (*envoy_config_cluster_v3.Cluster, error) {
|
||||||
cluster := new(envoy_config_cluster_v3.Cluster)
|
cluster := new(envoy_config_cluster_v3.Cluster)
|
||||||
proto.Merge(cluster, policy.EnvoyOpts)
|
proto.Merge(cluster, policy.EnvoyOpts)
|
||||||
|
|
||||||
if options.EnvoyBindConfigFreebind.IsSet() || options.EnvoyBindConfigSourceAddress != "" {
|
if cfg.Options.EnvoyBindConfigFreebind.IsSet() || cfg.Options.EnvoyBindConfigSourceAddress != "" {
|
||||||
cluster.UpstreamBindConfig = new(envoy_config_core_v3.BindConfig)
|
cluster.UpstreamBindConfig = new(envoy_config_core_v3.BindConfig)
|
||||||
if options.EnvoyBindConfigFreebind.IsSet() {
|
if cfg.Options.EnvoyBindConfigFreebind.IsSet() {
|
||||||
cluster.UpstreamBindConfig.Freebind = wrapperspb.Bool(options.EnvoyBindConfigFreebind.Bool)
|
cluster.UpstreamBindConfig.Freebind = wrapperspb.Bool(cfg.Options.EnvoyBindConfigFreebind.Bool)
|
||||||
}
|
}
|
||||||
if options.EnvoyBindConfigSourceAddress != "" {
|
if cfg.Options.EnvoyBindConfigSourceAddress != "" {
|
||||||
cluster.UpstreamBindConfig.SourceAddress = &envoy_config_core_v3.SocketAddress{
|
cluster.UpstreamBindConfig.SourceAddress = &envoy_config_core_v3.SocketAddress{
|
||||||
Address: options.EnvoyBindConfigSourceAddress,
|
Address: cfg.Options.EnvoyBindConfigSourceAddress,
|
||||||
PortSpecifier: &envoy_config_core_v3.SocketAddress_PortValue{
|
PortSpecifier: &envoy_config_core_v3.SocketAddress_PortValue{
|
||||||
PortValue: 0,
|
PortValue: 0,
|
||||||
},
|
},
|
||||||
|
@ -171,13 +175,13 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
|
||||||
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
|
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
|
||||||
|
|
||||||
name := getClusterID(policy)
|
name := getClusterID(policy)
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, options, policy)
|
endpoints, err := b.buildPolicyEndpoints(ctx, cfg, policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if cluster.DnsLookupFamily == envoy_config_cluster_v3.Cluster_AUTO {
|
if cluster.DnsLookupFamily == envoy_config_cluster_v3.Cluster_AUTO {
|
||||||
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
|
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily)
|
||||||
}
|
}
|
||||||
|
|
||||||
if policy.EnableGoogleCloudServerlessAuthentication {
|
if policy.EnableGoogleCloudServerlessAuthentication {
|
||||||
|
@ -193,12 +197,12 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
|
||||||
|
|
||||||
func (b *Builder) buildPolicyEndpoints(
|
func (b *Builder) buildPolicyEndpoints(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
options *config.Options,
|
cfg *config.Config,
|
||||||
policy *config.Policy,
|
policy *config.Policy,
|
||||||
) ([]Endpoint, error) {
|
) ([]Endpoint, error) {
|
||||||
var endpoints []Endpoint
|
var endpoints []Endpoint
|
||||||
for _, dst := range policy.To {
|
for _, dst := range policy.To {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, options, policy, dst.URL)
|
ts, err := b.buildPolicyTransportSocket(ctx, cfg, policy, dst.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -209,7 +213,7 @@ func (b *Builder) buildPolicyEndpoints(
|
||||||
|
|
||||||
func (b *Builder) buildInternalTransportSocket(
|
func (b *Builder) buildInternalTransportSocket(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
options *config.Options,
|
cfg *config.Config,
|
||||||
endpoint *url.URL,
|
endpoint *url.URL,
|
||||||
) (*envoy_config_core_v3.TransportSocket, error) {
|
) (*envoy_config_core_v3.TransportSocket, error) {
|
||||||
if endpoint.Scheme != "https" {
|
if endpoint.Scheme != "https" {
|
||||||
|
@ -218,10 +222,10 @@ func (b *Builder) buildInternalTransportSocket(
|
||||||
|
|
||||||
validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{
|
validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{
|
||||||
MatchTypedSubjectAltNames: []*envoy_extensions_transport_sockets_tls_v3.SubjectAltNameMatcher{
|
MatchTypedSubjectAltNames: []*envoy_extensions_transport_sockets_tls_v3.SubjectAltNameMatcher{
|
||||||
b.buildSubjectAltNameMatcher(endpoint, options.OverrideCertificateName),
|
b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile)
|
bs, err := getCombinedCertificateAuthority(cfg.Options.CA, cfg.Options.CAFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||||
} else {
|
} else {
|
||||||
|
@ -234,7 +238,7 @@ func (b *Builder) buildInternalTransportSocket(
|
||||||
ValidationContext: validationContext,
|
ValidationContext: validationContext,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Sni: b.buildSubjectNameIndication(endpoint, options.OverrideCertificateName),
|
Sni: b.buildSubjectNameIndication(endpoint, cfg.Options.OverrideCertificateName),
|
||||||
}
|
}
|
||||||
tlsConfig := marshalAny(tlsContext)
|
tlsConfig := marshalAny(tlsContext)
|
||||||
return &envoy_config_core_v3.TransportSocket{
|
return &envoy_config_core_v3.TransportSocket{
|
||||||
|
@ -247,7 +251,7 @@ func (b *Builder) buildInternalTransportSocket(
|
||||||
|
|
||||||
func (b *Builder) buildPolicyTransportSocket(
|
func (b *Builder) buildPolicyTransportSocket(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
options *config.Options,
|
cfg *config.Config,
|
||||||
policy *config.Policy,
|
policy *config.Policy,
|
||||||
dst url.URL,
|
dst url.URL,
|
||||||
) (*envoy_config_core_v3.TransportSocket, error) {
|
) (*envoy_config_core_v3.TransportSocket, error) {
|
||||||
|
@ -257,7 +261,7 @@ func (b *Builder) buildPolicyTransportSocket(
|
||||||
|
|
||||||
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
|
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
|
||||||
|
|
||||||
vc, err := b.buildPolicyValidationContext(ctx, options, policy, dst)
|
vc, err := b.buildPolicyValidationContext(ctx, cfg, policy, dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -318,7 +322,7 @@ func (b *Builder) buildPolicyTransportSocket(
|
||||||
|
|
||||||
func (b *Builder) buildPolicyValidationContext(
|
func (b *Builder) buildPolicyValidationContext(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
options *config.Options,
|
cfg *config.Config,
|
||||||
policy *config.Policy,
|
policy *config.Policy,
|
||||||
dst url.URL,
|
dst url.URL,
|
||||||
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
|
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
|
||||||
|
@ -343,7 +347,7 @@ func (b *Builder) buildPolicyValidationContext(
|
||||||
}
|
}
|
||||||
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
|
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
|
||||||
} else {
|
} else {
|
||||||
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile)
|
bs, err := getCombinedCertificateAuthority(cfg.Options.CA, cfg.Options.CAFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -37,14 +37,14 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename()
|
combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename()
|
||||||
|
|
||||||
t.Run("insecure", func(t *testing.T) {
|
t.Run("insecure", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "http://example.com"),
|
To: mustParseWeightedURLs(t, "http://example.com"),
|
||||||
}, *mustParseURL(t, "http://example.com"))
|
}, *mustParseURL(t, "http://example.com"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Nil(t, ts)
|
assert.Nil(t, ts)
|
||||||
})
|
})
|
||||||
t.Run("host as sni", func(t *testing.T) {
|
t.Run("host as sni", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -97,7 +97,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
`, ts)
|
`, ts)
|
||||||
})
|
})
|
||||||
t.Run("tls_server_name as sni", func(t *testing.T) {
|
t.Run("tls_server_name as sni", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
TLSServerName: "use-this-name.example.com",
|
TLSServerName: "use-this-name.example.com",
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
|
@ -151,7 +151,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
`, ts)
|
`, ts)
|
||||||
})
|
})
|
||||||
t.Run("tls_upstream_server_name as sni", func(t *testing.T) {
|
t.Run("tls_upstream_server_name as sni", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
TLSUpstreamServerName: "use-this-name.example.com",
|
TLSUpstreamServerName: "use-this-name.example.com",
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
|
@ -205,7 +205,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
`, ts)
|
`, ts)
|
||||||
})
|
})
|
||||||
t.Run("tls_skip_verify", func(t *testing.T) {
|
t.Run("tls_skip_verify", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
TLSSkipVerify: true,
|
TLSSkipVerify: true,
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
|
@ -260,7 +260,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
`, ts)
|
`, ts)
|
||||||
})
|
})
|
||||||
t.Run("custom ca", func(t *testing.T) {
|
t.Run("custom ca", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
|
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
|
@ -314,7 +314,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
`, ts)
|
`, ts)
|
||||||
})
|
})
|
||||||
t.Run("options custom ca", func(t *testing.T) {
|
t.Run("options custom ca", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o2, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o2), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -368,7 +368,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("client certificate", func(t *testing.T) {
|
t.Run("client certificate", func(t *testing.T) {
|
||||||
clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey)
|
clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey)
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, config.New(o1), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
ClientCertificate: clientCert,
|
ClientCertificate: clientCert,
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
|
@ -438,7 +438,7 @@ func Test_buildCluster(t *testing.T) {
|
||||||
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
||||||
o1 := config.NewDefaultOptions()
|
o1 := config.NewDefaultOptions()
|
||||||
t.Run("insecure", func(t *testing.T) {
|
t.Run("insecure", func(t *testing.T) {
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"),
|
To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"),
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -495,7 +495,7 @@ func Test_buildCluster(t *testing.T) {
|
||||||
`, cluster)
|
`, cluster)
|
||||||
})
|
})
|
||||||
t.Run("secure", func(t *testing.T) {
|
t.Run("secure", func(t *testing.T) {
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t,
|
To: mustParseWeightedURLs(t,
|
||||||
"https://example.com",
|
"https://example.com",
|
||||||
"https://example.com",
|
"https://example.com",
|
||||||
|
@ -663,7 +663,7 @@ func Test_buildCluster(t *testing.T) {
|
||||||
`, cluster)
|
`, cluster)
|
||||||
})
|
})
|
||||||
t.Run("ip addresses", func(t *testing.T) {
|
t.Run("ip addresses", func(t *testing.T) {
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"),
|
To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"),
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -718,7 +718,7 @@ func Test_buildCluster(t *testing.T) {
|
||||||
`, cluster)
|
`, cluster)
|
||||||
})
|
})
|
||||||
t.Run("weights", func(t *testing.T) {
|
t.Run("weights", func(t *testing.T) {
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"),
|
To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"),
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -775,7 +775,7 @@ func Test_buildCluster(t *testing.T) {
|
||||||
`, cluster)
|
`, cluster)
|
||||||
})
|
})
|
||||||
t.Run("localhost", func(t *testing.T) {
|
t.Run("localhost", func(t *testing.T) {
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "http://localhost"),
|
To: mustParseWeightedURLs(t, "http://localhost"),
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -821,7 +821,7 @@ func Test_buildCluster(t *testing.T) {
|
||||||
`, cluster)
|
`, cluster)
|
||||||
})
|
})
|
||||||
t.Run("outlier", func(t *testing.T) {
|
t.Run("outlier", func(t *testing.T) {
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
endpoints, err := b.buildPolicyEndpoints(ctx, config.New(o1), &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "http://example.com"),
|
To: mustParseWeightedURLs(t, "http://example.com"),
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -904,7 +904,7 @@ func Test_bindConfig(t *testing.T) {
|
||||||
|
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||||
t.Run("no bind config", func(t *testing.T) {
|
t.Run("no bind config", func(t *testing.T) {
|
||||||
cluster, err := b.buildPolicyCluster(ctx, &config.Options{}, &config.Policy{
|
cluster, err := b.buildPolicyCluster(ctx, config.New(nil), &config.Policy{
|
||||||
From: "https://from.example.com",
|
From: "https://from.example.com",
|
||||||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||||
})
|
})
|
||||||
|
@ -912,9 +912,9 @@ func Test_bindConfig(t *testing.T) {
|
||||||
assert.Nil(t, cluster.UpstreamBindConfig)
|
assert.Nil(t, cluster.UpstreamBindConfig)
|
||||||
})
|
})
|
||||||
t.Run("freebind", func(t *testing.T) {
|
t.Run("freebind", func(t *testing.T) {
|
||||||
cluster, err := b.buildPolicyCluster(ctx, &config.Options{
|
cluster, err := b.buildPolicyCluster(ctx, config.New(&config.Options{
|
||||||
EnvoyBindConfigFreebind: null.BoolFrom(true),
|
EnvoyBindConfigFreebind: null.BoolFrom(true),
|
||||||
}, &config.Policy{
|
}), &config.Policy{
|
||||||
From: "https://from.example.com",
|
From: "https://from.example.com",
|
||||||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||||
})
|
})
|
||||||
|
@ -930,9 +930,9 @@ func Test_bindConfig(t *testing.T) {
|
||||||
`, cluster.UpstreamBindConfig)
|
`, cluster.UpstreamBindConfig)
|
||||||
})
|
})
|
||||||
t.Run("source address", func(t *testing.T) {
|
t.Run("source address", func(t *testing.T) {
|
||||||
cluster, err := b.buildPolicyCluster(ctx, &config.Options{
|
cluster, err := b.buildPolicyCluster(ctx, config.New(&config.Options{
|
||||||
EnvoyBindConfigSourceAddress: "192.168.0.1",
|
EnvoyBindConfigSourceAddress: "192.168.0.1",
|
||||||
}, &config.Policy{
|
}), &config.Policy{
|
||||||
From: "https://from.example.com",
|
From: "https://from.example.com",
|
||||||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||||
})
|
})
|
||||||
|
|
|
@ -76,10 +76,10 @@ func newDefaultEnvoyClusterConfig() *envoy_config_cluster_v3.Cluster {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildAccessLogs(options *config.Options) []*envoy_config_accesslog_v3.AccessLog {
|
func buildAccessLogs(cfg *config.Config) []*envoy_config_accesslog_v3.AccessLog {
|
||||||
lvl := options.ProxyLogLevel
|
lvl := cfg.Options.ProxyLogLevel
|
||||||
if lvl == "" {
|
if lvl == "" {
|
||||||
lvl = options.LogLevel
|
lvl = cfg.Options.LogLevel
|
||||||
}
|
}
|
||||||
if lvl == "" {
|
if lvl == "" {
|
||||||
lvl = "debug"
|
lvl = "debug"
|
||||||
|
|
|
@ -10,7 +10,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (b *Builder) buildVirtualHost(
|
func (b *Builder) buildVirtualHost(
|
||||||
options *config.Options,
|
cfg *config.Config,
|
||||||
name string,
|
name string,
|
||||||
domain string,
|
domain string,
|
||||||
) (*envoy_config_route_v3.VirtualHost, error) {
|
) (*envoy_config_route_v3.VirtualHost, error) {
|
||||||
|
@ -20,15 +20,15 @@ func (b *Builder) buildVirtualHost(
|
||||||
}
|
}
|
||||||
|
|
||||||
// these routes match /.pomerium/... and similar paths
|
// these routes match /.pomerium/... and similar paths
|
||||||
rs, err := b.buildPomeriumHTTPRoutes(options, domain)
|
rs, err := b.buildPomeriumHTTPRoutes(cfg, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
vh.Routes = append(vh.Routes, rs...)
|
vh.Routes = append(vh.Routes, rs...)
|
||||||
|
|
||||||
// if we're the proxy or authenticate service, add our global headers
|
// if we're the proxy or authenticate service, add our global headers
|
||||||
if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) {
|
if config.IsProxy(cfg.Options.Services) || config.IsAuthenticate(cfg.Options.Services) {
|
||||||
vh.ResponseHeadersToAdd = toEnvoyHeaders(options.GetSetResponseHeaders())
|
vh.ResponseHeadersToAdd = toEnvoyHeaders(cfg.Options.GetSetResponseHeaders())
|
||||||
}
|
}
|
||||||
|
|
||||||
return vh, nil
|
return vh, nil
|
||||||
|
@ -37,13 +37,13 @@ func (b *Builder) buildVirtualHost(
|
||||||
// buildLocalReplyConfig builds the local reply config: the config used to modify "local" replies, that is replies
|
// buildLocalReplyConfig builds the local reply config: the config used to modify "local" replies, that is replies
|
||||||
// coming directly from envoy
|
// coming directly from envoy
|
||||||
func (b *Builder) buildLocalReplyConfig(
|
func (b *Builder) buildLocalReplyConfig(
|
||||||
options *config.Options,
|
cfg *config.Config,
|
||||||
) *envoy_http_connection_manager.LocalReplyConfig {
|
) *envoy_http_connection_manager.LocalReplyConfig {
|
||||||
// add global headers for HSTS headers (#2110)
|
// add global headers for HSTS headers (#2110)
|
||||||
var headers []*envoy_config_core_v3.HeaderValueOption
|
var headers []*envoy_config_core_v3.HeaderValueOption
|
||||||
// if we're the proxy or authenticate service, add our global headers
|
// if we're the proxy or authenticate service, add our global headers
|
||||||
if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) {
|
if config.IsProxy(cfg.Options.Services) || config.IsAuthenticate(cfg.Options.Services) {
|
||||||
headers = toEnvoyHeaders(options.GetSetResponseHeaders())
|
headers = toEnvoyHeaders(cfg.Options.GetSetResponseHeaders())
|
||||||
}
|
}
|
||||||
|
|
||||||
return &envoy_http_connection_manager.LocalReplyConfig{
|
return &envoy_http_connection_manager.LocalReplyConfig{
|
||||||
|
|
|
@ -53,7 +53,10 @@ func init() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// BuildListeners builds envoy listeners from the given config.
|
// BuildListeners builds envoy listeners from the given config.
|
||||||
func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*envoy_config_listener_v3.Listener, error) {
|
func (b *Builder) BuildListeners(
|
||||||
|
ctx context.Context,
|
||||||
|
cfg *config.Config,
|
||||||
|
) ([]*envoy_config_listener_v3.Listener, error) {
|
||||||
var listeners []*envoy_config_listener_v3.Listener
|
var listeners []*envoy_config_listener_v3.Listener
|
||||||
|
|
||||||
if config.IsAuthenticate(cfg.Options.Services) || config.IsProxy(cfg.Options.Services) {
|
if config.IsAuthenticate(cfg.Options.Services) || config.IsProxy(cfg.Options.Services) {
|
||||||
|
@ -89,19 +92,22 @@ func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*en
|
||||||
return listeners, nil
|
return listeners, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
|
func (b *Builder) buildMainListener(
|
||||||
|
ctx context.Context,
|
||||||
|
cfg *config.Config,
|
||||||
|
) (*envoy_config_listener_v3.Listener, error) {
|
||||||
listenerFilters := []*envoy_config_listener_v3.ListenerFilter{}
|
listenerFilters := []*envoy_config_listener_v3.ListenerFilter{}
|
||||||
if cfg.Options.UseProxyProtocol {
|
if cfg.Options.UseProxyProtocol {
|
||||||
listenerFilters = append(listenerFilters, ProxyProtocolFilter())
|
listenerFilters = append(listenerFilters, ProxyProtocolFilter())
|
||||||
}
|
}
|
||||||
|
|
||||||
if cfg.Options.InsecureServer {
|
if cfg.Options.InsecureServer {
|
||||||
allDomains, err := getAllRouteableDomains(cfg.Options, cfg.Options.Addr)
|
allDomains, err := getAllRouteableDomains(cfg, cfg.Options.Addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, allDomains, "")
|
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg, allDomains, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -118,9 +124,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
|
||||||
}
|
}
|
||||||
listenerFilters = append(listenerFilters, TLSInspectorFilter())
|
listenerFilters = append(listenerFilters, TLSInspectorFilter())
|
||||||
|
|
||||||
chains, err := b.buildFilterChains(cfg.Options, cfg.Options.Addr,
|
chains, err := b.buildFilterChains(cfg, cfg.Options.Addr,
|
||||||
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
|
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
|
||||||
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, tlsDomain)
|
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg, httpDomains, tlsDomain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -155,7 +161,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
|
||||||
return li, nil
|
return li, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
|
func (b *Builder) buildMetricsListener(
|
||||||
|
cfg *config.Config,
|
||||||
|
) (*envoy_config_listener_v3.Listener, error) {
|
||||||
filter, err := b.buildMetricsHTTPConnectionManagerFilter()
|
filter, err := b.buildMetricsHTTPConnectionManagerFilter()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -235,22 +243,23 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildFilterChains(
|
func (b *Builder) buildFilterChains(
|
||||||
options *config.Options, addr string,
|
cfg *config.Config,
|
||||||
|
addr string,
|
||||||
callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error),
|
callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error),
|
||||||
) ([]*envoy_config_listener_v3.FilterChain, error) {
|
) ([]*envoy_config_listener_v3.FilterChain, error) {
|
||||||
allDomains, err := getAllRouteableDomains(options, addr)
|
allDomains, err := getAllRouteableDomains(cfg, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
tlsDomains, err := getAllTLSDomains(options, addr)
|
tlsDomains, err := getAllTLSDomains(cfg, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var chains []*envoy_config_listener_v3.FilterChain
|
var chains []*envoy_config_listener_v3.FilterChain
|
||||||
for _, domain := range tlsDomains {
|
for _, domain := range tlsDomains {
|
||||||
routeableDomains, err := getRouteableDomainsForTLSServerName(options, addr, domain)
|
routeableDomains, err := getRouteableDomainsForTLSServerName(cfg, addr, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -273,31 +282,31 @@ func (b *Builder) buildFilterChains(
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
||||||
options *config.Options,
|
cfg *config.Config,
|
||||||
domains []string,
|
domains []string,
|
||||||
tlsDomain string,
|
tlsDomain string,
|
||||||
) (*envoy_config_listener_v3.Filter, error) {
|
) (*envoy_config_listener_v3.Filter, error) {
|
||||||
authorizeURLs, err := options.GetInternalAuthorizeURLs()
|
authorizeURLs, err := cfg.Options.GetInternalAuthorizeURLs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
dataBrokerURLs, err := options.GetInternalDataBrokerURLs()
|
dataBrokerURLs, err := cfg.Options.GetInternalDataBrokerURLs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var virtualHosts []*envoy_config_route_v3.VirtualHost
|
var virtualHosts []*envoy_config_route_v3.VirtualHost
|
||||||
for _, domain := range domains {
|
for _, domain := range domains {
|
||||||
vh, err := b.buildVirtualHost(options, domain, domain)
|
vh, err := b.buildVirtualHost(cfg, domain, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if options.Addr == options.GetGRPCAddr() {
|
if cfg.Options.Addr == cfg.Options.GetGRPCAddr() {
|
||||||
// if this is a gRPC service domain and we're supposed to handle that, add those routes
|
// if this is a gRPC service domain and we're supposed to handle that, add those routes
|
||||||
if (config.IsAuthorize(options.Services) && hostsMatchDomain(authorizeURLs, domain)) ||
|
if (config.IsAuthorize(cfg.Options.Services) && hostsMatchDomain(authorizeURLs, domain)) ||
|
||||||
(config.IsDataBroker(options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) {
|
(config.IsDataBroker(cfg.Options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) {
|
||||||
rs, err := b.buildGRPCRoutes()
|
rs, err := b.buildGRPCRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -307,8 +316,8 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
||||||
}
|
}
|
||||||
|
|
||||||
// if we're the proxy, add all the policy routes
|
// if we're the proxy, add all the policy routes
|
||||||
if config.IsProxy(options.Services) {
|
if config.IsProxy(cfg.Options.Services) {
|
||||||
rs, err := b.buildPolicyRoutes(options, domain)
|
rs, err := b.buildPolicyRoutes(cfg, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -320,15 +329,15 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
vh, err := b.buildVirtualHost(options, "catch-all", "*")
|
vh, err := b.buildVirtualHost(cfg, "catch-all", "*")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
virtualHosts = append(virtualHosts, vh)
|
virtualHosts = append(virtualHosts, vh)
|
||||||
|
|
||||||
var grpcClientTimeout *durationpb.Duration
|
var grpcClientTimeout *durationpb.Duration
|
||||||
if options.GRPCClientTimeout != 0 {
|
if cfg.Options.GRPCClientTimeout != 0 {
|
||||||
grpcClientTimeout = durationpb.New(options.GRPCClientTimeout)
|
grpcClientTimeout = durationpb.New(cfg.Options.GRPCClientTimeout)
|
||||||
} else {
|
} else {
|
||||||
grpcClientTimeout = durationpb.New(30 * time.Second)
|
grpcClientTimeout = durationpb.New(30 * time.Second)
|
||||||
}
|
}
|
||||||
|
@ -346,15 +355,15 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
||||||
filters = append(filters, HTTPRouterFilter())
|
filters = append(filters, HTTPRouterFilter())
|
||||||
|
|
||||||
var maxStreamDuration *durationpb.Duration
|
var maxStreamDuration *durationpb.Duration
|
||||||
if options.WriteTimeout > 0 {
|
if cfg.Options.WriteTimeout > 0 {
|
||||||
maxStreamDuration = durationpb.New(options.WriteTimeout)
|
maxStreamDuration = durationpb.New(cfg.Options.WriteTimeout)
|
||||||
}
|
}
|
||||||
|
|
||||||
rc, err := b.buildRouteConfiguration("main", virtualHosts)
|
rc, err := b.buildRouteConfiguration("main", virtualHosts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
tracingProvider, err := buildTracingHTTP(options)
|
tracingProvider, err := buildTracingHTTP(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -362,27 +371,27 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
||||||
return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{
|
return HTTPConnectionManagerFilter(&envoy_http_connection_manager.HttpConnectionManager{
|
||||||
AlwaysSetRequestIdInResponse: true,
|
AlwaysSetRequestIdInResponse: true,
|
||||||
|
|
||||||
CodecType: options.GetCodecType().ToEnvoy(),
|
CodecType: cfg.Options.GetCodecType().ToEnvoy(),
|
||||||
StatPrefix: "ingress",
|
StatPrefix: "ingress",
|
||||||
RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{
|
RouteSpecifier: &envoy_http_connection_manager.HttpConnectionManager_RouteConfig{
|
||||||
RouteConfig: rc,
|
RouteConfig: rc,
|
||||||
},
|
},
|
||||||
HttpFilters: filters,
|
HttpFilters: filters,
|
||||||
AccessLog: buildAccessLogs(options),
|
AccessLog: buildAccessLogs(cfg),
|
||||||
CommonHttpProtocolOptions: &envoy_config_core_v3.HttpProtocolOptions{
|
CommonHttpProtocolOptions: &envoy_config_core_v3.HttpProtocolOptions{
|
||||||
IdleTimeout: durationpb.New(options.IdleTimeout),
|
IdleTimeout: durationpb.New(cfg.Options.IdleTimeout),
|
||||||
MaxStreamDuration: maxStreamDuration,
|
MaxStreamDuration: maxStreamDuration,
|
||||||
},
|
},
|
||||||
RequestTimeout: durationpb.New(options.ReadTimeout),
|
RequestTimeout: durationpb.New(cfg.Options.ReadTimeout),
|
||||||
Tracing: &envoy_http_connection_manager.HttpConnectionManager_Tracing{
|
Tracing: &envoy_http_connection_manager.HttpConnectionManager_Tracing{
|
||||||
RandomSampling: &envoy_type_v3.Percent{Value: options.TracingSampleRate * 100},
|
RandomSampling: &envoy_type_v3.Percent{Value: cfg.Options.TracingSampleRate * 100},
|
||||||
Provider: tracingProvider,
|
Provider: tracingProvider,
|
||||||
},
|
},
|
||||||
// See https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_conn_man/headers#x-forwarded-for
|
// See https://www.envoyproxy.io/docs/envoy/latest/configuration/http/http_conn_man/headers#x-forwarded-for
|
||||||
UseRemoteAddress: &wrappers.BoolValue{Value: true},
|
UseRemoteAddress: &wrappers.BoolValue{Value: true},
|
||||||
SkipXffAppend: options.SkipXffAppend,
|
SkipXffAppend: cfg.Options.SkipXffAppend,
|
||||||
XffNumTrustedHops: options.XffNumTrustedHops,
|
XffNumTrustedHops: cfg.Options.XffNumTrustedHops,
|
||||||
LocalReplyConfig: b.buildLocalReplyConfig(options),
|
LocalReplyConfig: b.buildLocalReplyConfig(cfg),
|
||||||
}), nil
|
}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -420,7 +429,10 @@ func (b *Builder) buildMetricsHTTPConnectionManagerFilter() (*envoy_config_liste
|
||||||
}), nil
|
}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
|
func (b *Builder) buildGRPCListener(
|
||||||
|
ctx context.Context,
|
||||||
|
cfg *config.Config,
|
||||||
|
) (*envoy_config_listener_v3.Listener, error) {
|
||||||
filter, err := b.buildGRPCHTTPConnectionManagerFilter()
|
filter, err := b.buildGRPCHTTPConnectionManagerFilter()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -437,7 +449,7 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e
|
||||||
return li, nil
|
return li, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
chains, err := b.buildFilterChains(cfg.Options, cfg.Options.GRPCAddr,
|
chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr,
|
||||||
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
|
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
|
||||||
filterChain := &envoy_config_listener_v3.FilterChain{
|
filterChain := &envoy_config_listener_v3.FilterChain{
|
||||||
Filters: []*envoy_config_listener_v3.Filter{filter},
|
Filters: []*envoy_config_listener_v3.Filter{filter},
|
||||||
|
@ -518,7 +530,10 @@ func (b *Builder) buildGRPCHTTPConnectionManagerFilter() (*envoy_config_listener
|
||||||
}), nil
|
}), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_config_route_v3.VirtualHost) (*envoy_config_route_v3.RouteConfiguration, error) {
|
func (b *Builder) buildRouteConfiguration(
|
||||||
|
name string,
|
||||||
|
virtualHosts []*envoy_config_route_v3.VirtualHost,
|
||||||
|
) (*envoy_config_route_v3.RouteConfiguration, error) {
|
||||||
return &envoy_config_route_v3.RouteConfiguration{
|
return &envoy_config_route_v3.RouteConfiguration{
|
||||||
Name: name,
|
Name: name,
|
||||||
VirtualHosts: virtualHosts,
|
VirtualHosts: virtualHosts,
|
||||||
|
@ -527,7 +542,8 @@ func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_con
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
|
func (b *Builder) buildDownstreamTLSContext(
|
||||||
|
ctx context.Context,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
domain string,
|
domain string,
|
||||||
) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
|
) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
|
||||||
|
@ -570,7 +586,8 @@ func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
|
func (b *Builder) buildDownstreamValidationContext(
|
||||||
|
ctx context.Context,
|
||||||
cfg *config.Config,
|
cfg *config.Config,
|
||||||
domain string,
|
domain string,
|
||||||
) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext {
|
) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext {
|
||||||
|
@ -580,7 +597,7 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
|
||||||
needsClientCert = true
|
needsClientCert = true
|
||||||
}
|
}
|
||||||
if !needsClientCert {
|
if !needsClientCert {
|
||||||
for _, p := range getPoliciesForDomain(cfg.Options, domain) {
|
for _, p := range getPoliciesForDomain(cfg, domain) {
|
||||||
if p.TLSDownstreamClientCA != "" {
|
if p.TLSDownstreamClientCA != "" {
|
||||||
needsClientCert = true
|
needsClientCert = true
|
||||||
break
|
break
|
||||||
|
@ -613,19 +630,23 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
|
||||||
return vc
|
return vc
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRouteableDomainsForTLSServerName(options *config.Options, addr string, tlsServerName string) ([]string, error) {
|
func getRouteableDomainsForTLSServerName(
|
||||||
|
cfg *config.Config,
|
||||||
|
addr string,
|
||||||
|
tlsServerName string,
|
||||||
|
) ([]string, error) {
|
||||||
allDomains := sets.NewSorted[string]()
|
allDomains := sets.NewSorted[string]()
|
||||||
|
|
||||||
if addr == options.Addr {
|
if addr == cfg.Options.Addr {
|
||||||
domains, err := options.GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName)
|
domains, err := cfg.Options.GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
allDomains.Add(domains...)
|
allDomains.Add(domains...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr == options.GetGRPCAddr() {
|
if addr == cfg.Options.GetGRPCAddr() {
|
||||||
domains, err := options.GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName)
|
domains, err := cfg.Options.GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -635,19 +656,22 @@ func getRouteableDomainsForTLSServerName(options *config.Options, addr string, t
|
||||||
return allDomains.ToSlice(), nil
|
return allDomains.ToSlice(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) {
|
func getAllRouteableDomains(
|
||||||
|
cfg *config.Config,
|
||||||
|
addr string,
|
||||||
|
) ([]string, error) {
|
||||||
allDomains := sets.NewSorted[string]()
|
allDomains := sets.NewSorted[string]()
|
||||||
|
|
||||||
if addr == options.Addr {
|
if addr == cfg.Options.Addr {
|
||||||
domains, err := options.GetAllRouteableHTTPDomains()
|
domains, err := cfg.Options.GetAllRouteableHTTPDomains()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
allDomains.Add(domains...)
|
allDomains.Add(domains...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr == options.GetGRPCAddr() {
|
if addr == cfg.Options.GetGRPCAddr() {
|
||||||
domains, err := options.GetAllRouteableGRPCDomains()
|
domains, err := cfg.Options.GetAllRouteableGRPCDomains()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -657,8 +681,11 @@ func getAllRouteableDomains(options *config.Options, addr string) ([]string, err
|
||||||
return allDomains.ToSlice(), nil
|
return allDomains.ToSlice(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {
|
func getAllTLSDomains(
|
||||||
allDomains, err := getAllRouteableDomains(options, addr)
|
cfg *config.Config,
|
||||||
|
addr string,
|
||||||
|
) ([]string, error) {
|
||||||
|
allDomains, err := getAllRouteableDomains(cfg, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -711,9 +738,12 @@ func hostMatchesDomain(u *url.URL, host string) bool {
|
||||||
return h1 == h2 && p1 == p2
|
return h1 == h2 && p1 == p2
|
||||||
}
|
}
|
||||||
|
|
||||||
func getPoliciesForDomain(options *config.Options, domain string) []config.Policy {
|
func getPoliciesForDomain(
|
||||||
|
cfg *config.Config,
|
||||||
|
domain string,
|
||||||
|
) []config.Policy {
|
||||||
var policies []config.Policy
|
var policies []config.Policy
|
||||||
for _, p := range options.GetAllPolicies() {
|
for _, p := range cfg.Options.GetAllPolicies() {
|
||||||
if p.Source != nil && p.Source.URL.Hostname() == domain {
|
if p.Source != nil && p.Source.URL.Hostname() == domain {
|
||||||
policies = append(policies, p)
|
policies = append(policies, p)
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,13 +26,11 @@ func Test_buildMetricsHTTPConnectionManagerFilter(t *testing.T) {
|
||||||
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")
|
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")
|
||||||
|
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||||
li, err := b.buildMetricsListener(&config.Config{
|
li, err := b.buildMetricsListener(config.New(&config.Options{
|
||||||
Options: &config.Options{
|
MetricsAddr: "127.0.0.1:9902",
|
||||||
MetricsAddr: "127.0.0.1:9902",
|
MetricsCertificate: aExampleComCert,
|
||||||
MetricsCertificate: aExampleComCert,
|
MetricsCertificateKey: aExampleComKey,
|
||||||
MetricsCertificateKey: aExampleComKey,
|
}))
|
||||||
},
|
|
||||||
})
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
{
|
{
|
||||||
|
@ -115,7 +113,7 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
|
||||||
options := config.NewDefaultOptions()
|
options := config.NewDefaultOptions()
|
||||||
options.SkipXffAppend = true
|
options.SkipXffAppend = true
|
||||||
options.XffNumTrustedHops = 1
|
options.XffNumTrustedHops = 1
|
||||||
filter, err := b.buildMainHTTPConnectionManagerFilter(options, []string{"example.com"}, "*")
|
filter, err := b.buildMainHTTPConnectionManagerFilter(config.New(options), []string{"example.com"}, "*")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, `{
|
testutil.AssertProtoJSONEqual(t, `{
|
||||||
"name": "envoy.filters.network.http_connection_manager",
|
"name": "envoy.filters.network.http_connection_manager",
|
||||||
|
@ -544,10 +542,10 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
||||||
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")
|
keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem")
|
||||||
|
|
||||||
t.Run("no-validation", func(t *testing.T) {
|
t.Run("no-validation", func(t *testing.T) {
|
||||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
|
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
|
||||||
Cert: aExampleComCert,
|
Cert: aExampleComCert,
|
||||||
Key: aExampleComKey,
|
Key: aExampleComKey,
|
||||||
}}, "a.example.com")
|
}), "a.example.com")
|
||||||
|
|
||||||
testutil.AssertProtoJSONEqual(t, `{
|
testutil.AssertProtoJSONEqual(t, `{
|
||||||
"commonTlsContext": {
|
"commonTlsContext": {
|
||||||
|
@ -577,11 +575,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
||||||
}`, downstreamTLSContext)
|
}`, downstreamTLSContext)
|
||||||
})
|
})
|
||||||
t.Run("client-ca", func(t *testing.T) {
|
t.Run("client-ca", func(t *testing.T) {
|
||||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
|
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
|
||||||
Cert: aExampleComCert,
|
Cert: aExampleComCert,
|
||||||
Key: aExampleComKey,
|
Key: aExampleComKey,
|
||||||
ClientCA: "TEST",
|
ClientCA: "TEST",
|
||||||
}}, "a.example.com")
|
}), "a.example.com")
|
||||||
|
|
||||||
testutil.AssertProtoJSONEqual(t, `{
|
testutil.AssertProtoJSONEqual(t, `{
|
||||||
"commonTlsContext": {
|
"commonTlsContext": {
|
||||||
|
@ -614,7 +612,7 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
||||||
}`, downstreamTLSContext)
|
}`, downstreamTLSContext)
|
||||||
})
|
})
|
||||||
t.Run("policy-client-ca", func(t *testing.T) {
|
t.Run("policy-client-ca", func(t *testing.T) {
|
||||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
|
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
|
||||||
Cert: aExampleComCert,
|
Cert: aExampleComCert,
|
||||||
Key: aExampleComKey,
|
Key: aExampleComKey,
|
||||||
Policies: []config.Policy{
|
Policies: []config.Policy{
|
||||||
|
@ -623,7 +621,7 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
||||||
TLSDownstreamClientCA: "TEST",
|
TLSDownstreamClientCA: "TEST",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}}, "a.example.com")
|
}), "a.example.com")
|
||||||
|
|
||||||
testutil.AssertProtoJSONEqual(t, `{
|
testutil.AssertProtoJSONEqual(t, `{
|
||||||
"commonTlsContext": {
|
"commonTlsContext": {
|
||||||
|
@ -656,11 +654,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
||||||
}`, downstreamTLSContext)
|
}`, downstreamTLSContext)
|
||||||
})
|
})
|
||||||
t.Run("http1", func(t *testing.T) {
|
t.Run("http1", func(t *testing.T) {
|
||||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
|
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
|
||||||
Cert: aExampleComCert,
|
Cert: aExampleComCert,
|
||||||
Key: aExampleComKey,
|
Key: aExampleComKey,
|
||||||
CodecType: config.CodecTypeHTTP1,
|
CodecType: config.CodecTypeHTTP1,
|
||||||
}}, "a.example.com")
|
}), "a.example.com")
|
||||||
|
|
||||||
testutil.AssertProtoJSONEqual(t, `{
|
testutil.AssertProtoJSONEqual(t, `{
|
||||||
"commonTlsContext": {
|
"commonTlsContext": {
|
||||||
|
@ -690,11 +688,11 @@ func Test_buildDownstreamTLSContext(t *testing.T) {
|
||||||
}`, downstreamTLSContext)
|
}`, downstreamTLSContext)
|
||||||
})
|
})
|
||||||
t.Run("http2", func(t *testing.T) {
|
t.Run("http2", func(t *testing.T) {
|
||||||
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{
|
downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), config.New(&config.Options{
|
||||||
Cert: aExampleComCert,
|
Cert: aExampleComCert,
|
||||||
Key: aExampleComKey,
|
Key: aExampleComKey,
|
||||||
CodecType: config.CodecTypeHTTP2,
|
CodecType: config.CodecTypeHTTP2,
|
||||||
}}, "a.example.com")
|
}), "a.example.com")
|
||||||
|
|
||||||
testutil.AssertProtoJSONEqual(t, `{
|
testutil.AssertProtoJSONEqual(t, `{
|
||||||
"commonTlsContext": {
|
"commonTlsContext": {
|
||||||
|
@ -739,9 +737,10 @@ func Test_getAllDomains(t *testing.T) {
|
||||||
{Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}},
|
{Source: &config.StringURL{URL: mustParseURL(t, "https://c.example.com")}},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
cfg := config.New(options)
|
||||||
t.Run("routable", func(t *testing.T) {
|
t.Run("routable", func(t *testing.T) {
|
||||||
t.Run("http", func(t *testing.T) {
|
t.Run("http", func(t *testing.T) {
|
||||||
actual, err := getAllRouteableDomains(options, "127.0.0.1:9000")
|
actual, err := getAllRouteableDomains(cfg, "127.0.0.1:9000")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expect := []string{
|
expect := []string{
|
||||||
"a.example.com",
|
"a.example.com",
|
||||||
|
@ -756,7 +755,7 @@ func Test_getAllDomains(t *testing.T) {
|
||||||
assert.Equal(t, expect, actual)
|
assert.Equal(t, expect, actual)
|
||||||
})
|
})
|
||||||
t.Run("grpc", func(t *testing.T) {
|
t.Run("grpc", func(t *testing.T) {
|
||||||
actual, err := getAllRouteableDomains(options, "127.0.0.1:9001")
|
actual, err := getAllRouteableDomains(cfg, "127.0.0.1:9001")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expect := []string{
|
expect := []string{
|
||||||
"authorize.example.com:9001",
|
"authorize.example.com:9001",
|
||||||
|
@ -765,9 +764,9 @@ func Test_getAllDomains(t *testing.T) {
|
||||||
assert.Equal(t, expect, actual)
|
assert.Equal(t, expect, actual)
|
||||||
})
|
})
|
||||||
t.Run("both", func(t *testing.T) {
|
t.Run("both", func(t *testing.T) {
|
||||||
newOptions := *options
|
newCfg := cfg.Clone()
|
||||||
newOptions.GRPCAddr = newOptions.Addr
|
newCfg.Options.GRPCAddr = cfg.Options.Addr
|
||||||
actual, err := getAllRouteableDomains(&newOptions, "127.0.0.1:9000")
|
actual, err := getAllRouteableDomains(newCfg, "127.0.0.1:9000")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expect := []string{
|
expect := []string{
|
||||||
"a.example.com",
|
"a.example.com",
|
||||||
|
@ -786,7 +785,7 @@ func Test_getAllDomains(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("tls", func(t *testing.T) {
|
t.Run("tls", func(t *testing.T) {
|
||||||
t.Run("http", func(t *testing.T) {
|
t.Run("http", func(t *testing.T) {
|
||||||
actual, err := getAllTLSDomains(options, "127.0.0.1:9000")
|
actual, err := getAllTLSDomains(cfg, "127.0.0.1:9000")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expect := []string{
|
expect := []string{
|
||||||
"a.example.com",
|
"a.example.com",
|
||||||
|
@ -797,7 +796,7 @@ func Test_getAllDomains(t *testing.T) {
|
||||||
assert.Equal(t, expect, actual)
|
assert.Equal(t, expect, actual)
|
||||||
})
|
})
|
||||||
t.Run("grpc", func(t *testing.T) {
|
t.Run("grpc", func(t *testing.T) {
|
||||||
actual, err := getAllTLSDomains(options, "127.0.0.1:9001")
|
actual, err := getAllTLSDomains(cfg, "127.0.0.1:9001")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
expect := []string{
|
expect := []string{
|
||||||
"authorize.example.com",
|
"authorize.example.com",
|
||||||
|
@ -831,10 +830,10 @@ func Test_buildRouteConfiguration(t *testing.T) {
|
||||||
func Test_requireProxyProtocol(t *testing.T) {
|
func Test_requireProxyProtocol(t *testing.T) {
|
||||||
b := New("local-grpc", "local-http", "local-metrics", nil, nil)
|
b := New("local-grpc", "local-http", "local-metrics", nil, nil)
|
||||||
t.Run("required", func(t *testing.T) {
|
t.Run("required", func(t *testing.T) {
|
||||||
li, err := b.buildMainListener(context.Background(), &config.Config{Options: &config.Options{
|
li, err := b.buildMainListener(context.Background(), config.New(&config.Options{
|
||||||
UseProxyProtocol: true,
|
UseProxyProtocol: true,
|
||||||
InsecureServer: true,
|
InsecureServer: true,
|
||||||
}})
|
}))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, `[
|
testutil.AssertProtoJSONEqual(t, `[
|
||||||
{
|
{
|
||||||
|
@ -846,10 +845,10 @@ func Test_requireProxyProtocol(t *testing.T) {
|
||||||
]`, li.GetListenerFilters())
|
]`, li.GetListenerFilters())
|
||||||
})
|
})
|
||||||
t.Run("not required", func(t *testing.T) {
|
t.Run("not required", func(t *testing.T) {
|
||||||
li, err := b.buildMainListener(context.Background(), &config.Config{Options: &config.Options{
|
li, err := b.buildMainListener(context.Background(), config.New(&config.Options{
|
||||||
UseProxyProtocol: false,
|
UseProxyProtocol: false,
|
||||||
InsecureServer: true,
|
InsecureServer: true,
|
||||||
}})
|
}))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Len(t, li.GetListenerFilters(), 0)
|
assert.Len(t, li.GetListenerFilters(), 0)
|
||||||
})
|
})
|
||||||
|
|
|
@ -14,7 +14,9 @@ import (
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (b *Builder) buildOutboundListener(cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
|
func (b *Builder) buildOutboundListener(
|
||||||
|
cfg *config.Config,
|
||||||
|
) (*envoy_config_listener_v3.Listener, error) {
|
||||||
outboundPort, err := strconv.Atoi(cfg.OutboundPort)
|
outboundPort, err := strconv.Atoi(cfg.OutboundPort)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid outbound port %v: %w", cfg.OutboundPort, err)
|
return nil, fmt.Errorf("invalid outbound port %v: %w", cfg.OutboundPort, err)
|
||||||
|
|
|
@ -47,12 +47,15 @@ func (b *Builder) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) {
|
||||||
}}, nil
|
}}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) {
|
func (b *Builder) buildPomeriumHTTPRoutes(
|
||||||
|
cfg *config.Config,
|
||||||
|
domain string,
|
||||||
|
) ([]*envoy_config_route_v3.Route, error) {
|
||||||
var routes []*envoy_config_route_v3.Route
|
var routes []*envoy_config_route_v3.Route
|
||||||
|
|
||||||
// if this is the pomerium proxy in front of the the authenticate service, don't add
|
// if this is the pomerium proxy in front of the the authenticate service, don't add
|
||||||
// these routes since they will be handled by authenticate
|
// these routes since they will be handled by authenticate
|
||||||
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain)
|
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(cfg, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -69,27 +72,27 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string
|
||||||
b.buildControlPlanePrefixRoute("/.well-known/pomerium/", false),
|
b.buildControlPlanePrefixRoute("/.well-known/pomerium/", false),
|
||||||
)
|
)
|
||||||
// per #837, only add robots.txt if there are no unauthenticated routes
|
// per #837, only add robots.txt if there are no unauthenticated routes
|
||||||
if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) {
|
if !hasPublicPolicyMatchingURL(cfg, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) {
|
||||||
routes = append(routes, b.buildControlPlanePathRoute("/robots.txt", false))
|
routes = append(routes, b.buildControlPlanePathRoute("/robots.txt", false))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// if we're handling authentication, add the oauth2 callback url
|
// if we're handling authentication, add the oauth2 callback url
|
||||||
authenticateURL, err := options.GetInternalAuthenticateURL()
|
authenticateURL, err := cfg.Options.GetInternalAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) {
|
if config.IsAuthenticate(cfg.Options.Services) && hostMatchesDomain(authenticateURL, domain) {
|
||||||
routes = append(routes,
|
routes = append(routes,
|
||||||
b.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false),
|
b.buildControlPlanePathRoute(cfg.Options.AuthenticateCallbackPath, false),
|
||||||
b.buildControlPlanePathRoute("/", false),
|
b.buildControlPlanePathRoute("/", false),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
// if we're the proxy and this is the forward-auth url
|
// if we're the proxy and this is the forward-auth url
|
||||||
forwardAuthURL, err := options.GetForwardAuthURL()
|
forwardAuthURL, err := cfg.Options.GetForwardAuthURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if config.IsProxy(options.Services) && hostMatchesDomain(forwardAuthURL, domain) {
|
if config.IsProxy(cfg.Options.Services) && hostMatchesDomain(forwardAuthURL, domain) {
|
||||||
// disable ext_authz and pass request to proxy handlers that enable authN flow
|
// disable ext_authz and pass request to proxy handlers that enable authN flow
|
||||||
r, err := b.buildControlPlanePathAndQueryRoute("/verify", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI})
|
r, err := b.buildControlPlanePathAndQueryRoute("/verify", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -227,10 +230,13 @@ func getClusterStatsName(policy *config.Policy) string {
|
||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) {
|
func (b *Builder) buildPolicyRoutes(
|
||||||
|
cfg *config.Config,
|
||||||
|
domain string,
|
||||||
|
) ([]*envoy_config_route_v3.Route, error) {
|
||||||
var routes []*envoy_config_route_v3.Route
|
var routes []*envoy_config_route_v3.Route
|
||||||
|
|
||||||
for i, p := range options.GetAllPolicies() {
|
for i, p := range cfg.Options.GetAllPolicies() {
|
||||||
policy := p
|
policy := p
|
||||||
if !hostMatchesDomain(policy.Source.URL, domain) {
|
if !hostMatchesDomain(policy.Source.URL, domain) {
|
||||||
continue
|
continue
|
||||||
|
@ -242,7 +248,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
|
||||||
Match: match,
|
Match: match,
|
||||||
Metadata: &envoy_config_core_v3.Metadata{},
|
Metadata: &envoy_config_core_v3.Metadata{},
|
||||||
RequestHeadersToAdd: toEnvoyHeaders(policy.SetRequestHeaders),
|
RequestHeadersToAdd: toEnvoyHeaders(policy.SetRequestHeaders),
|
||||||
RequestHeadersToRemove: getRequestHeadersToRemove(options, &policy),
|
RequestHeadersToRemove: getRequestHeadersToRemove(cfg, &policy),
|
||||||
ResponseHeadersToAdd: toEnvoyHeaders(policy.SetResponseHeaders),
|
ResponseHeadersToAdd: toEnvoyHeaders(policy.SetResponseHeaders),
|
||||||
}
|
}
|
||||||
if policy.Redirect != nil {
|
if policy.Redirect != nil {
|
||||||
|
@ -252,7 +258,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
|
||||||
}
|
}
|
||||||
envoyRoute.Action = &envoy_config_route_v3.Route_Redirect{Redirect: action}
|
envoyRoute.Action = &envoy_config_route_v3.Route_Redirect{Redirect: action}
|
||||||
} else {
|
} else {
|
||||||
action, err := b.buildPolicyRouteRouteAction(options, &policy)
|
action, err := b.buildPolicyRouteRouteAction(cfg, &policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -264,7 +270,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
|
||||||
}
|
}
|
||||||
|
|
||||||
// disable authentication entirely when the proxy is fronting authenticate
|
// disable authentication entirely when the proxy is fronting authenticate
|
||||||
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain)
|
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(cfg, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -275,7 +281,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
|
||||||
} else {
|
} else {
|
||||||
luaMetadata["remove_pomerium_cookie"] = &structpb.Value{
|
luaMetadata["remove_pomerium_cookie"] = &structpb.Value{
|
||||||
Kind: &structpb.Value_StringValue{
|
Kind: &structpb.Value_StringValue{
|
||||||
StringValue: options.CookieName,
|
StringValue: cfg.Options.CookieName,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
luaMetadata["remove_pomerium_authorization"] = &structpb.Value{
|
luaMetadata["remove_pomerium_authorization"] = &structpb.Value{
|
||||||
|
@ -350,13 +356,16 @@ func (b *Builder) buildPolicyRouteRedirectAction(r *config.PolicyRedirect) (*env
|
||||||
return action, nil
|
return action, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildPolicyRouteRouteAction(options *config.Options, policy *config.Policy) (*envoy_config_route_v3.RouteAction, error) {
|
func (b *Builder) buildPolicyRouteRouteAction(
|
||||||
|
cfg *config.Config,
|
||||||
|
policy *config.Policy,
|
||||||
|
) (*envoy_config_route_v3.RouteAction, error) {
|
||||||
clusterName := getClusterID(policy)
|
clusterName := getClusterID(policy)
|
||||||
// kubernetes requests are sent to the http control plane to be reproxied
|
// kubernetes requests are sent to the http control plane to be reproxied
|
||||||
if policy.IsForKubernetes() {
|
if policy.IsForKubernetes() {
|
||||||
clusterName = httpCluster
|
clusterName = httpCluster
|
||||||
}
|
}
|
||||||
routeTimeout := getRouteTimeout(options, policy)
|
routeTimeout := getRouteTimeout(cfg, policy)
|
||||||
idleTimeout := getRouteIdleTimeout(policy)
|
idleTimeout := getRouteIdleTimeout(policy)
|
||||||
prefixRewrite, regexRewrite := getRewriteOptions(policy)
|
prefixRewrite, regexRewrite := getRewriteOptions(policy)
|
||||||
upgradeConfigs := []*envoy_config_route_v3.RouteAction_UpgradeConfig{
|
upgradeConfigs := []*envoy_config_route_v3.RouteAction_UpgradeConfig{
|
||||||
|
@ -464,13 +473,16 @@ func mkRouteMatch(policy *config.Policy) *envoy_config_route_v3.RouteMatch {
|
||||||
return match
|
return match
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRequestHeadersToRemove(options *config.Options, policy *config.Policy) []string {
|
func getRequestHeadersToRemove(
|
||||||
|
cfg *config.Config,
|
||||||
|
policy *config.Policy,
|
||||||
|
) []string {
|
||||||
requestHeadersToRemove := policy.RemoveRequestHeaders
|
requestHeadersToRemove := policy.RemoveRequestHeaders
|
||||||
if !policy.PassIdentityHeaders {
|
if !policy.PassIdentityHeaders {
|
||||||
requestHeadersToRemove = append(requestHeadersToRemove,
|
requestHeadersToRemove = append(requestHeadersToRemove,
|
||||||
httputil.HeaderPomeriumJWTAssertion,
|
httputil.HeaderPomeriumJWTAssertion,
|
||||||
httputil.HeaderPomeriumJWTAssertionFor)
|
httputil.HeaderPomeriumJWTAssertionFor)
|
||||||
for headerName := range options.JWTClaimsHeaders {
|
for headerName := range cfg.Options.JWTClaimsHeaders {
|
||||||
requestHeadersToRemove = append(requestHeadersToRemove, headerName)
|
requestHeadersToRemove = append(requestHeadersToRemove, headerName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -482,7 +494,10 @@ func getRequestHeadersToRemove(options *config.Options, policy *config.Policy) [
|
||||||
return requestHeadersToRemove
|
return requestHeadersToRemove
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRouteTimeout(options *config.Options, policy *config.Policy) *durationpb.Duration {
|
func getRouteTimeout(
|
||||||
|
cfg *config.Config,
|
||||||
|
policy *config.Policy,
|
||||||
|
) *durationpb.Duration {
|
||||||
var routeTimeout *durationpb.Duration
|
var routeTimeout *durationpb.Duration
|
||||||
if policy.UpstreamTimeout != nil {
|
if policy.UpstreamTimeout != nil {
|
||||||
routeTimeout = durationpb.New(*policy.UpstreamTimeout)
|
routeTimeout = durationpb.New(*policy.UpstreamTimeout)
|
||||||
|
@ -490,7 +505,7 @@ func getRouteTimeout(options *config.Options, policy *config.Policy) *durationpb
|
||||||
// a non-zero value would conflict with idleTimeout and/or websocket / tcp calls
|
// a non-zero value would conflict with idleTimeout and/or websocket / tcp calls
|
||||||
routeTimeout = durationpb.New(0)
|
routeTimeout = durationpb.New(0)
|
||||||
} else {
|
} else {
|
||||||
routeTimeout = durationpb.New(options.DefaultUpstreamTimeout)
|
routeTimeout = durationpb.New(cfg.Options.DefaultUpstreamTimeout)
|
||||||
}
|
}
|
||||||
return routeTimeout
|
return routeTimeout
|
||||||
}
|
}
|
||||||
|
@ -564,8 +579,11 @@ func setHostRewriteOptions(policy *config.Policy, action *envoy_config_route_v3.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) bool {
|
func hasPublicPolicyMatchingURL(
|
||||||
for _, policy := range options.GetAllPolicies() {
|
cfg *config.Config,
|
||||||
|
requestURL url.URL,
|
||||||
|
) bool {
|
||||||
|
for _, policy := range cfg.Options.GetAllPolicies() {
|
||||||
if policy.AllowPublicUnauthenticatedAccess && policy.Matches(requestURL) {
|
if policy.AllowPublicUnauthenticatedAccess && policy.Matches(requestURL) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
@ -573,13 +591,16 @@ func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) boo
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func isProxyFrontingAuthenticate(options *config.Options, domain string) (bool, error) {
|
func isProxyFrontingAuthenticate(
|
||||||
authenticateURL, err := options.GetAuthenticateURL()
|
cfg *config.Config,
|
||||||
|
domain string,
|
||||||
|
) (bool, error) {
|
||||||
|
authenticateURL, err := cfg.Options.GetAuthenticateURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if !config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) {
|
if !config.IsAuthenticate(cfg.Options.Services) && hostMatchesDomain(authenticateURL, domain) {
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -84,7 +84,8 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
|
||||||
AuthenticateCallbackPath: "/oauth2/callback",
|
AuthenticateCallbackPath: "/oauth2/callback",
|
||||||
ForwardAuthURLString: "https://forward-auth.example.com",
|
ForwardAuthURLString: "https://forward-auth.example.com",
|
||||||
}
|
}
|
||||||
routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com")
|
cfg := config.New(options)
|
||||||
|
routes, err := b.buildPomeriumHTTPRoutes(cfg, "authenticate.example.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testutil.AssertProtoJSONEqual(t, `[
|
testutil.AssertProtoJSONEqual(t, `[
|
||||||
|
@ -106,7 +107,8 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
|
||||||
AuthenticateURLString: "https://authenticate.example.com",
|
AuthenticateURLString: "https://authenticate.example.com",
|
||||||
AuthenticateCallbackPath: "/oauth2/callback",
|
AuthenticateCallbackPath: "/oauth2/callback",
|
||||||
}
|
}
|
||||||
routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com")
|
cfg := config.New(options)
|
||||||
|
routes, err := b.buildPomeriumHTTPRoutes(cfg, "authenticate.example.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, "null", routes)
|
testutil.AssertProtoJSONEqual(t, "null", routes)
|
||||||
})
|
})
|
||||||
|
@ -122,8 +124,9 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
|
||||||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
|
cfg := config.New(options)
|
||||||
_ = options.Policies[0].Validate()
|
_ = options.Policies[0].Validate()
|
||||||
routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com")
|
routes, err := b.buildPomeriumHTTPRoutes(cfg, "from.example.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testutil.AssertProtoJSONEqual(t, `[
|
testutil.AssertProtoJSONEqual(t, `[
|
||||||
|
@ -150,8 +153,9 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
|
||||||
AllowPublicUnauthenticatedAccess: true,
|
AllowPublicUnauthenticatedAccess: true,
|
||||||
}},
|
}},
|
||||||
}
|
}
|
||||||
|
cfg := config.New(options)
|
||||||
_ = options.Policies[0].Validate()
|
_ = options.Policies[0].Validate()
|
||||||
routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com")
|
routes, err := b.buildPomeriumHTTPRoutes(cfg, "from.example.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testutil.AssertProtoJSONEqual(t, `[
|
testutil.AssertProtoJSONEqual(t, `[
|
||||||
|
@ -242,7 +246,7 @@ func TestTimeouts(t *testing.T) {
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
b := &Builder{filemgr: filemgr.NewManager()}
|
b := &Builder{filemgr: filemgr.NewManager()}
|
||||||
routes, err := b.buildPolicyRoutes(&config.Options{
|
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
|
||||||
CookieName: "pomerium",
|
CookieName: "pomerium",
|
||||||
DefaultUpstreamTimeout: time.Second * 3,
|
DefaultUpstreamTimeout: time.Second * 3,
|
||||||
Policies: []config.Policy{
|
Policies: []config.Policy{
|
||||||
|
@ -253,7 +257,7 @@ func TestTimeouts(t *testing.T) {
|
||||||
IdleTimeout: getDuration(tc.idle),
|
IdleTimeout: getDuration(tc.idle),
|
||||||
AllowWebsockets: tc.allowWebsockets,
|
AllowWebsockets: tc.allowWebsockets,
|
||||||
}},
|
}},
|
||||||
}, "example.com")
|
}), "example.com")
|
||||||
if !assert.NoError(t, err, "%v", tc) || !assert.Len(t, routes, 1, tc) || !assert.NotNil(t, routes[0].GetRoute(), "%v", tc) {
|
if !assert.NoError(t, err, "%v", tc) || !assert.Len(t, routes, 1, tc) || !assert.NotNil(t, routes[0].GetRoute(), "%v", tc) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
@ -295,7 +299,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
||||||
ten := time.Second * 10
|
ten := time.Second * 10
|
||||||
|
|
||||||
b := &Builder{filemgr: filemgr.NewManager()}
|
b := &Builder{filemgr: filemgr.NewManager()}
|
||||||
routes, err := b.buildPolicyRoutes(&config.Options{
|
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
|
||||||
CookieName: "pomerium",
|
CookieName: "pomerium",
|
||||||
DefaultUpstreamTimeout: time.Second * 3,
|
DefaultUpstreamTimeout: time.Second * 3,
|
||||||
Policies: []config.Policy{
|
Policies: []config.Policy{
|
||||||
|
@ -357,7 +361,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
||||||
UpstreamTimeout: &ten,
|
UpstreamTimeout: &ten,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, "example.com")
|
}), "example.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
|
@ -724,7 +728,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
||||||
`, routes)
|
`, routes)
|
||||||
|
|
||||||
t.Run("fronting-authenticate", func(t *testing.T) {
|
t.Run("fronting-authenticate", func(t *testing.T) {
|
||||||
routes, err := b.buildPolicyRoutes(&config.Options{
|
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
|
||||||
AuthenticateURLString: "https://authenticate.example.com",
|
AuthenticateURLString: "https://authenticate.example.com",
|
||||||
Services: "proxy",
|
Services: "proxy",
|
||||||
CookieName: "pomerium",
|
CookieName: "pomerium",
|
||||||
|
@ -735,7 +739,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
||||||
PassIdentityHeaders: true,
|
PassIdentityHeaders: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, "authenticate.example.com")
|
}), "authenticate.example.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
|
@ -791,7 +795,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("tcp", func(t *testing.T) {
|
t.Run("tcp", func(t *testing.T) {
|
||||||
routes, err := b.buildPolicyRoutes(&config.Options{
|
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
|
||||||
CookieName: "pomerium",
|
CookieName: "pomerium",
|
||||||
DefaultUpstreamTimeout: time.Second * 3,
|
DefaultUpstreamTimeout: time.Second * 3,
|
||||||
Policies: []config.Policy{
|
Policies: []config.Policy{
|
||||||
|
@ -805,7 +809,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
||||||
UpstreamTimeout: &ten,
|
UpstreamTimeout: &ten,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, "example.com:22")
|
}), "example.com:22")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
|
@ -905,7 +909,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("remove-pomerium-headers", func(t *testing.T) {
|
t.Run("remove-pomerium-headers", func(t *testing.T) {
|
||||||
routes, err := b.buildPolicyRoutes(&config.Options{
|
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
|
||||||
AuthenticateURLString: "https://authenticate.example.com",
|
AuthenticateURLString: "https://authenticate.example.com",
|
||||||
Services: "proxy",
|
Services: "proxy",
|
||||||
CookieName: "pomerium",
|
CookieName: "pomerium",
|
||||||
|
@ -918,7 +922,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
|
||||||
Source: &config.StringURL{URL: mustParseURL(t, "https://from.example.com")},
|
Source: &config.StringURL{URL: mustParseURL(t, "https://from.example.com")},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, "from.example.com")
|
}), "from.example.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
|
@ -980,7 +984,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) {
|
||||||
}(getClusterID)
|
}(getClusterID)
|
||||||
getClusterID = policyNameFunc()
|
getClusterID = policyNameFunc()
|
||||||
b := &Builder{filemgr: filemgr.NewManager()}
|
b := &Builder{filemgr: filemgr.NewManager()}
|
||||||
routes, err := b.buildPolicyRoutes(&config.Options{
|
routes, err := b.buildPolicyRoutes(config.New(&config.Options{
|
||||||
CookieName: "pomerium",
|
CookieName: "pomerium",
|
||||||
DefaultUpstreamTimeout: time.Second * 3,
|
DefaultUpstreamTimeout: time.Second * 3,
|
||||||
Policies: []config.Policy{
|
Policies: []config.Policy{
|
||||||
|
@ -1022,7 +1026,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) {
|
||||||
HostPathRegexRewriteSubstitution: "\\1",
|
HostPathRegexRewriteSubstitution: "\\1",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}, "example.com")
|
}), "example.com")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
|
|
|
@ -14,8 +14,10 @@ import (
|
||||||
"github.com/pomerium/pomerium/pkg/protoutil"
|
"github.com/pomerium/pomerium/pkg/protoutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Cluster, error) {
|
func buildTracingCluster(
|
||||||
tracingOptions, err := config.NewTracingOptions(options)
|
cfg *config.Config,
|
||||||
|
) (*envoy_config_cluster_v3.Cluster, error) {
|
||||||
|
tracingOptions, err := config.NewTracingOptions(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("envoyconfig: invalid tracing config: %w", err)
|
return nil, fmt.Errorf("envoyconfig: invalid tracing config: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -24,8 +26,8 @@ func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Clus
|
||||||
case trace.DatadogTracingProviderName:
|
case trace.DatadogTracingProviderName:
|
||||||
addr, _ := parseAddress("127.0.0.1:8126")
|
addr, _ := parseAddress("127.0.0.1:8126")
|
||||||
|
|
||||||
if options.TracingDatadogAddress != "" {
|
if cfg.Options.TracingDatadogAddress != "" {
|
||||||
addr, err = parseAddress(options.TracingDatadogAddress)
|
addr, err = parseAddress(cfg.Options.TracingDatadogAddress)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("envoyconfig: invalid tracing datadog address: %w", err)
|
return nil, fmt.Errorf("envoyconfig: invalid tracing datadog address: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -94,8 +96,10 @@ func buildTracingCluster(options *config.Options) (*envoy_config_cluster_v3.Clus
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildTracingHTTP(options *config.Options) (*envoy_config_trace_v3.Tracing_Http, error) {
|
func buildTracingHTTP(
|
||||||
tracingOptions, err := config.NewTracingOptions(options)
|
cfg *config.Config,
|
||||||
|
) (*envoy_config_trace_v3.Tracing_Http, error) {
|
||||||
|
tracingOptions, err := config.NewTracingOptions(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("invalid tracing config: %w", err)
|
return nil, fmt.Errorf("invalid tracing config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -11,9 +11,9 @@ import (
|
||||||
|
|
||||||
func TestBuildTracingCluster(t *testing.T) {
|
func TestBuildTracingCluster(t *testing.T) {
|
||||||
t.Run("datadog", func(t *testing.T) {
|
t.Run("datadog", func(t *testing.T) {
|
||||||
c, err := buildTracingCluster(&config.Options{
|
c, err := buildTracingCluster(config.New(&config.Options{
|
||||||
TracingProvider: "datadog",
|
TracingProvider: "datadog",
|
||||||
})
|
}))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
{
|
{
|
||||||
|
@ -38,10 +38,10 @@ func TestBuildTracingCluster(t *testing.T) {
|
||||||
}
|
}
|
||||||
`, c)
|
`, c)
|
||||||
|
|
||||||
c, err = buildTracingCluster(&config.Options{
|
c, err = buildTracingCluster(config.New(&config.Options{
|
||||||
TracingProvider: "datadog",
|
TracingProvider: "datadog",
|
||||||
TracingDatadogAddress: "example.com:8126",
|
TracingDatadogAddress: "example.com:8126",
|
||||||
})
|
}))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
{
|
{
|
||||||
|
@ -67,10 +67,10 @@ func TestBuildTracingCluster(t *testing.T) {
|
||||||
`, c)
|
`, c)
|
||||||
})
|
})
|
||||||
t.Run("zipkin", func(t *testing.T) {
|
t.Run("zipkin", func(t *testing.T) {
|
||||||
c, err := buildTracingCluster(&config.Options{
|
c, err := buildTracingCluster(config.New(&config.Options{
|
||||||
TracingProvider: "zipkin",
|
TracingProvider: "zipkin",
|
||||||
ZipkinEndpoint: "https://example.com/api/v2/spans",
|
ZipkinEndpoint: "https://example.com/api/v2/spans",
|
||||||
})
|
}))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
{
|
{
|
||||||
|
@ -99,9 +99,9 @@ func TestBuildTracingCluster(t *testing.T) {
|
||||||
|
|
||||||
func TestBuildTracingHTTP(t *testing.T) {
|
func TestBuildTracingHTTP(t *testing.T) {
|
||||||
t.Run("datadog", func(t *testing.T) {
|
t.Run("datadog", func(t *testing.T) {
|
||||||
h, err := buildTracingHTTP(&config.Options{
|
h, err := buildTracingHTTP(config.New(&config.Options{
|
||||||
TracingProvider: "datadog",
|
TracingProvider: "datadog",
|
||||||
})
|
}))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
{
|
{
|
||||||
|
@ -115,10 +115,10 @@ func TestBuildTracingHTTP(t *testing.T) {
|
||||||
`, h)
|
`, h)
|
||||||
})
|
})
|
||||||
t.Run("zipkin", func(t *testing.T) {
|
t.Run("zipkin", func(t *testing.T) {
|
||||||
h, err := buildTracingHTTP(&config.Options{
|
h, err := buildTracingHTTP(config.New(&config.Options{
|
||||||
TracingProvider: "zipkin",
|
TracingProvider: "zipkin",
|
||||||
ZipkinEndpoint: "https://example.com/api/v2/spans",
|
ZipkinEndpoint: "https://example.com/api/v2/spans",
|
||||||
})
|
}))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
testutil.AssertProtoJSONEqual(t, `
|
testutil.AssertProtoJSONEqual(t, `
|
||||||
{
|
{
|
||||||
|
|
|
@ -18,28 +18,28 @@ import (
|
||||||
type TracingOptions = trace.TracingOptions
|
type TracingOptions = trace.TracingOptions
|
||||||
|
|
||||||
// NewTracingOptions builds a new TracingOptions from core Options
|
// NewTracingOptions builds a new TracingOptions from core Options
|
||||||
func NewTracingOptions(o *Options) (*TracingOptions, error) {
|
func NewTracingOptions(cfg *Config) (*TracingOptions, error) {
|
||||||
tracingOpts := TracingOptions{
|
tracingOpts := TracingOptions{
|
||||||
Provider: o.TracingProvider,
|
Provider: cfg.Options.TracingProvider,
|
||||||
Service: telemetry.ServiceName(o.Services),
|
Service: telemetry.ServiceName(cfg.Options.Services),
|
||||||
JaegerAgentEndpoint: o.TracingJaegerAgentEndpoint,
|
JaegerAgentEndpoint: cfg.Options.TracingJaegerAgentEndpoint,
|
||||||
SampleRate: o.TracingSampleRate,
|
SampleRate: cfg.Options.TracingSampleRate,
|
||||||
}
|
}
|
||||||
|
|
||||||
switch o.TracingProvider {
|
switch cfg.Options.TracingProvider {
|
||||||
case trace.DatadogTracingProviderName:
|
case trace.DatadogTracingProviderName:
|
||||||
tracingOpts.DatadogAddress = o.TracingDatadogAddress
|
tracingOpts.DatadogAddress = cfg.Options.TracingDatadogAddress
|
||||||
case trace.JaegerTracingProviderName:
|
case trace.JaegerTracingProviderName:
|
||||||
if o.TracingJaegerCollectorEndpoint != "" {
|
if cfg.Options.TracingJaegerCollectorEndpoint != "" {
|
||||||
jaegerCollectorEndpoint, err := urlutil.ParseAndValidateURL(o.TracingJaegerCollectorEndpoint)
|
jaegerCollectorEndpoint, err := urlutil.ParseAndValidateURL(cfg.Options.TracingJaegerCollectorEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("config: invalid jaeger endpoint url: %w", err)
|
return nil, fmt.Errorf("config: invalid jaeger endpoint url: %w", err)
|
||||||
}
|
}
|
||||||
tracingOpts.JaegerCollectorEndpoint = jaegerCollectorEndpoint
|
tracingOpts.JaegerCollectorEndpoint = jaegerCollectorEndpoint
|
||||||
tracingOpts.JaegerAgentEndpoint = o.TracingJaegerAgentEndpoint
|
tracingOpts.JaegerAgentEndpoint = cfg.Options.TracingJaegerAgentEndpoint
|
||||||
}
|
}
|
||||||
case trace.ZipkinTracingProviderName:
|
case trace.ZipkinTracingProviderName:
|
||||||
zipkinEndpoint, err := urlutil.ParseAndValidateURL(o.ZipkinEndpoint)
|
zipkinEndpoint, err := urlutil.ParseAndValidateURL(cfg.Options.ZipkinEndpoint)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("config: invalid zipkin endpoint url: %w", err)
|
return nil, fmt.Errorf("config: invalid zipkin endpoint url: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -47,7 +47,7 @@ func NewTracingOptions(o *Options) (*TracingOptions, error) {
|
||||||
case "":
|
case "":
|
||||||
return &TracingOptions{}, nil
|
return &TracingOptions{}, nil
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("config: provider %s unknown", o.TracingProvider)
|
return nil, fmt.Errorf("config: provider %s unknown", cfg.Options.TracingProvider)
|
||||||
}
|
}
|
||||||
|
|
||||||
return &tracingOpts, nil
|
return &tracingOpts, nil
|
||||||
|
@ -88,7 +88,7 @@ func (mgr *TraceManager) OnConfigChange(ctx context.Context, cfg *Config) {
|
||||||
mgr.mu.Lock()
|
mgr.mu.Lock()
|
||||||
defer mgr.mu.Unlock()
|
defer mgr.mu.Unlock()
|
||||||
|
|
||||||
traceOpts, err := NewTracingOptions(cfg.Options)
|
traceOpts, err := NewTracingOptions(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx).Err(err).Msg("trace: failed to build tracing options")
|
log.Error(ctx).Err(err).Msg("trace: failed to build tracing options")
|
||||||
return
|
return
|
||||||
|
|
|
@ -68,7 +68,7 @@ func Test_NewTracingOptions(t *testing.T) {
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := NewTracingOptions(tt.opts)
|
got, err := NewTracingOptions(&Config{Options: tt.opts})
|
||||||
assert.NotEqual(t, err == nil, tt.wantErr, "unexpected error value")
|
assert.NotEqual(t, err == nil, tt.wantErr, "unexpected error value")
|
||||||
assert.Empty(t, cmp.Diff(tt.want, got))
|
assert.Empty(t, cmp.Diff(tt.want, got))
|
||||||
})
|
})
|
||||||
|
|
|
@ -28,7 +28,7 @@ func TestNew(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tt.opts.Provider = "google"
|
tt.opts.Provider = "google"
|
||||||
_, err := New(&config.Config{Options: &tt.opts}, events.New())
|
_, err := New(config.New(&tt.opts), events.New())
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
|
|
@ -236,8 +236,8 @@ func TestConfig(t *testing.T) {
|
||||||
}
|
}
|
||||||
_ = p1.Validate()
|
_ = p1.Validate()
|
||||||
|
|
||||||
mgr, err := newManager(ctx, config.NewStaticSource(&config.Config{
|
mgr, err := newManager(ctx,
|
||||||
Options: &config.Options{
|
config.NewStaticSource(config.New(&config.Options{
|
||||||
AutocertOptions: config.AutocertOptions{
|
AutocertOptions: config.AutocertOptions{
|
||||||
Enable: true,
|
Enable: true,
|
||||||
UseStaging: true,
|
UseStaging: true,
|
||||||
|
@ -247,11 +247,11 @@ func TestConfig(t *testing.T) {
|
||||||
},
|
},
|
||||||
HTTPRedirectAddr: addr,
|
HTTPRedirectAddr: addr,
|
||||||
Policies: []config.Policy{p1},
|
Policies: []config.Policy{p1},
|
||||||
},
|
})),
|
||||||
}), certmagic.ACMEIssuer{
|
certmagic.ACMEIssuer{
|
||||||
CA: srv.URL + "/acme/directory",
|
CA: srv.URL + "/acme/directory",
|
||||||
TestCA: srv.URL + "/acme/directory",
|
TestCA: srv.URL + "/acme/directory",
|
||||||
}, time.Millisecond*100)
|
}, time.Millisecond*100)
|
||||||
if !assert.NoError(t, err) {
|
if !assert.NoError(t, err) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -304,16 +304,14 @@ func TestRedirect(t *testing.T) {
|
||||||
addr := li.Addr().String()
|
addr := li.Addr().String()
|
||||||
_ = li.Close()
|
_ = li.Close()
|
||||||
|
|
||||||
src := config.NewStaticSource(&config.Config{
|
src := config.NewStaticSource(config.New(&config.Options{
|
||||||
Options: &config.Options{
|
HTTPRedirectAddr: addr,
|
||||||
HTTPRedirectAddr: addr,
|
SetResponseHeaders: map[string]string{
|
||||||
SetResponseHeaders: map[string]string{
|
"X-Frame-Options": "SAMEORIGIN",
|
||||||
"X-Frame-Options": "SAMEORIGIN",
|
"X-XSS-Protection": "1; mode=block",
|
||||||
"X-XSS-Protection": "1; mode=block",
|
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
||||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
})
|
}))
|
||||||
_, err = New(src)
|
_, err = New(src)
|
||||||
if !assert.NoError(t, err) {
|
if !assert.NoError(t, err) {
|
||||||
return
|
return
|
||||||
|
|
|
@ -28,7 +28,7 @@ import (
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
key []byte
|
key []byte
|
||||||
options *config.Options
|
cfg *config.Config
|
||||||
policies map[uint64]config.Policy
|
policies map[uint64]config.Policy
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -83,7 +83,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler {
|
||||||
}
|
}
|
||||||
|
|
||||||
h.mu.RLock()
|
h.mu.RLock()
|
||||||
options := h.options
|
options := h.cfg.Options
|
||||||
policy, ok := h.policies[policyID]
|
policy, ok := h.policies[policyID]
|
||||||
h.mu.RUnlock()
|
h.mu.RUnlock()
|
||||||
|
|
||||||
|
@ -132,7 +132,7 @@ func (h *Handler) Update(ctx context.Context, cfg *config.Config) {
|
||||||
defer h.mu.Unlock()
|
defer h.mu.Unlock()
|
||||||
|
|
||||||
h.key, _ = cfg.Options.GetSharedKey()
|
h.key, _ = cfg.Options.GetSharedKey()
|
||||||
h.options = cfg.Options
|
h.cfg = cfg
|
||||||
h.policies = make(map[uint64]config.Policy)
|
h.policies = make(map[uint64]config.Policy)
|
||||||
for _, p := range cfg.Options.GetAllPolicies() {
|
for _, p := range cfg.Options.GetAllPolicies() {
|
||||||
id, err := p.RouteID()
|
id, err := p.RouteID()
|
||||||
|
|
|
@ -49,15 +49,13 @@ func TestMiddleware(t *testing.T) {
|
||||||
srv2 := httptest.NewServer(h.Middleware(next))
|
srv2 := httptest.NewServer(h.Middleware(next))
|
||||||
defer srv2.Close()
|
defer srv2.Close()
|
||||||
|
|
||||||
cfg := &config.Config{
|
cfg := config.New(&config.Options{
|
||||||
Options: &config.Options{
|
SharedKey: cryptutil.NewBase64Key(),
|
||||||
SharedKey: cryptutil.NewBase64Key(),
|
Policies: []config.Policy{{
|
||||||
Policies: []config.Policy{{
|
To: config.WeightedURLs{{URL: *u}},
|
||||||
To: config.WeightedURLs{{URL: *u}},
|
KubernetesServiceAccountToken: "ABCD",
|
||||||
KubernetesServiceAccountToken: "ABCD",
|
}},
|
||||||
}},
|
})
|
||||||
},
|
|
||||||
}
|
|
||||||
h.Update(context.Background(), cfg)
|
h.Update(context.Background(), cfg)
|
||||||
|
|
||||||
policyID, _ := cfg.Options.Policies[0].RouteID()
|
policyID, _ := cfg.Options.Policies[0].RouteID()
|
||||||
|
|
|
@ -80,11 +80,11 @@ func TestProxy_ForwardAuth(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p, err := New(&config.Config{Options: tt.options})
|
p, err := New(config.New(tt.options))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
p.OnConfigChange(ctx, &config.Config{Options: tt.options})
|
p.OnConfigChange(ctx, config.New(tt.options))
|
||||||
state := p.state.Load()
|
state := p.state.Load()
|
||||||
state.sessionStore = tt.sessionStore
|
state.sessionStore = tt.sessionStore
|
||||||
signer, err := jws.NewHS256Signer(nil)
|
signer, err := jws.NewHS256Signer(nil)
|
||||||
|
|
|
@ -58,7 +58,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) error {
|
||||||
state := p.state.Load()
|
state := p.state.Load()
|
||||||
|
|
||||||
var redirectURL *url.URL
|
var redirectURL *url.URL
|
||||||
signOutURL, err := p.currentOptions.Load().GetSignOutRedirectURL()
|
signOutURL, err := p.currentConfig.Load().Options.GetSignOutRedirectURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return httputil.NewError(http.StatusInternalServerError, err)
|
return httputil.NewError(http.StatusInternalServerError, err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -47,7 +47,7 @@ func TestProxy_Signout(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
proxy, err := New(&config.Config{Options: opts})
|
proxy, err := New(config.New(opts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -70,7 +70,7 @@ func TestProxy_userInfo(t *testing.T) {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
proxy, err := New(&config.Config{Options: opts})
|
proxy, err := New(config.New(opts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -102,7 +102,7 @@ func TestProxy_SignOut(t *testing.T) {
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
opts := testOptions(t)
|
opts := testOptions(t)
|
||||||
p, err := New(&config.Config{Options: opts})
|
p, err := New(config.New(opts))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -236,11 +236,11 @@ func TestProxy_Callback(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p, err := New(&config.Config{Options: tt.options})
|
p, err := New(config.New(tt.options))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
|
p.OnConfigChange(context.Background(), config.New(tt.options))
|
||||||
state := p.state.Load()
|
state := p.state.Load()
|
||||||
state.encoder = tt.cipher
|
state.encoder = tt.cipher
|
||||||
state.sessionStore = tt.sessionStore
|
state.sessionStore = tt.sessionStore
|
||||||
|
@ -350,7 +350,7 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p, err := New(&config.Config{Options: tt.options})
|
p, err := New(config.New(tt.options))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -482,11 +482,11 @@ func TestProxy_ProgrammaticCallback(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p, err := New(&config.Config{Options: tt.options})
|
p, err := New(config.New(tt.options))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
p.OnConfigChange(context.Background(), &config.Config{Options: tt.options})
|
p.OnConfigChange(context.Background(), config.New(tt.options))
|
||||||
state := p.state.Load()
|
state := p.state.Load()
|
||||||
state.encoder = tt.cipher
|
state.encoder = tt.cipher
|
||||||
state.sessionStore = tt.sessionStore
|
state.sessionStore = tt.sessionStore
|
||||||
|
|
|
@ -51,9 +51,9 @@ func ValidateOptions(o *config.Options) error {
|
||||||
|
|
||||||
// Proxy stores all the information associated with proxying a request.
|
// Proxy stores all the information associated with proxying a request.
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
state *atomicutil.Value[*proxyState]
|
state *atomicutil.Value[*proxyState]
|
||||||
currentOptions *atomicutil.Value[*config.Options]
|
currentConfig *atomicutil.Value[*config.Config]
|
||||||
currentRouter *atomicutil.Value[*mux.Router]
|
currentRouter *atomicutil.Value[*mux.Router]
|
||||||
}
|
}
|
||||||
|
|
||||||
// New takes a Proxy service from options and a validation function.
|
// New takes a Proxy service from options and a validation function.
|
||||||
|
@ -65,13 +65,13 @@ func New(cfg *config.Config) (*Proxy, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
p := &Proxy{
|
p := &Proxy{
|
||||||
state: atomicutil.NewValue(state),
|
state: atomicutil.NewValue(state),
|
||||||
currentOptions: config.NewAtomicOptions(),
|
currentConfig: atomicutil.NewValue(cfg),
|
||||||
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
|
currentRouter: atomicutil.NewValue(httputil.NewRouter()),
|
||||||
}
|
}
|
||||||
|
|
||||||
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
|
metrics.AddPolicyCountCallback("pomerium-proxy", func() int64 {
|
||||||
return int64(len(p.currentOptions.Load().GetAllPolicies()))
|
return int64(len(p.currentConfig.Load().Options.GetAllPolicies()))
|
||||||
})
|
})
|
||||||
|
|
||||||
return p, nil
|
return p, nil
|
||||||
|
@ -88,8 +88,8 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
p.currentOptions.Store(cfg.Options)
|
p.currentConfig.Store(cfg)
|
||||||
if err := p.setHandlers(cfg.Options); err != nil {
|
if err := p.setHandlers(cfg); err != nil {
|
||||||
log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
|
log.Error(context.TODO()).Err(err).Msg("proxy: failed to update proxy handlers from configuration settings")
|
||||||
}
|
}
|
||||||
if state, err := newProxyStateFromConfig(cfg); err != nil {
|
if state, err := newProxyStateFromConfig(cfg); err != nil {
|
||||||
|
@ -99,8 +99,8 @@ func (p *Proxy) OnConfigChange(ctx context.Context, cfg *config.Config) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) setHandlers(opts *config.Options) error {
|
func (p *Proxy) setHandlers(cfg *config.Config) error {
|
||||||
if len(opts.GetAllPolicies()) == 0 {
|
if len(cfg.Options.GetAllPolicies()) == 0 {
|
||||||
log.Warn(context.TODO()).Msg("proxy: configuration has no policies")
|
log.Warn(context.TODO()).Msg("proxy: configuration has no policies")
|
||||||
}
|
}
|
||||||
r := httputil.NewRouter()
|
r := httputil.NewRouter()
|
||||||
|
@ -113,7 +113,7 @@ func (p *Proxy) setHandlers(opts *config.Options) error {
|
||||||
// dashboard handlers are registered to all routes
|
// dashboard handlers are registered to all routes
|
||||||
r = p.registerDashboardHandlers(r)
|
r = p.registerDashboardHandlers(r)
|
||||||
|
|
||||||
forwardAuthURL, err := opts.GetForwardAuthURL()
|
forwardAuthURL, err := cfg.Options.GetForwardAuthURL()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -99,7 +99,7 @@ func TestNew(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
got, err := New(&config.Config{Options: tt.opts})
|
got, err := New(config.New(tt.opts))
|
||||||
if (err != nil) != tt.wantErr {
|
if (err != nil) != tt.wantErr {
|
||||||
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
t.Errorf("New() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
return
|
return
|
||||||
|
@ -194,12 +194,12 @@ func Test_UpdateOptions(t *testing.T) {
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
p, err := New(&config.Config{Options: tt.originalOptions})
|
p, err := New(config.New(tt.originalOptions))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
p.OnConfigChange(context.Background(), &config.Config{Options: tt.updatedOptions})
|
p.OnConfigChange(context.Background(), config.New(tt.updatedOptions))
|
||||||
r := httptest.NewRequest("GET", tt.host, nil)
|
r := httptest.NewRequest("GET", tt.host, nil)
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
p.ServeHTTP(w, r)
|
p.ServeHTTP(w, r)
|
||||||
|
@ -212,5 +212,5 @@ func Test_UpdateOptions(t *testing.T) {
|
||||||
|
|
||||||
// Test nil
|
// Test nil
|
||||||
var p *Proxy
|
var p *Proxy
|
||||||
p.OnConfigChange(context.Background(), &config.Config{})
|
p.OnConfigChange(context.Background(), config.New(nil))
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue