config: return errors on invalid URLs, fix linting (#1829)

This commit is contained in:
Caleb Doxsey 2021-01-27 07:58:30 -07:00 committed by GitHub
parent a8a703218f
commit bec98051ae
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 256 additions and 149 deletions

View file

@ -71,7 +71,6 @@ linters:
- stylecheck - stylecheck
- typecheck - typecheck
- unconvert - unconvert
- unparam
- unused - unused
- varcheck - varcheck
# - asciicheck # - asciicheck

View file

@ -49,7 +49,7 @@ func (a *Authorize) okResponse(reply *evaluator.Result) *envoy_service_auth_v2.C
func (a *Authorize) deniedResponse( func (a *Authorize) deniedResponse(
in *envoy_service_auth_v2.CheckRequest, in *envoy_service_auth_v2.CheckRequest,
code int32, reason string, headers map[string]string, code int32, reason string, headers map[string]string,
) *envoy_service_auth_v2.CheckResponse { ) (*envoy_service_auth_v2.CheckResponse, error) {
returnHTMLError := true returnHTMLError := true
inHeaders := in.GetAttributes().GetRequest().GetHttp().GetHeaders() inHeaders := in.GetAttributes().GetRequest().GetHttp().GetHeaders()
if inHeaders != nil { if inHeaders != nil {
@ -59,15 +59,19 @@ func (a *Authorize) deniedResponse(
if returnHTMLError { if returnHTMLError {
return a.htmlDeniedResponse(in, code, reason, headers) return a.htmlDeniedResponse(in, code, reason, headers)
} }
return a.plainTextDeniedResponse(code, reason, headers) return a.plainTextDeniedResponse(code, reason, headers), nil
} }
func (a *Authorize) htmlDeniedResponse( func (a *Authorize) htmlDeniedResponse(
in *envoy_service_auth_v2.CheckRequest, in *envoy_service_auth_v2.CheckRequest,
code int32, reason string, headers map[string]string, code int32, reason string, headers map[string]string,
) *envoy_service_auth_v2.CheckResponse { ) (*envoy_service_auth_v2.CheckResponse, error) {
opts := a.currentOptions.Load() opts := a.currentOptions.Load()
debugEndpoint := opts.GetAuthenticateURL().ResolveReference(&url.URL{Path: "/.pomerium/"}) authenticateURL, err := opts.GetAuthenticateURL()
if err != nil {
return nil, err
}
debugEndpoint := authenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/"})
// create go-style http request // create go-style http request
r := getHTTPRequestFromCheckRequest(in) r := getHTTPRequestFromCheckRequest(in)
@ -97,7 +101,7 @@ func (a *Authorize) htmlDeniedResponse(
} }
var buf bytes.Buffer var buf bytes.Buffer
err := a.templates.ExecuteTemplate(&buf, "error.html", map[string]interface{}{ err = a.templates.ExecuteTemplate(&buf, "error.html", map[string]interface{}{
"Status": code, "Status": code,
"StatusText": reason, "StatusText": reason,
"CanDebug": code/100 == 4, "CanDebug": code/100 == 4,
@ -127,7 +131,7 @@ func (a *Authorize) htmlDeniedResponse(
Body: buf.String(), Body: buf.String(),
}, },
}, },
} }, nil
} }
func (a *Authorize) plainTextDeniedResponse(code int32, reason string, headers map[string]string) *envoy_service_auth_v2.CheckResponse { func (a *Authorize) plainTextDeniedResponse(code int32, reason string, headers map[string]string) *envoy_service_auth_v2.CheckResponse {
@ -152,10 +156,16 @@ func (a *Authorize) plainTextDeniedResponse(code int32, reason string, headers m
} }
} }
func (a *Authorize) redirectResponse(in *envoy_service_auth_v2.CheckRequest) *envoy_service_auth_v2.CheckResponse { func (a *Authorize) redirectResponse(in *envoy_service_auth_v2.CheckRequest) (*envoy_service_auth_v2.CheckResponse, error) {
opts := a.currentOptions.Load() opts := a.currentOptions.Load()
authenticateURL, err := opts.GetAuthenticateURL()
if err != nil {
return nil, err
}
signinURL := opts.GetAuthenticateURL().ResolveReference(&url.URL{Path: "/.pomerium/sign_in"}) signinURL := authenticateURL.ResolveReference(&url.URL{
Path: "/.pomerium/sign_in",
})
q := signinURL.Query() q := signinURL.Query()
// always assume https scheme // always assume https scheme

View file

@ -280,7 +280,8 @@ func TestAuthorize_deniedResponse(t *testing.T) {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
got := a.deniedResponse(tc.in, tc.code, tc.reason, tc.headers) got, err := a.deniedResponse(tc.in, tc.code, tc.reason, tc.headers)
require.NoError(t, err)
assert.Equal(t, tc.want.Status.Code, got.Status.Code) assert.Equal(t, tc.want.Status.Code, got.Status.Code)
assert.Equal(t, tc.want.Status.Message, got.Status.Message) assert.Equal(t, tc.want.Status.Message, got.Status.Message)
assert.Equal(t, tc.want.GetDeniedResponse().GetHeaders(), got.GetDeniedResponse().GetHeaders()) assert.Equal(t, tc.want.GetDeniedResponse().GetHeaders(), got.GetDeniedResponse().GetHeaders())

View file

@ -78,11 +78,11 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe
return a.okResponse(reply), nil return a.okResponse(reply), nil
case reply.Status == http.StatusUnauthorized: case reply.Status == http.StatusUnauthorized:
if isForwardAuth && hreq.URL.Path == "/verify" { if isForwardAuth && hreq.URL.Path == "/verify" {
return a.deniedResponse(in, http.StatusUnauthorized, "Unauthenticated", nil), nil return a.deniedResponse(in, http.StatusUnauthorized, "Unauthenticated", nil)
} }
return a.redirectResponse(in), nil return a.redirectResponse(in)
} }
return a.deniedResponse(in, int32(reply.Status), reply.Message, nil), nil return a.deniedResponse(in, int32(reply.Status), reply.Message, nil)
} }
func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) error { func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) error {
@ -212,9 +212,14 @@ func (a *Authorize) isForwardAuth(req *envoy_service_auth_v2.CheckRequest) bool
return false return false
} }
forwardAuthURL, err := opts.GetForwardAuthURL()
if err != nil {
return false
}
checkURL := getCheckRequestURL(req) checkURL := getCheckRequestURL(req)
return urlutil.StripPort(checkURL.Host) == urlutil.StripPort(opts.GetForwardAuthURL().Host) return urlutil.StripPort(checkURL.Host) == urlutil.StripPort(forwardAuthURL.Host)
} }
func (a *Authorize) getEvaluatorRequestFromCheckRequest(in *envoy_service_auth_v2.CheckRequest, sessionState *sessions.State) *evaluator.Request { func (a *Authorize) getEvaluatorRequestFromCheckRequest(in *envoy_service_auth_v2.CheckRequest, sessionState *sessions.State) *evaluator.Request {

View file

@ -710,45 +710,46 @@ func (o *Options) Validate() error {
} }
// GetAuthenticateURL returns the AuthenticateURL in the options or 127.0.0.1. // GetAuthenticateURL returns the AuthenticateURL in the options or 127.0.0.1.
func (o *Options) GetAuthenticateURL() *url.URL { func (o *Options) GetAuthenticateURL() (*url.URL, error) {
if o != nil && o.AuthenticateURL != nil { if o != nil && o.AuthenticateURL != nil {
return o.AuthenticateURL return o.AuthenticateURL, nil
} }
u, _ := url.Parse("https://127.0.0.1") return url.Parse("https://127.0.0.1")
return u
} }
// GetAuthorizeURL returns the AuthorizeURL in the options or 127.0.0.1:5443. // GetAuthorizeURL returns the AuthorizeURL in the options or 127.0.0.1:5443.
func (o *Options) GetAuthorizeURL() *url.URL { func (o *Options) GetAuthorizeURL() (*url.URL, error) {
if o != nil && o.AuthorizeURL != nil { if o != nil && o.AuthorizeURL != nil {
return o.AuthorizeURL return o.AuthorizeURL, nil
} }
u, _ := url.Parse("http://127.0.0.1" + DefaultAlternativeAddr) return url.Parse("http://127.0.0.1" + DefaultAlternativeAddr)
return u
} }
// GetDataBrokerURL returns the DataBrokerURL in the options or 127.0.0.1:5443. // GetDataBrokerURL returns the DataBrokerURL in the options or 127.0.0.1:5443.
func (o *Options) GetDataBrokerURL() *url.URL { func (o *Options) GetDataBrokerURL() (*url.URL, error) {
if o != nil && o.DataBrokerURL != nil { if o != nil && o.DataBrokerURL != nil {
return o.DataBrokerURL return o.DataBrokerURL, nil
} }
u, _ := url.Parse("http://127.0.0.1" + DefaultAlternativeAddr) return url.Parse("http://127.0.0.1" + DefaultAlternativeAddr)
return u
} }
// GetForwardAuthURL returns the ForwardAuthURL in the options or 127.0.0.1. // GetForwardAuthURL returns the ForwardAuthURL in the options or 127.0.0.1.
func (o *Options) GetForwardAuthURL() *url.URL { func (o *Options) GetForwardAuthURL() (*url.URL, error) {
if o != nil && o.ForwardAuthURL != nil { if o != nil && o.ForwardAuthURL != nil {
return o.ForwardAuthURL return o.ForwardAuthURL, nil
} }
u, _ := url.Parse("https://127.0.0.1") return url.Parse("https://127.0.0.1")
return u
} }
// GetOauthOptions gets the oauth.Options for the given config options. // GetOauthOptions gets the oauth.Options for the given config options.
func (o *Options) GetOauthOptions() oauth.Options { func (o *Options) GetOauthOptions() (oauth.Options, error) {
redirectURL := o.GetAuthenticateURL() redirectURL, err := o.GetAuthenticateURL()
redirectURL.Path = o.AuthenticateCallbackPath if err != nil {
return oauth.Options{}, err
}
redirectURL = redirectURL.ResolveReference(&url.URL{
Path: o.AuthenticateCallbackPath,
})
return oauth.Options{ return oauth.Options{
RedirectURL: redirectURL, RedirectURL: redirectURL,
ProviderName: o.Provider, ProviderName: o.Provider,
@ -757,7 +758,7 @@ func (o *Options) GetOauthOptions() oauth.Options {
ClientSecret: o.ClientSecret, ClientSecret: o.ClientSecret,
Scopes: o.Scopes, Scopes: o.Scopes,
ServiceAccount: o.ServiceAccount, ServiceAccount: o.ServiceAccount,
} }, nil
} }
// GetAllPolicies gets all the policies in the options. // GetAllPolicies gets all the policies in the options.

View file

@ -14,6 +14,7 @@ import (
"github.com/google/go-cmp/cmp/cmpopts" "github.com/google/go-cmp/cmp/cmpopts"
"github.com/spf13/viper" "github.com/spf13/viper"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
var cmpOptIgnoreUnexported = cmpopts.IgnoreUnexported(Options{}) var cmpOptIgnoreUnexported = cmpopts.IgnoreUnexported(Options{})
@ -498,7 +499,7 @@ func TestOptions_DefaultURL(t *testing.T) {
} }
tests := []struct { tests := []struct {
name string name string
f func() *url.URL f func() (*url.URL, error)
expectedURLStr string expectedURLStr string
}{ }{
{"default authenticate url", defaultOptions.GetAuthenticateURL, "https://127.0.0.1"}, {"default authenticate url", defaultOptions.GetAuthenticateURL, "https://127.0.0.1"},
@ -515,7 +516,9 @@ func TestOptions_DefaultURL(t *testing.T) {
tc := tc tc := tc
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
t.Parallel() t.Parallel()
assert.Equal(t, tc.expectedURLStr, tc.f().String()) u, err := tc.f()
require.NoError(t, err)
assert.Equal(t, tc.expectedURLStr, u.String())
}) })
} }
} }
@ -530,7 +533,9 @@ func mustParseURL(str string) *url.URL {
func TestOptions_GetOauthOptions(t *testing.T) { func TestOptions_GetOauthOptions(t *testing.T) {
opts := &Options{AuthenticateURL: mustParseURL("https://authenticate.example.com")} opts := &Options{AuthenticateURL: mustParseURL("https://authenticate.example.com")}
oauthOptions, err := opts.GetOauthOptions()
require.NoError(t, err)
// Test that oauth redirect url hostname must point to authenticate url hostname. // Test that oauth redirect url hostname must point to authenticate url hostname.
assert.Equal(t, opts.AuthenticateURL.Hostname(), opts.GetOauthOptions().RedirectURL.Hostname()) assert.Equal(t, opts.AuthenticateURL.Hostname(), oauthOptions.RedirectURL.Hostname())
} }

View file

@ -81,13 +81,17 @@ func New(cfg *config.Config) (*DataBroker, error) {
} }
dataBrokerServer := newDataBrokerServer(cfg) dataBrokerServer := newDataBrokerServer(cfg)
dataBrokerURL, err := cfg.Options.GetDataBrokerURL()
if err != nil {
return nil, err
}
c := &DataBroker{ c := &DataBroker{
dataBrokerServer: dataBrokerServer, dataBrokerServer: dataBrokerServer,
localListener: localListener, localListener: localListener,
localGRPCServer: localGRPCServer, localGRPCServer: localGRPCServer,
localGRPCConnection: localGRPCConnection, localGRPCConnection: localGRPCConnection,
deprecatedCacheClusterDomain: cfg.Options.GetDataBrokerURL().Hostname(), deprecatedCacheClusterDomain: dataBrokerURL.Hostname(),
dataBrokerStorageType: cfg.Options.DataBrokerStorageType, dataBrokerStorageType: cfg.Options.DataBrokerStorageType,
} }
c.Register(c.localGRPCServer) c.Register(c.localGRPCServer)
@ -138,7 +142,12 @@ func (c *DataBroker) update(cfg *config.Config) error {
return fmt.Errorf("databroker: bad option: %w", err) return fmt.Errorf("databroker: bad option: %w", err)
} }
authenticator, err := identity.NewAuthenticator(cfg.Options.GetOauthOptions()) oauthOptions, err := cfg.Options.GetOauthOptions()
if err != nil {
return fmt.Errorf("databroker: invalid oauth options: %w", err)
}
authenticator, err := identity.NewAuthenticator(oauthOptions)
if err != nil { if err != nil {
return fmt.Errorf("databroker: failed to create authenticator: %w", err) return fmt.Errorf("databroker: failed to create authenticator: %w", err)
} }

