diff --git a/.golangci.yml b/.golangci.yml index eca87506f..3e3a3e0b3 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -71,7 +71,6 @@ linters: - stylecheck - typecheck - unconvert - - unparam - unused - varcheck # - asciicheck diff --git a/authorize/check_response.go b/authorize/check_response.go index 64a38ef92..142279edb 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -49,7 +49,7 @@ func (a *Authorize) okResponse(reply *evaluator.Result) *envoy_service_auth_v2.C func (a *Authorize) deniedResponse( in *envoy_service_auth_v2.CheckRequest, code int32, reason string, headers map[string]string, -) *envoy_service_auth_v2.CheckResponse { +) (*envoy_service_auth_v2.CheckResponse, error) { returnHTMLError := true inHeaders := in.GetAttributes().GetRequest().GetHttp().GetHeaders() if inHeaders != nil { @@ -59,15 +59,19 @@ func (a *Authorize) deniedResponse( if returnHTMLError { return a.htmlDeniedResponse(in, code, reason, headers) } - return a.plainTextDeniedResponse(code, reason, headers) + return a.plainTextDeniedResponse(code, reason, headers), nil } func (a *Authorize) htmlDeniedResponse( in *envoy_service_auth_v2.CheckRequest, code int32, reason string, headers map[string]string, -) *envoy_service_auth_v2.CheckResponse { +) (*envoy_service_auth_v2.CheckResponse, error) { opts := a.currentOptions.Load() - debugEndpoint := opts.GetAuthenticateURL().ResolveReference(&url.URL{Path: "/.pomerium/"}) + authenticateURL, err := opts.GetAuthenticateURL() + if err != nil { + return nil, err + } + debugEndpoint := authenticateURL.ResolveReference(&url.URL{Path: "/.pomerium/"}) // create go-style http request r := getHTTPRequestFromCheckRequest(in) @@ -97,7 +101,7 @@ func (a *Authorize) htmlDeniedResponse( } var buf bytes.Buffer - err := a.templates.ExecuteTemplate(&buf, "error.html", map[string]interface{}{ + err = a.templates.ExecuteTemplate(&buf, "error.html", map[string]interface{}{ "Status": code, "StatusText": reason, "CanDebug": code/100 == 4, @@ -127,7 +131,7 @@ func (a *Authorize) htmlDeniedResponse( Body: buf.String(), }, }, - } + }, nil } func (a *Authorize) plainTextDeniedResponse(code int32, reason string, headers map[string]string) *envoy_service_auth_v2.CheckResponse { @@ -152,10 +156,16 @@ func (a *Authorize) plainTextDeniedResponse(code int32, reason string, headers m } } -func (a *Authorize) redirectResponse(in *envoy_service_auth_v2.CheckRequest) *envoy_service_auth_v2.CheckResponse { +func (a *Authorize) redirectResponse(in *envoy_service_auth_v2.CheckRequest) (*envoy_service_auth_v2.CheckResponse, error) { opts := a.currentOptions.Load() + authenticateURL, err := opts.GetAuthenticateURL() + if err != nil { + return nil, err + } - signinURL := opts.GetAuthenticateURL().ResolveReference(&url.URL{Path: "/.pomerium/sign_in"}) + signinURL := authenticateURL.ResolveReference(&url.URL{ + Path: "/.pomerium/sign_in", + }) q := signinURL.Query() // always assume https scheme diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index be74e741b..541f0f5be 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -280,7 +280,8 @@ func TestAuthorize_deniedResponse(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - got := a.deniedResponse(tc.in, tc.code, tc.reason, tc.headers) + got, err := a.deniedResponse(tc.in, tc.code, tc.reason, tc.headers) + require.NoError(t, err) assert.Equal(t, tc.want.Status.Code, got.Status.Code) assert.Equal(t, tc.want.Status.Message, got.Status.Message) assert.Equal(t, tc.want.GetDeniedResponse().GetHeaders(), got.GetDeniedResponse().GetHeaders()) diff --git a/authorize/grpc.go b/authorize/grpc.go index fcd033ac5..f97642f56 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -78,11 +78,11 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v2.CheckRe return a.okResponse(reply), nil case reply.Status == http.StatusUnauthorized: if isForwardAuth && hreq.URL.Path == "/verify" { - return a.deniedResponse(in, http.StatusUnauthorized, "Unauthenticated", nil), nil + return a.deniedResponse(in, http.StatusUnauthorized, "Unauthenticated", nil) } - return a.redirectResponse(in), nil + return a.redirectResponse(in) } - return a.deniedResponse(in, int32(reply.Status), reply.Message, nil), nil + return a.deniedResponse(in, int32(reply.Status), reply.Message, nil) } func (a *Authorize) forceSync(ctx context.Context, ss *sessions.State) error { @@ -212,9 +212,14 @@ func (a *Authorize) isForwardAuth(req *envoy_service_auth_v2.CheckRequest) bool return false } + forwardAuthURL, err := opts.GetForwardAuthURL() + if err != nil { + return false + } + checkURL := getCheckRequestURL(req) - return urlutil.StripPort(checkURL.Host) == urlutil.StripPort(opts.GetForwardAuthURL().Host) + return urlutil.StripPort(checkURL.Host) == urlutil.StripPort(forwardAuthURL.Host) } func (a *Authorize) getEvaluatorRequestFromCheckRequest(in *envoy_service_auth_v2.CheckRequest, sessionState *sessions.State) *evaluator.Request { diff --git a/config/options.go b/config/options.go index e8e341e91..5bde3c3b7 100644 --- a/config/options.go +++ b/config/options.go @@ -710,45 +710,46 @@ func (o *Options) Validate() error { } // GetAuthenticateURL returns the AuthenticateURL in the options or 127.0.0.1. -func (o *Options) GetAuthenticateURL() *url.URL { +func (o *Options) GetAuthenticateURL() (*url.URL, error) { if o != nil && o.AuthenticateURL != nil { - return o.AuthenticateURL + return o.AuthenticateURL, nil } - u, _ := url.Parse("https://127.0.0.1") - return u + return url.Parse("https://127.0.0.1") } // GetAuthorizeURL returns the AuthorizeURL in the options or 127.0.0.1:5443. -func (o *Options) GetAuthorizeURL() *url.URL { +func (o *Options) GetAuthorizeURL() (*url.URL, error) { if o != nil && o.AuthorizeURL != nil { - return o.AuthorizeURL + return o.AuthorizeURL, nil } - u, _ := url.Parse("http://127.0.0.1" + DefaultAlternativeAddr) - return u + return url.Parse("http://127.0.0.1" + DefaultAlternativeAddr) } // GetDataBrokerURL returns the DataBrokerURL in the options or 127.0.0.1:5443. -func (o *Options) GetDataBrokerURL() *url.URL { +func (o *Options) GetDataBrokerURL() (*url.URL, error) { if o != nil && o.DataBrokerURL != nil { - return o.DataBrokerURL + return o.DataBrokerURL, nil } - u, _ := url.Parse("http://127.0.0.1" + DefaultAlternativeAddr) - return u + return url.Parse("http://127.0.0.1" + DefaultAlternativeAddr) } // GetForwardAuthURL returns the ForwardAuthURL in the options or 127.0.0.1. -func (o *Options) GetForwardAuthURL() *url.URL { +func (o *Options) GetForwardAuthURL() (*url.URL, error) { if o != nil && o.ForwardAuthURL != nil { - return o.ForwardAuthURL + return o.ForwardAuthURL, nil } - u, _ := url.Parse("https://127.0.0.1") - return u + return url.Parse("https://127.0.0.1") } // GetOauthOptions gets the oauth.Options for the given config options. -func (o *Options) GetOauthOptions() oauth.Options { - redirectURL := o.GetAuthenticateURL() - redirectURL.Path = o.AuthenticateCallbackPath +func (o *Options) GetOauthOptions() (oauth.Options, error) { + redirectURL, err := o.GetAuthenticateURL() + if err != nil { + return oauth.Options{}, err + } + redirectURL = redirectURL.ResolveReference(&url.URL{ + Path: o.AuthenticateCallbackPath, + }) return oauth.Options{ RedirectURL: redirectURL, ProviderName: o.Provider, @@ -757,7 +758,7 @@ func (o *Options) GetOauthOptions() oauth.Options { ClientSecret: o.ClientSecret, Scopes: o.Scopes, ServiceAccount: o.ServiceAccount, - } + }, nil } // GetAllPolicies gets all the policies in the options. diff --git a/config/options_test.go b/config/options_test.go index e3dd0d7d1..b932de6b9 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -14,6 +14,7 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "github.com/spf13/viper" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) var cmpOptIgnoreUnexported = cmpopts.IgnoreUnexported(Options{}) @@ -498,7 +499,7 @@ func TestOptions_DefaultURL(t *testing.T) { } tests := []struct { name string - f func() *url.URL + f func() (*url.URL, error) expectedURLStr string }{ {"default authenticate url", defaultOptions.GetAuthenticateURL, "https://127.0.0.1"}, @@ -515,7 +516,9 @@ func TestOptions_DefaultURL(t *testing.T) { tc := tc t.Run(tc.name, func(t *testing.T) { t.Parallel() - assert.Equal(t, tc.expectedURLStr, tc.f().String()) + u, err := tc.f() + require.NoError(t, err) + assert.Equal(t, tc.expectedURLStr, u.String()) }) } } @@ -530,7 +533,9 @@ func mustParseURL(str string) *url.URL { func TestOptions_GetOauthOptions(t *testing.T) { opts := &Options{AuthenticateURL: mustParseURL("https://authenticate.example.com")} + oauthOptions, err := opts.GetOauthOptions() + require.NoError(t, err) // Test that oauth redirect url hostname must point to authenticate url hostname. - assert.Equal(t, opts.AuthenticateURL.Hostname(), opts.GetOauthOptions().RedirectURL.Hostname()) + assert.Equal(t, opts.AuthenticateURL.Hostname(), oauthOptions.RedirectURL.Hostname()) } diff --git a/databroker/cache.go b/databroker/cache.go index 68aa36599..3283141a2 100644 --- a/databroker/cache.go +++ b/databroker/cache.go @@ -81,13 +81,17 @@ func New(cfg *config.Config) (*DataBroker, error) { } dataBrokerServer := newDataBrokerServer(cfg) + dataBrokerURL, err := cfg.Options.GetDataBrokerURL() + if err != nil { + return nil, err + } c := &DataBroker{ dataBrokerServer: dataBrokerServer, localListener: localListener, localGRPCServer: localGRPCServer, localGRPCConnection: localGRPCConnection, - deprecatedCacheClusterDomain: cfg.Options.GetDataBrokerURL().Hostname(), + deprecatedCacheClusterDomain: dataBrokerURL.Hostname(), dataBrokerStorageType: cfg.Options.DataBrokerStorageType, } c.Register(c.localGRPCServer) @@ -138,7 +142,12 @@ func (c *DataBroker) update(cfg *config.Config) error { return fmt.Errorf("databroker: bad option: %w", err) } - authenticator, err := identity.NewAuthenticator(cfg.Options.GetOauthOptions()) + oauthOptions, err := cfg.Options.GetOauthOptions() + if err != nil { + return fmt.Errorf("databroker: invalid oauth options: %w", err) + } + + authenticator, err := identity.NewAuthenticator(oauthOptions) if err != nil { return fmt.Errorf("databroker: failed to create authenticator: %w", err) } diff --git a/internal/cmd/pomerium/pomerium.go b/internal/cmd/pomerium/pomerium.go index 228b2a8ac..b6ae40e68 100644 --- a/internal/cmd/pomerium/pomerium.go +++ b/internal/cmd/pomerium/pomerium.go @@ -158,9 +158,15 @@ func setupAuthenticate(src config.Source, controlPlane *controlplane.Server) err if err != nil { return fmt.Errorf("error creating authenticate service: %w", err) } + + authenticateURL, err := src.GetConfig().Options.GetAuthenticateURL() + if err != nil { + return fmt.Errorf("error getting authenticate URL: %w", err) + } + src.OnConfigChange(svc.OnConfigChange) svc.OnConfigChange(src.GetConfig()) - host := urlutil.StripPort(src.GetConfig().Options.GetAuthenticateURL().Host) + host := urlutil.StripPort(authenticateURL.Host) sr := controlPlane.HTTPRouter.Host(host).Subrouter() svc.Mount(sr) log.Info().Str("host", host).Msg("enabled authenticate service") diff --git a/internal/controlplane/xds_clusters.go b/internal/controlplane/xds_clusters.go index d13477789..b4fd76564 100644 --- a/internal/controlplane/xds_clusters.go +++ b/internal/controlplane/xds_clusters.go @@ -50,9 +50,9 @@ func (srv *Server) buildClusters(options *config.Options) ([]*envoy_config_clust Scheme: "http", Host: srv.HTTPListener.Addr().String(), } - authzURL := &url.URL{ - Scheme: options.GetAuthorizeURL().Scheme, - Host: options.GetAuthorizeURL().Host, + authzURL, err := options.GetAuthorizeURL() + if err != nil { + return nil, err } controlGRPC, err := srv.buildInternalCluster(options, "pomerium-control-plane-grpc", grpcURL, true) @@ -132,22 +132,22 @@ func (srv *Server) buildPolicyCluster(options *config.Options, policy *config.Po func (srv *Server) buildInternalEndpoints(options *config.Options, dst *url.URL) ([]Endpoint, error) { var endpoints []Endpoint - if ts, err := srv.buildInternalTransportSocket(options, dst); err != nil { + ts, err := srv.buildInternalTransportSocket(options, dst) + if err != nil { return nil, err - } else { - endpoints = append(endpoints, NewEndpoint(dst, ts)) } + endpoints = append(endpoints, NewEndpoint(dst, ts)) return endpoints, nil } func (srv *Server) buildPolicyEndpoints(policy *config.Policy) ([]Endpoint, error) { var endpoints []Endpoint for _, dst := range policy.Destinations { - if ts, err := srv.buildPolicyTransportSocket(policy, dst); err != nil { + ts, err := srv.buildPolicyTransportSocket(policy, dst) + if err != nil { return nil, err - } else { - endpoints = append(endpoints, NewEndpoint(dst, ts)) } + endpoints = append(endpoints, NewEndpoint(dst, ts)) } return endpoints, nil } @@ -246,7 +246,9 @@ func (srv *Server) buildPolicyTransportSocket(policy *config.Policy, dst *url.UR }, nil } -func (srv *Server) buildPolicyValidationContext(policy *config.Policy, dst *url.URL) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) { +func (srv *Server) buildPolicyValidationContext( + policy *config.Policy, dst *url.URL, +) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) { if dst == nil { return nil, nil } diff --git a/internal/controlplane/xds_listeners.go b/internal/controlplane/xds_listeners.go index 02bbd3378..c1306c48d 100644 --- a/internal/controlplane/xds_listeners.go +++ b/internal/controlplane/xds_listeners.go @@ -43,19 +43,19 @@ func (srv *Server) buildListeners(cfg *config.Config) ([]*envoy_config_listener_ var listeners []*envoy_config_listener_v3.Listener if config.IsAuthenticate(cfg.Options.Services) || config.IsProxy(cfg.Options.Services) { - if li, err := srv.buildMainListener(cfg); err != nil { + li, err := srv.buildMainListener(cfg) + if err != nil { return nil, err - } else { - listeners = append(listeners, li) } + listeners = append(listeners, li) } if config.IsAuthorize(cfg.Options.Services) || config.IsDataBroker(cfg.Options.Services) { - if li, err := srv.buildGRPCListener(cfg); err != nil { + li, err := srv.buildGRPCListener(cfg) + if err != nil { return nil, err - } else { - listeners = append(listeners, li) } + listeners = append(listeners, li) } return listeners, nil @@ -74,8 +74,12 @@ func (srv *Server) buildMainListener(cfg *config.Config) (*envoy_config_listener } if cfg.Options.InsecureServer { - filter, err := srv.buildMainHTTPConnectionManagerFilter(cfg.Options, - getAllRouteableDomains(cfg.Options, cfg.Options.Addr), "") + allDomains, err := getAllRouteableDomains(cfg.Options, cfg.Options.Addr) + if err != nil { + return nil, err + } + + filter, err := srv.buildMainHTTPConnectionManagerFilter(cfg.Options, allDomains, "") if err != nil { return nil, err } @@ -143,24 +147,37 @@ func (srv *Server) buildFilterChains( options *config.Options, addr string, callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error), ) ([]*envoy_config_listener_v3.FilterChain, error) { - allDomains := getAllRouteableDomains(options, addr) - tlsDomains := getAllTLSDomains(options, addr) + allDomains, err := getAllRouteableDomains(options, addr) + if err != nil { + return nil, err + } + + tlsDomains, err := getAllTLSDomains(options, addr) + if err != nil { + return nil, err + } + var chains []*envoy_config_listener_v3.FilterChain for _, domain := range tlsDomains { - // first we match on SNI - if chain, err := callback(domain, getRouteableDomainsForTLSDomain(options, addr, domain)); err != nil { + routeableDomains, err := getRouteableDomainsForTLSDomain(options, addr, domain) + if err != nil { return nil, err - } else { - chains = append(chains, chain) } + + // first we match on SNI + chain, err := callback(domain, routeableDomains) + if err != nil { + return nil, err + } + chains = append(chains, chain) } // if there are no SNI matches we match on HTTP host - if chain, err := callback("*", allDomains); err != nil { + chain, err := callback("*", allDomains) + if err != nil { return nil, err - } else { - chains = append(chains, chain) } + chains = append(chains, chain) return chains, nil } @@ -169,6 +186,16 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter( domains []string, tlsDomain string, ) (*envoy_config_listener_v3.Filter, error) { + authorizeURL, err := options.GetAuthorizeURL() + if err != nil { + return nil, err + } + + dataBrokerURL, err := options.GetDataBrokerURL() + if err != nil { + return nil, err + } + var virtualHosts []*envoy_config_route_v3.VirtualHost for _, domain := range domains { vh := &envoy_config_route_v3.VirtualHost{ @@ -178,30 +205,30 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter( if options.Addr == options.GRPCAddr { // if this is a gRPC service domain and we're supposed to handle that, add those routes - if (config.IsAuthorize(options.Services) && hostMatchesDomain(options.GetAuthorizeURL(), domain)) || - (config.IsDataBroker(options.Services) && hostMatchesDomain(options.GetDataBrokerURL(), domain)) { - if rs, err := srv.buildGRPCRoutes(); err != nil { + if (config.IsAuthorize(options.Services) && hostMatchesDomain(authorizeURL, domain)) || + (config.IsDataBroker(options.Services) && hostMatchesDomain(dataBrokerURL, domain)) { + rs, err := srv.buildGRPCRoutes() + if err != nil { return nil, err - } else { - vh.Routes = append(vh.Routes, rs...) } + vh.Routes = append(vh.Routes, rs...) } } // these routes match /.pomerium/... and similar paths - if rs, err := srv.buildPomeriumHTTPRoutes(options, domain); err != nil { + rs, err := srv.buildPomeriumHTTPRoutes(options, domain) + if err != nil { return nil, err - } else { - vh.Routes = append(vh.Routes, rs...) } + vh.Routes = append(vh.Routes, rs...) // if we're the proxy, add all the policy routes if config.IsProxy(options.Services) { - if rs, err := srv.buildPolicyRoutes(options, domain); err != nil { + rs, err := srv.buildPolicyRoutes(options, domain) + if err != nil { return nil, err - } else { - vh.Routes = append(vh.Routes, rs...) } + vh.Routes = append(vh.Routes, rs...) } if len(vh.Routes) > 0 { @@ -209,15 +236,15 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter( } } - if rs, err := srv.buildPomeriumHTTPRoutes(options, "*"); err != nil { + rs, err := srv.buildPomeriumHTTPRoutes(options, "*") + if err != nil { return nil, err - } else { - virtualHosts = append(virtualHosts, &envoy_config_route_v3.VirtualHost{ - Name: "catch-all", - Domains: []string{"*"}, - Routes: rs, - }) } + virtualHosts = append(virtualHosts, &envoy_config_route_v3.VirtualHost{ + Name: "catch-all", + Domains: []string{"*"}, + Routes: rs, + }) var grpcClientTimeout *durationpb.Duration if options.GRPCClientTimeout != 0 { @@ -235,7 +262,7 @@ func (srv *Server) buildMainHTTPConnectionManagerFilter( Timeout: grpcClientTimeout, TargetSpecifier: &envoy_config_core_v3.GrpcService_EnvoyGrpc_{ EnvoyGrpc: &envoy_config_core_v3.GrpcService_EnvoyGrpc{ - ClusterName: options.GetAuthorizeURL().Host, + ClusterName: authorizeURL.Host, }, }, }, @@ -501,31 +528,55 @@ func (srv *Server) buildDownstreamTLSContext(cfg *config.Config, domain string) } } -func getRouteableDomainsForTLSDomain(options *config.Options, addr string, tlsDomain string) []string { - allDomains := getAllRouteableDomains(options, addr) +func getRouteableDomainsForTLSDomain(options *config.Options, addr string, tlsDomain string) ([]string, error) { + allDomains, err := getAllRouteableDomains(options, addr) + if err != nil { + return nil, err + } + var filtered []string for _, domain := range allDomains { if urlutil.StripPort(domain) == tlsDomain { filtered = append(filtered, domain) } } - return filtered + return filtered, nil } -func getAllRouteableDomains(options *config.Options, addr string) []string { +func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) { + authenticateURL, err := options.GetAuthenticateURL() + if err != nil { + return nil, err + } + + authorizeURL, err := options.GetAuthorizeURL() + if err != nil { + return nil, err + } + + dataBrokerURL, err := options.GetDataBrokerURL() + if err != nil { + return nil, err + } + + forwardAuthURL, err := options.GetForwardAuthURL() + if err != nil { + return nil, err + } + lookup := map[string]struct{}{} if config.IsAuthenticate(options.Services) && addr == options.Addr { - for _, h := range urlutil.GetDomainsForURL(options.GetAuthenticateURL()) { + for _, h := range urlutil.GetDomainsForURL(authenticateURL) { lookup[h] = struct{}{} } } if config.IsAuthorize(options.Services) && addr == options.GRPCAddr { - for _, h := range urlutil.GetDomainsForURL(options.GetAuthorizeURL()) { + for _, h := range urlutil.GetDomainsForURL(authorizeURL) { lookup[h] = struct{}{} } } if config.IsDataBroker(options.Services) && addr == options.GRPCAddr { - for _, h := range urlutil.GetDomainsForURL(options.GetDataBrokerURL()) { + for _, h := range urlutil.GetDomainsForURL(dataBrokerURL) { lookup[h] = struct{}{} } } @@ -536,7 +587,7 @@ func getAllRouteableDomains(options *config.Options, addr string) []string { } } if options.ForwardAuthURL != nil { - for _, h := range urlutil.GetDomainsForURL(options.GetForwardAuthURL()) { + for _, h := range urlutil.GetDomainsForURL(forwardAuthURL) { lookup[h] = struct{}{} } } @@ -548,12 +599,17 @@ func getAllRouteableDomains(options *config.Options, addr string) []string { } sort.Strings(domains) - return domains + return domains, nil } -func getAllTLSDomains(options *config.Options, addr string) []string { +func getAllTLSDomains(options *config.Options, addr string) ([]string, error) { + allDomains, err := getAllRouteableDomains(options, addr) + if err != nil { + return nil, err + } + lookup := map[string]struct{}{} - for _, hp := range getAllRouteableDomains(options, addr) { + for _, hp := range allDomains { if d, _, err := net.SplitHostPort(hp); err == nil { lookup[d] = struct{}{} } else { @@ -567,7 +623,7 @@ func getAllTLSDomains(options *config.Options, addr string) []string { } sort.Strings(domains) - return domains + return domains, nil } func hostMatchesDomain(u *url.URL, host string) bool { diff --git a/internal/controlplane/xds_listeners_test.go b/internal/controlplane/xds_listeners_test.go index 5131c79a5..b6c86eeeb 100644 --- a/internal/controlplane/xds_listeners_test.go +++ b/internal/controlplane/xds_listeners_test.go @@ -446,7 +446,8 @@ func Test_getAllDomains(t *testing.T) { } t.Run("routable", func(t *testing.T) { t.Run("http", func(t *testing.T) { - actual := getAllRouteableDomains(options, "127.0.0.1:9000") + actual, err := getAllRouteableDomains(options, "127.0.0.1:9000") + require.NoError(t, err) expect := []string{ "a.example.com", "a.example.com:80", @@ -460,7 +461,8 @@ func Test_getAllDomains(t *testing.T) { assert.Equal(t, expect, actual) }) t.Run("grpc", func(t *testing.T) { - actual := getAllRouteableDomains(options, "127.0.0.1:9001") + actual, err := getAllRouteableDomains(options, "127.0.0.1:9001") + require.NoError(t, err) expect := []string{ "authorize.example.com:9001", "cache.example.com:9001", @@ -470,7 +472,8 @@ func Test_getAllDomains(t *testing.T) { }) t.Run("tls", func(t *testing.T) { t.Run("http", func(t *testing.T) { - actual := getAllTLSDomains(options, "127.0.0.1:9000") + actual, err := getAllTLSDomains(options, "127.0.0.1:9000") + require.NoError(t, err) expect := []string{ "a.example.com", "authenticate.example.com", @@ -480,7 +483,8 @@ func Test_getAllDomains(t *testing.T) { assert.Equal(t, expect, actual) }) t.Run("grpc", func(t *testing.T) { - actual := getAllTLSDomains(options, "127.0.0.1:9001") + actual, err := getAllTLSDomains(options, "127.0.0.1:9001") + require.NoError(t, err) expect := []string{ "authorize.example.com", "cache.example.com", diff --git a/internal/controlplane/xds_routes.go b/internal/controlplane/xds_routes.go index e24842ece..a3dcb4e84 100644 --- a/internal/controlplane/xds_routes.go +++ b/internal/controlplane/xds_routes.go @@ -50,96 +50,105 @@ func (srv *Server) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) { func (srv *Server) buildPomeriumHTTPRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) { var routes []*envoy_config_route_v3.Route // enable ext_authz - if r, err := srv.buildControlPlanePathRoute("/.pomerium/jwt", true); err != nil { + r, err := srv.buildControlPlanePathRoute("/.pomerium/jwt", true) + if err != nil { return nil, err - } else { - routes = append(routes, r) } + routes = append(routes, r) // disable ext_authz and passthrough to proxy handlers - if r, err := srv.buildControlPlanePathRoute("/ping", false); err != nil { + r, err = srv.buildControlPlanePathRoute("/ping", false) + if err != nil { return nil, err - } else { - routes = append(routes, r) } - if r, err := srv.buildControlPlanePathRoute("/healthz", false); err != nil { + routes = append(routes, r) + r, err = srv.buildControlPlanePathRoute("/healthz", false) + if err != nil { return nil, err - } else { - routes = append(routes, r) } - if r, err := srv.buildControlPlanePathRoute("/.pomerium/admin", true); err != nil { + routes = append(routes, r) + r, err = srv.buildControlPlanePathRoute("/.pomerium/admin", true) + if err != nil { return nil, err - } else { - routes = append(routes, r) } - if r, err := srv.buildControlPlanePrefixRoute("/.pomerium/admin/", true); err != nil { + routes = append(routes, r) + + r, err = srv.buildControlPlanePrefixRoute("/.pomerium/admin/", true) + if err != nil { return nil, err - } else { - routes = append(routes, r) } - if r, err := srv.buildControlPlanePathRoute("/.pomerium", false); err != nil { + routes = append(routes, r) + r, err = srv.buildControlPlanePathRoute("/.pomerium", false) + if err != nil { return nil, err - } else { - routes = append(routes, r) } - if r, err := srv.buildControlPlanePrefixRoute("/.pomerium/", false); err != nil { + routes = append(routes, r) + r, err = srv.buildControlPlanePrefixRoute("/.pomerium/", false) + if err != nil { return nil, err - } else { - routes = append(routes, r) } - if r, err := srv.buildControlPlanePathRoute("/.well-known/pomerium", false); err != nil { + routes = append(routes, r) + r, err = srv.buildControlPlanePathRoute("/.well-known/pomerium", false) + if err != nil { return nil, err - } else { - routes = append(routes, r) } - if r, err := srv.buildControlPlanePrefixRoute("/.well-known/pomerium/", false); err != nil { + routes = append(routes, r) + r, err = srv.buildControlPlanePrefixRoute("/.well-known/pomerium/", false) + if err != nil { return nil, err - } else { - routes = append(routes, r) } + routes = append(routes, r) // per #837, only add robots.txt if there are no unauthenticated routes if !hasPublicPolicyMatchingURL(options, mustParseURL("https://"+domain+"/robots.txt")) { - if r, err := srv.buildControlPlanePathRoute("/robots.txt", false); err != nil { + r, err := srv.buildControlPlanePathRoute("/robots.txt", false) + if err != nil { return nil, err - } else { - routes = append(routes, r) } + routes = append(routes, r) } // if we're handling authentication, add the oauth2 callback url - if config.IsAuthenticate(options.Services) && hostMatchesDomain(options.GetAuthenticateURL(), domain) { - if r, err := srv.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false); err != nil { + authenticateURL, err := options.GetAuthenticateURL() + if err != nil { + return nil, err + } + if config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) { + r, err := srv.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false) + if err != nil { return nil, err - } else { - routes = append(routes, r) } + routes = append(routes, r) } // if we're the proxy and this is the forward-auth url - if config.IsProxy(options.Services) && options.ForwardAuthURL != nil && hostMatchesDomain(options.GetForwardAuthURL(), domain) { + forwardAuthURL, err := options.GetForwardAuthURL() + if err != nil { + return nil, err + } + if config.IsProxy(options.Services) && options.ForwardAuthURL != nil && hostMatchesDomain(forwardAuthURL, domain) { // disable ext_authz and pass request to proxy handlers that enable authN flow - if r, err := srv.buildControlPlanePathAndQueryRoute("/verify", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI}); err != nil { + r, err := srv.buildControlPlanePathAndQueryRoute("/verify", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI}) + if err != nil { return nil, err - } else { - routes = append(routes, r) } - if r, err := srv.buildControlPlanePathAndQueryRoute("/", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI}); err != nil { + routes = append(routes, r) + r, err = srv.buildControlPlanePathAndQueryRoute("/", []string{urlutil.QueryForwardAuthURI, urlutil.QuerySessionEncrypted, urlutil.QueryRedirectURI}) + if err != nil { return nil, err - } else { - routes = append(routes, r) } - if r, err := srv.buildControlPlanePathAndQueryRoute("/", []string{urlutil.QueryForwardAuthURI}); err != nil { + routes = append(routes, r) + r, err = srv.buildControlPlanePathAndQueryRoute("/", []string{urlutil.QueryForwardAuthURI}) + if err != nil { return nil, err - } else { - routes = append(routes, r) } + routes = append(routes, r) // otherwise, enforce ext_authz; pass all other requests through to an upstream // handler that will simply respond with http status 200 / OK indicating that // the fronting forward-auth proxy can continue. - if r, err := srv.buildControlPlaneProtectedPrefixRoute("/"); err != nil { + r, err = srv.buildControlPlaneProtectedPrefixRoute("/") + if err != nil { return nil, err - } else { - routes = append(routes, r) } + routes = append(routes, r) } return routes, nil }