From 488bcd6f72f8a91f01fde67a7628200146630158 Mon Sep 17 00:00:00 2001 From: Denis Mishin Date: Thu, 5 Jan 2023 16:35:58 -0500 Subject: [PATCH] auto tls (#3856) --- config/config.go | 40 ++++++++++++ config/envoyconfig/clusters.go | 44 ++++++------- config/envoyconfig/clusters_test.go | 46 +++++++------- config/envoyconfig/envoyconfig.go | 31 +++------ config/layered.go | 68 ++++++++++++++++++++ config/layered_test.go | 44 +++++++++++++ config/options.go | 12 ++++ internal/fileutil/fileutil.go | 37 +++++++++++ internal/fileutil/fileutil_test.go | 33 ++++++++++ pkg/cmd/pomerium/pomerium.go | 11 +++- pkg/derivecert/config/builder.go | 91 +++++++++++++++++++++++++++ pkg/derivecert/config/builder_test.go | 57 +++++++++++++++++ 12 files changed, 447 insertions(+), 67 deletions(-) create mode 100644 config/layered.go create mode 100644 config/layered_test.go create mode 100644 pkg/derivecert/config/builder.go create mode 100644 pkg/derivecert/config/builder_test.go diff --git a/config/config.go b/config/config.go index c1cd2effa..5be5cfe89 100644 --- a/config/config.go +++ b/config/config.go @@ -1,8 +1,11 @@ package config import ( + "bytes" "crypto/tls" + "encoding/base64" + "github.com/pomerium/pomerium/internal/fileutil" "github.com/pomerium/pomerium/internal/hashutil" "github.com/pomerium/pomerium/internal/telemetry/metrics" "github.com/pomerium/pomerium/pkg/cryptutil" @@ -17,6 +20,12 @@ type Config struct { AutoCertificates []tls.Certificate EnvoyVersion string + // DerivedCertificates are TLS certificates derived from the shared secret + DerivedCertificates []tls.Certificate + // DerivedCAPEM is a PEM-encoded certificate authority + // derived from the shared secret + DerivedCAPEM []byte + // GRPCPort is the port the gRPC server is running on. GRPCPort string // HTTPPort is the port the HTTP server is running on. @@ -57,9 +66,39 @@ func (cfg *Config) Clone() *Config { ACMETLSALPNPort: cfg.ACMETLSALPNPort, MetricsScrapeEndpoints: endpoints, + + DerivedCertificates: cfg.DerivedCertificates, + DerivedCAPEM: cfg.DerivedCAPEM, } } +// AllCertificateAuthoritiesPEM returns all CAs as PEM bundle bytes +func (cfg *Config) AllCertificateAuthoritiesPEM() ([]byte, error) { + var combined bytes.Buffer + if cfg.Options.CA != "" { + bs, err := base64.StdEncoding.DecodeString(cfg.Options.CA) + if err != nil { + return nil, err + } + _, _ = combined.Write(bs) + _, _ = combined.WriteRune('\n') + } + + if cfg.Options.CAFile != "" { + if err := fileutil.CopyFileUpTo(&combined, cfg.Options.CAFile, 5<<20); err != nil { + return nil, err + } + _, _ = combined.WriteRune('\n') + } + + if cfg.DerivedCAPEM != nil { + _, _ = combined.Write(cfg.DerivedCAPEM) + _, _ = combined.WriteRune('\n') + } + + return combined.Bytes(), nil +} + // AllCertificates returns all the certificates in the config. func (cfg *Config) AllCertificates() ([]tls.Certificate, error) { optionCertificates, err := cfg.Options.GetCertificates() @@ -70,6 +109,7 @@ func (cfg *Config) AllCertificates() ([]tls.Certificate, error) { var certs []tls.Certificate certs = append(certs, optionCertificates...) certs = append(certs, cfg.AutoCertificates...) + certs = append(certs, cfg.DerivedCertificates...) return certs, nil } diff --git a/config/envoyconfig/clusters.go b/config/envoyconfig/clusters.go index 44f08cbf4..d4c75f9a7 100644 --- a/config/envoyconfig/clusters.go +++ b/config/envoyconfig/clusters.go @@ -51,22 +51,22 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env } } - controlGRPC, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-grpc", grpcURLs, upstreamProtocolHTTP2) + controlGRPC, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-grpc", grpcURLs, upstreamProtocolHTTP2) if err != nil { return nil, err } - controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto) + controlHTTP, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto) if err != nil { return nil, err } - controlMetrics, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto) + controlMetrics, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto) if err != nil { return nil, err } - authorizeCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2) + authorizeCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2) if err != nil { return nil, err } @@ -75,7 +75,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env authorizeCluster.OutlierDetection = grpcOutlierDetection() } - databrokerCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2) + databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2) if err != nil { return nil, err } @@ -113,7 +113,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env policy.EnvoyOpts = newDefaultEnvoyClusterConfig() } if len(policy.To) > 0 { - cluster, err := b.buildPolicyCluster(ctx, cfg.Options, &policy) + cluster, err := b.buildPolicyCluster(ctx, cfg, &policy) if err != nil { return nil, fmt.Errorf("policy #%d: %w", i, err) } @@ -131,16 +131,16 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env func (b *Builder) buildInternalCluster( ctx context.Context, - options *config.Options, + cfg *config.Config, name string, dsts []*url.URL, upstreamProtocol upstreamProtocolConfig, ) (*envoy_config_cluster_v3.Cluster, error) { cluster := newDefaultEnvoyClusterConfig() - cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily) + cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily) var endpoints []Endpoint for _, dst := range dsts { - ts, err := b.buildInternalTransportSocket(ctx, options, dst) + ts, err := b.buildInternalTransportSocket(ctx, cfg, dst) if err != nil { return nil, err } @@ -153,10 +153,12 @@ func (b *Builder) buildInternalCluster( return cluster, nil } -func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) { +func (b *Builder) buildPolicyCluster(ctx context.Context, cfg *config.Config, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) { cluster := new(envoy_config_cluster_v3.Cluster) proto.Merge(cluster, policy.EnvoyOpts) + options := cfg.Options + if options.EnvoyBindConfigFreebind.IsSet() || options.EnvoyBindConfigSourceAddress != "" { cluster.UpstreamBindConfig = new(envoy_config_core_v3.BindConfig) if options.EnvoyBindConfigFreebind.IsSet() { @@ -183,7 +185,7 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy) name := getClusterID(policy) - endpoints, err := b.buildPolicyEndpoints(ctx, options, policy) + endpoints, err := b.buildPolicyEndpoints(ctx, cfg, policy) if err != nil { return nil, err } @@ -205,12 +207,12 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option func (b *Builder) buildPolicyEndpoints( ctx context.Context, - options *config.Options, + cfg *config.Config, policy *config.Policy, ) ([]Endpoint, error) { var endpoints []Endpoint for _, dst := range policy.To { - ts, err := b.buildPolicyTransportSocket(ctx, options, policy, dst.URL) + ts, err := b.buildPolicyTransportSocket(ctx, cfg, policy, dst.URL) if err != nil { return nil, err } @@ -221,7 +223,7 @@ func (b *Builder) buildPolicyEndpoints( func (b *Builder) buildInternalTransportSocket( ctx context.Context, - options *config.Options, + cfg *config.Config, endpoint *url.URL, ) (*envoy_config_core_v3.TransportSocket, error) { if endpoint.Scheme != "https" { @@ -230,10 +232,10 @@ func (b *Builder) buildInternalTransportSocket( validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{ MatchTypedSubjectAltNames: []*envoy_extensions_transport_sockets_tls_v3.SubjectAltNameMatcher{ - b.buildSubjectAltNameMatcher(endpoint, options.OverrideCertificateName), + b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName), }, } - bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile) + bs, err := getCombinedCertificateAuthority(cfg) if err != nil { log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found") } else { @@ -246,7 +248,7 @@ func (b *Builder) buildInternalTransportSocket( ValidationContext: validationContext, }, }, - Sni: b.buildSubjectNameIndication(endpoint, options.OverrideCertificateName), + Sni: b.buildSubjectNameIndication(endpoint, cfg.Options.OverrideCertificateName), } tlsConfig := marshalAny(tlsContext) return &envoy_config_core_v3.TransportSocket{ @@ -259,7 +261,7 @@ func (b *Builder) buildInternalTransportSocket( func (b *Builder) buildPolicyTransportSocket( ctx context.Context, - options *config.Options, + cfg *config.Config, policy *config.Policy, dst url.URL, ) (*envoy_config_core_v3.TransportSocket, error) { @@ -269,7 +271,7 @@ func (b *Builder) buildPolicyTransportSocket( upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy) - vc, err := b.buildPolicyValidationContext(ctx, options, policy, dst) + vc, err := b.buildPolicyValidationContext(ctx, cfg, policy, dst) if err != nil { return nil, err } @@ -331,7 +333,7 @@ func (b *Builder) buildPolicyTransportSocket( func (b *Builder) buildPolicyValidationContext( ctx context.Context, - options *config.Options, + cfg *config.Config, policy *config.Policy, dst url.URL, ) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) { @@ -356,7 +358,7 @@ func (b *Builder) buildPolicyValidationContext( } validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs) } else { - bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile) + bs, err := getCombinedCertificateAuthority(cfg) if err != nil { log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found") } else { diff --git a/config/envoyconfig/clusters_test.go b/config/envoyconfig/clusters_test.go index e423380f3..c471e99fb 100644 --- a/config/envoyconfig/clusters_test.go +++ b/config/envoyconfig/clusters_test.go @@ -26,25 +26,25 @@ func Test_buildPolicyTransportSocket(t *testing.T) { customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-32484c314b584447463735303142374c31414145374650305a525539554938594d524855353757313942494d473847535231.pem") b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) - rootCABytes, _ := getCombinedCertificateAuthority("", "") + rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}}) rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename() o1 := config.NewDefaultOptions() o2 := config.NewDefaultOptions() o2.CA = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}) - combinedCABytes, _ := getCombinedCertificateAuthority(o2.CA, "") + combinedCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{CA: o2.CA}}) combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename() t.Run("insecure", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "http://example.com"), }, *mustParseURL(t, "http://example.com")) require.NoError(t, err) assert.Nil(t, ts) }) t.Run("host as sni", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), }, *mustParseURL(t, "https://example.com")) require.NoError(t, err) @@ -97,7 +97,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { `, ts) }) t.Run("tls_server_name as sni", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), TLSServerName: "use-this-name.example.com", }, *mustParseURL(t, "https://example.com")) @@ -151,7 +151,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { `, ts) }) t.Run("tls_upstream_server_name as sni", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), TLSUpstreamServerName: "use-this-name.example.com", }, *mustParseURL(t, "https://example.com")) @@ -205,7 +205,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { `, ts) }) t.Run("tls_skip_verify", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), TLSSkipVerify: true, }, *mustParseURL(t, "https://example.com")) @@ -260,7 +260,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { `, ts) }) t.Run("custom ca", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}), }, *mustParseURL(t, "https://example.com")) @@ -314,7 +314,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { `, ts) }) t.Run("options custom ca", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o2, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o2}, &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), }, *mustParseURL(t, "https://example.com")) require.NoError(t, err) @@ -368,7 +368,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { }) t.Run("client certificate", func(t *testing.T) { clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey) - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), ClientCertificate: clientCert, }, *mustParseURL(t, "https://example.com")) @@ -430,7 +430,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) { `, ts) }) t.Run("allow renegotiation", func(t *testing.T) { - ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{ + ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com"), TLSUpstreamAllowRenegotiation: true, }, *mustParseURL(t, "https://example.com")) @@ -489,11 +489,11 @@ func Test_buildPolicyTransportSocket(t *testing.T) { func Test_buildCluster(t *testing.T) { ctx := context.Background() b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) - rootCABytes, _ := getCombinedCertificateAuthority("", "") + rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}}) rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename() o1 := config.NewDefaultOptions() t.Run("insecure", func(t *testing.T) { - endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ + endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"), }) require.NoError(t, err) @@ -550,7 +550,7 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("secure", func(t *testing.T) { - endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ + endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "https://example.com", "https://example.com", @@ -718,7 +718,7 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("ip addresses", func(t *testing.T) { - endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ + endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"), }) require.NoError(t, err) @@ -773,7 +773,7 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("weights", func(t *testing.T) { - endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ + endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"), }) require.NoError(t, err) @@ -830,7 +830,7 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("localhost", func(t *testing.T) { - endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ + endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "http://localhost"), }) require.NoError(t, err) @@ -876,7 +876,7 @@ func Test_buildCluster(t *testing.T) { `, cluster) }) t.Run("outlier", func(t *testing.T) { - endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{ + endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{ To: mustParseWeightedURLs(t, "http://example.com"), }) require.NoError(t, err) @@ -959,7 +959,7 @@ func Test_bindConfig(t *testing.T) { b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil) t.Run("no bind config", func(t *testing.T) { - cluster, err := b.buildPolicyCluster(ctx, &config.Options{}, &config.Policy{ + cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{}}, &config.Policy{ From: "https://from.example.com", To: mustParseWeightedURLs(t, "https://to.example.com"), }) @@ -967,9 +967,9 @@ func Test_bindConfig(t *testing.T) { assert.Nil(t, cluster.UpstreamBindConfig) }) t.Run("freebind", func(t *testing.T) { - cluster, err := b.buildPolicyCluster(ctx, &config.Options{ + cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{ EnvoyBindConfigFreebind: null.BoolFrom(true), - }, &config.Policy{ + }}, &config.Policy{ From: "https://from.example.com", To: mustParseWeightedURLs(t, "https://to.example.com"), }) @@ -985,9 +985,9 @@ func Test_bindConfig(t *testing.T) { `, cluster.UpstreamBindConfig) }) t.Run("source address", func(t *testing.T) { - cluster, err := b.buildPolicyCluster(ctx, &config.Options{ + cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{ EnvoyBindConfigSourceAddress: "192.168.0.1", - }, &config.Policy{ + }}, &config.Policy{ From: "https://from.example.com", To: mustParseWeightedURLs(t, "https://to.example.com"), }) diff --git a/config/envoyconfig/envoyconfig.go b/config/envoyconfig/envoyconfig.go index 0edcfbb08..d92f92006 100644 --- a/config/envoyconfig/envoyconfig.go +++ b/config/envoyconfig/envoyconfig.go @@ -6,7 +6,6 @@ import ( "context" "crypto/tls" "crypto/x509" - "encoding/base64" "encoding/pem" "errors" "fmt" @@ -32,6 +31,7 @@ import ( "google.golang.org/protobuf/types/known/wrapperspb" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/fileutil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/pkg/cryptutil" ) @@ -211,36 +211,25 @@ func getRootCertificateAuthority() (string, error) { return rootCABundle.value, nil } -func getCombinedCertificateAuthority(customCA, customCAFile string) ([]byte, error) { +func getCombinedCertificateAuthority(cfg *config.Config) ([]byte, error) { rootFile, err := getRootCertificateAuthority() if err != nil { return nil, err } - combined, err := os.ReadFile(rootFile) - if err != nil { + var buf bytes.Buffer + if err := fileutil.CopyFileUpTo(&buf, rootFile, 5<<20); err != nil { return nil, fmt.Errorf("error reading root certificates: %w", err) } + buf.WriteRune('\n') - if customCA != "" { - bs, err := base64.StdEncoding.DecodeString(customCA) - if err != nil { - return nil, err - } - combined = append(combined, '\n') - combined = append(combined, bs...) + all, err := cfg.AllCertificateAuthoritiesPEM() + if err != nil { + return nil, fmt.Errorf("get all CA: %w", err) } + buf.Write(all) - if customCAFile != "" { - bs, err := os.ReadFile(customCAFile) - if err != nil { - return nil, err - } - combined = append(combined, '\n') - combined = append(combined, bs...) - } - - return combined, nil + return buf.Bytes(), nil } func marshalAny(msg proto.Message) *anypb.Any { diff --git a/config/layered.go b/config/layered.go new file mode 100644 index 000000000..5cdba509e --- /dev/null +++ b/config/layered.go @@ -0,0 +1,68 @@ +package config + +import ( + "context" + "fmt" + "sync" + + "github.com/pomerium/pomerium/internal/log" +) + +// LayeredSource is an abstraction for a ConfigSource that depends on an underlying config, +// and uses a builder to build the relevant part of the configuration +type LayeredSource struct { + mx sync.Mutex + + cfg *Config + underlying Source + builder func(*Config) error + + ChangeDispatcher +} + +var ( + _ = Source(&LayeredSource{}) +) + +// NewLayeredSource creates a new config source that is watching the underlying source for changes +func NewLayeredSource(ctx context.Context, underlying Source, builder func(*Config) error) (*LayeredSource, error) { + cfg := underlying.GetConfig().Clone() + src := LayeredSource{ + cfg: cfg, + underlying: underlying, + builder: builder, + } + + if err := builder(cfg); err != nil { + return nil, fmt.Errorf("build initial config: %w", err) + } + + underlying.OnConfigChange(ctx, src.onUnderlyingConfigChange) + + return &src, nil +} + +func (src *LayeredSource) onUnderlyingConfigChange(ctx context.Context, next *Config) { + cfg := src.rebuild(ctx, next) + src.Trigger(ctx, cfg) +} + +func (src *LayeredSource) rebuild(ctx context.Context, next *Config) *Config { + src.mx.Lock() + defer src.mx.Unlock() + + cfg := next.Clone() + if err := src.builder(cfg); err != nil { + log.Error(ctx).Err(err).Msg("building config") + cfg = next + } + src.cfg = cfg + return cfg +} + +// GetConfig returns currently stored config +func (src *LayeredSource) GetConfig() *Config { + src.mx.Lock() + defer src.mx.Unlock() + return src.cfg +} diff --git a/config/layered_test.go b/config/layered_test.go new file mode 100644 index 000000000..a8da007c8 --- /dev/null +++ b/config/layered_test.go @@ -0,0 +1,44 @@ +package config_test + +import ( + "context" + "errors" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/pomerium/pomerium/config" +) + +func TestLayeredConfig(t *testing.T) { + ctx := context.Background() + + t.Run("error on initial build", func(t *testing.T) { + underlying := config.NewStaticSource(&config.Config{}) + _, err := config.NewLayeredSource(ctx, underlying, func(c *config.Config) error { + return errors.New("error") + }) + require.Error(t, err) + }) + + t.Run("propagate new config on error", func(t *testing.T) { + underlying := config.NewStaticSource(&config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("a.com")}}) + layered, err := config.NewLayeredSource(ctx, underlying, func(c *config.Config) error { + if c.Options.GetDeriveInternalDomain() == "b.com" { + return errors.New("reject update") + } + return nil + }) + require.NoError(t, err) + + var dst *config.Config + layered.OnConfigChange(ctx, func(ctx context.Context, c *config.Config) { + dst = c + }) + + underlying.SetConfig(ctx, &config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("b.com")}}) + assert.Equal(t, "b.com", dst.Options.GetDeriveInternalDomain()) + }) +} diff --git a/config/options.go b/config/options.go index 159bf4f09..22500a889 100644 --- a/config/options.go +++ b/config/options.go @@ -162,6 +162,10 @@ type Options struct { CA string `mapstructure:"certificate_authority" yaml:"certificate_authority,omitempty"` CAFile string `mapstructure:"certificate_authority_file" yaml:"certificate_authority_file,omitempty"` + // DeriveInternalDomainCert is an option that would derive certificate authority + // and domain certificates from the shared key and use them for internal communication + DeriveInternalDomainCert *string `mapstructure:"derive_tls" yaml:"derive_tls,omitempty"` + // SigningKey is the private key used to add a JWT-signature to upstream requests. // https://www.pomerium.com/docs/topics/getting-users-identity.html SigningKey string `mapstructure:"signing_key" yaml:"signing_key,omitempty"` @@ -728,6 +732,14 @@ func (o *Options) Validate() error { return nil } +// GetDeriveInternalDomain returns an optional internal domain name to use for gRPC endpoint +func (o *Options) GetDeriveInternalDomain() string { + if o.DeriveInternalDomainCert == nil { + return "" + } + return *o.DeriveInternalDomainCert +} + // GetAuthenticateURL returns the AuthenticateURL in the options or 127.0.0.1. func (o *Options) GetAuthenticateURL() (*url.URL, error) { rawurl := o.AuthenticateURLString diff --git a/internal/fileutil/fileutil.go b/internal/fileutil/fileutil.go index 9e30acfed..6171a2399 100644 --- a/internal/fileutil/fileutil.go +++ b/internal/fileutil/fileutil.go @@ -3,7 +3,10 @@ package fileutil import ( + "bytes" "errors" + "fmt" + "io" "os" ) @@ -46,3 +49,37 @@ func Getwd() string { } return p } + +// ReadFileUpTo reads file up to given size +// it returns an error if file is larger than allowed maximum +func ReadFileUpTo(fname string, maxSize int64) ([]byte, error) { + var buf bytes.Buffer + if err := CopyFileUpTo(&buf, fname, maxSize); err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// CopyFileUpTo copies content of the file up to maxBytes +// it returns an error if file is larger than allowed maximum +func CopyFileUpTo(dst io.Writer, fname string, maxBytes int64) error { + fd, err := os.Open(fname) + if err != nil { + return fmt.Errorf("open %s: %w", fname, err) + } + defer func() { _ = fd.Close() }() + + fi, err := fd.Stat() + if err != nil { + return fmt.Errorf("stat %s: %w", fname, err) + } + if fi.Size() > maxBytes { + return fmt.Errorf("file %s size %d > max %d", fname, fi.Size(), maxBytes) + } + + if _, err := io.Copy(dst, fd); err != nil { + return fmt.Errorf("read %s: %w", fname, err) + } + + return nil +} diff --git a/internal/fileutil/fileutil_test.go b/internal/fileutil/fileutil_test.go index 4361827a1..aebab7c53 100644 --- a/internal/fileutil/fileutil_test.go +++ b/internal/fileutil/fileutil_test.go @@ -1,8 +1,15 @@ package fileutil import ( + "bytes" + "fmt" + "os" + "path" "strings" "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestIsReadableFile(t *testing.T) { @@ -45,3 +52,29 @@ func TestGetwd(t *testing.T) { }) } } + +func TestReadFileUpTo(t *testing.T) { + d := t.TempDir() + input := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9} + fname := path.Join(d, "test") + require.NoError(t, os.WriteFile(fname, input, 0600)) + + for _, tc := range []struct { + size int + expectError bool + }{ + {len(input) - 1, true}, + {len(input), false}, + {len(input) + 1, false}, + } { + t.Run(fmt.Sprint(tc), func(t *testing.T) { + out, err := ReadFileUpTo(fname, int64(tc.size)) + if tc.expectError { + require.Error(t, err) + return + } + require.NoError(t, err) + assert.True(t, bytes.Equal(input, out)) + }) + } +} diff --git a/pkg/cmd/pomerium/pomerium.go b/pkg/cmd/pomerium/pomerium.go index 50712d025..c732463c3 100644 --- a/pkg/cmd/pomerium/pomerium.go +++ b/pkg/cmd/pomerium/pomerium.go @@ -23,6 +23,7 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/registry" "github.com/pomerium/pomerium/internal/version" + derivecert_config "github.com/pomerium/pomerium/pkg/derivecert/config" "github.com/pomerium/pomerium/pkg/envoy" "github.com/pomerium/pomerium/pkg/envoy/files" "github.com/pomerium/pomerium/proxy" @@ -35,6 +36,10 @@ func Run(ctx context.Context, src config.Source) error { Str("version", version.FullVersion()). Msg("cmd/pomerium") + src, err := config.NewLayeredSource(ctx, src, derivecert_config.NewBuilder()) + if err != nil { + return err + } src = databroker.NewConfigSource(ctx, src) logMgr := config.NewLogManager(ctx, src) defer logMgr.Close() @@ -42,7 +47,7 @@ func Run(ctx context.Context, src config.Source) error { // trigger changes when underlying files are changed src = config.NewFileWatcherSource(src) - src, err := autocert.New(src) + src, err = autocert.New(src) if err != nil { return err } @@ -57,8 +62,10 @@ func Run(ctx context.Context, src config.Source) error { eventsMgr := events.New() + cfg := src.GetConfig() + // setup the control plane - controlPlane, err := controlplane.NewServer(src.GetConfig(), metricsMgr, eventsMgr) + controlPlane, err := controlplane.NewServer(cfg, metricsMgr, eventsMgr) if err != nil { return fmt.Errorf("error creating control plane: %w", err) } diff --git a/pkg/derivecert/config/builder.go b/pkg/derivecert/config/builder.go new file mode 100644 index 000000000..3a43ddf0d --- /dev/null +++ b/pkg/derivecert/config/builder.go @@ -0,0 +1,91 @@ +// Package config implements derived certs in the Pomerium Configuration +package config + +import ( + "bytes" + "crypto/tls" + "fmt" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/derivecert" +) + +type builder struct { + psk []byte + ca *derivecert.CA + caCertPEM []byte + + domain string + certs []tls.Certificate +} + +// NewBuilder returns a new derived certs config builder with caching +func NewBuilder() func(*config.Config) error { + return new(builder).Build +} + +func (x *builder) Build(cfg *config.Config) error { + if cfg.Options.DeriveInternalDomainCert == nil { + return nil + } + + psk, err := cfg.Options.GetSharedKey() + if err != nil { + return fmt.Errorf("shared key: %w", err) + } + + if err = x.buildCA(psk); err != nil { + return err + } + + if err = x.buildCert(*cfg.Options.DeriveInternalDomainCert); err != nil { + return err + } + + cfg.DerivedCAPEM = x.caCertPEM + cfg.DerivedCertificates = x.certs + return nil +} + +func (x *builder) buildCA(psk []byte) error { + if bytes.Equal(x.psk, psk) { + return nil + } + + ca, err := derivecert.NewCA(psk) + if err != nil { + return fmt.Errorf("building certificate authority from shared key: %w", err) + } + + pem, err := ca.PEM() + if err != nil { + return fmt.Errorf("encode derived CA to PEM: %w", err) + } + + x.psk = psk + x.ca = ca + x.caCertPEM = pem.Cert + + return nil +} + +func (x *builder) buildCert(domain string) error { + if x.domain == domain { + return nil + } + + certPEM, err := x.ca.NewServerCert([]string{domain}) + if err != nil { + return fmt.Errorf("generate cert: %w", err) + } + + cert, err := certPEM.TLS() + if err != nil { + return fmt.Errorf("parse TLS cert: %w", err) + } + + x.domain = domain + x.certs = []tls.Certificate{cert} + + return nil +} diff --git a/pkg/derivecert/config/builder_test.go b/pkg/derivecert/config/builder_test.go new file mode 100644 index 000000000..8a07a4e3e --- /dev/null +++ b/pkg/derivecert/config/builder_test.go @@ -0,0 +1,57 @@ +package config_test + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/cryptutil" + dcfg "github.com/pomerium/pomerium/pkg/derivecert/config" +) + +func TestBuild(t *testing.T) { + build := dcfg.NewBuilder() + + key := cryptutil.NewBase64Key() + + cfgA := config.Config{Options: &config.Options{SharedKey: key}} + t.Run("no domain requested", func(t *testing.T) { + require.NoError(t, build(&cfgA)) + assert.Empty(t, cfgA.DerivedCAPEM) + assert.Empty(t, cfgA.DerivedCertificates) + }) + + cfgA.Options.DeriveInternalDomainCert = proto.String("example.com") + t.Run("generate server cert", func(t *testing.T) { + require.NoError(t, build(&cfgA)) + assert.NotEmpty(t, cfgA.DerivedCAPEM) + assert.Len(t, cfgA.DerivedCertificates, 1) + }) + + cfgB := config.Config{Options: &config.Options{ + SharedKey: key, + DeriveInternalDomainCert: proto.String("example.com"), + }} + t.Run("caching", func(t *testing.T) { + require.NoError(t, build(&cfgB)) + assert.Equal(t, cfgA.DerivedCAPEM, cfgB.DerivedCAPEM) + assert.Equal(t, cfgA.DerivedCertificates[0].Certificate, cfgB.DerivedCertificates[0].Certificate) + }) + + t.Run("no domain requested after run", func(t *testing.T) { + cfg := config.Config{Options: &config.Options{SharedKey: key}} + require.NoError(t, build(&cfg)) + assert.Empty(t, cfg.DerivedCAPEM) + assert.Empty(t, cfg.DerivedCertificates) + }) + + cfgB.Options.DeriveInternalDomainCert = proto.String("example2.com") + t.Run("ca caching", func(t *testing.T) { + require.NoError(t, build(&cfgB)) + assert.Equal(t, cfgA.DerivedCAPEM, cfgB.DerivedCAPEM) + assert.NotEqual(t, cfgA.DerivedCertificates[0].Certificate, cfgB.DerivedCertificates[0].Certificate) + }) +}