config: simplify default set response headers (#4196)

This commit is contained in:
Caleb Doxsey 2023-05-30 17:44:06 -06:00 committed by GitHub
parent d315e68335
commit a741cce50e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 60 additions and 80 deletions

View file

@ -13,7 +13,6 @@ func (b *Builder) buildVirtualHost(
options *config.Options, options *config.Options,
name string, name string,
host string, host string,
requireStrictTransportSecurity bool,
) (*envoy_config_route_v3.VirtualHost, error) { ) (*envoy_config_route_v3.VirtualHost, error) {
vh := &envoy_config_route_v3.VirtualHost{ vh := &envoy_config_route_v3.VirtualHost{
Name: name, Name: name,
@ -21,7 +20,7 @@ func (b *Builder) buildVirtualHost(
} }
// these routes match /.pomerium/... and similar paths // these routes match /.pomerium/... and similar paths
rs, err := b.buildPomeriumHTTPRoutes(options, host, requireStrictTransportSecurity) rs, err := b.buildPomeriumHTTPRoutes(options, host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -34,13 +33,12 @@ func (b *Builder) buildVirtualHost(
// coming directly from envoy // coming directly from envoy
func (b *Builder) buildLocalReplyConfig( func (b *Builder) buildLocalReplyConfig(
options *config.Options, options *config.Options,
requireStrictTransportSecurity bool,
) *envoy_http_connection_manager.LocalReplyConfig { ) *envoy_http_connection_manager.LocalReplyConfig {
// add global headers for HSTS headers (#2110) // add global headers for HSTS headers (#2110)
var headers []*envoy_config_core_v3.HeaderValueOption var headers []*envoy_config_core_v3.HeaderValueOption
// if we're the proxy or authenticate service, add our global headers // if we're the proxy or authenticate service, add our global headers
if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) { if config.IsProxy(options.Services) || config.IsAuthenticate(options.Services) {
headers = toEnvoyHeaders(options.GetSetResponseHeaders(requireStrictTransportSecurity)) headers = toEnvoyHeaders(options.GetSetResponseHeaders())
} }
return &envoy_http_connection_manager.LocalReplyConfig{ return &envoy_http_connection_manager.LocalReplyConfig{

View file

@ -298,7 +298,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
UseRemoteAddress: &wrappers.BoolValue{Value: true}, UseRemoteAddress: &wrappers.BoolValue{Value: true},
SkipXffAppend: cfg.Options.SkipXffAppend, SkipXffAppend: cfg.Options.SkipXffAppend,
XffNumTrustedHops: cfg.Options.XffNumTrustedHops, XffNumTrustedHops: cfg.Options.XffNumTrustedHops,
LocalReplyConfig: b.buildLocalReplyConfig(cfg.Options, false), LocalReplyConfig: b.buildLocalReplyConfig(cfg.Options),
NormalizePath: wrapperspb.Bool(true), NormalizePath: wrapperspb.Bool(true),
} }

View file

@ -2,12 +2,10 @@ package envoyconfig
import ( import (
"context" "context"
"crypto/tls"
envoy_config_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" envoy_config_route_v3 "github.com/envoyproxy/go-control-plane/envoy/config/route/v3"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/cryptutil"
) )
// BuildRouteConfigurations builds the route configurations for the RDS service. // BuildRouteConfigurations builds the route configurations for the RDS service.
@ -32,15 +30,6 @@ func (b *Builder) buildMainRouteConfiguration(
_ context.Context, _ context.Context,
cfg *config.Config, cfg *config.Config,
) (*envoy_config_route_v3.RouteConfiguration, error) { ) (*envoy_config_route_v3.RouteConfiguration, error) {
var certs []tls.Certificate
if !cfg.Options.InsecureServer {
var err error
certs, err = getAllCertificates(cfg)
if err != nil {
return nil, err
}
}
authorizeURLs, err := cfg.Options.GetInternalAuthorizeURLs() authorizeURLs, err := cfg.Options.GetInternalAuthorizeURLs()
if err != nil { if err != nil {
return nil, err return nil, err
@ -58,8 +47,7 @@ func (b *Builder) buildMainRouteConfiguration(
var virtualHosts []*envoy_config_route_v3.VirtualHost var virtualHosts []*envoy_config_route_v3.VirtualHost
for _, host := range allHosts { for _, host := range allHosts {
requireStrictTransportSecurity := cryptutil.HasCertificateForServerName(certs, host) vh, err := b.buildVirtualHost(cfg.Options, host, host)
vh, err := b.buildVirtualHost(cfg.Options, host, host, requireStrictTransportSecurity)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -78,7 +66,7 @@ func (b *Builder) buildMainRouteConfiguration(
// if we're the proxy, add all the policy routes // if we're the proxy, add all the policy routes
if config.IsProxy(cfg.Options.Services) { if config.IsProxy(cfg.Options.Services) {
rs, err := b.buildRoutesForPoliciesWithHost(cfg, certs, host) rs, err := b.buildRoutesForPoliciesWithHost(cfg, host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -90,12 +78,12 @@ func (b *Builder) buildMainRouteConfiguration(
} }
} }
vh, err := b.buildVirtualHost(cfg.Options, "catch-all", "*", false) vh, err := b.buildVirtualHost(cfg.Options, "catch-all", "*")
if err != nil { if err != nil {
return nil, err return nil, err
} }
if config.IsProxy(cfg.Options.Services) { if config.IsProxy(cfg.Options.Services) {
rs, err := b.buildRoutesForPoliciesWithCatchAll(cfg, certs) rs, err := b.buildRoutesForPoliciesWithCatchAll(cfg)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -41,13 +41,13 @@ func TestBuilder_buildMainRouteConfiguration(t *testing.T) {
"name": "catch-all", "name": "catch-all",
"domains": ["*"], "domains": ["*"],
"routes": [ "routes": [
`+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/ping", false))+`, `+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/ping"))+`,
`+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/healthz", false))+`, `+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/healthz"))+`,
`+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/.pomerium", false))+`, `+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/.pomerium"))+`,
`+protojson.Format(b.buildControlPlanePrefixRoute(cfg.Options, "/.pomerium/", false))+`, `+protojson.Format(b.buildControlPlanePrefixRoute(cfg.Options, "/.pomerium/"))+`,
`+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/.well-known/pomerium", false))+`, `+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/.well-known/pomerium"))+`,
`+protojson.Format(b.buildControlPlanePrefixRoute(cfg.Options, "/.well-known/pomerium/", false))+`, `+protojson.Format(b.buildControlPlanePrefixRoute(cfg.Options, "/.well-known/pomerium/"))+`,
`+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/robots.txt", false))+`, `+protojson.Format(b.buildControlPlanePathRoute(cfg.Options, "/robots.txt"))+`,
{ {
"name": "policy-0", "name": "policy-0",
"match": { "match": {

View file

@ -1,7 +1,6 @@
package envoyconfig package envoyconfig
import ( import (
"crypto/tls"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/url" "net/url"
@ -20,7 +19,6 @@ import (
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
) )
const ( const (
@ -53,7 +51,6 @@ func (b *Builder) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) {
func (b *Builder) buildPomeriumHTTPRoutes( func (b *Builder) buildPomeriumHTTPRoutes(
options *config.Options, options *config.Options,
host string, host string,
requireStrictTransportSecurity bool,
) ([]*envoy_config_route_v3.Route, error) { ) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route var routes []*envoy_config_route_v3.Route
@ -65,20 +62,20 @@ func (b *Builder) buildPomeriumHTTPRoutes(
} }
if !isFrontingAuthenticate { if !isFrontingAuthenticate {
routes = append(routes, routes = append(routes,
b.buildControlPlanePathRoute(options, "/ping", requireStrictTransportSecurity), b.buildControlPlanePathRoute(options, "/ping"),
b.buildControlPlanePathRoute(options, "/healthz", requireStrictTransportSecurity), b.buildControlPlanePathRoute(options, "/healthz"),
b.buildControlPlanePathRoute(options, "/.pomerium", requireStrictTransportSecurity), b.buildControlPlanePathRoute(options, "/.pomerium"),
b.buildControlPlanePrefixRoute(options, "/.pomerium/", requireStrictTransportSecurity), b.buildControlPlanePrefixRoute(options, "/.pomerium/"),
b.buildControlPlanePathRoute(options, "/.well-known/pomerium", requireStrictTransportSecurity), b.buildControlPlanePathRoute(options, "/.well-known/pomerium"),
b.buildControlPlanePrefixRoute(options, "/.well-known/pomerium/", requireStrictTransportSecurity), b.buildControlPlanePrefixRoute(options, "/.well-known/pomerium/"),
) )
// per #837, only add robots.txt if there are no unauthenticated routes // per #837, only add robots.txt if there are no unauthenticated routes
if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: host, Path: "/robots.txt"}) { if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: host, Path: "/robots.txt"}) {
routes = append(routes, b.buildControlPlanePathRoute(options, "/robots.txt", requireStrictTransportSecurity)) routes = append(routes, b.buildControlPlanePathRoute(options, "/robots.txt"))
} }
} }
authRoutes, err := b.buildPomeriumAuthenticateHTTPRoutes(options, host, requireStrictTransportSecurity) authRoutes, err := b.buildPomeriumAuthenticateHTTPRoutes(options, host)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -89,7 +86,6 @@ func (b *Builder) buildPomeriumHTTPRoutes(
func (b *Builder) buildPomeriumAuthenticateHTTPRoutes( func (b *Builder) buildPomeriumAuthenticateHTTPRoutes(
options *config.Options, options *config.Options,
host string, host string,
requireStrictTransportSecurity bool,
) ([]*envoy_config_route_v3.Route, error) { ) ([]*envoy_config_route_v3.Route, error) {
if !config.IsAuthenticate(options.Services) { if !config.IsAuthenticate(options.Services) {
return nil, nil return nil, nil
@ -105,8 +101,8 @@ func (b *Builder) buildPomeriumAuthenticateHTTPRoutes(
} }
if urlMatchesHost(u, host) { if urlMatchesHost(u, host) {
return []*envoy_config_route_v3.Route{ return []*envoy_config_route_v3.Route{
b.buildControlPlanePathRoute(options, options.AuthenticateCallbackPath, requireStrictTransportSecurity), b.buildControlPlanePathRoute(options, options.AuthenticateCallbackPath),
b.buildControlPlanePathRoute(options, "/", requireStrictTransportSecurity), b.buildControlPlanePathRoute(options, "/"),
}, nil }, nil
} }
} }
@ -116,7 +112,6 @@ func (b *Builder) buildPomeriumAuthenticateHTTPRoutes(
func (b *Builder) buildControlPlanePathRoute( func (b *Builder) buildControlPlanePathRoute(
options *config.Options, options *config.Options,
path string, path string,
requireStrictTransportSecurity bool,
) *envoy_config_route_v3.Route { ) *envoy_config_route_v3.Route {
r := &envoy_config_route_v3.Route{ r := &envoy_config_route_v3.Route{
Name: "pomerium-path-" + path, Name: "pomerium-path-" + path,
@ -130,7 +125,7 @@ func (b *Builder) buildControlPlanePathRoute(
}, },
}, },
}, },
ResponseHeadersToAdd: toEnvoyHeaders(options.GetSetResponseHeaders(requireStrictTransportSecurity)), ResponseHeadersToAdd: toEnvoyHeaders(options.GetSetResponseHeaders()),
TypedPerFilterConfig: map[string]*any.Any{ TypedPerFilterConfig: map[string]*any.Any{
PerFilterConfigExtAuthzName: PerFilterConfigExtAuthzContextExtensions(MakeExtAuthzContextExtensions(true, 0)), PerFilterConfigExtAuthzName: PerFilterConfigExtAuthzContextExtensions(MakeExtAuthzContextExtensions(true, 0)),
}, },
@ -141,7 +136,6 @@ func (b *Builder) buildControlPlanePathRoute(
func (b *Builder) buildControlPlanePrefixRoute( func (b *Builder) buildControlPlanePrefixRoute(
options *config.Options, options *config.Options,
prefix string, prefix string,
requireStrictTransportSecurity bool,
) *envoy_config_route_v3.Route { ) *envoy_config_route_v3.Route {
r := &envoy_config_route_v3.Route{ r := &envoy_config_route_v3.Route{
Name: "pomerium-prefix-" + prefix, Name: "pomerium-prefix-" + prefix,
@ -155,7 +149,7 @@ func (b *Builder) buildControlPlanePrefixRoute(
}, },
}, },
}, },
ResponseHeadersToAdd: toEnvoyHeaders(options.GetSetResponseHeaders(requireStrictTransportSecurity)), ResponseHeadersToAdd: toEnvoyHeaders(options.GetSetResponseHeaders()),
TypedPerFilterConfig: map[string]*any.Any{ TypedPerFilterConfig: map[string]*any.Any{
PerFilterConfigExtAuthzName: PerFilterConfigExtAuthzContextExtensions(MakeExtAuthzContextExtensions(true, 0)), PerFilterConfigExtAuthzName: PerFilterConfigExtAuthzContextExtensions(MakeExtAuthzContextExtensions(true, 0)),
}, },
@ -184,7 +178,6 @@ func getClusterStatsName(policy *config.Policy) string {
func (b *Builder) buildRoutesForPoliciesWithHost( func (b *Builder) buildRoutesForPoliciesWithHost(
cfg *config.Config, cfg *config.Config,
certs []tls.Certificate,
host string, host string,
) ([]*envoy_config_route_v3.Route, error) { ) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route var routes []*envoy_config_route_v3.Route
@ -199,7 +192,7 @@ func (b *Builder) buildRoutesForPoliciesWithHost(
continue continue
} }
policyRoutes, err := b.buildRoutesForPolicy(cfg, certs, &policy, fmt.Sprintf("policy-%d", i)) policyRoutes, err := b.buildRoutesForPolicy(cfg, &policy, fmt.Sprintf("policy-%d", i))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -211,7 +204,6 @@ func (b *Builder) buildRoutesForPoliciesWithHost(
func (b *Builder) buildRoutesForPoliciesWithCatchAll( func (b *Builder) buildRoutesForPoliciesWithCatchAll(
cfg *config.Config, cfg *config.Config,
certs []tls.Certificate,
) ([]*envoy_config_route_v3.Route, error) { ) ([]*envoy_config_route_v3.Route, error) {
var routes []*envoy_config_route_v3.Route var routes []*envoy_config_route_v3.Route
for i, p := range cfg.Options.GetAllPolicies() { for i, p := range cfg.Options.GetAllPolicies() {
@ -225,7 +217,7 @@ func (b *Builder) buildRoutesForPoliciesWithCatchAll(
continue continue
} }
policyRoutes, err := b.buildRoutesForPolicy(cfg, certs, &policy, fmt.Sprintf("policy-%d", i)) policyRoutes, err := b.buildRoutesForPolicy(cfg, &policy, fmt.Sprintf("policy-%d", i))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -237,7 +229,6 @@ func (b *Builder) buildRoutesForPoliciesWithCatchAll(
func (b *Builder) buildRoutesForPolicy( func (b *Builder) buildRoutesForPolicy(
cfg *config.Config, cfg *config.Config,
certs []tls.Certificate,
policy *config.Policy, policy *config.Policy,
name string, name string,
) ([]*envoy_config_route_v3.Route, error) { ) ([]*envoy_config_route_v3.Route, error) {
@ -250,14 +241,14 @@ func (b *Builder) buildRoutesForPolicy(
if strings.Contains(fromURL.Host, "*") { if strings.Contains(fromURL.Host, "*") {
// we have to match '*.example.com' and '*.example.com:443', so there are two routes // we have to match '*.example.com' and '*.example.com:443', so there are two routes
for _, host := range urlutil.GetDomainsForURL(fromURL) { for _, host := range urlutil.GetDomainsForURL(fromURL) {
route, err := b.buildRouteForPolicyAndMatch(cfg, certs, policy, name, mkRouteMatchForHost(policy, host)) route, err := b.buildRouteForPolicyAndMatch(cfg, policy, name, mkRouteMatchForHost(policy, host))
if err != nil { if err != nil {
return nil, err return nil, err
} }
routes = append(routes, route) routes = append(routes, route)
} }
} else { } else {
route, err := b.buildRouteForPolicyAndMatch(cfg, certs, policy, name, mkRouteMatch(policy)) route, err := b.buildRouteForPolicyAndMatch(cfg, policy, name, mkRouteMatch(policy))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -268,7 +259,6 @@ func (b *Builder) buildRoutesForPolicy(
func (b *Builder) buildRouteForPolicyAndMatch( func (b *Builder) buildRouteForPolicyAndMatch(
cfg *config.Config, cfg *config.Config,
certs []tls.Certificate,
policy *config.Policy, policy *config.Policy,
name string, name string,
match *envoy_config_route_v3.RouteMatch, match *envoy_config_route_v3.RouteMatch,
@ -283,15 +273,13 @@ func (b *Builder) buildRouteForPolicyAndMatch(
return nil, err return nil, err
} }
requireStrictTransportSecurity := cryptutil.HasCertificateForServerName(certs, fromURL.Hostname())
route := &envoy_config_route_v3.Route{ route := &envoy_config_route_v3.Route{
Name: name, Name: name,
Match: match, Match: match,
Metadata: &envoy_config_core_v3.Metadata{}, Metadata: &envoy_config_core_v3.Metadata{},
RequestHeadersToAdd: toEnvoyHeaders(policy.SetRequestHeaders), RequestHeadersToAdd: toEnvoyHeaders(policy.SetRequestHeaders),
RequestHeadersToRemove: getRequestHeadersToRemove(cfg.Options, policy), RequestHeadersToRemove: getRequestHeadersToRemove(cfg.Options, policy),
ResponseHeadersToAdd: toEnvoyHeaders(cfg.Options.GetSetResponseHeadersForPolicy(policy, requireStrictTransportSecurity)), ResponseHeadersToAdd: toEnvoyHeaders(cfg.Options.GetSetResponseHeadersForPolicy(policy)),
} }
if policy.Redirect != nil { if policy.Redirect != nil {
action, err := b.buildPolicyRouteRedirectAction(policy.Redirect) action, err := b.buildPolicyRouteRedirectAction(policy.Redirect)

View file

@ -100,7 +100,7 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
AuthenticateURLString: "https://authenticate.example.com", AuthenticateURLString: "https://authenticate.example.com",
AuthenticateCallbackPath: "/oauth2/callback", AuthenticateCallbackPath: "/oauth2/callback",
} }
routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com", false) routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[ testutil.AssertProtoJSONEqual(t, `[
@ -121,7 +121,7 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
AuthenticateURLString: "https://authenticate.example.com", AuthenticateURLString: "https://authenticate.example.com",
AuthenticateCallbackPath: "/oauth2/callback", AuthenticateCallbackPath: "/oauth2/callback",
} }
routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com", false) routes, err := b.buildPomeriumHTTPRoutes(options, "authenticate.example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, "null", routes) testutil.AssertProtoJSONEqual(t, "null", routes)
}) })
@ -137,7 +137,7 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
}}, }},
} }
_ = options.Policies[0].Validate() _ = options.Policies[0].Validate()
routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com", false) routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[ testutil.AssertProtoJSONEqual(t, `[
@ -163,7 +163,7 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
}}, }},
} }
_ = options.Policies[0].Validate() _ = options.Policies[0].Validate()
routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com", false) routes, err := b.buildPomeriumHTTPRoutes(options, "from.example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `[ testutil.AssertProtoJSONEqual(t, `[
@ -180,7 +180,7 @@ func Test_buildPomeriumHTTPRoutes(t *testing.T) {
func Test_buildControlPlanePathRoute(t *testing.T) { func Test_buildControlPlanePathRoute(t *testing.T) {
options := config.NewDefaultOptions() options := config.NewDefaultOptions()
b := &Builder{filemgr: filemgr.NewManager()} b := &Builder{filemgr: filemgr.NewManager()}
route := b.buildControlPlanePathRoute(options, "/hello/world", false) route := b.buildControlPlanePathRoute(options, "/hello/world")
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
{ {
"name": "pomerium-path-/hello/world", "name": "pomerium-path-/hello/world",
@ -224,7 +224,7 @@ func Test_buildControlPlanePathRoute(t *testing.T) {
func Test_buildControlPlanePrefixRoute(t *testing.T) { func Test_buildControlPlanePrefixRoute(t *testing.T) {
options := config.NewDefaultOptions() options := config.NewDefaultOptions()
b := &Builder{filemgr: filemgr.NewManager()} b := &Builder{filemgr: filemgr.NewManager()}
route := b.buildControlPlanePrefixRoute(options, "/hello/world/", false) route := b.buildControlPlanePrefixRoute(options, "/hello/world/")
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
{ {
"name": "pomerium-prefix-/hello/world/", "name": "pomerium-prefix-/hello/world/",
@ -311,7 +311,7 @@ func TestTimeouts(t *testing.T) {
AllowWebsockets: tc.allowWebsockets, AllowWebsockets: tc.allowWebsockets,
}, },
}, },
}}, nil, "example.com") }}, "example.com")
if !assert.NoError(t, err, "%v", tc) || !assert.Len(t, routes, 1, tc) || !assert.NotNil(t, routes[0].GetRoute(), "%v", tc) { if !assert.NoError(t, err, "%v", tc) || !assert.Len(t, routes, 1, tc) || !assert.NotNil(t, routes[0].GetRoute(), "%v", tc) {
continue continue
} }
@ -425,7 +425,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
UpstreamTimeout: &ten, UpstreamTimeout: &ten,
}, },
}, },
}}, nil, "example.com") }}, "example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
@ -1020,7 +1020,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
PassIdentityHeaders: true, PassIdentityHeaders: true,
}, },
}, },
}}, nil, "authenticate.example.com") }}, "authenticate.example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
@ -1109,7 +1109,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
UpstreamTimeout: &ten, UpstreamTimeout: &ten,
}, },
}, },
}}, nil, "example.com:22") }}, "example.com:22")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
@ -1278,7 +1278,7 @@ func Test_buildPolicyRoutes(t *testing.T) {
To: mustParseWeightedURLs(t, "https://to.example.com"), To: mustParseWeightedURLs(t, "https://to.example.com"),
}, },
}, },
}}, nil, "from.example.com") }}, "from.example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `
@ -1410,7 +1410,7 @@ func Test_buildPolicyRoutesRewrite(t *testing.T) {
HostPathRegexRewriteSubstitution: "\\1", HostPathRegexRewriteSubstitution: "\\1",
}, },
}, },
}}, nil, "example.com") }}, "example.com")
require.NoError(t, err) require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, ` testutil.AssertProtoJSONEqual(t, `

