mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-20 20:47:16 +02:00
envoyconfig: clean up filter chain construction (#3844)
* cleanup filter chain construction * rename domains to server names * rename to hosts * fix tests * update function name * improved domaain matching
This commit is contained in:
parent
a49f86d023
commit
67e12101fa
10 changed files with 405 additions and 246 deletions
|
@ -12,16 +12,16 @@ import (
|
|||
func (b *Builder) buildVirtualHost(
|
||||
options *config.Options,
|
||||
name string,
|
||||
domain string,
|
||||
host string,
|
||||
requireStrictTransportSecurity bool,
|
||||
) (*envoy_config_route_v3.VirtualHost, error) {
|
||||
vh := &envoy_config_route_v3.VirtualHost{
|
||||
Name: name,
|
||||
Domains: []string{domain},
|
||||
Domains: []string{host},
|
||||
}
|
||||
|
||||
// these routes match /.pomerium/... and similar paths
|
||||
rs, err := b.buildPomeriumHTTPRoutes(options, domain)
|
||||
rs, err := b.buildPomeriumHTTPRoutes(options, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -25,6 +25,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/sets"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
||||
|
@ -99,51 +100,50 @@ func (b *Builder) BuildListeners(ctx context.Context, cfg *config.Config) ([]*en
|
|||
}
|
||||
|
||||
func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*envoy_config_listener_v3.Listener, error) {
|
||||
listenerFilters := []*envoy_config_listener_v3.ListenerFilter{}
|
||||
li := newEnvoyListener("http-ingress")
|
||||
if cfg.Options.UseProxyProtocol {
|
||||
listenerFilters = append(listenerFilters, ProxyProtocolFilter())
|
||||
li.ListenerFilters = append(li.ListenerFilters, ProxyProtocolFilter())
|
||||
}
|
||||
|
||||
if cfg.Options.InsecureServer {
|
||||
allDomains, err := getAllRouteableDomains(cfg.Options, cfg.Options.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, allDomains, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
li := newEnvoyListener("http-ingress")
|
||||
li.Address = buildAddress(cfg.Options.Addr, 80)
|
||||
li.ListenerFilters = listenerFilters
|
||||
|
||||
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
li.FilterChains = []*envoy_config_listener_v3.FilterChain{{
|
||||
Filters: []*envoy_config_listener_v3.Filter{
|
||||
filter,
|
||||
},
|
||||
}}
|
||||
return li, nil
|
||||
}
|
||||
listenerFilters = append(listenerFilters, TLSInspectorFilter())
|
||||
} else {
|
||||
li.Address = buildAddress(cfg.Options.Addr, 443)
|
||||
li.ListenerFilters = append(li.ListenerFilters, TLSInspectorFilter())
|
||||
|
||||
chains, err := b.buildFilterChains(cfg, cfg.Options.Addr,
|
||||
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
|
||||
allCertificates, _ := cfg.AllCertificates()
|
||||
requireStrictTransportSecurity := cryptutil.HasCertificateForDomain(allCertificates, tlsDomain)
|
||||
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, httpDomains, requireStrictTransportSecurity)
|
||||
|
||||
serverNames, err := getAllServerNames(cfg, cfg.Options.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, serverName := range serverNames {
|
||||
requireStrictTransportSecurity := cryptutil.HasCertificateForServerName(allCertificates, serverName)
|
||||
filter, err := b.buildMainHTTPConnectionManagerFilter(cfg.Options, requireStrictTransportSecurity)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
filterChain := &envoy_config_listener_v3.FilterChain{
|
||||
Filters: []*envoy_config_listener_v3.Filter{filter},
|
||||
}
|
||||
if tlsDomain != "*" {
|
||||
if serverName != "*" {
|
||||
filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{
|
||||
ServerNames: []string{tlsDomain},
|
||||
ServerNames: []string{serverName},
|
||||
}
|
||||
}
|
||||
tlsContext := b.buildDownstreamTLSContext(ctx, cfg, tlsDomain)
|
||||
tlsContext := b.buildDownstreamTLSContext(ctx, cfg, serverName)
|
||||
if tlsContext != nil {
|
||||
tlsConfig := marshalAny(tlsContext)
|
||||
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
|
||||
|
@ -153,16 +153,9 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e
|
|||
},
|
||||
}
|
||||
}
|
||||
return filterChain, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
li.FilterChains = append(li.FilterChains, filterChain)
|
||||
}
|
||||
}
|
||||
|
||||
li := newEnvoyListener("https-ingress")
|
||||
li.Address = buildAddress(cfg.Options.Addr, 443)
|
||||
li.ListenerFilters = listenerFilters
|
||||
li.FilterChains = chains
|
||||
return li, nil
|
||||
}
|
||||
|
||||
|
@ -245,42 +238,8 @@ func (b *Builder) buildMetricsListener(cfg *config.Config) (*envoy_config_listen
|
|||
return li, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildFilterChains(
|
||||
cfg *config.Config, addr string,
|
||||
callback func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error),
|
||||
) ([]*envoy_config_listener_v3.FilterChain, error) {
|
||||
allDomains, err := getAllRouteableDomains(cfg.Options, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsDomains, err := getAllTLSDomains(cfg, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var chains []*envoy_config_listener_v3.FilterChain
|
||||
chains = append(chains, b.buildACMETLSALPNFilterChain())
|
||||
for _, domain := range tlsDomains {
|
||||
chain, err := callback(domain, allDomains)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chains = append(chains, chain)
|
||||
}
|
||||
|
||||
// if there are no SNI matches we match on HTTP host
|
||||
chain, err := callback("*", allDomains)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
chains = append(chains, chain)
|
||||
return chains, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
||||
options *config.Options,
|
||||
domains []string,
|
||||
requireStrictTransportSecurity bool,
|
||||
) (*envoy_config_listener_v3.Filter, error) {
|
||||
authorizeURLs, err := options.GetInternalAuthorizeURLs()
|
||||
|
@ -293,17 +252,22 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
|||
return nil, err
|
||||
}
|
||||
|
||||
allHosts, err := getAllRouteableHosts(options, options.Addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var virtualHosts []*envoy_config_route_v3.VirtualHost
|
||||
for _, domain := range domains {
|
||||
vh, err := b.buildVirtualHost(options, domain, domain, requireStrictTransportSecurity)
|
||||
for _, host := range allHosts {
|
||||
vh, err := b.buildVirtualHost(options, host, host, requireStrictTransportSecurity)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if options.Addr == options.GetGRPCAddr() {
|
||||
// if this is a gRPC service domain and we're supposed to handle that, add those routes
|
||||
if (config.IsAuthorize(options.Services) && hostsMatchDomain(authorizeURLs, domain)) ||
|
||||
(config.IsDataBroker(options.Services) && hostsMatchDomain(dataBrokerURLs, domain)) {
|
||||
if (config.IsAuthorize(options.Services) && urlsMatchHost(authorizeURLs, host)) ||
|
||||
(config.IsDataBroker(options.Services) && urlsMatchHost(dataBrokerURLs, host)) {
|
||||
rs, err := b.buildGRPCRoutes()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -314,7 +278,7 @@ func (b *Builder) buildMainHTTPConnectionManagerFilter(
|
|||
|
||||
// if we're the proxy, add all the policy routes
|
||||
if config.IsProxy(options.Services) {
|
||||
rs, err := b.buildPolicyRoutes(options, domain)
|
||||
rs, err := b.buildPolicyRoutes(options, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -445,28 +409,35 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e
|
|||
return nil, err
|
||||
}
|
||||
|
||||
if cfg.Options.GetGRPCInsecure() {
|
||||
li := newEnvoyListener("grpc-ingress")
|
||||
if cfg.Options.GetGRPCInsecure() {
|
||||
li.Address = buildAddress(cfg.Options.GetGRPCAddr(), 80)
|
||||
li.FilterChains = []*envoy_config_listener_v3.FilterChain{{
|
||||
Filters: []*envoy_config_listener_v3.Filter{
|
||||
filter,
|
||||
},
|
||||
}}
|
||||
return li, nil
|
||||
} else {
|
||||
li.Address = buildAddress(cfg.Options.GetGRPCAddr(), 443)
|
||||
li.ListenerFilters = []*envoy_config_listener_v3.ListenerFilter{
|
||||
TLSInspectorFilter(),
|
||||
}
|
||||
|
||||
chains, err := b.buildFilterChains(cfg, cfg.Options.GRPCAddr,
|
||||
func(tlsDomain string, httpDomains []string) (*envoy_config_listener_v3.FilterChain, error) {
|
||||
serverNames, err := getAllServerNames(cfg, cfg.Options.GRPCAddr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, serverName := range serverNames {
|
||||
filterChain := &envoy_config_listener_v3.FilterChain{
|
||||
Filters: []*envoy_config_listener_v3.Filter{filter},
|
||||
}
|
||||
if tlsDomain != "*" {
|
||||
if serverName != "*" {
|
||||
filterChain.FilterChainMatch = &envoy_config_listener_v3.FilterChainMatch{
|
||||
ServerNames: []string{tlsDomain},
|
||||
ServerNames: []string{serverName},
|
||||
}
|
||||
}
|
||||
tlsContext := b.buildDownstreamTLSContext(ctx, cfg, tlsDomain)
|
||||
tlsContext := b.buildDownstreamTLSContext(ctx, cfg, serverName)
|
||||
if tlsContext != nil {
|
||||
tlsConfig := marshalAny(tlsContext)
|
||||
filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{
|
||||
|
@ -476,18 +447,9 @@ func (b *Builder) buildGRPCListener(ctx context.Context, cfg *config.Config) (*e
|
|||
},
|
||||
}
|
||||
}
|
||||
return filterChain, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
li.FilterChains = append(li.FilterChains, filterChain)
|
||||
}
|
||||
|
||||
li := newEnvoyListener("grpc-ingress")
|
||||
li.Address = buildAddress(cfg.Options.GetGRPCAddr(), 443)
|
||||
li.ListenerFilters = []*envoy_config_listener_v3.ListenerFilter{
|
||||
TLSInspectorFilter(),
|
||||
}
|
||||
li.FilterChains = chains
|
||||
return li, nil
|
||||
}
|
||||
|
||||
|
@ -548,23 +510,23 @@ func (b *Builder) buildRouteConfiguration(name string, virtualHosts []*envoy_con
|
|||
|
||||
func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
|
||||
cfg *config.Config,
|
||||
domain string,
|
||||
serverName string,
|
||||
) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext {
|
||||
certs, err := cfg.AllCertificates()
|
||||
if err != nil {
|
||||
log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get all certificates from config")
|
||||
log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get all certificates from config")
|
||||
return nil
|
||||
}
|
||||
|
||||
cert, err := cryptutil.GetCertificateForDomain(certs, domain)
|
||||
cert, err := cryptutil.GetCertificateForServerName(certs, serverName)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Str("domain", domain).Err(err).Msg("failed to get certificate for domain")
|
||||
log.Warn(ctx).Str("domain", serverName).Err(err).Msg("failed to get certificate for domain")
|
||||
return nil
|
||||
}
|
||||
|
||||
err = validateCertificate(cert)
|
||||
if err != nil {
|
||||
log.Warn(ctx).Str("domain", domain).Err(err).Msg("invalid certificate for domain")
|
||||
log.Warn(ctx).Str("domain", serverName).Err(err).Msg("invalid certificate for domain")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -584,14 +546,14 @@ func (b *Builder) buildDownstreamTLSContext(ctx context.Context,
|
|||
TlsParams: tlsParams,
|
||||
TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{envoyCert},
|
||||
AlpnProtocols: alpnProtocols,
|
||||
ValidationContextType: b.buildDownstreamValidationContext(ctx, cfg, domain),
|
||||
ValidationContextType: b.buildDownstreamValidationContext(ctx, cfg, serverName),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
|
||||
cfg *config.Config,
|
||||
domain string,
|
||||
serverName string,
|
||||
) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext {
|
||||
needsClientCert := false
|
||||
|
||||
|
@ -599,7 +561,7 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
|
|||
needsClientCert = true
|
||||
}
|
||||
if !needsClientCert {
|
||||
for _, p := range getPoliciesForDomain(cfg.Options, domain) {
|
||||
for _, p := range getPoliciesForServerName(cfg.Options, serverName) {
|
||||
if p.TLSDownstreamClientCA != "" {
|
||||
needsClientCert = true
|
||||
break
|
||||
|
@ -632,40 +594,41 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
|
|||
return vc
|
||||
}
|
||||
|
||||
func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) {
|
||||
allDomains := sets.NewSorted[string]()
|
||||
func getAllRouteableHosts(options *config.Options, addr string) ([]string, error) {
|
||||
allHosts := sets.NewSorted[string]()
|
||||
|
||||
if addr == options.Addr {
|
||||
domains, err := options.GetAllRouteableHTTPDomains()
|
||||
hosts, err := options.GetAllRouteableHTTPHosts()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allDomains.Add(domains...)
|
||||
allHosts.Add(hosts...)
|
||||
}
|
||||
|
||||
if addr == options.GetGRPCAddr() {
|
||||
domains, err := options.GetAllRouteableGRPCDomains()
|
||||
hosts, err := options.GetAllRouteableGRPCHosts()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
allDomains.Add(domains...)
|
||||
allHosts.Add(hosts...)
|
||||
}
|
||||
|
||||
return allDomains.ToSlice(), nil
|
||||
return allHosts.ToSlice(), nil
|
||||
}
|
||||
|
||||
func getAllTLSDomains(cfg *config.Config, addr string) ([]string, error) {
|
||||
domains := sets.NewSorted[string]()
|
||||
func getAllServerNames(cfg *config.Config, addr string) ([]string, error) {
|
||||
serverNames := sets.NewSorted[string]()
|
||||
serverNames.Add("*")
|
||||
|
||||
routeableDomains, err := getAllRouteableDomains(cfg.Options, addr)
|
||||
routeableHosts, err := getAllRouteableHosts(cfg.Options, addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, hp := range routeableDomains {
|
||||
if d, _, err := net.SplitHostPort(hp); err == nil {
|
||||
domains.Add(d)
|
||||
for _, hp := range routeableHosts {
|
||||
if h, _, err := net.SplitHostPort(hp); err == nil {
|
||||
serverNames.Add(h)
|
||||
} else {
|
||||
domains.Add(hp)
|
||||
serverNames.Add(hp)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -674,24 +637,24 @@ func getAllTLSDomains(cfg *config.Config, addr string) ([]string, error) {
|
|||
return nil, err
|
||||
}
|
||||
for i := range certs {
|
||||
for _, domain := range cryptutil.GetCertificateDomains(&certs[i]) {
|
||||
domains.Add(domain)
|
||||
for _, domain := range cryptutil.GetCertificateServerNames(&certs[i]) {
|
||||
serverNames.Add(domain)
|
||||
}
|
||||
}
|
||||
|
||||
return domains.ToSlice(), nil
|
||||
return serverNames.ToSlice(), nil
|
||||
}
|
||||
|
||||
func hostsMatchDomain(urls []*url.URL, host string) bool {
|
||||
func urlsMatchHost(urls []*url.URL, host string) bool {
|
||||
for _, u := range urls {
|
||||
if hostMatchesDomain(u, host) {
|
||||
if urlMatchesHost(u, host) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func hostMatchesDomain(u *url.URL, host string) bool {
|
||||
func urlMatchesHost(u *url.URL, host string) bool {
|
||||
if u == nil {
|
||||
return false
|
||||
}
|
||||
|
@ -718,10 +681,10 @@ func hostMatchesDomain(u *url.URL, host string) bool {
|
|||
return h1 == h2 && p1 == p2
|
||||
}
|
||||
|
||||
func getPoliciesForDomain(options *config.Options, domain string) []config.Policy {
|
||||
func getPoliciesForServerName(options *config.Options, serverName string) []config.Policy {
|
||||
var policies []config.Policy
|
||||
for _, p := range options.GetAllPolicies() {
|
||||
if p.Source != nil && p.Source.URL.Hostname() == domain {
|
||||
if p.Source != nil && urlutil.MatchesServerName(*p.Source.URL, serverName) {
|
||||
policies = append(policies, p)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -129,7 +129,8 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
|
|||
options := config.NewDefaultOptions()
|
||||
options.SkipXffAppend = true
|
||||
options.XffNumTrustedHops = 1
|
||||
filter, err := b.buildMainHTTPConnectionManagerFilter(options, []string{"example.com"}, true)
|
||||
options.AuthenticateURLString = "https://authenticate.example.com"
|
||||
filter, err := b.buildMainHTTPConnectionManagerFilter(options, true)
|
||||
require.NoError(t, err)
|
||||
testutil.AssertProtoJSONEqual(t, `{
|
||||
"name": "envoy.filters.network.http_connection_manager",
|
||||
|
@ -220,8 +221,8 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
|
|||
"name": "main",
|
||||
"virtualHosts": [
|
||||
{
|
||||
"name": "example.com",
|
||||
"domains": ["example.com"],
|
||||
"name": "authenticate.example.com",
|
||||
"domains": ["authenticate.example.com"],
|
||||
"responseHeadersToAdd": [{
|
||||
"appendAction": "OVERWRITE_IF_EXISTS_OR_ADD",
|
||||
"header": {
|
||||
|
@ -366,6 +367,216 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) {
|
|||
"disabled": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "pomerium-path-/oauth2/callback",
|
||||
"match": {
|
||||
"path": "/oauth2/callback"
|
||||
},
|
||||
"route": {
|
||||
"cluster": "pomerium-control-plane-http"
|
||||
},
|
||||
"typedPerFilterConfig": {
|
||||
"envoy.filters.http.ext_authz": {
|
||||
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "pomerium-path-/",
|
||||
"match": {
|
||||
"path": "/"
|
||||
},
|
||||
"route": {
|
||||
"cluster": "pomerium-control-plane-http"
|
||||
},
|
||||
"typedPerFilterConfig": {
|
||||
"envoy.filters.http.ext_authz": {
|
||||
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "authenticate.example.com:443",
|
||||
"domains": ["authenticate.example.com:443"],
|
||||
"responseHeadersToAdd": [{
|
||||
"appendAction": "OVERWRITE_IF_EXISTS_OR_ADD",
|
||||
"header": {
|
||||
"key": "Strict-Transport-Security",
|
||||
"value": "max-age=31536000; includeSubDomains; preload"
|
||||
}
|
||||
},
|
||||
{
|
||||
"appendAction": "OVERWRITE_IF_EXISTS_OR_ADD",
|
||||
"header": {
|
||||
"key": "X-Frame-Options",
|
||||
"value": "SAMEORIGIN"
|
||||
}
|
||||
},
|
||||
{
|
||||
"appendAction": "OVERWRITE_IF_EXISTS_OR_ADD",
|
||||
"header": {
|
||||
"key": "X-XSS-Protection",
|
||||
"value": "1; mode=block"
|
||||
}
|
||||
}],
|
||||
"routes": [
|
||||
{
|
||||
"name": "pomerium-path-/.pomerium/jwt",
|
||||
"match": {
|
||||
"path": "/.pomerium/jwt"
|
||||
},
|
||||
"route": {
|
||||
"cluster": "pomerium-control-plane-http"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "pomerium-path-/.pomerium/webauthn",
|
||||
"match": {
|
||||
"path": "/.pomerium/webauthn"
|
||||
},
|
||||
"route": {
|
||||
"cluster": "pomerium-control-plane-http"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "pomerium-path-/ping",
|
||||
"match": {
|
||||
"path": "/ping"
|
||||
},
|
||||
"route": {
|
||||
"cluster": "pomerium-control-plane-http"
|
||||
},
|
||||
"typedPerFilterConfig": {
|
||||
"envoy.filters.http.ext_authz": {
|
||||
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "pomerium-path-/healthz",
|
||||
"match": {
|
||||
"path": "/healthz"
|
||||
},
|
||||
"route": {
|
||||
"cluster": "pomerium-control-plane-http"
|
||||
},
|
||||
"typedPerFilterConfig": {
|
||||
"envoy.filters.http.ext_authz": {
|
||||
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "pomerium-path-/.pomerium",
|
||||
"match": {
|
||||
"path": "/.pomerium"
|
||||
},
|
||||
"route": {
|
||||
"cluster": "pomerium-control-plane-http"
|
||||
},
|
||||
"typedPerFilterConfig": {
|
||||
"envoy.filters.http.ext_authz": {
|
||||
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "pomerium-prefix-/.pomerium/",
|
||||
"match": {
|
||||
"prefix": "/.pomerium/"
|
||||
},
|
||||
"route": {
|
||||
"cluster": "pomerium-control-plane-http"
|
||||
},
|
||||
"typedPerFilterConfig": {
|
||||
"envoy.filters.http.ext_authz": {
|
||||
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "pomerium-path-/.well-known/pomerium",
|
||||
"match": {
|
||||
"path": "/.well-known/pomerium"
|
||||
},
|
||||
"route": {
|
||||
"cluster": "pomerium-control-plane-http"
|
||||
},
|
||||
"typedPerFilterConfig": {
|
||||
"envoy.filters.http.ext_authz": {
|
||||
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "pomerium-prefix-/.well-known/pomerium/",
|
||||
"match": {
|
||||
"prefix": "/.well-known/pomerium/"
|
||||
},
|
||||
"route": {
|
||||
"cluster": "pomerium-control-plane-http"
|
||||
},
|
||||
"typedPerFilterConfig": {
|
||||
"envoy.filters.http.ext_authz": {
|
||||
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "pomerium-path-/robots.txt",
|
||||
"match": {
|
||||
"path": "/robots.txt"
|
||||
},
|
||||
"route": {
|
||||
"cluster": "pomerium-control-plane-http"
|
||||
},
|
||||
"typedPerFilterConfig": {
|
||||
"envoy.filters.http.ext_authz": {
|
||||
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "pomerium-path-/oauth2/callback",
|
||||
"match": {
|
||||
"path": "/oauth2/callback"
|
||||
},
|
||||
"route": {
|
||||
"cluster": "pomerium-control-plane-http"
|
||||
},
|
||||
"typedPerFilterConfig": {
|
||||
"envoy.filters.http.ext_authz": {
|
||||
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "pomerium-path-/",
|
||||
"match": {
|
||||
"path": "/"
|
||||
},
|
||||
"route": {
|
||||
"cluster": "pomerium-control-plane-http"
|
||||
},
|
||||
"typedPerFilterConfig": {
|
||||
"envoy.filters.http.ext_authz": {
|
||||
"@type": "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute",
|
||||
"disabled": true
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
|
@ -779,7 +990,7 @@ func Test_getAllDomains(t *testing.T) {
|
|||
}
|
||||
t.Run("routable", func(t *testing.T) {
|
||||
t.Run("http", func(t *testing.T) {
|
||||
actual, err := getAllRouteableDomains(options, "127.0.0.1:9000")
|
||||
actual, err := getAllRouteableHosts(options, "127.0.0.1:9000")
|
||||
require.NoError(t, err)
|
||||
expect := []string{
|
||||
"a.example.com",
|
||||
|
@ -794,7 +1005,7 @@ func Test_getAllDomains(t *testing.T) {
|
|||
assert.Equal(t, expect, actual)
|
||||
})
|
||||
t.Run("grpc", func(t *testing.T) {
|
||||
actual, err := getAllRouteableDomains(options, "127.0.0.1:9001")
|
||||
actual, err := getAllRouteableHosts(options, "127.0.0.1:9001")
|
||||
require.NoError(t, err)
|
||||
expect := []string{
|
||||
"authorize.example.com:9001",
|
||||
|
@ -805,7 +1016,7 @@ func Test_getAllDomains(t *testing.T) {
|
|||
t.Run("both", func(t *testing.T) {
|
||||
newOptions := *options
|
||||
newOptions.GRPCAddr = newOptions.Addr
|
||||
actual, err := getAllRouteableDomains(&newOptions, "127.0.0.1:9000")
|
||||
actual, err := getAllRouteableHosts(&newOptions, "127.0.0.1:9000")
|
||||
require.NoError(t, err)
|
||||
expect := []string{
|
||||
"a.example.com",
|
||||
|
@ -824,9 +1035,10 @@ func Test_getAllDomains(t *testing.T) {
|
|||
})
|
||||
t.Run("tls", func(t *testing.T) {
|
||||
t.Run("http", func(t *testing.T) {
|
||||
actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9000")
|
||||
actual, err := getAllServerNames(&config.Config{Options: options}, "127.0.0.1:9000")
|
||||
require.NoError(t, err)
|
||||
expect := []string{
|
||||
"*",
|
||||
"*.unknown.example.com",
|
||||
"a.example.com",
|
||||
"authenticate.example.com",
|
||||
|
@ -836,9 +1048,10 @@ func Test_getAllDomains(t *testing.T) {
|
|||
assert.Equal(t, expect, actual)
|
||||
})
|
||||
t.Run("grpc", func(t *testing.T) {
|
||||
actual, err := getAllTLSDomains(&config.Config{Options: options}, "127.0.0.1:9001")
|
||||
actual, err := getAllServerNames(&config.Config{Options: options}, "127.0.0.1:9001")
|
||||
require.NoError(t, err)
|
||||
expect := []string{
|
||||
"*",
|
||||
"*.unknown.example.com",
|
||||
"authorize.example.com",
|
||||
"cache.example.com",
|
||||
|
@ -848,14 +1061,31 @@ func Test_getAllDomains(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func Test_hostMatchesDomain(t *testing.T) {
|
||||
assert.True(t, hostMatchesDomain(mustParseURL(t, "http://example.com"), "example.com"))
|
||||
assert.True(t, hostMatchesDomain(mustParseURL(t, "http://example.com"), "example.com:80"))
|
||||
assert.True(t, hostMatchesDomain(mustParseURL(t, "https://example.com"), "example.com:443"))
|
||||
assert.True(t, hostMatchesDomain(mustParseURL(t, "https://example.com:443"), "example.com:443"))
|
||||
assert.True(t, hostMatchesDomain(mustParseURL(t, "https://example.com:443"), "example.com"))
|
||||
assert.False(t, hostMatchesDomain(mustParseURL(t, "http://example.com:81"), "example.com"))
|
||||
assert.False(t, hostMatchesDomain(mustParseURL(t, "http://example.com:81"), "example.com:80"))
|
||||
func Test_urlMatchesHost(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
sourceURL string
|
||||
host string
|
||||
matches bool
|
||||
}{
|
||||
{"no port", "http://example.com", "example.com", true},
|
||||
{"host http port", "http://example.com", "example.com:80", true},
|
||||
{"host https port", "https://example.com", "example.com:443", true},
|
||||
{"with port", "https://example.com:443", "example.com:443", true},
|
||||
{"url port", "https://example.com:443", "example.com", true},
|
||||
{"non standard port", "http://example.com:81", "example.com", false},
|
||||
{"non standard host port", "http://example.com:81", "example.com:80", false},
|
||||
} {
|
||||
tc := tc
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
assert.Equal(t, tc.matches, urlMatchesHost(mustParseURL(t, tc.sourceURL), tc.host),
|
||||
"urlMatchesHost(%s,%s)", tc.sourceURL, tc.host)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_buildRouteConfiguration(t *testing.T) {
|
||||
|
|
|
@ -47,12 +47,12 @@ func (b *Builder) buildGRPCRoutes() ([]*envoy_config_route_v3.Route, error) {
|
|||
}}, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) {
|
||||
func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, host string) ([]*envoy_config_route_v3.Route, error) {
|
||||
var routes []*envoy_config_route_v3.Route
|
||||
|
||||
// if this is the pomerium proxy in front of the the authenticate service, don't add
|
||||
// these routes since they will be handled by authenticate
|
||||
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain)
|
||||
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -70,7 +70,7 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string
|
|||
b.buildControlPlanePrefixRoute("/.well-known/pomerium/", false),
|
||||
)
|
||||
// per #837, only add robots.txt if there are no unauthenticated routes
|
||||
if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: domain, Path: "/robots.txt"}) {
|
||||
if !hasPublicPolicyMatchingURL(options, url.URL{Scheme: "https", Host: host, Path: "/robots.txt"}) {
|
||||
routes = append(routes, b.buildControlPlanePathRoute("/robots.txt", false))
|
||||
}
|
||||
}
|
||||
|
@ -79,7 +79,7 @@ func (b *Builder) buildPomeriumHTTPRoutes(options *config.Options, domain string
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) {
|
||||
if config.IsAuthenticate(options.Services) && urlMatchesHost(authenticateURL, host) {
|
||||
routes = append(routes,
|
||||
b.buildControlPlanePathRoute(options.AuthenticateCallbackPath, false),
|
||||
b.buildControlPlanePathRoute("/", false),
|
||||
|
@ -151,12 +151,12 @@ func getClusterStatsName(policy *config.Policy) string {
|
|||
return ""
|
||||
}
|
||||
|
||||
func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*envoy_config_route_v3.Route, error) {
|
||||
func (b *Builder) buildPolicyRoutes(options *config.Options, host string) ([]*envoy_config_route_v3.Route, error) {
|
||||
var routes []*envoy_config_route_v3.Route
|
||||
|
||||
for i, p := range options.GetAllPolicies() {
|
||||
policy := p
|
||||
if !hostMatchesDomain(policy.Source.URL, domain) {
|
||||
if !urlMatchesHost(policy.Source.URL, host) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -188,7 +188,7 @@ func (b *Builder) buildPolicyRoutes(options *config.Options, domain string) ([]*
|
|||
}
|
||||
|
||||
// disable authentication entirely when the proxy is fronting authenticate
|
||||
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, domain)
|
||||
isFrontingAuthenticate, err := isProxyFrontingAuthenticate(options, host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -497,13 +497,13 @@ func hasPublicPolicyMatchingURL(options *config.Options, requestURL url.URL) boo
|
|||
return false
|
||||
}
|
||||
|
||||
func isProxyFrontingAuthenticate(options *config.Options, domain string) (bool, error) {
|
||||
func isProxyFrontingAuthenticate(options *config.Options, host string) (bool, error) {
|
||||
authenticateURL, err := options.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if !config.IsAuthenticate(options.Services) && hostMatchesDomain(authenticateURL, domain) {
|
||||
if !config.IsAuthenticate(options.Services) && urlMatchesHost(authenticateURL, host) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -1015,15 +1015,9 @@ func (o *Options) GetCodecType() CodecType {
|
|||
return o.CodecType
|
||||
}
|
||||
|
||||
// GetAllRouteableGRPCDomains returns all the possible gRPC domains handled by the Pomerium options.
|
||||
func (o *Options) GetAllRouteableGRPCDomains() ([]string, error) {
|
||||
return o.GetAllRouteableGRPCDomainsForTLSServerName("")
|
||||
}
|
||||
|
||||
// GetAllRouteableGRPCDomainsForTLSServerName returns all the possible gRPC domains handled by the Pomerium options
|
||||
// for the given TLS server name.
|
||||
func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName string) ([]string, error) {
|
||||
domains := sets.NewSorted[string]()
|
||||
// GetAllRouteableGRPCHosts returns all the possible gRPC hosts handled by the Pomerium options.
|
||||
func (o *Options) GetAllRouteableGRPCHosts() ([]string, error) {
|
||||
hosts := sets.NewSorted[string]()
|
||||
|
||||
// authorize urls
|
||||
if IsAll(o.Services) {
|
||||
|
@ -1032,11 +1026,7 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin
|
|||
return nil, err
|
||||
}
|
||||
for _, u := range authorizeURLs {
|
||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
|
||||
domains.Add(h)
|
||||
}
|
||||
}
|
||||
hosts.Add(urlutil.GetDomainsForURL(*u)...)
|
||||
}
|
||||
} else if IsAuthorize(o.Services) {
|
||||
authorizeURLs, err := o.GetInternalAuthorizeURLs()
|
||||
|
@ -1044,11 +1034,7 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin
|
|||
return nil, err
|
||||
}
|
||||
for _, u := range authorizeURLs {
|
||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
|
||||
domains.Add(h)
|
||||
}
|
||||
}
|
||||
hosts.Add(urlutil.GetDomainsForURL(*u)...)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1059,11 +1045,7 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin
|
|||
return nil, err
|
||||
}
|
||||
for _, u := range dataBrokerURLs {
|
||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
|
||||
domains.Add(h)
|
||||
}
|
||||
}
|
||||
hosts.Add(urlutil.GetDomainsForURL(*u)...)
|
||||
}
|
||||
} else if IsDataBroker(o.Services) {
|
||||
dataBrokerURLs, err := o.GetInternalDataBrokerURLs()
|
||||
|
@ -1071,71 +1053,42 @@ func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName strin
|
|||
return nil, err
|
||||
}
|
||||
for _, u := range dataBrokerURLs {
|
||||
for _, h := range urlutil.GetDomainsForURL(*u) {
|
||||
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
|
||||
domains.Add(h)
|
||||
}
|
||||
}
|
||||
hosts.Add(urlutil.GetDomainsForURL(*u)...)
|
||||
}
|
||||
}
|
||||
|
||||
return domains.ToSlice(), nil
|
||||
return hosts.ToSlice(), nil
|
||||
}
|
||||
|
||||
// GetAllRouteableHTTPDomains returns all the possible HTTP domains handled by the Pomerium options.
|
||||
func (o *Options) GetAllRouteableHTTPDomains() ([]string, error) {
|
||||
return o.GetAllRouteableHTTPDomainsForTLSServerName("")
|
||||
}
|
||||
|
||||
// GetAllRouteableHTTPDomainsForTLSServerName returns all the possible HTTP domains handled by the Pomerium options
|
||||
// for the given TLS server name.
|
||||
func (o *Options) GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName string) ([]string, error) {
|
||||
domains := sets.NewSorted[string]()
|
||||
// GetAllRouteableHTTPHosts returns all the possible HTTP hosts handled by the Pomerium options.
|
||||
func (o *Options) GetAllRouteableHTTPHosts() ([]string, error) {
|
||||
hosts := sets.NewSorted[string]()
|
||||
if IsAuthenticate(o.Services) {
|
||||
authenticateURL, err := o.GetInternalAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) {
|
||||
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
|
||||
domains.Add(h)
|
||||
}
|
||||
}
|
||||
hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...)
|
||||
|
||||
authenticateURL, err = o.GetAuthenticateURL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) {
|
||||
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
|
||||
domains.Add(h)
|
||||
}
|
||||
}
|
||||
hosts.Add(urlutil.GetDomainsForURL(*authenticateURL)...)
|
||||
}
|
||||
|
||||
// policy urls
|
||||
if IsProxy(o.Services) {
|
||||
for _, policy := range o.GetAllPolicies() {
|
||||
for _, h := range urlutil.GetDomainsForURL(*policy.Source.URL) {
|
||||
if tlsServerName == "" ||
|
||||
policy.TLSDownstreamServerName == tlsServerName ||
|
||||
urlutil.StripPort(h) == tlsServerName {
|
||||
domains.Add(h)
|
||||
}
|
||||
}
|
||||
hosts.Add(urlutil.GetDomainsForURL(*policy.Source.URL)...)
|
||||
if policy.TLSDownstreamServerName != "" {
|
||||
tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
|
||||
for _, h := range urlutil.GetDomainsForURL(*tlsURL) {
|
||||
if tlsServerName == "" ||
|
||||
urlutil.StripPort(h) == tlsServerName {
|
||||
domains.Add(h)
|
||||
}
|
||||
}
|
||||
hosts.Add(urlutil.GetDomainsForURL(*tlsURL)...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return domains.ToSlice(), nil
|
||||
return hosts.ToSlice(), nil
|
||||
}
|
||||
|
||||
// GetClientSecret gets the client secret.
|
||||
|
|
|
@ -666,14 +666,14 @@ func TestOptions_GetOauthOptions(t *testing.T) {
|
|||
assert.Equal(t, u.Hostname(), oauthOptions.RedirectURL.Hostname())
|
||||
}
|
||||
|
||||
func TestOptions_GetAllRouteableGRPCDomains(t *testing.T) {
|
||||
func TestOptions_GetAllRouteableGRPCHosts(t *testing.T) {
|
||||
opts := &Options{
|
||||
AuthenticateURLString: "https://authenticate.example.com",
|
||||
AuthorizeURLString: "https://authorize.example.com",
|
||||
DataBrokerURLString: "https://databroker.example.com",
|
||||
Services: "all",
|
||||
}
|
||||
domains, err := opts.GetAllRouteableGRPCDomains()
|
||||
hosts, err := opts.GetAllRouteableGRPCHosts()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []string{
|
||||
|
@ -681,10 +681,10 @@ func TestOptions_GetAllRouteableGRPCDomains(t *testing.T) {
|
|||
"authorize.example.com:443",
|
||||
"databroker.example.com",
|
||||
"databroker.example.com:443",
|
||||
}, domains)
|
||||
}, hosts)
|
||||
}
|
||||
|
||||
func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
|
||||
func TestOptions_GetAllRouteableHTTPHosts(t *testing.T) {
|
||||
p1 := Policy{From: "https://from1.example.com"}
|
||||
p1.Validate()
|
||||
p2 := Policy{From: "https://from2.example.com"}
|
||||
|
@ -699,7 +699,7 @@ func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
|
|||
Policies: []Policy{p1, p2, p3},
|
||||
Services: "all",
|
||||
}
|
||||
domains, err := opts.GetAllRouteableHTTPDomains()
|
||||
hosts, err := opts.GetAllRouteableHTTPHosts()
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, []string{
|
||||
|
@ -713,7 +713,7 @@ func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
|
|||
"from2.example.com:443",
|
||||
"from3.example.com",
|
||||
"from3.example.com:443",
|
||||
}, domains)
|
||||
}, hosts)
|
||||
}
|
||||
|
||||
func TestOptions_ApplySettings(t *testing.T) {
|
||||
|
|
|
@ -8,6 +8,8 @@ import (
|
|||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/caddyserver/certmagic"
|
||||
)
|
||||
|
||||
const (
|
||||
|
@ -160,3 +162,8 @@ func GetExternalRequest(internalURL, externalURL *url.URL, r *http.Request) *htt
|
|||
}
|
||||
return er
|
||||
}
|
||||
|
||||
// MatchesServerName returnes true if the url's host matches the given server name.
|
||||
func MatchesServerName(u url.URL, serverName string) bool {
|
||||
return certmagic.MatchWildcard(u.Hostname(), serverName)
|
||||
}
|
||||
|
|
|
@ -166,3 +166,9 @@ func TestJoin(t *testing.T) {
|
|||
assert.Equal(t, "/x/y/z/", Join("/x", "/y/z/"))
|
||||
assert.Equal(t, "/x/y/z/", Join("/x/", "/y/z/"))
|
||||
}
|
||||
|
||||
func TestMatchesServerName(t *testing.T) {
|
||||
t.Run("wildcard", func(t *testing.T) {
|
||||
assert.True(t, MatchesServerName(MustParseAndValidateURL("https://domain.example.com"), "*.example.com"))
|
||||
})
|
||||
}
|
||||
|
|
|
@ -44,36 +44,36 @@ func GetCertPool(ca, caFile string) (*x509.CertPool, error) {
|
|||
return rootCAs, nil
|
||||
}
|
||||
|
||||
// GetCertificateForDomain returns the tls Certificate which matches the given domain name.
|
||||
// GetCertificateForServerName returns the tls Certificate which matches the given server name.
|
||||
// It should handle both exact matches and wildcard matches. If none of those match, the first certificate will be used.
|
||||
// Finally if there are no matching certificates one will be generated.
|
||||
func GetCertificateForDomain(certificates []tls.Certificate, domain string) (*tls.Certificate, error) {
|
||||
func GetCertificateForServerName(certificates []tls.Certificate, serverName string) (*tls.Certificate, error) {
|
||||
// first try a direct name match
|
||||
for i := range certificates {
|
||||
if matchesDomain(&certificates[i], domain) {
|
||||
if matchesServerName(&certificates[i], serverName) {
|
||||
return &certificates[i], nil
|
||||
}
|
||||
}
|
||||
|
||||
log.WarnNoTLSCertificate(domain)
|
||||
log.WarnNoTLSCertificate(serverName)
|
||||
|
||||
// finally fall back to a generated, self-signed certificate
|
||||
return GenerateSelfSignedCertificate(domain)
|
||||
return GenerateSelfSignedCertificate(serverName)
|
||||
}
|
||||
|
||||
// HasCertificateForDomain returns true if a TLS certificate matches the given domain.
|
||||
func HasCertificateForDomain(certificates []tls.Certificate, domain string) bool {
|
||||
// HasCertificateForServerName returns true if a TLS certificate matches the given server name.
|
||||
func HasCertificateForServerName(certificates []tls.Certificate, serverName string) bool {
|
||||
for i := range certificates {
|
||||
if matchesDomain(&certificates[i], domain) {
|
||||
if matchesServerName(&certificates[i], serverName) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetCertificateDomains gets all the certificate's matching domain names.
|
||||
// GetCertificateServerNames gets all the certificate's server names.
|
||||
// Will return an empty slice if certificate is nil, empty, or x509 parsing fails.
|
||||
func GetCertificateDomains(cert *tls.Certificate) []string {
|
||||
func GetCertificateServerNames(cert *tls.Certificate) []string {
|
||||
if cert == nil || len(cert.Certificate) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
@ -83,19 +83,19 @@ func GetCertificateDomains(cert *tls.Certificate) []string {
|
|||
return nil
|
||||
}
|
||||
|
||||
var domains []string
|
||||
var serverNames []string
|
||||
if xcert.Subject.CommonName != "" {
|
||||
domains = append(domains, xcert.Subject.CommonName)
|
||||
serverNames = append(serverNames, xcert.Subject.CommonName)
|
||||
}
|
||||
for _, dnsName := range xcert.DNSNames {
|
||||
if dnsName != "" {
|
||||
domains = append(domains, dnsName)
|
||||
serverNames = append(serverNames, dnsName)
|
||||
}
|
||||
}
|
||||
return domains
|
||||
return serverNames
|
||||
}
|
||||
|
||||
func matchesDomain(cert *tls.Certificate, domain string) bool {
|
||||
func matchesServerName(cert *tls.Certificate, serverName string) bool {
|
||||
if cert == nil || len(cert.Certificate) == 0 {
|
||||
return false
|
||||
}
|
||||
|
@ -105,12 +105,12 @@ func matchesDomain(cert *tls.Certificate, domain string) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
if certmagic.MatchWildcard(domain, xcert.Subject.CommonName) {
|
||||
if certmagic.MatchWildcard(serverName, xcert.Subject.CommonName) {
|
||||
return true
|
||||
}
|
||||
|
||||
for _, san := range xcert.DNSNames {
|
||||
if certmagic.MatchWildcard(domain, san) {
|
||||
if certmagic.MatchWildcard(serverName, san) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,10 +8,10 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGetCertificateForDomain(t *testing.T) {
|
||||
gen := func(t *testing.T, domain string) *tls.Certificate {
|
||||
cert, err := GenerateSelfSignedCertificate(domain)
|
||||
if !assert.NoError(t, err, "error generating certificate for: %s", domain) {
|
||||
func TestGetCertificateForServerName(t *testing.T) {
|
||||
gen := func(t *testing.T, serverName string) *tls.Certificate {
|
||||
cert, err := GenerateSelfSignedCertificate(serverName)
|
||||
if !assert.NoError(t, err, "error generating certificate for: %s", serverName) {
|
||||
t.FailNow()
|
||||
}
|
||||
return cert
|
||||
|
@ -23,7 +23,7 @@ func TestGetCertificateForDomain(t *testing.T) {
|
|||
*gen(t, "b.example.com"),
|
||||
}
|
||||
|
||||
found, err := GetCertificateForDomain(certs, "b.example.com")
|
||||
found, err := GetCertificateForServerName(certs, "b.example.com")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ func TestGetCertificateForDomain(t *testing.T) {
|
|||
*gen(t, "*.example.com"),
|
||||
}
|
||||
|
||||
found, err := GetCertificateForDomain(certs, "b.example.com")
|
||||
found, err := GetCertificateForServerName(certs, "b.example.com")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ func TestGetCertificateForDomain(t *testing.T) {
|
|||
*gen(t, "a.example.com"),
|
||||
}
|
||||
|
||||
found, err := GetCertificateForDomain(certs, "b.example.com")
|
||||
found, err := GetCertificateForServerName(certs, "b.example.com")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
@ -56,7 +56,7 @@ func TestGetCertificateForDomain(t *testing.T) {
|
|||
t.Run("generate", func(t *testing.T) {
|
||||
certs := []tls.Certificate{}
|
||||
|
||||
found, err := GetCertificateForDomain(certs, "b.example.com")
|
||||
found, err := GetCertificateForServerName(certs, "b.example.com")
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
|
@ -64,8 +64,8 @@ func TestGetCertificateForDomain(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestGetCertificateDomains(t *testing.T) {
|
||||
func TestGetCertificateServerNames(t *testing.T) {
|
||||
cert, err := GenerateSelfSignedCertificate("www.example.com")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, []string{"www.example.com"}, GetCertificateDomains(cert))
|
||||
assert.Equal(t, []string{"www.example.com"}, GetCertificateServerNames(cert))
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue