diff --git a/internal/controlplane/xds_routes.go b/internal/controlplane/xds_routes.go index 99429217c..8d06c823c 100644 --- a/internal/controlplane/xds_routes.go +++ b/internal/controlplane/xds_routes.go @@ -100,60 +100,19 @@ func buildControlPlanePrefixRoute(prefix string) *envoy_config_route_v3.Route { func buildPolicyRoutes(options *config.Options, domain string) []*envoy_config_route_v3.Route { var routes []*envoy_config_route_v3.Route - responseHeadersToAdd := make([]*envoy_config_core_v3.HeaderValueOption, 0, len(options.Headers)) - for k, v := range options.Headers { - responseHeadersToAdd = append(responseHeadersToAdd, mkEnvoyHeader(k, v)) - } + responseHeadersToAdd := toEnvoyHeaders(options.Headers) for i, policy := range options.Policies { if policy.Source.Host != domain { continue } - match := &envoy_config_route_v3.RouteMatch{} - switch { - case policy.Regex != "": - match.PathSpecifier = &envoy_config_route_v3.RouteMatch_SafeRegex{ - SafeRegex: &envoy_type_matcher_v3.RegexMatcher{ - EngineType: &envoy_type_matcher_v3.RegexMatcher_GoogleRe2{ - GoogleRe2: &envoy_type_matcher_v3.RegexMatcher_GoogleRE2{}, - }, - Regex: policy.Regex, - }, - } - case policy.Path != "": - match.PathSpecifier = &envoy_config_route_v3.RouteMatch_Path{Path: policy.Path} - case policy.Prefix != "": - match.PathSpecifier = &envoy_config_route_v3.RouteMatch_Prefix{Prefix: policy.Prefix} - default: - match.PathSpecifier = &envoy_config_route_v3.RouteMatch_Prefix{Prefix: "/"} - } + match := mkRouteMatch(&policy) clusterName := getPolicyName(&policy) - - requestHeadersToAdd := make([]*envoy_config_core_v3.HeaderValueOption, 0, len(policy.SetRequestHeaders)) - for k, v := range policy.SetRequestHeaders { - requestHeadersToAdd = append(requestHeadersToAdd, mkEnvoyHeader(k, v)) - } - - requestHeadersToRemove := policy.RemoveRequestHeaders - if !policy.PassIdentityHeaders { - requestHeadersToRemove = append(requestHeadersToRemove, httputil.HeaderPomeriumJWTAssertion) - for _, claim := range options.JWTClaimsHeaders { - requestHeadersToRemove = append(requestHeadersToRemove, httputil.PomeriumJWTHeaderName(claim)) - } - } - - var routeTimeout *durationpb.Duration - if policy.AllowWebsockets { - // disable the route timeout for websocket support - routeTimeout = ptypes.DurationProto(0) - } else { - if policy.UpstreamTimeout != 0 { - routeTimeout = ptypes.DurationProto(policy.UpstreamTimeout) - } else { - routeTimeout = ptypes.DurationProto(options.DefaultUpstreamTimeout) - } - } + requestHeadersToAdd := toEnvoyHeaders(policy.SetRequestHeaders) + requestHeadersToRemove := getRequestHeadersToRemove(options, &policy) + routeTimeout := getRouteTimeout(options, &policy) + prefixRewrite := getPrefixRewrite(&policy) routes = append(routes, &envoy_config_route_v3.Route{ Name: fmt.Sprintf("policy-%d", i), @@ -188,7 +147,8 @@ func buildPolicyRoutes(options *config.Options, domain string) []*envoy_config_r HostRewriteSpecifier: &envoy_config_route_v3.RouteAction_AutoHostRewrite{ AutoHostRewrite: &wrappers.BoolValue{Value: !policy.PreserveHostHeader}, }, - Timeout: routeTimeout, + Timeout: routeTimeout, + PrefixRewrite: prefixRewrite, }, }, RequestHeadersToAdd: requestHeadersToAdd, @@ -208,3 +168,67 @@ func mkEnvoyHeader(k, v string) *envoy_config_core_v3.HeaderValueOption { Append: &wrappers.BoolValue{Value: false}, } } + +func toEnvoyHeaders(headers map[string]string) []*envoy_config_core_v3.HeaderValueOption { + envoyHeaders := make([]*envoy_config_core_v3.HeaderValueOption, 0, len(headers)) + for k, v := range headers { + envoyHeaders = append(envoyHeaders, mkEnvoyHeader(k, v)) + } + return envoyHeaders +} + +func mkRouteMatch(policy *config.Policy) *envoy_config_route_v3.RouteMatch { + match := &envoy_config_route_v3.RouteMatch{} + switch { + case policy.Regex != "": + match.PathSpecifier = &envoy_config_route_v3.RouteMatch_SafeRegex{ + SafeRegex: &envoy_type_matcher_v3.RegexMatcher{ + EngineType: &envoy_type_matcher_v3.RegexMatcher_GoogleRe2{ + GoogleRe2: &envoy_type_matcher_v3.RegexMatcher_GoogleRE2{}, + }, + Regex: policy.Regex, + }, + } + case policy.Path != "": + match.PathSpecifier = &envoy_config_route_v3.RouteMatch_Path{Path: policy.Path} + case policy.Prefix != "": + match.PathSpecifier = &envoy_config_route_v3.RouteMatch_Prefix{Prefix: policy.Prefix} + default: + match.PathSpecifier = &envoy_config_route_v3.RouteMatch_Prefix{Prefix: "/"} + } + return match +} + +func getRequestHeadersToRemove(options *config.Options, policy *config.Policy) []string { + requestHeadersToRemove := policy.RemoveRequestHeaders + if !policy.PassIdentityHeaders { + requestHeadersToRemove = append(requestHeadersToRemove, httputil.HeaderPomeriumJWTAssertion) + for _, claim := range options.JWTClaimsHeaders { + requestHeadersToRemove = append(requestHeadersToRemove, httputil.PomeriumJWTHeaderName(claim)) + } + } + return requestHeadersToRemove +} + +func getRouteTimeout(options *config.Options, policy *config.Policy) *durationpb.Duration { + var routeTimeout *durationpb.Duration + if policy.AllowWebsockets { + // disable the route timeout for websocket support + routeTimeout = ptypes.DurationProto(0) + } else { + if policy.UpstreamTimeout != 0 { + routeTimeout = ptypes.DurationProto(policy.UpstreamTimeout) + } else { + routeTimeout = ptypes.DurationProto(options.DefaultUpstreamTimeout) + } + } + return routeTimeout +} + +func getPrefixRewrite(policy *config.Policy) string { + prefixRewrite := "" + if policy.Destination != nil && policy.Destination.Path != "" { + prefixRewrite = policy.Destination.Path + } + return prefixRewrite +} diff --git a/internal/controlplane/xds_routes_test.go b/internal/controlplane/xds_routes_test.go index 015ec6f81..ee42664b7 100644 --- a/internal/controlplane/xds_routes_test.go +++ b/internal/controlplane/xds_routes_test.go @@ -434,6 +434,49 @@ func TestAddOptionsHeadersToResponse(t *testing.T) { `, routes) } +func Test_buildPolicyRoutesWithDestinationPath(t *testing.T) { + routes := buildPolicyRoutes(&config.Options{ + CookieName: "pomerium", + DefaultUpstreamTimeout: time.Second * 3, + Policies: []config.Policy{ + { + Source: &config.StringURL{URL: mustParseURL("https://example.com")}, + Destination: mustParseURL("https://foo.example.com/bar"), + PassIdentityHeaders: true, + }, + }, + }, "example.com") + + testutil.AssertProtoJSONEqual(t, ` + [ + { + "name": "policy-0", + "match": { + "prefix": "/" + }, + "metadata": { + "filterMetadata": { + "envoy.filters.http.lua": { + "remove_pomerium_authorization": true, + "remove_pomerium_cookie": "pomerium" + } + } + }, + "route": { + "autoHostRewrite": true, + "prefixRewrite": "/bar", + "cluster": "policy-605b7be39724cb4f", + "timeout": "3s", + "upgradeConfigs": [{ + "enabled": false, + "upgradeType": "websocket" + }] + } + } + ] + `, routes) +} + func mustParseURL(str string) *url.URL { u, err := url.Parse(str) if err != nil {