View file

@ -978,6 +978,11 @@ func (o *Options) GetCertificates() ([]tls.Certificate, error) {
return certs, nil return certs, nil
} }
// HasCertificates returns true if options has any certificates.
func (o *Options) HasCertificates() bool {
return o.Cert != "" || o.Key != "" || len(o.CertificateFiles) > 0 || o.CertFile != "" || o.KeyFile != ""
}
// GetSharedKey gets the decoded shared key. // GetSharedKey gets the decoded shared key.
func (o *Options) GetSharedKey() ([]byte, error) { func (o *Options) GetSharedKey() ([]byte, error) {
sharedKey := o.SharedKey sharedKey := o.SharedKey
@ -1017,18 +1022,22 @@ func (o *Options) GetGoogleCloudServerlessAuthenticationServiceAccount() string
} }
// GetSetResponseHeaders gets the SetResponseHeaders. // GetSetResponseHeaders gets the SetResponseHeaders.
func (o *Options) GetSetResponseHeaders(requireStrictTransportSecurity bool) map[string]string { func (o *Options) GetSetResponseHeaders() map[string]string {
return o.GetSetResponseHeadersForPolicy(nil, requireStrictTransportSecurity) return o.GetSetResponseHeadersForPolicy(nil)
} }
// GetSetResponseHeadersForPolicy gets the SetResponseHeaders for a policy. // GetSetResponseHeadersForPolicy gets the SetResponseHeaders for a policy.
func (o *Options) GetSetResponseHeadersForPolicy(policy *Policy, requireStrictTransportSecurity bool) map[string]string { func (o *Options) GetSetResponseHeadersForPolicy(policy *Policy) map[string]string {
hdrs := o.SetResponseHeaders hdrs := o.SetResponseHeaders
if hdrs == nil { if hdrs == nil {
hdrs = make(map[string]string) hdrs = make(map[string]string)
for k, v := range defaultSetResponseHeaders { for k, v := range defaultSetResponseHeaders {
hdrs[k] = v hdrs[k] = v
} }
if !o.HasCertificates() {
delete(hdrs, "Strict-Transport-Security")
}
} }
if _, ok := hdrs[DisableHeaderKey]; ok { if _, ok := hdrs[DisableHeaderKey]; ok {
hdrs = make(map[string]string) hdrs = make(map[string]string)
@ -1043,10 +1052,6 @@ func (o *Options) GetSetResponseHeadersForPolicy(policy *Policy, requireStrictTr
hdrs = make(map[string]string) hdrs = make(map[string]string)
} }
if !requireStrictTransportSecurity {
delete(hdrs, "Strict-Transport-Security")
}
return hdrs return hdrs
} }

View file

@ -752,20 +752,21 @@ func TestOptions_GetSetResponseHeaders(t *testing.T) {
assert.Equal(t, map[string]string{ assert.Equal(t, map[string]string{
"X-Frame-Options": "SAMEORIGIN", "X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block", "X-XSS-Protection": "1; mode=block",
}, options.GetSetResponseHeaders(false)) }, options.GetSetResponseHeaders())
}) })
t.Run("strict", func(t *testing.T) { t.Run("strict", func(t *testing.T) {
options := NewDefaultOptions() options := NewDefaultOptions()
options.Cert = "CERT"
assert.Equal(t, map[string]string{ assert.Equal(t, map[string]string{
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload", "Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
"X-Frame-Options": "SAMEORIGIN", "X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block", "X-XSS-Protection": "1; mode=block",
}, options.GetSetResponseHeaders(true)) }, options.GetSetResponseHeaders())
}) })
t.Run("disable", func(t *testing.T) { t.Run("disable", func(t *testing.T) {
options := NewDefaultOptions() options := NewDefaultOptions()
options.SetResponseHeaders = map[string]string{DisableHeaderKey: "1", "x-other": "xyz"} options.SetResponseHeaders = map[string]string{DisableHeaderKey: "1", "x-other": "xyz"}
assert.Equal(t, map[string]string{}, options.GetSetResponseHeaders(true)) assert.Equal(t, map[string]string{}, options.GetSetResponseHeaders())
}) })
} }
@ -776,7 +777,7 @@ func TestOptions_GetSetResponseHeadersForPolicy(t *testing.T) {
policy := &Policy{ policy := &Policy{
SetResponseHeaders: map[string]string{"x": "y"}, SetResponseHeaders: map[string]string{"x": "y"},
} }
assert.Equal(t, map[string]string{"x": "y"}, options.GetSetResponseHeadersForPolicy(policy, true)) assert.Equal(t, map[string]string{"x": "y"}, options.GetSetResponseHeadersForPolicy(policy))
}) })
} }