View file

@ -158,9 +158,15 @@ func setupAuthenticate(src config.Source, controlPlane *controlplane.Server) err
if err != nil { if err != nil {
return fmt.Errorf("error creating authenticate service: %w", err) return fmt.Errorf("error creating authenticate service: %w", err)
} }
authenticateURL, err := src.GetConfig().Options.GetAuthenticateURL()
if err != nil {
return fmt.Errorf("error getting authenticate URL: %w", err)
}
src.OnConfigChange(svc.OnConfigChange) src.OnConfigChange(svc.OnConfigChange)
svc.OnConfigChange(src.GetConfig()) svc.OnConfigChange(src.GetConfig())
host := urlutil.StripPort(src.GetConfig().Options.GetAuthenticateURL().Host) host := urlutil.StripPort(authenticateURL.Host)
sr := controlPlane.HTTPRouter.Host(host).Subrouter() sr := controlPlane.HTTPRouter.Host(host).Subrouter()
svc.Mount(sr) svc.Mount(sr)
log.Info().Str("host", host).Msg("enabled authenticate service") log.Info().Str("host", host).Msg("enabled authenticate service")

View file

@ -50,9 +50,9 @@ func (srv *Server) buildClusters(options *config.Options) ([]*envoy_config_clust
Scheme: "http", Scheme: "http",
Host: srv.HTTPListener.Addr().String(), Host: srv.HTTPListener.Addr().String(),
} }
authzURL := &url.URL{ authzURL, err := options.GetAuthorizeURL()
Scheme: options.GetAuthorizeURL().Scheme, if err != nil {
Host: options.GetAuthorizeURL().Host, return nil, err
} }
controlGRPC, err := srv.buildInternalCluster(options, "pomerium-control-plane-grpc", grpcURL, true) controlGRPC, err := srv.buildInternalCluster(options, "pomerium-control-plane-grpc", grpcURL, true)
@ -132,22 +132,22 @@ func (srv *Server) buildPolicyCluster(options *config.Options, policy *config.Po
func (srv *Server) buildInternalEndpoints(options *config.Options, dst *url.URL) ([]Endpoint, error) { func (srv *Server) buildInternalEndpoints(options *config.Options, dst *url.URL) ([]Endpoint, error) {
var endpoints []Endpoint var endpoints []Endpoint
if ts, err := srv.buildInternalTransportSocket(options, dst); err != nil { ts, err := srv.buildInternalTransportSocket(options, dst)
if err != nil {
return nil, err return nil, err
} else {
endpoints = append(endpoints, NewEndpoint(dst, ts))
} }
endpoints = append(endpoints, NewEndpoint(dst, ts))
return endpoints, nil return endpoints, nil
} }
func (srv *Server) buildPolicyEndpoints(policy *config.Policy) ([]Endpoint, error) { func (srv *Server) buildPolicyEndpoints(policy *config.Policy) ([]Endpoint, error) {
var endpoints []Endpoint var endpoints []Endpoint
for _, dst := range policy.Destinations { for _, dst := range policy.Destinations {
if ts, err := srv.buildPolicyTransportSocket(policy, dst); err != nil { ts, err := srv.buildPolicyTransportSocket(policy, dst)
if err != nil {
return nil, err return nil, err
} else {
endpoints = append(endpoints, NewEndpoint(dst, ts))
} }
endpoints = append(endpoints, NewEndpoint(dst, ts))
} }
return endpoints, nil return endpoints, nil
} }
@ -246,7 +246,9 @@ func (srv *Server) buildPolicyTransportSocket(policy *config.Policy, dst *url.UR
}, nil }, nil
} }
func (srv *Server) buildPolicyValidationContext(policy *config.Policy, dst *url.URL) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) { func (srv *Server) buildPolicyValidationContext(
policy *config.Policy, dst *url.URL,
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
if dst == nil { if dst == nil {
return nil, nil return nil, nil
} }

View file

@ -43,19 +43,19 @@ func (srv *Server) buildListeners(cfg *config.Config) ([]*envoy_config_listener_
var listeners []*envoy_config_listener_v3.Listener var listeners []*envoy_config_listener_v3.Listener
if config.IsAuthenticate(cfg.Options.Services) || config.IsProxy(cfg.Options.Services) { if config.IsAuthenticate(cfg.Options.Services) || config.IsProxy(cfg.Options.Services) {
if li, err := srv.buildMainListener(cfg); err != nil { li, err := srv.buildMainListener(cfg)
if err != nil {
return nil, err return nil, err
} else {
listeners = append(listeners, li)
} }
listeners = append(listeners, li)
} }
if config.IsAuthorize(cfg.Options.Services) || config.IsDataBroker(cfg.Options.Services) { if config.IsAuthorize(cfg.Options.Services) || config.IsDataBroker(cfg.Options.Services) {
if li, err := srv.buildGRPCListener(cfg); err != nil { li, err := srv.buildGRPCListener(cfg)
if err != nil {
return nil, err return nil, err
} else {
listeners = append(listeners, li)
} }
listeners = append(listeners, li)
} }
return listeners, nil return listeners, nil
@ -74,8 +74,12 @@ func (srv *Server) buildMainListener(cfg *config.Config) (*envoy_config_listener
} }
if cfg.Options.InsecureServer { if cfg.Options.InsecureServer {
filter, err := srv.buildMainHTTPConnectionManagerFilter(cfg.Options, allDomains, err := getAllRouteableDomains(cfg.Options, cfg.Options.Addr)
getAllRouteableDomains(cfg.Options, cfg.Options.Addr), "") if err != nil {
return nil, err
}
filter, err := srv.buildMainHTTPConnectionManagerFilter(cfg.Options, allDomains, "")
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -143,24 +147,37 @@ func (srv *Server) buildFilterChains(
options *config.Options, addr string, options *config.Options, addr string,
callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error), callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error),
) ([]*envoy_config_listener_v3.FilterChain, error) { ) ([]*envoy_config_listener_v3.FilterChain, error) {
allDomains := getAllRouteableDomains(options, addr) allDomains, err := getAllRouteableDomains(options, addr)
tlsDomains := getAllTLSDomains(options, addr) if err != nil {
return nil, err
}
tlsDomains, err := getAllTLSDomains(options, addr)
if err != nil {
return nil, err
}
var chains []*envoy_config_listener_v3.FilterChain var chains []*envoy_config_listener_v3.FilterChain
for _, domain := range tlsDomains { for _, domain := range tlsDomains {
// first we match on SNI routeableDomains, err := getRouteableDomainsForTLSDomain(options, addr, domain)
if chain, err := callback(domain, getRouteableDomainsForTLSDomain(options, addr, domain)); err != nil { if err != nil {
return nil, err return nil, err
} else {
chains = append(chains, chain)
} }
// first we match on SNI
chain, err := callback(domain, routeableDomains)
if err != nil {
return nil, err
}
chains = append(chains, chain)
} }
// if there are no SNI matches we match on HTTP host // if there are no SNI matches we match on HTTP host
if chain, err := callback("*", allDomains); err != nil { chain, err := callback("*", allDomains)
if err != nil {
return nil, err return nil, err
} else {
chains = append(chains, chain)
} }
chains = append(chains, chain)
return chains, nil return chains, nil
} }
@ -169,6 +186,16 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter(
domains []string, domains []string,
tlsDomain string, tlsDomain string,
) (*envoy_config_listener_v3.Filter, error) { ) (*envoy_config_listener_v3.Filter, error) {
authorizeURL, err := options.GetAuthorizeURL()
if err != nil {
return nil, err
}
dataBrokerURL, err := options.GetDataBrokerURL()
if err != nil {
return nil, err
}
var virtualHosts []*envoy_config_route_v3.VirtualHost var virtualHosts []*envoy_config_route_v3.VirtualHost
for _, domain := range domains { for _, domain := range domains {
vh := &envoy_config_route_v3.VirtualHost{ vh := &envoy_config_route_v3.VirtualHost{
@ -178,30 +205,30 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter(
if options.Addr == options.GRPCAddr { if options.Addr == options.GRPCAddr {
// if this is a gRPC service domain and we're supposed to handle that, add those routes // if this is a gRPC service domain and we're supposed to handle that, add those routes
if (config.IsAuthorize(options.Services) && hostMatchesDomain(options.GetAuthorizeURL(), domain)) || if (config.IsAuthorize(options.Services) && hostMatchesDomain(authorizeURL, domain)) ||
(config.IsDataBroker(options.Services) && hostMatchesDomain(options.GetDataBrokerURL(), domain)) { (config.IsDataBroker(options.Services) && hostMatchesDomain(dataBrokerURL, domain)) {
if rs, err := srv.buildGRPCRoutes(); err != nil { rs, err := srv.buildGRPCRoutes()
if err != nil {
return nil, err return nil, err
} else {
vh.Routes = append(vh.Routes, rs...)
} }
vh.Routes = append(vh.Routes, rs...)
} }
} }
// these routes match /.pomerium/... and similar paths // these routes match /.pomerium/... and similar paths
if rs, err := srv.buildPomeriumHTTPRoutes(options, domain); err != nil { rs, err := srv.buildPomeriumHTTPRoutes(options, domain)
if err != nil {
return nil, err return nil, err
} else {
vh.Routes = append(vh.Routes, rs...)
} }
vh.Routes = append(vh.Routes, rs...)
// if we're the proxy, add all the policy routes // if we're the proxy, add all the policy routes
if config.IsProxy(options.Services) { if config.IsProxy(options.Services) {
if rs, err := srv.buildPolicyRoutes(options, domain); err != nil { rs, err := srv.buildPolicyRoutes(options, domain)
if err != nil {
return nil, err return nil, err
} else {
vh.Routes = append(vh.Routes, rs...)
} }
vh.Routes = append(vh.Routes, rs...)
} }
if len(vh.Routes) > 0 { if len(vh.Routes) > 0 {
@ -209,15 +236,15 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter(
} }
} }
if rs, err := srv.buildPomeriumHTTPRoutes(options, "*"); err != nil { rs, err := srv.buildPomeriumHTTPRoutes(options, "*")
if err != nil {
return nil, err return nil, err
} else { }
virtualHosts = append(virtualHosts, &envoy_config_route_v3.VirtualHost{ virtualHosts = append(virtualHosts, &envoy_config_route_v3.VirtualHost{
Name: "catch-all", Name: "catch-all",
Domains: []string{"*"}, Domains: []string{"*"},
Routes: rs, Routes: rs,
}) })
}
var grpcClientTimeout *durationpb.Duration var grpcClientTimeout *durationpb.Duration
if options.GRPCClientTimeout != 0 { if options.GRPCClientTimeout != 0 {
@ -235,7 +262,7 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter(
Timeout: grpcClientTimeout, Timeout: grpcClientTimeout,
TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{ TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{
EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{ EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{
ClusterName: options.GetAuthorizeURL().Host, ClusterName: authorizeURL.Host,
}, },
}, },
}, },
@ -501,31 +528,55 @@ func (srv *Server) buildDownstreamTLSContext(cfg *config.Config, domain string)
} }
} }
func getRouteableDomainsForTLSDomain(options *config.Options, addr string, tlsDomain string) []string { func getRouteableDomainsForTLSDomain(options *config.Options, addr string, tlsDomain string) ([]string, error) {
allDomains := getAllRouteableDomains(options, addr) allDomains, err := getAllRouteableDomains(options, addr)
if err != nil {
return nil, err
}
var filtered []string var filtered []string
for _, domain := range allDomains { for _, domain := range allDomains {
if urlutil.StripPort(domain) == tlsDomain { if urlutil.StripPort(domain) == tlsDomain {
filtered = append(filtered, domain) filtered = append(filtered, domain)
} }
} }
return filtered return filtered, nil
} }
func getAllRouteableDomains(options *config.Options, addr string) []string { func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) {
authenticateURL, err := options.GetAuthenticateURL()
if err != nil {
return nil, err
}
authorizeURL, err := options.GetAuthorizeURL()
if err != nil {
return nil, err
}
dataBrokerURL, err := options.GetDataBrokerURL()
if err != nil {
return nil, err
}
forwardAuthURL, err := options.GetForwardAuthURL()
if err != nil {
return nil, err
}
lookup := map[string]struct{}{} lookup := map[string]struct{}{}
if config.IsAuthenticate(options.Services) && addr == options.Addr { if config.IsAuthenticate(options.Services) && addr == options.Addr {
for _, h := range urlutil.GetDomainsForURL(options.GetAuthenticateURL()) { for _, h := range urlutil.GetDomainsForURL(authenticateURL) {
lookup[h] = struct{}{} lookup[h] = struct{}{}
} }
} }
if config.IsAuthorize(options.Services) && addr == options.GRPCAddr { if config.IsAuthorize(options.Services) && addr == options.GRPCAddr {
for _, h := range urlutil.GetDomainsForURL(options.GetAuthorizeURL()) { for _, h := range urlutil.GetDomainsForURL(authorizeURL) {
lookup[h] = struct{}{} lookup[h] = struct{}{}
} }
} }
if config.IsDataBroker(options.Services) && addr == options.GRPCAddr { if config.IsDataBroker(options.Services) && addr == options.GRPCAddr {
for _, h := range urlutil.GetDomainsForURL(options.GetDataBrokerURL()) { for _, h := range urlutil.GetDomainsForURL(dataBrokerURL) {
lookup[h] = struct{}{} lookup[h] = struct{}{}
} }
} }
@ -536,7 +587,7 @@ func getAllRouteableDomains(options *config.Options, addr string) []string {
} }
} }
if options.ForwardAuthURL != nil { if options.ForwardAuthURL != nil {
for _, h := range urlutil.GetDomainsForURL(options.GetForwardAuthURL()) { for _, h := range urlutil.GetDomainsForURL(forwardAuthURL) {
lookup[h] = struct{}{} lookup[h] = struct{}{}
} }
} }
@ -548,12 +599,17 @@ func getAllRouteableDomains(options *config.Options, addr string) []string {
} }
sort.Strings(domains) sort.Strings(domains)
return domains return domains, nil
} }
func getAllTLSDomains(options *config.Options, addr string) []string { func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {
allDomains, err := getAllRouteableDomains(options, addr)
if err != nil {
return nil, err
}
lookup := map[string]struct{}{} lookup := map[string]struct{}{}
for _, hp := range getAllRouteableDomains(options, addr) { for _, hp := range allDomains {
if d, _, err := net.SplitHostPort(hp); err == nil { if d, _, err := net.SplitHostPort(hp); err == nil {
lookup[d] = struct{}{} lookup[d] = struct{}{}
} else { } else {
@ -567,7 +623,7 @@ func getAllTLSDomains(options *config.Options, addr string) []string {
} }
sort.Strings(domains) sort.Strings(domains)
return domains return domains, nil
} }
func hostMatchesDomain(u *url.URL, host string) bool { func hostMatchesDomain(u *url.URL, host string) bool {

View file

@ -446,7 +446,8 @@ func Test_getAllDomains(t *testing.T) {
} }
t.Run("routable", func(t *testing.T) { t.Run("routable", func(t *testing.T) {
t.Run("http", func(t *testing.T) { t.Run("http", func(t *testing.T) {
actual := getAllRouteableDomains(options, "127.0.0.1:9000") actual, err := getAllRouteableDomains(options, "127.0.0.1:9000")
require.NoError(t, err)
expect := []string{ expect := []string{
"a.example.com", "a.example.com",
"a.example.com:80", "a.example.com:80",
@ -460,7 +461,8 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
}) })
t.Run("grpc", func(t *testing.T) { t.Run("grpc", func(t *testing.T) {
actual := getAllRouteableDomains(options, "127.0.0.1:9001") actual, err := getAllRouteableDomains(options, "127.0.0.1:9001")
require.NoError(t, err)
expect := []string{ expect := []string{
"authorize.example.com:9001", "authorize.example.com:9001",
"cache.example.com:9001", "cache.example.com:9001",
@ -470,7 +472,8 @@ func Test_getAllDomains(t *testing.T) {
}) })
t.Run("tls", func(t *testing.T) { t.Run("tls", func(t *testing.T) {
t.Run("http", func(t *testing.T) { t.Run("http", func(t *testing.T) {
actual := getAllTLSDomains(options, "127.0.0.1:9000") actual, err := getAllTLSDomains(options, "127.0.0.1:9000")
require.NoError(t, err)
expect := []string{ expect := []string{
"a.example.com", "a.example.com",
"authenticate.example.com", "authenticate.example.com",
@ -480,7 +483,8 @@ func Test_getAllDomains(t *testing.T) {
assert.Equal(t, expect, actual) assert.Equal(t, expect, actual)
}) })
t.Run("grpc", func(t *testing.T) { t.Run("grpc", func(t *testing.T) {
actual := getAllTLSDomains(options, "127.0.0.1:9001") actual, err := getAllTLSDomains(options, "127.0.0.1:9001")
require.NoError(t, err)
expect := []string{ expect := []string{
"authorize.example.com", "authorize.example.com",
"cache.example.com", "cache.example.com",

View file

@ -50,96 +50,105 @@ func (srv *Server) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) {
func (srv *Server) buildPomeriumHTTPRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) { func (srv *Server) buildPomeriumHTTPRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route var routes []*envoy_config_route_v3.Route
// enable ext_authz // enable ext_authz
if r, err := srv.buildControlPlanePathRoute("/.pomerium/jwt", true); err != nil { r, err := srv.buildControlPlanePathRoute("/.pomerium/jwt", true)
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
routes = append(routes, r)
// disable ext_authz and passthrough to proxy handlers // disable ext_authz and passthrough to proxy handlers
if r, err := srv.buildControlPlanePathRoute("/ping", false); err != nil { r, err = srv.buildControlPlanePathRoute("/ping", false)
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
if r, err := srv.buildControlPlanePathRoute("/healthz", false); err != nil { routes = append(routes, r)
r, err = srv.buildControlPlanePathRoute("/healthz", false)
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
if r, err := srv.buildControlPlanePathRoute("/.pomerium/admin", true); err != nil { routes = append(routes, r)
r, err = srv.buildControlPlanePathRoute("/.pomerium/admin", true)
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
if r, err := srv.buildControlPlanePrefixRoute("/.pomerium/admin/", true); err != nil { routes = append(routes, r)
r, err = srv.buildControlPlanePrefixRoute("/.pomerium/admin/", true)
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
if r, err := srv.buildControlPlanePathRoute("/.pomerium", false); err != nil { routes = append(routes, r)
r, err = srv.buildControlPlanePathRoute("/.pomerium", false)
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
if r, err := srv.buildControlPlanePrefixRoute("/.pomerium/", false); err != nil { routes = append(routes, r)
r, err = srv.buildControlPlanePrefixRoute("/.pomerium/", false)
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
if r, err := srv.buildControlPlanePathRoute("/.well-known/pomerium", false); err != nil { routes = append(routes, r)
r, err = srv.buildControlPlanePathRoute("/.well-known/pomerium", false)
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
if r, err := srv.buildControlPlanePrefixRoute("/.well-known/pomerium/", false); err != nil { routes = append(routes, r)
r, err = srv.buildControlPlanePrefixRoute("/.well-known/pomerium/", false)
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
routes = append(routes, r)
// per #837, only add robots.txt if there are no unauthenticated routes // per #837, only add robots.txt if there are no unauthenticated routes
if !hasPublicPolicyMatchingURL(options, mustParseURL("https://"+domain+"/robots.txt")) { if !hasPublicPolicyMatchingURL(options, mustParseURL("https://"+domain+"/robots.txt")) {
if r, err := srv.buildControlPlanePathRoute("/robots.txt", false); err != nil { r, err := srv.buildControlPlanePathRoute("/robots.txt", false)
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
routes = append(routes, r)
} }
// if we're handling authentication, add the oauth2 callback url // if we're handling authentication, add the oauth2 callback url
if config.IsAuthenticate(options.Services) && hostMatchesDomain(options.GetAuthenticateURL(), domain) { authenticateURL, err := options.GetAuthenticateURL()
if r, err := srv.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false); err != nil { if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
if config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) {
r, err := srv.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false)
if err != nil {
return nil, err
}
routes = append(routes, r)
} }
// if we're the proxy and this is the forward-auth url // if we're the proxy and this is the forward-auth url
if config.IsProxy(options.Services) && options.ForwardAuthURL != nil && hostMatchesDomain(options.GetForwardAuthURL(), domain) { forwardAuthURL, err := options.GetForwardAuthURL()
if err != nil {
return nil, err
}
if config.IsProxy(options.Services) && options.ForwardAuthURL != nil && hostMatchesDomain(forwardAuthURL, domain) {
// disable ext_authz and pass request to proxy handlers that enable authN flow // disable ext_authz and pass request to proxy handlers that enable authN flow
if r, err := srv.buildControlPlanePathAndQueryRoute("/verify", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI}); err != nil { r, err := srv.buildControlPlanePathAndQueryRoute("/verify", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI})
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
if r, err := srv.buildControlPlanePathAndQueryRoute("/", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI}); err != nil { routes = append(routes, r)
r, err = srv.buildControlPlanePathAndQueryRoute("/", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI})
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
if r, err := srv.buildControlPlanePathAndQueryRoute("/", []string{urlutil.QueryForwardAuthURI}); err != nil { routes = append(routes, r)
r, err = srv.buildControlPlanePathAndQueryRoute("/", []string{urlutil.QueryForwardAuthURI})
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
routes = append(routes, r)
// otherwise, enforce ext_authz; pass all other requests through to an upstream // otherwise, enforce ext_authz; pass all other requests through to an upstream
// handler that will simply respond with http status 200 / OK indicating that // handler that will simply respond with http status 200 / OK indicating that
// the fronting forward-auth proxy can continue. // the fronting forward-auth proxy can continue.
if r, err := srv.buildControlPlaneProtectedPrefixRoute("/"); err != nil { r, err = srv.buildControlPlaneProtectedPrefixRoute("/")
if err != nil {
return nil, err return nil, err
} else {
routes = append(routes, r)
} }
routes = append(routes, r)
} }
return routes, nil return routes, nil
} }