diff --git a/config/envoyconfig/listeners.go b/config/envoyconfig/listeners.go index 03589620e..8e159cf96 100644 --- a/config/envoyconfig/listeners.go +++ b/config/envoyconfig/listeners.go @@ -1,6 +1,7 @@ package envoyconfig import ( + "bytes" "context" "encoding/base64" "fmt" @@ -150,7 +151,7 @@ func (b *Builder) buildMainListener(ctx context.Context, cfg *config.Config) (*e ServerNames: []string{tlsDomain}, } } - tlsContext := b.buildDownstreamTLSContext(ctx, cfg, tlsDomain) + tlsContext := b.buildDownstreamTLSContextWithValidation(ctx, cfg, tlsDomain) if tlsContext != nil { tlsConfig := marshalAny(tlsContext) filterChain.TransportSocket = &envoy_config_core_v3.TransportSocket{ @@ -670,40 +671,60 @@ func (b *Builder) buildDownstreamTLSContext(ctx context.Context, envoyCert := b.envoyTLSCertificateFromGoTLSCertificate(ctx, cert) return &envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext{ CommonTlsContext: &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext{ - TlsParams: tlsParams, - TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{envoyCert}, - AlpnProtocols: alpnProtocols, - ValidationContextType: b.buildDownstreamValidationContext(ctx, cfg, domain), + TlsParams: tlsParams, + TlsCertificates: []*envoy_extensions_transport_sockets_tls_v3.TlsCertificate{envoyCert}, + AlpnProtocols: alpnProtocols, }, } } -func (b *Builder) buildDownstreamValidationContext(ctx context.Context, +func (b *Builder) buildDownstreamTLSContextWithValidation( + ctx context.Context, cfg *config.Config, domain string, -) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext { - needsClientCert := false - - if ca, _ := cfg.Options.GetClientCA(); len(ca) > 0 { - needsClientCert = true +) *envoy_extensions_transport_sockets_tls_v3.DownstreamTlsContext { + dtc := b.buildDownstreamTLSContext(ctx, cfg, domain) + if clientCA := clientCAForDomain(ctx, cfg, domain); len(clientCA) > 0 { + dtc.CommonTlsContext.ValidationContextType = b.buildDownstreamValidationContext(ctx, cfg, clientCA) + dtc.RequireClientCertificate = wrapperspb.Bool(true) } - if !needsClientCert { - for _, p := range getPoliciesForDomain(cfg.Options, domain) { - if p.TLSDownstreamClientCA != "" { - needsClientCert = true - break - } + return dtc +} + +// clientCAForDomain returns a bundle of all per-route client CAs configured +// for the given domain, or else the globally configured client CA. +func clientCAForDomain(ctx context.Context, cfg *config.Config, domain string) []byte { + var bundle bytes.Buffer + for _, p := range getPoliciesForDomain(cfg.Options, domain) { + if p.TLSDownstreamClientCA == "" { + continue + } + ca, err := base64.StdEncoding.DecodeString(p.TLSDownstreamClientCA) + if err != nil { + log.Error(ctx).Err(err).Msg("invalid client CA") + continue + } + bundle.Write(ca) + // In case there are multiple CAs, make sure they are separated by a newline. + if ca[len(ca)-1] != '\n' { + bundle.WriteByte('\n') } } - - if !needsClientCert { - return nil + if bundle.Len() > 0 { + return bundle.Bytes() } + ca, _ := cfg.Options.GetClientCA() + return ca +} - // trusted_ca is left blank because we verify the client certificate in the authorize service +func (b *Builder) buildDownstreamValidationContext( + ctx context.Context, + cfg *config.Config, + clientCA []byte, +) *envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext { vc := &envoy_extensions_transport_sockets_tls_v3.CommonTlsContext_ValidationContext{ ValidationContext: &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{ - TrustChainVerification: envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext_ACCEPT_UNTRUSTED, + TrustedCa: b.filemgr.BytesDataSource("client-ca.pem", clientCA), }, } diff --git a/config/envoyconfig/listeners_test.go b/config/envoyconfig/listeners_test.go index f5d265f69..8d54ee8a9 100644 --- a/config/envoyconfig/listeners_test.go +++ b/config/envoyconfig/listeners_test.go @@ -2,6 +2,7 @@ package envoyconfig import ( "context" + "encoding/base64" "os" "path/filepath" "testing" @@ -562,15 +563,16 @@ func Test_buildMainHTTPConnectionManagerFilter(t *testing.T) { }`, filter) } -func Test_buildDownstreamTLSContext(t *testing.T) { +func Test_buildDownstreamTLSContextWithValidation(t *testing.T) { b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) cacheDir, _ := os.UserCacheDir() certFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-crt-354e49305a5a39414a545530374e58454e48334148524c4e324258463837364355564c4e4532464b54355139495547514a38.pem") keyFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "tls-key-3350415a38414e4e4a4655424e55393430474147324651433949384e485341334b5157364f424b4c5856365a545937383735.pem") + clientCAFileName := filepath.Join(cacheDir, "pomerium", "envoy", "files", "client-ca-3533485838304b593757424e3354425157494c4747433534384f474f3631364d5332554c3332485a483834334d50454c344a.pem") t.Run("no-validation", func(t *testing.T) { - downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ + downstreamTLSContext := b.buildDownstreamTLSContextWithValidation(context.Background(), &config.Config{Options: &config.Options{ Cert: aExampleComCert, Key: aExampleComKey, }}, "a.example.com") @@ -603,10 +605,10 @@ func Test_buildDownstreamTLSContext(t *testing.T) { }`, downstreamTLSContext) }) t.Run("client-ca", func(t *testing.T) { - downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ + downstreamTLSContext := b.buildDownstreamTLSContextWithValidation(context.Background(), &config.Config{Options: &config.Options{ Cert: aExampleComCert, Key: aExampleComKey, - ClientCA: "TEST", + ClientCA: "VEVTVAo=", // "TEST\n" (with a trailing newline) }}, "a.example.com") testutil.AssertProtoJSONEqual(t, `{ @@ -634,19 +636,22 @@ func Test_buildDownstreamTLSContext(t *testing.T) { } ], "validationContext": { - "trustChainVerification": "ACCEPT_UNTRUSTED" + "trustedCa": { + "filename": "`+clientCAFileName+`" + } } - } + }, + "requireClientCertificate": true }`, downstreamTLSContext) }) t.Run("policy-client-ca", func(t *testing.T) { - downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ + downstreamTLSContext := b.buildDownstreamTLSContextWithValidation(context.Background(), &config.Config{Options: &config.Options{ Cert: aExampleComCert, Key: aExampleComKey, Policies: []config.Policy{ { Source: &config.StringURL{URL: mustParseURL(t, "https://a.example.com:1234")}, - TLSDownstreamClientCA: "TEST", + TLSDownstreamClientCA: "VEVTVA==", // "TEST" (no trailing newline) }, }, }}, "a.example.com") @@ -676,13 +681,16 @@ func Test_buildDownstreamTLSContext(t *testing.T) { } ], "validationContext": { - "trustChainVerification": "ACCEPT_UNTRUSTED" + "trustedCa": { + "filename": "`+clientCAFileName+`" + } } - } + }, + "requireClientCertificate": true }`, downstreamTLSContext) }) t.Run("http1", func(t *testing.T) { - downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ + downstreamTLSContext := b.buildDownstreamTLSContextWithValidation(context.Background(), &config.Config{Options: &config.Options{ Cert: aExampleComCert, Key: aExampleComKey, CodecType: config.CodecTypeHTTP1, @@ -716,7 +724,7 @@ func Test_buildDownstreamTLSContext(t *testing.T) { }`, downstreamTLSContext) }) t.Run("http2", func(t *testing.T) { - downstreamTLSContext := b.buildDownstreamTLSContext(context.Background(), &config.Config{Options: &config.Options{ + downstreamTLSContext := b.buildDownstreamTLSContextWithValidation(context.Background(), &config.Config{Options: &config.Options{ Cert: aExampleComCert, Key: aExampleComKey, CodecType: config.CodecTypeHTTP2, @@ -751,6 +759,128 @@ func Test_buildDownstreamTLSContext(t *testing.T) { }) } +func Test_clientCAForDomain_globalAndPerRoute(t *testing.T) { + clientCA1 := []byte("client CA 1\n") + clientCA2 := []byte("client CA 2\n") + clientCA3 := []byte("client CA 3\n") + clientCA2and3 := []byte("client CA 2\nclient CA 3\n") + + b64 := base64.StdEncoding.EncodeToString + cfg := &config.Config{Options: &config.Options{ + ClientCA: b64(clientCA1), + Policies: []config.Policy{ + { + Source: &config.StringURL{ + URL: mustParseURL(t, "https://a.example.com:1234")}, + TLSDownstreamClientCA: b64(clientCA2), + }, + { + Source: &config.StringURL{ + URL: mustParseURL(t, "https://a.example.com:4567")}, + TLSDownstreamClientCA: b64(clientCA3), + }, + { + Source: &config.StringURL{ + URL: mustParseURL(t, "https://b.example.com")}, + TLSDownstreamClientCA: b64(clientCA3), + }, + { + Source: &config.StringURL{ + URL: mustParseURL(t, "https://c.example.com")}, + }, + }, + }} + + cases := []struct { + domain string + expected []byte + }{ + {"a.example.com", clientCA2and3}, + {"b.example.com", clientCA3}, + {"c.example.com", clientCA1}, // no per-route client CA override + {"any-other-domain", clientCA1}, + } + for i := range cases { + c := &cases[i] + t.Run(c.domain, func(t *testing.T) { + actual := clientCAForDomain(context.Background(), cfg, c.domain) + assert.Equal(t, c.expected, actual) + }) + } +} + +func Test_clientCAForDomain_perRouteOnly(t *testing.T) { + clientCA1 := []byte("client CA 1\n") + clientCA2 := []byte("client CA 2\n") + + b64 := base64.StdEncoding.EncodeToString + cfg := &config.Config{Options: &config.Options{ + Policies: []config.Policy{ + { + Source: &config.StringURL{ + URL: mustParseURL(t, "https://a.example.com")}, + }, + { + Source: &config.StringURL{ + URL: mustParseURL(t, "https://b.example.com")}, + TLSDownstreamClientCA: b64(clientCA2), + }, + { + Source: &config.StringURL{ + URL: mustParseURL(t, "https://c.example.com")}, + TLSDownstreamClientCA: b64(clientCA1), + }, + }, + }} + + cases := []struct { + domain string + expected []byte + }{ + {"a.example.com", nil}, + {"b.example.com", clientCA2}, + {"c.example.com", clientCA1}, + } + for i := range cases { + c := &cases[i] + t.Run(c.domain, func(t *testing.T) { + actual := clientCAForDomain(context.Background(), cfg, c.domain) + assert.Equal(t, c.expected, actual) + }) + } +} + +func Test_clientCAForDomain_newlines(t *testing.T) { + // Make sure multiple bundled per-route CAs are separated by newlines. + clientCA1 := []byte("client CA 1") + clientCA2 := []byte("client CA 2") + clientCA3 := []byte("client CA 3") + + b64 := base64.StdEncoding.EncodeToString + cfg := &config.Config{Options: &config.Options{ + Policies: []config.Policy{ + { + Source: &config.StringURL{ + URL: mustParseURL(t, "https://foo.example.com:123")}, + TLSDownstreamClientCA: b64(clientCA3), + }, + { + Source: &config.StringURL{ + URL: mustParseURL(t, "https://foo.example.com:456")}, + TLSDownstreamClientCA: b64(clientCA2), + }, + { + Source: &config.StringURL{ + URL: mustParseURL(t, "https://foo.example.com:789")}, + TLSDownstreamClientCA: b64(clientCA1), + }, + }, + }} + expected := []byte("client CA 3\nclient CA 2\nclient CA 1\n") + actual := clientCAForDomain(context.Background(), cfg, "foo.example.com") + assert.Equal(t, expected, actual) +} + func Test_getAllDomains(t *testing.T) { options := &config.Options{ Addr: "127.0.0.1:9000", diff --git a/integration/main_test.go b/integration/main_test.go index 736aeb5b5..fb2af82aa 100644 --- a/integration/main_test.go +++ b/integration/main_test.go @@ -170,3 +170,14 @@ func mustParseURL(str string) *url.URL { } return u } + +func loadCertificate(t *testing.T, certName string) tls.Certificate { + t.Helper() + certFile := filepath.Join(".", "tpl", "files", certName+".pem") + keyFile := filepath.Join(".", "tpl", "files", certName+"-key.pem") + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + t.Fatal(err) + } + return cert +} diff --git a/integration/policy_test.go b/integration/policy_test.go index fcbe563c5..297572470 100644 --- a/integration/policy_test.go +++ b/integration/policy_test.go @@ -11,8 +11,10 @@ import ( "github.com/gorilla/websocket" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/integration/flows" + "github.com/pomerium/pomerium/internal/httputil" ) func TestCORS(t *testing.T) { @@ -320,3 +322,166 @@ func TestLoadBalancer(t *testing.T) { distribution) }) } + +func TestDownstreamClientCA(t *testing.T) { + if ClusterType == "traefik" || ClusterType == "nginx" { + t.Skip() + return + } + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Minute*10) + defer clearTimeout() + + t.Run("no client cert", func(t *testing.T) { + req, err := http.NewRequestWithContext(ctx, "GET", + "https://client-cert-required.localhost.pomerium.io/", nil) + require.NoError(t, err) + + res, err := getClient().Do(req) + if assert.Error(t, err, "expected error when no certificate provided") { + assert.Contains(t, err.Error(), "remote error: tls: certificate required") + } else { + res.Body.Close() + } + }) + t.Run("untrusted client cert", func(t *testing.T) { + // Configure an http.Client with an untrusted client certificate. + cert := loadCertificate(t, "downstream-2-client") + client := *getClient() + tr := client.Transport.(*http.Transport).Clone() + // We need to use the GetClientCertificate callback here in order to + // present a certificate that doesn't match the advertised CA. + tr.TLSClientConfig.GetClientCertificate = + func(_ *tls.CertificateRequestInfo) (*tls.Certificate, error) { return &cert, nil } + client.Transport = tr + + req, err := http.NewRequestWithContext(ctx, "GET", + "https://client-cert-required.localhost.pomerium.io/", nil) + require.NoError(t, err) + + res, err := client.Do(req) + if assert.Error(t, err, "expected error for untrusted certificate") { + assert.Contains(t, err.Error(), "remote error: tls: unknown certificate authority") + } else { + res.Body.Close() + } + }) + t.Run("valid client cert", func(t *testing.T) { + // Configure an http.Client with a trusted client certificate. + cert := loadCertificate(t, "downstream-1-client") + client := *getClient() + tr := client.Transport.(*http.Transport).Clone() + tr.TLSClientConfig.Certificates = []tls.Certificate{cert} + client.Transport = tr + + res, err := flows.Authenticate(ctx, &client, + mustParseURL("https://client-cert-required.localhost.pomerium.io/"), + flows.WithEmail("user1@dogs.test")) + require.NoError(t, err, "unexpected http error") + defer res.Body.Close() + + var result struct { + Path string `json:"path"` + } + err = json.NewDecoder(res.Body).Decode(&result) + require.NoError(t, err) + assert.Equal(t, "/", result.Path) + }) +} + +func TestMultipleDownstreamClientCAs(t *testing.T) { + if ClusterType == "traefik" || ClusterType == "nginx" { + t.Skip() + return + } + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Minute*10) + defer clearTimeout() + + // Initializes a new http.Client with the given certificate. + newClientWithCert := func(certName string) *http.Client { + cert := loadCertificate(t, certName) + client := *getClient() + tr := client.Transport.(*http.Transport).Clone() + tr.TLSClientConfig.Certificates = []tls.Certificate{cert} + client.Transport = tr + return &client + } + + // Asserts that we get a successful JSON response from the httpdetails + // service, matching the given path. + assertOK := func(res *http.Response, err error, path string) { + require.NoError(t, err, "unexpected http error") + defer res.Body.Close() + + var result struct { + Path string `json:"path"` + } + err = json.NewDecoder(res.Body).Decode(&result) + require.NoError(t, err) + assert.Equal(t, path, result.Path) + } + + t.Run("cert1", func(t *testing.T) { + client := newClientWithCert("downstream-1-client") + + // With cert1, we should get a valid response for the /ca1 path. + res, err := flows.Authenticate(ctx, client, + mustParseURL("https://client-cert-overlap.localhost.pomerium.io/ca1"), + flows.WithEmail("user1@dogs.test")) + assertOK(res, err, "/ca1") + + // With cert1, we should get an HTTP error response for the /ca2 path. + req, err := http.NewRequestWithContext(ctx, "GET", + "https://client-cert-overlap.localhost.pomerium.io/ca2", nil) + require.NoError(t, err) + + res, err = client.Do(req) + require.NoError(t, err) + defer res.Body.Close() + assert.Equal(t, httputil.StatusInvalidClientCertificate, res.StatusCode) + }) + t.Run("cert2", func(t *testing.T) { + client := newClientWithCert("downstream-2-client") + + // With cert2, we should get an HTTP error response for the /ca1 path. + req, err := http.NewRequestWithContext(ctx, "GET", + "https://client-cert-overlap.localhost.pomerium.io/ca1", nil) + require.NoError(t, err) + + res, err := client.Do(req) + require.NoError(t, err, "unexpected http error") + defer res.Body.Close() + assert.Equal(t, httputil.StatusInvalidClientCertificate, res.StatusCode) + + // With cert2, we should get a valid response for the /ca2 path. + res, err = flows.Authenticate(ctx, client, + mustParseURL("https://client-cert-overlap.localhost.pomerium.io/ca2"), + flows.WithEmail("user1@dogs.test")) + assertOK(res, err, "/ca2") + }) + t.Run("no cert", func(t *testing.T) { + // Without a client certificate, connections should be rejected. + req, err := http.NewRequestWithContext(ctx, "GET", + "https://client-cert-overlap.localhost.pomerium.io/ca1", nil) + require.NoError(t, err) + + res, err := getClient().Do(req) + if assert.Error(t, err, "expected error when no certificate provided") { + assert.Contains(t, err.Error(), "remote error: tls: certificate required") + } else { + res.Body.Close() + } + + req, err = http.NewRequestWithContext(ctx, "GET", + "https://client-cert-overlap.localhost.pomerium.io/ca2", nil) + require.NoError(t, err) + + res, err = getClient().Do(req) + if assert.Error(t, err, "expected error when no certificate provided") { + assert.Contains(t, err.Error(), "remote error: tls: certificate required") + } else { + res.Body.Close() + } + }) +}