config: add support for downstream TLS server name (#3243)

* config: add support for downstream TLS server name

* fix whitespace

* fix whitespace

* add docs

* add tls_upstream_server_name and tls_downstream_server_name to config

* Update docs/reference/settings.yaml

Co-authored-by: Alex Fornuto <afornuto@pomerium.com>

* Update docs/reference/readme.md

Co-authored-by: Alex Fornuto <afornuto@pomerium.com>

* add deprecation notice

Co-authored-by: Alex Fornuto <afornuto@pomerium.com>
This commit is contained in:
Caleb Doxsey 2022-04-06 07:48:45 -06:00 committed by GitHub
parent e1403e33b4
commit b79f1e379f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 837 additions and 614 deletions

View file

@ -267,6 +267,9 @@ func (b *Builder) buildPolicyTransportSocket(
if policy.TLSServerName != "" {
sni = policy.TLSServerName
}
if policy.TLSUpstreamServerName != "" {
sni = policy.TLSUpstreamServerName
}
tlsContext := &envoy_extensions_transport_sockets_tls_v3.UpstreamTlsContext{
CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{
TlsParams: &envoy_extensions_transport_sockets_tls_v3.TlsParameters{
@ -320,9 +323,16 @@ func (b *Builder) buildPolicyValidationContext(
policy *config.Policy,
dst url.URL,
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
overrideName := ""
if policy.TLSServerName != "" {
overrideName = policy.TLSServerName
}
if policy.TLSUpstreamServerName != "" {
overrideName = policy.TLSUpstreamServerName
}
validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{
MatchTypedSubjectAltNames: []*envoy_extensions_transport_sockets_tls_v3.SubjectAltNameMatcher{
b.buildSubjectAltNameMatcher(&dst, policy.TLSServerName),
b.buildSubjectAltNameMatcher(&dst, overrideName),
},
}
if policy.TLSCustomCAFile != "" {

View file

@ -111,21 +111,75 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
"alpnProtocols": ["h2", "http/1.1"],
"tlsParams": {
"cipherSuites": [
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES256-GCM-SHA384",
"ECDHE-ECDSA-AES128-GCM-SHA256",
"ECDHE-RSA-AES128-GCM-SHA256",
"ECDHE-ECDSA-CHACHA20-POLY1305",
"ECDHE-RSA-CHACHA20-POLY1305",
"ECDHE-ECDSA-AES128-SHA",
"ECDHE-RSA-AES128-SHA",
"AES128-GCM-SHA256",
"AES128-SHA",
"ECDHE-ECDSA-AES256-SHA",
"ECDHE-RSA-AES256-SHA",
"AES256-GCM-SHA384",
"AES256-SHA"
],
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES256-GCM-SHA384",
"ECDHE-ECDSA-AES128-GCM-SHA256",
"ECDHE-RSA-AES128-GCM-SHA256",
"ECDHE-ECDSA-CHACHA20-POLY1305",
"ECDHE-RSA-CHACHA20-POLY1305",
"ECDHE-ECDSA-AES128-SHA",
"ECDHE-RSA-AES128-SHA",
"AES128-GCM-SHA256",
"AES128-SHA",
"ECDHE-ECDSA-AES256-SHA",
"ECDHE-RSA-AES256-SHA",
"AES256-GCM-SHA384",
"AES256-SHA"
],
"ecdhCurves": [
"X25519",
"P-256",
"P-384",
"P-521"
]
},
"validationContext": {
"matchTypedSubjectAltNames": [{
"sanType": "DNS",
"matcher": {
"exact": "use-this-name.example.com"
}
}],
"trustedCa": {
"filename": "`+rootCA+`"
}
}
},
"sni": "use-this-name.example.com"
}
}
`, ts)
})
t.Run("tls_upstream_server_name as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
TLSUpstreamServerName: "use-this-name.example.com",
}, *mustParseURL(t, "https://example.com"))
require.NoError(t, err)
testutil.AssertProtoJSONEqual(t, `
{
"name": "tls",
"typedConfig": {
"@type": "type.googleapis.com/envoy.extensions.transport_sockets.tls.v3.UpstreamTlsContext",
"commonTlsContext": {
"alpnProtocols": ["h2", "http/1.1"],
"tlsParams": {
"cipherSuites": [
"ECDHE-ECDSA-AES256-GCM-SHA384",
"ECDHE-RSA-AES256-GCM-SHA384",
"ECDHE-ECDSA-AES128-GCM-SHA256",
"ECDHE-RSA-AES128-GCM-SHA256",
"ECDHE-ECDSA-CHACHA20-POLY1305",
"ECDHE-RSA-CHACHA20-POLY1305",
"ECDHE-ECDSA-AES128-SHA",
"ECDHE-RSA-AES128-SHA",
"AES128-GCM-SHA256",
"AES128-SHA",
"ECDHE-ECDSA-AES256-SHA",
"ECDHE-RSA-AES256-SHA",
"AES256-GCM-SHA384",
"AES256-SHA"
],
"ecdhCurves": [
"X25519",
"P-256",

View file

@ -6,7 +6,6 @@ import (
"fmt"
"net"
"net/url"
"sort"
"time"
envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3"
@ -20,7 +19,6 @@ import (
envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3"
"github.com/golang/protobuf/ptypes/any"
"github.com/golang/protobuf/ptypes/wrappers"
"github.com/scylladb/go-set"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/emptypb"
"google.golang.org/protobuf/types/known/wrapperspb"
@ -28,7 +26,7 @@ import (
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/hashutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/internal/sets"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
@ -268,7 +266,7 @@ func (b *Builder) buildFilterChains(
var chains []*envoy_config_listener_v3.FilterChain
for _, domain := range tlsDomains {
routeableDomains, err := getRouteableDomainsForTLSDomain(options, addr, domain)
routeableDomains, err := getRouteableDomainsForTLSServerName(options, addr, domain)
if err != nil {
return nil, err
}
@ -718,23 +716,30 @@ func (b *Builder) buildDownstreamValidationContext(ctx context.Context,
return vc
}
func getRouteableDomainsForTLSDomain(options *config.Options, addr string, tlsDomain string) ([]string, error) {
allDomains, err := getAllRouteableDomains(options, addr)
if err != nil {
return nil, err
func getRouteableDomainsForTLSServerName(options *config.Options, addr string, tlsServerName string) ([]string, error) {
allDomains := sets.NewSortedString()
if addr == options.Addr {
domains, err := options.GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName)
if err != nil {
return nil, err
}
allDomains.Add(domains...)
}
var filtered []string
for _, domain := range allDomains {
if urlutil.StripPort(domain) == tlsDomain {
filtered = append(filtered, domain)
if addr == options.GetGRPCAddr() {
domains, err := options.GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName)
if err != nil {
return nil, err
}
allDomains.Add(domains...)
}
return filtered, nil
return allDomains.ToSlice(), nil
}
func getAllRouteableDomains(options *config.Options, addr string) ([]string, error) {
allDomains := set.NewStringSet()
allDomains := sets.NewSortedString()
if addr == options.Addr {
domains, err := options.GetAllRouteableHTTPDomains()
@ -752,10 +757,7 @@ func getAllRouteableDomains(options *config.Options, addr string) ([]string, err
allDomains.Add(domains...)
}
domains := allDomains.List()
sort.Strings(domains)
return domains, nil
return allDomains.ToSlice(), nil
}
func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {
@ -764,22 +766,16 @@ func getAllTLSDomains(options *config.Options, addr string) ([]string, error) {
return nil, err
}
lookup := map[string]struct{}{}
domains := sets.NewSortedString()
for _, hp := range allDomains {
if d, _, err := net.SplitHostPort(hp); err == nil {
lookup[d] = struct{}{}
domains.Add(d)
} else {
lookup[hp] = struct{}{}
domains.Add(hp)
}
}
domains := make([]string, 0, len(lookup))
for domain := range lookup {
domains = append(domains, domain)
}
sort.Strings(domains)
return domains, nil
return domains.ToSlice(), nil
}
func hostsMatchDomain(urls []*url.URL, host string) bool {

View file

@ -117,6 +117,10 @@ func NewPolicyHTTPTransport(options *Options, policy *Policy, disableHTTP2 bool)
tlsClientConfig.ServerName = policy.TLSServerName
isCustomClientConfig = true
}
if policy.TLSUpstreamServerName != "" {
tlsClientConfig.ServerName = policy.TLSUpstreamServerName
isCustomClientConfig = true
}
// We avoid setting a custom client config unless we have to as
// if TLSClientConfig is nil, the default configuration is used.

View file

@ -11,7 +11,6 @@ import (
"os"
"path/filepath"
"reflect"
"sort"
"strings"
"sync/atomic"
"time"
@ -29,6 +28,7 @@ import (
"github.com/pomerium/pomerium/internal/hashutil"
"github.com/pomerium/pomerium/internal/identity/oauth"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/sets"
"github.com/pomerium/pomerium/internal/telemetry"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/internal/urlutil"
@ -1041,7 +1041,13 @@ func (o *Options) GetCodecType() CodecType {
// GetAllRouteableGRPCDomains returns all the possible gRPC domains handled by the Pomerium options.
func (o *Options) GetAllRouteableGRPCDomains() ([]string, error) {
lookup := map[string]struct{}{}
return o.GetAllRouteableGRPCDomainsForTLSServerName("")
}
// GetAllRouteableGRPCDomainsForTLSServerName returns all the possible gRPC domains handled by the Pomerium options
// for the given TLS server name.
func (o *Options) GetAllRouteableGRPCDomainsForTLSServerName(tlsServerName string) ([]string, error) {
domains := sets.NewSortedString()
// authorize urls
if IsAll(o.Services) {
@ -1051,7 +1057,9 @@ func (o *Options) GetAllRouteableGRPCDomains() ([]string, error) {
}
for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
}
} else if IsAuthorize(o.Services) {
@ -1061,7 +1069,9 @@ func (o *Options) GetAllRouteableGRPCDomains() ([]string, error) {
}
for _, u := range authorizeURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
}
}
@ -1074,7 +1084,9 @@ func (o *Options) GetAllRouteableGRPCDomains() ([]string, error) {
}
for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
}
} else if IsDataBroker(o.Services) {
@ -1084,35 +1096,39 @@ func (o *Options) GetAllRouteableGRPCDomains() ([]string, error) {
}
for _, u := range dataBrokerURLs {
for _, h := range urlutil.GetDomainsForURL(*u) {
lookup[h] = struct{}{}
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
}
}
domains := make([]string, 0, len(lookup))
for domain := range lookup {
domains = append(domains, domain)
}
sort.Strings(domains)
return domains, nil
return domains.ToSlice(), nil
}
// GetAllRouteableHTTPDomains returns all the possible HTTP domains handled by the Pomerium options.
func (o *Options) GetAllRouteableHTTPDomains() ([]string, error) {
return o.GetAllRouteableHTTPDomainsForTLSServerName("")
}
// GetAllRouteableHTTPDomainsForTLSServerName returns all the possible HTTP domains handled by the Pomerium options
// for the given TLS server name.
func (o *Options) GetAllRouteableHTTPDomainsForTLSServerName(tlsServerName string) ([]string, error) {
forwardAuthURL, err := o.GetForwardAuthURL()
if err != nil {
return nil, err
}
lookup := map[string]struct{}{}
domains := sets.NewSortedString()
if IsAuthenticate(o.Services) {
authenticateURL, err := o.GetInternalAuthenticateURL()
if err != nil {
return nil, err
}
for _, h := range urlutil.GetDomainsForURL(*authenticateURL) {
lookup[h] = struct{}{}
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
}
@ -1120,23 +1136,32 @@ func (o *Options) GetAllRouteableHTTPDomains() ([]string, error) {
if IsProxy(o.Services) {
for _, policy := range o.GetAllPolicies() {
for _, h := range urlutil.GetDomainsForURL(*policy.Source.URL) {
lookup[h] = struct{}{}
if tlsServerName == "" ||
policy.TLSDownstreamServerName == tlsServerName ||
urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
if policy.TLSDownstreamServerName != "" {
tlsURL := policy.Source.URL.ResolveReference(&url.URL{Host: policy.TLSDownstreamServerName})
for _, h := range urlutil.GetDomainsForURL(*tlsURL) {
if tlsServerName == "" ||
urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
}
}
if forwardAuthURL != nil {
for _, h := range urlutil.GetDomainsForURL(*forwardAuthURL) {
lookup[h] = struct{}{}
if tlsServerName == "" || urlutil.StripPort(h) == tlsServerName {
domains.Add(h)
}
}
}
}
domains := make([]string, 0, len(lookup))
for domain := range lookup {
domains = append(domains, domain)
}
sort.Strings(domains)
return domains, nil
return domains.ToSlice(), nil
}
// Checksum returns the checksum of the current options struct

View file

@ -707,12 +707,14 @@ func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
p1.Validate()
p2 := Policy{From: "https://from2.example.com"}
p2.Validate()
p3 := Policy{From: "https://from3.example.com", TLSDownstreamServerName: "from.example.com"}
p3.Validate()
opts := &Options{
AuthenticateURLString: "https://authenticate.example.com",
AuthorizeURLString: "https://authorize.example.com",
DataBrokerURLString: "https://databroker.example.com",
Policies: []Policy{p1, p2},
Policies: []Policy{p1, p2, p3},
Services: "all",
}
domains, err := opts.GetAllRouteableHTTPDomains()
@ -721,10 +723,14 @@ func TestOptions_GetAllRouteableHTTPDomains(t *testing.T) {
assert.Equal(t, []string{
"authenticate.example.com",
"authenticate.example.com:443",
"from.example.com",
"from.example.com:443",
"from1.example.com",
"from1.example.com:443",
"from2.example.com",
"from2.example.com:443",
"from3.example.com",
"from3.example.com:443",
}, domains)
}

View file

@ -99,7 +99,9 @@ type Policy struct {
// if your backend is an HTTPS server with a valid certificate, but you
// want to communicate to the backend with an internal hostname (e.g.
// Docker container name).
TLSServerName string `mapstructure:"tls_server_name" yaml:"tls_server_name,omitempty"`
TLSServerName string `mapstructure:"tls_server_name" yaml:"tls_server_name,omitempty"`
TLSDownstreamServerName string `mapstructure:"tls_downstream_server_name" yaml:"tls_downstream_server_name,omitempty"`
TLSUpstreamServerName string `mapstructure:"tls_upstream_server_name" yaml:"tls_upstream_server_name,omitempty"`
// TLSCustomCA defines the root certificate to use with a given
// route when verifying server certificates.
@ -239,6 +241,8 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
AllowSPDY: pb.GetAllowSpdy(),
TLSSkipVerify: pb.GetTlsSkipVerify(),
TLSServerName: pb.GetTlsServerName(),
TLSDownstreamServerName: pb.GetTlsDownstreamServerName(),
TLSUpstreamServerName: pb.GetTlsUpstreamServerName(),
TLSCustomCA: pb.GetTlsCustomCa(),
TLSCustomCAFile: pb.GetTlsCustomCaFile(),
TLSClientCert: pb.GetTlsClientCert(),
@ -360,6 +364,8 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
AllowSpdy: p.AllowSPDY,
TlsSkipVerify: p.TLSSkipVerify,
TlsServerName: p.TLSServerName,
TlsUpstreamServerName: p.TLSUpstreamServerName,
TlsDownstreamServerName: p.TLSDownstreamServerName,
TlsCustomCa: p.TLSCustomCA,
TlsCustomCaFile: p.TLSCustomCAFile,
TlsClientCert: p.TLSClientCert,

View file

@ -1725,7 +1725,23 @@ TLS Skip Verification controls whether the Pomerium Proxy Service verifies the u
- Type: `string`
- Optional
TLS Server Name overrides the hostname specified in the `to` field. If set, this server name will be used to verify the certificate name. This is useful when the backend of your service is an TLS server with a valid certificate, but mismatched name.
**Deprecated**: this key has been replaced with `tls_upstream_server_name`.
### TLS Upstream Server Name
- Config File Key: `tls_upstream_server_name`
- Type: `string`
- Optional
TLS Upstream Server Name overrides the hostname specified in the `to` field. If set, this server name will be used to verify the certificate name. This is useful when the backend of your service is a TLS server with a valid certificate, but mismatched name.
### TLS Downstream Server Name
- Config File Key: `tls_downstream_server_name`
- Type: `string`
- Optional
TLS Downstream Server Name overrides the hostname specified in the `from` field. When a connection to Pomerium is made via TLS the `tls_downstream_server_name` will be used as the expected Server Name Indication, whereas the host part of the `from` field, will be expected to match the `Host` or `:authority` headers of the HTTP request.
### To

View file

@ -1882,7 +1882,23 @@ settings:
- Type: `string`
- Optional
doc: |
TLS Server Name overrides the hostname specified in the `to` field. If set, this server name will be used to verify the certificate name. This is useful when the backend of your service is an TLS server with a valid certificate, but mismatched name.
**Deprecated**: this key has been replaced with `tls_upstream_server_name`.
- name: "TLS Upstream Server Name"
keys: ["tls_upstream_server_name"]
attributes: |
- Config File Key: `tls_upstream_server_name`
- Type: `string`
- Optional
doc: |
TLS Upstream Server Name overrides the hostname specified in the `to` field. If set, this server name will be used to verify the certificate name. This is useful when the backend of your service is a TLS server with a valid certificate, but mismatched name.
- name: "TLS Downstream Server Name"
keys: ["tls_downstream_server_name"]
attributes: |
- Config File Key: `tls_downstream_server_name`
- Type: `string`
- Optional
doc: |
TLS Downstream Server Name overrides the hostname specified in the `from` field. When a connection to Pomerium is made via TLS the `tls_downstream_server_name` will be used as the expected Server Name Indication, whereas the host part of the `from` field, will be expected to match the `Host` or `:authority` headers of the HTTP request.
- name: "To"
keys: ["to"]
attributes: |

65
internal/sets/sorted.go Normal file
View file

@ -0,0 +1,65 @@
package sets
import "github.com/google/btree"
type stringItem string
func (item stringItem) Less(than btree.Item) bool {
return item < than.(stringItem)
}
// A SortedString is a set of strings with sorted iteration.
type SortedString struct {
b *btree.BTree
}
// NewSortedString creates a new sorted string set.
func NewSortedString() *SortedString {
return &SortedString{
b: btree.New(8),
}
}
// Add adds a string to the set.
func (s *SortedString) Add(elements ...string) {
for _, element := range elements {
s.b.ReplaceOrInsert(stringItem(element))
}
}
// Clear clears the set.
func (s *SortedString) Clear() {
s.b = btree.New(8)
}
// Delete deletes an element from the set.
func (s *SortedString) Delete(element string) {
s.b.Delete(stringItem(element))
}
// ForEach iterates over the set in ascending order.
func (s *SortedString) ForEach(callback func(element string) bool) {
s.b.Ascend(func(i btree.Item) bool {
return callback(string(i.(stringItem)))
})
}
// Has returns true if the elment is in the set.
func (s *SortedString) Has(element string) bool {
return s.b.Has(stringItem(element))
}
// Size returns the size of the set.
func (s *SortedString) Size() int {
return s.b.Len()
}
// ToSlice returns a slice of all the elements in the set.
func (s *SortedString) ToSlice() []string {
arr := make([]string, 0, s.Size())
s.b.Ascend(func(i btree.Item) bool {
arr = append(arr, string(i.(stringItem)))
return true
})
return arr
}

File diff suppressed because it is too large Load diff

View file

@ -78,6 +78,8 @@ message Route {
bool tls_skip_verify = 14;
string tls_server_name = 15;
string tls_upstream_server_name = 57;
string tls_downstream_server_name = 58;
string tls_custom_ca = 16;
string tls_custom_ca_file = 17;