mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 18:06:34 +02:00
auto tls (#3856)
This commit is contained in:
parent
78fc4853db
commit
488bcd6f72
12 changed files with 447 additions and 67 deletions
|
@ -1,8 +1,11 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"encoding/base64"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/fileutil"
|
||||
"github.com/pomerium/pomerium/internal/hashutil"
|
||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
|
@ -17,6 +20,12 @@ type Config struct {
|
|||
AutoCertificates []tls.Certificate
|
||||
EnvoyVersion string
|
||||
|
||||
// DerivedCertificates are TLS certificates derived from the shared secret
|
||||
DerivedCertificates []tls.Certificate
|
||||
// DerivedCAPEM is a PEM-encoded certificate authority
|
||||
// derived from the shared secret
|
||||
DerivedCAPEM []byte
|
||||
|
||||
// GRPCPort is the port the gRPC server is running on.
|
||||
GRPCPort string
|
||||
// HTTPPort is the port the HTTP server is running on.
|
||||
|
@ -57,9 +66,39 @@ func (cfg *Config) Clone() *Config {
|
|||
ACMETLSALPNPort: cfg.ACMETLSALPNPort,
|
||||
|
||||
MetricsScrapeEndpoints: endpoints,
|
||||
|
||||
DerivedCertificates: cfg.DerivedCertificates,
|
||||
DerivedCAPEM: cfg.DerivedCAPEM,
|
||||
}
|
||||
}
|
||||
|
||||
// AllCertificateAuthoritiesPEM returns all CAs as PEM bundle bytes
|
||||
func (cfg *Config) AllCertificateAuthoritiesPEM() ([]byte, error) {
|
||||
var combined bytes.Buffer
|
||||
if cfg.Options.CA != "" {
|
||||
bs, err := base64.StdEncoding.DecodeString(cfg.Options.CA)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _ = combined.Write(bs)
|
||||
_, _ = combined.WriteRune('\n')
|
||||
}
|
||||
|
||||
if cfg.Options.CAFile != "" {
|
||||
if err := fileutil.CopyFileUpTo(&combined, cfg.Options.CAFile, 5<<20); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, _ = combined.WriteRune('\n')
|
||||
}
|
||||
|
||||
if cfg.DerivedCAPEM != nil {
|
||||
_, _ = combined.Write(cfg.DerivedCAPEM)
|
||||
_, _ = combined.WriteRune('\n')
|
||||
}
|
||||
|
||||
return combined.Bytes(), nil
|
||||
}
|
||||
|
||||
// AllCertificates returns all the certificates in the config.
|
||||
func (cfg *Config) AllCertificates() ([]tls.Certificate, error) {
|
||||
optionCertificates, err := cfg.Options.GetCertificates()
|
||||
|
@ -70,6 +109,7 @@ func (cfg *Config) AllCertificates() ([]tls.Certificate, error) {
|
|||
var certs []tls.Certificate
|
||||
certs = append(certs, optionCertificates...)
|
||||
certs = append(certs, cfg.AutoCertificates...)
|
||||
certs = append(certs, cfg.DerivedCertificates...)
|
||||
return certs, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -51,22 +51,22 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
|||
}
|
||||
}
|
||||
|
||||
controlGRPC, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-grpc", grpcURLs, upstreamProtocolHTTP2)
|
||||
controlGRPC, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-grpc", grpcURLs, upstreamProtocolHTTP2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
|
||||
controlHTTP, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
controlMetrics, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto)
|
||||
controlMetrics, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authorizeCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
|
||||
authorizeCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
|||
authorizeCluster.OutlierDetection = grpcOutlierDetection()
|
||||
}
|
||||
|
||||
databrokerCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
|
||||
databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -113,7 +113,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
|||
policy.EnvoyOpts = newDefaultEnvoyClusterConfig()
|
||||
}
|
||||
if len(policy.To) > 0 {
|
||||
cluster, err := b.buildPolicyCluster(ctx, cfg.Options, &policy)
|
||||
cluster, err := b.buildPolicyCluster(ctx, cfg, &policy)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("policy #%d: %w", i, err)
|
||||
}
|
||||
|
@ -131,16 +131,16 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
|||
|
||||
func (b *Builder) buildInternalCluster(
|
||||
ctx context.Context,
|
||||
options *config.Options,
|
||||
cfg *config.Config,
|
||||
name string,
|
||||
dsts []*url.URL,
|
||||
upstreamProtocol upstreamProtocolConfig,
|
||||
) (*envoy_config_cluster_v3.Cluster, error) {
|
||||
cluster := newDefaultEnvoyClusterConfig()
|
||||
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
|
||||
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily)
|
||||
var endpoints []Endpoint
|
||||
for _, dst := range dsts {
|
||||
ts, err := b.buildInternalTransportSocket(ctx, options, dst)
|
||||
ts, err := b.buildInternalTransportSocket(ctx, cfg, dst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -153,10 +153,12 @@ func (b *Builder) buildInternalCluster(
|
|||
return cluster, nil
|
||||
}
|
||||
|
||||
func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
|
||||
func (b *Builder) buildPolicyCluster(ctx context.Context, cfg *config.Config, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
|
||||
cluster := new(envoy_config_cluster_v3.Cluster)
|
||||
proto.Merge(cluster, policy.EnvoyOpts)
|
||||
|
||||
options := cfg.Options
|
||||
|
||||
if options.EnvoyBindConfigFreebind.IsSet() || options.EnvoyBindConfigSourceAddress != "" {
|
||||
cluster.UpstreamBindConfig = new(envoy_config_core_v3.BindConfig)
|
||||
if options.EnvoyBindConfigFreebind.IsSet() {
|
||||
|
@ -183,7 +185,7 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
|
|||
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
|
||||
|
||||
name := getClusterID(policy)
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, options, policy)
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, cfg, policy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -205,12 +207,12 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
|
|||
|
||||
func (b *Builder) buildPolicyEndpoints(
|
||||
ctx context.Context,
|
||||
options *config.Options,
|
||||
cfg *config.Config,
|
||||
policy *config.Policy,
|
||||
) ([]Endpoint, error) {
|
||||
var endpoints []Endpoint
|
||||
for _, dst := range policy.To {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, options, policy, dst.URL)
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, cfg, policy, dst.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -221,7 +223,7 @@ func (b *Builder) buildPolicyEndpoints(
|
|||
|
||||
func (b *Builder) buildInternalTransportSocket(
|
||||
ctx context.Context,
|
||||
options *config.Options,
|
||||
cfg *config.Config,
|
||||
endpoint *url.URL,
|
||||
) (*envoy_config_core_v3.TransportSocket, error) {
|
||||
if endpoint.Scheme != "https" {
|
||||
|
@ -230,10 +232,10 @@ func (b *Builder) buildInternalTransportSocket(
|
|||
|
||||
validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{
|
||||
MatchTypedSubjectAltNames: []*envoy_extensions_transport_sockets_tls_v3.SubjectAltNameMatcher{
|
||||
b.buildSubjectAltNameMatcher(endpoint, options.OverrideCertificateName),
|
||||
b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName),
|
||||
},
|
||||
}
|
||||
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile)
|
||||
bs, err := getCombinedCertificateAuthority(cfg)
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||
} else {
|
||||
|
@ -246,7 +248,7 @@ func (b *Builder) buildInternalTransportSocket(
|
|||
ValidationContext: validationContext,
|
||||
},
|
||||
},
|
||||
Sni: b.buildSubjectNameIndication(endpoint, options.OverrideCertificateName),
|
||||
Sni: b.buildSubjectNameIndication(endpoint, cfg.Options.OverrideCertificateName),
|
||||
}
|
||||
tlsConfig := marshalAny(tlsContext)
|
||||
return &envoy_config_core_v3.TransportSocket{
|
||||
|
@ -259,7 +261,7 @@ func (b *Builder) buildInternalTransportSocket(
|
|||
|
||||
func (b *Builder) buildPolicyTransportSocket(
|
||||
ctx context.Context,
|
||||
options *config.Options,
|
||||
cfg *config.Config,
|
||||
policy *config.Policy,
|
||||
dst url.URL,
|
||||
) (*envoy_config_core_v3.TransportSocket, error) {
|
||||
|
@ -269,7 +271,7 @@ func (b *Builder) buildPolicyTransportSocket(
|
|||
|
||||
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
|
||||
|
||||
vc, err := b.buildPolicyValidationContext(ctx, options, policy, dst)
|
||||
vc, err := b.buildPolicyValidationContext(ctx, cfg, policy, dst)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -331,7 +333,7 @@ func (b *Builder) buildPolicyTransportSocket(
|
|||
|
||||
func (b *Builder) buildPolicyValidationContext(
|
||||
ctx context.Context,
|
||||
options *config.Options,
|
||||
cfg *config.Config,
|
||||
policy *config.Policy,
|
||||
dst url.URL,
|
||||
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
|
||||
|
@ -356,7 +358,7 @@ func (b *Builder) buildPolicyValidationContext(
|
|||
}
|
||||
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
|
||||
} else {
|
||||
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile)
|
||||
bs, err := getCombinedCertificateAuthority(cfg)
|
||||
if err != nil {
|
||||
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||
} else {
|
||||
|
|
|
@ -26,25 +26,25 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-32484c314b584447463735303142374c31414145374650305a525539554938594d524855353757313942494d473847535231.pem")
|
||||
|
||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||
rootCABytes, _ := getCombinedCertificateAuthority("", "")
|
||||
rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}})
|
||||
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
||||
|
||||
o1 := config.NewDefaultOptions()
|
||||
o2 := config.NewDefaultOptions()
|
||||
o2.CA = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0})
|
||||
|
||||
combinedCABytes, _ := getCombinedCertificateAuthority(o2.CA, "")
|
||||
combinedCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{CA: o2.CA}})
|
||||
combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename()
|
||||
|
||||
t.Run("insecure", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://example.com"),
|
||||
}, *mustParseURL(t, "http://example.com"))
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, ts)
|
||||
})
|
||||
t.Run("host as sni", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
require.NoError(t, err)
|
||||
|
@ -97,7 +97,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("tls_server_name as sni", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
TLSServerName: "use-this-name.example.com",
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -151,7 +151,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("tls_upstream_server_name as sni", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
TLSUpstreamServerName: "use-this-name.example.com",
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -205,7 +205,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("tls_skip_verify", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
TLSSkipVerify: true,
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -260,7 +260,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("custom ca", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -314,7 +314,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("options custom ca", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o2, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o2}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
require.NoError(t, err)
|
||||
|
@ -368,7 +368,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
})
|
||||
t.Run("client certificate", func(t *testing.T) {
|
||||
clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey)
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
ClientCertificate: clientCert,
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -430,7 +430,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
`, ts)
|
||||
})
|
||||
t.Run("allow renegotiation", func(t *testing.T) {
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
||||
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||
TLSUpstreamAllowRenegotiation: true,
|
||||
}, *mustParseURL(t, "https://example.com"))
|
||||
|
@ -489,11 +489,11 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
|||
func Test_buildCluster(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||
rootCABytes, _ := getCombinedCertificateAuthority("", "")
|
||||
rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}})
|
||||
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
||||
o1 := config.NewDefaultOptions()
|
||||
t.Run("insecure", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -550,7 +550,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("secure", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t,
|
||||
"https://example.com",
|
||||
"https://example.com",
|
||||
|
@ -718,7 +718,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("ip addresses", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -773,7 +773,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("weights", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -830,7 +830,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("localhost", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://localhost"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -876,7 +876,7 @@ func Test_buildCluster(t *testing.T) {
|
|||
`, cluster)
|
||||
})
|
||||
t.Run("outlier", func(t *testing.T) {
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
||||
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||
To: mustParseWeightedURLs(t, "http://example.com"),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
@ -959,7 +959,7 @@ func Test_bindConfig(t *testing.T) {
|
|||
|
||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||
t.Run("no bind config", func(t *testing.T) {
|
||||
cluster, err := b.buildPolicyCluster(ctx, &config.Options{}, &config.Policy{
|
||||
cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{}}, &config.Policy{
|
||||
From: "https://from.example.com",
|
||||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||
})
|
||||
|
@ -967,9 +967,9 @@ func Test_bindConfig(t *testing.T) {
|
|||
assert.Nil(t, cluster.UpstreamBindConfig)
|
||||
})
|
||||
t.Run("freebind", func(t *testing.T) {
|
||||
cluster, err := b.buildPolicyCluster(ctx, &config.Options{
|
||||
cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{
|
||||
EnvoyBindConfigFreebind: null.BoolFrom(true),
|
||||
}, &config.Policy{
|
||||
}}, &config.Policy{
|
||||
From: "https://from.example.com",
|
||||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||
})
|
||||
|
@ -985,9 +985,9 @@ func Test_bindConfig(t *testing.T) {
|
|||
`, cluster.UpstreamBindConfig)
|
||||
})
|
||||
t.Run("source address", func(t *testing.T) {
|
||||
cluster, err := b.buildPolicyCluster(ctx, &config.Options{
|
||||
cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{
|
||||
EnvoyBindConfigSourceAddress: "192.168.0.1",
|
||||
}, &config.Policy{
|
||||
}}, &config.Policy{
|
||||
From: "https://from.example.com",
|
||||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||
})
|
||||
|
|
|
@ -6,7 +6,6 @@ import (
|
|||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
@ -32,6 +31,7 @@ import (
|
|||
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/fileutil"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
)
|
||||
|
@ -211,36 +211,25 @@ func getRootCertificateAuthority() (string, error) {
|
|||
return rootCABundle.value, nil
|
||||
}
|
||||
|
||||
func getCombinedCertificateAuthority(customCA, customCAFile string) ([]byte, error) {
|
||||
func getCombinedCertificateAuthority(cfg *config.Config) ([]byte, error) {
|
||||
rootFile, err := getRootCertificateAuthority()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
combined, err := os.ReadFile(rootFile)
|
||||
if err != nil {
|
||||
var buf bytes.Buffer
|
||||
if err := fileutil.CopyFileUpTo(&buf, rootFile, 5<<20); err != nil {
|
||||
return nil, fmt.Errorf("error reading root certificates: %w", err)
|
||||
}
|
||||
buf.WriteRune('\n')
|
||||
|
||||
if customCA != "" {
|
||||
bs, err := base64.StdEncoding.DecodeString(customCA)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
combined = append(combined, '\n')
|
||||
combined = append(combined, bs...)
|
||||
all, err := cfg.AllCertificateAuthoritiesPEM()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get all CA: %w", err)
|
||||
}
|
||||
buf.Write(all)
|
||||
|
||||
if customCAFile != "" {
|
||||
bs, err := os.ReadFile(customCAFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
combined = append(combined, '\n')
|
||||
combined = append(combined, bs...)
|
||||
}
|
||||
|
||||
return combined, nil
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
func marshalAny(msg proto.Message) *anypb.Any {
|
||||
|
|
68
config/layered.go
Normal file
68
config/layered.go
Normal file
|
@ -0,0 +1,68 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
)
|
||||
|
||||
// LayeredSource is an abstraction for a ConfigSource that depends on an underlying config,
|
||||
// and uses a builder to build the relevant part of the configuration
|
||||
type LayeredSource struct {
|
||||
mx sync.Mutex
|
||||
|
||||
cfg *Config
|
||||
underlying Source
|
||||
builder func(*Config) error
|
||||
|
||||
ChangeDispatcher
|
||||
}
|
||||
|
||||
var (
|
||||
_ = Source(&LayeredSource{})
|
||||
)
|
||||
|
||||
// NewLayeredSource creates a new config source that is watching the underlying source for changes
|
||||
func NewLayeredSource(ctx context.Context, underlying Source, builder func(*Config) error) (*LayeredSource, error) {
|
||||
cfg := underlying.GetConfig().Clone()
|
||||
src := LayeredSource{
|
||||
cfg: cfg,
|
||||
underlying: underlying,
|
||||
builder: builder,
|
||||
}
|
||||
|
||||
if err := builder(cfg); err != nil {
|
||||
return nil, fmt.Errorf("build initial config: %w", err)
|
||||
}
|
||||
|
||||
underlying.OnConfigChange(ctx, src.onUnderlyingConfigChange)
|
||||
|
||||
return &src, nil
|
||||
}
|
||||
|
||||
func (src *LayeredSource) onUnderlyingConfigChange(ctx context.Context, next *Config) {
|
||||
cfg := src.rebuild(ctx, next)
|
||||
src.Trigger(ctx, cfg)
|
||||
}
|
||||
|
||||
func (src *LayeredSource) rebuild(ctx context.Context, next *Config) *Config {
|
||||
src.mx.Lock()
|
||||
defer src.mx.Unlock()
|
||||
|
||||
cfg := next.Clone()
|
||||
if err := src.builder(cfg); err != nil {
|
||||
log.Error(ctx).Err(err).Msg("building config")
|
||||
cfg = next
|
||||
}
|
||||
src.cfg = cfg
|
||||
return cfg
|
||||
}
|
||||
|
||||
// GetConfig returns currently stored config
|
||||
func (src *LayeredSource) GetConfig() *Config {
|
||||
src.mx.Lock()
|
||||
defer src.mx.Unlock()
|
||||
return src.cfg
|
||||
}
|
44
config/layered_test.go
Normal file
44
config/layered_test.go
Normal file
|
@ -0,0 +1,44 @@
|
|||
package config_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
)
|
||||
|
||||
func TestLayeredConfig(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("error on initial build", func(t *testing.T) {
|
||||
underlying := config.NewStaticSource(&config.Config{})
|
||||
_, err := config.NewLayeredSource(ctx, underlying, func(c *config.Config) error {
|
||||
return errors.New("error")
|
||||
})
|
||||
require.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("propagate new config on error", func(t *testing.T) {
|
||||
underlying := config.NewStaticSource(&config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("a.com")}})
|
||||
layered, err := config.NewLayeredSource(ctx, underlying, func(c *config.Config) error {
|
||||
if c.Options.GetDeriveInternalDomain() == "b.com" {
|
||||
return errors.New("reject update")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
var dst *config.Config
|
||||
layered.OnConfigChange(ctx, func(ctx context.Context, c *config.Config) {
|
||||
dst = c
|
||||
})
|
||||
|
||||
underlying.SetConfig(ctx, &config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("b.com")}})
|
||||
assert.Equal(t, "b.com", dst.Options.GetDeriveInternalDomain())
|
||||
})
|
||||
}
|
|
@ -162,6 +162,10 @@ type Options struct {
|
|||
CA string `mapstructure:"certificate_authority" yaml:"certificate_authority,omitempty"`
|
||||
CAFile string `mapstructure:"certificate_authority_file" yaml:"certificate_authority_file,omitempty"`
|
||||
|
||||
// DeriveInternalDomainCert is an option that would derive certificate authority
|
||||
// and domain certificates from the shared key and use them for internal communication
|
||||
DeriveInternalDomainCert *string `mapstructure:"derive_tls" yaml:"derive_tls,omitempty"`
|
||||
|
||||
// SigningKey is the private key used to add a JWT-signature to upstream requests.
|
||||
// https://www.pomerium.com/docs/topics/getting-users-identity.html
|
||||
SigningKey string `mapstructure:"signing_key" yaml:"signing_key,omitempty"`
|
||||
|
@ -728,6 +732,14 @@ func (o *Options) Validate() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// GetDeriveInternalDomain returns an optional internal domain name to use for gRPC endpoint
|
||||
func (o *Options) GetDeriveInternalDomain() string {
|
||||
if o.DeriveInternalDomainCert == nil {
|
||||
return ""
|
||||
}
|
||||
return *o.DeriveInternalDomainCert
|
||||
}
|
||||
|
||||
// GetAuthenticateURL returns the AuthenticateURL in the options or 127.0.0.1.
|
||||
func (o *Options) GetAuthenticateURL() (*url.URL, error) {
|
||||
rawurl := o.AuthenticateURLString
|
||||
|
|
|
@ -3,7 +3,10 @@
|
|||
package fileutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
)
|
||||
|
||||
|
@ -46,3 +49,37 @@ func Getwd() string {
|
|||
}
|
||||
return p
|
||||
}
|
||||
|
||||
// ReadFileUpTo reads file up to given size
|
||||
// it returns an error if file is larger than allowed maximum
|
||||
func ReadFileUpTo(fname string, maxSize int64) ([]byte, error) {
|
||||
var buf bytes.Buffer
|
||||
if err := CopyFileUpTo(&buf, fname, maxSize); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf.Bytes(), nil
|
||||
}
|
||||
|
||||
// CopyFileUpTo copies content of the file up to maxBytes
|
||||
// it returns an error if file is larger than allowed maximum
|
||||
func CopyFileUpTo(dst io.Writer, fname string, maxBytes int64) error {
|
||||
fd, err := os.Open(fname)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open %s: %w", fname, err)
|
||||
}
|
||||
defer func() { _ = fd.Close() }()
|
||||
|
||||
fi, err := fd.Stat()
|
||||
if err != nil {
|
||||
return fmt.Errorf("stat %s: %w", fname, err)
|
||||
}
|
||||
if fi.Size() > maxBytes {
|
||||
return fmt.Errorf("file %s size %d > max %d", fname, fi.Size(), maxBytes)
|
||||
}
|
||||
|
||||
if _, err := io.Copy(dst, fd); err != nil {
|
||||
return fmt.Errorf("read %s: %w", fname, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
@ -1,8 +1,15 @@
|
|||
package fileutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestIsReadableFile(t *testing.T) {
|
||||
|
@ -45,3 +52,29 @@ func TestGetwd(t *testing.T) {
|
|||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestReadFileUpTo(t *testing.T) {
|
||||
d := t.TempDir()
|
||||
input := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||
fname := path.Join(d, "test")
|
||||
require.NoError(t, os.WriteFile(fname, input, 0600))
|
||||
|
||||
for _, tc := range []struct {
|
||||
size int
|
||||
expectError bool
|
||||
}{
|
||||
{len(input) - 1, true},
|
||||
{len(input), false},
|
||||
{len(input) + 1, false},
|
||||
} {
|
||||
t.Run(fmt.Sprint(tc), func(t *testing.T) {
|
||||
out, err := ReadFileUpTo(fname, int64(tc.size))
|
||||
if tc.expectError {
|
||||
require.Error(t, err)
|
||||
return
|
||||
}
|
||||
require.NoError(t, err)
|
||||
assert.True(t, bytes.Equal(input, out))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -23,6 +23,7 @@ import (
|
|||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/registry"
|
||||
"github.com/pomerium/pomerium/internal/version"
|
||||
derivecert_config "github.com/pomerium/pomerium/pkg/derivecert/config"
|
||||
"github.com/pomerium/pomerium/pkg/envoy"
|
||||
"github.com/pomerium/pomerium/pkg/envoy/files"
|
||||
"github.com/pomerium/pomerium/proxy"
|
||||
|
@ -35,6 +36,10 @@ func Run(ctx context.Context, src config.Source) error {
|
|||
Str("version", version.FullVersion()).
|
||||
Msg("cmd/pomerium")
|
||||
|
||||
src, err := config.NewLayeredSource(ctx, src, derivecert_config.NewBuilder())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
src = databroker.NewConfigSource(ctx, src)
|
||||
logMgr := config.NewLogManager(ctx, src)
|
||||
defer logMgr.Close()
|
||||
|
@ -42,7 +47,7 @@ func Run(ctx context.Context, src config.Source) error {
|
|||
// trigger changes when underlying files are changed
|
||||
src = config.NewFileWatcherSource(src)
|
||||
|
||||
src, err := autocert.New(src)
|
||||
src, err = autocert.New(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -57,8 +62,10 @@ func Run(ctx context.Context, src config.Source) error {
|
|||
|
||||
eventsMgr := events.New()
|
||||
|
||||
cfg := src.GetConfig()
|
||||
|
||||
// setup the control plane
|
||||
controlPlane, err := controlplane.NewServer(src.GetConfig(), metricsMgr, eventsMgr)
|
||||
controlPlane, err := controlplane.NewServer(cfg, metricsMgr, eventsMgr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating control plane: %w", err)
|
||||
}
|
||||
|
|
91
pkg/derivecert/config/builder.go
Normal file
91
pkg/derivecert/config/builder.go
Normal file
|
@ -0,0 +1,91 @@
|
|||
// Package config implements derived certs in the Pomerium Configuration
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/pkg/derivecert"
|
||||
)
|
||||
|
||||
type builder struct {
|
||||
psk []byte
|
||||
ca *derivecert.CA
|
||||
caCertPEM []byte
|
||||
|
||||
domain string
|
||||
certs []tls.Certificate
|
||||
}
|
||||
|
||||
// NewBuilder returns a new derived certs config builder with caching
|
||||
func NewBuilder() func(*config.Config) error {
|
||||
return new(builder).Build
|
||||
}
|
||||
|
||||
func (x *builder) Build(cfg *config.Config) error {
|
||||
if cfg.Options.DeriveInternalDomainCert == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
psk, err := cfg.Options.GetSharedKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("shared key: %w", err)
|
||||
}
|
||||
|
||||
if err = x.buildCA(psk); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = x.buildCert(*cfg.Options.DeriveInternalDomainCert); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cfg.DerivedCAPEM = x.caCertPEM
|
||||
cfg.DerivedCertificates = x.certs
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *builder) buildCA(psk []byte) error {
|
||||
if bytes.Equal(x.psk, psk) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ca, err := derivecert.NewCA(psk)
|
||||
if err != nil {
|
||||
return fmt.Errorf("building certificate authority from shared key: %w", err)
|
||||
}
|
||||
|
||||
pem, err := ca.PEM()
|
||||
if err != nil {
|
||||
return fmt.Errorf("encode derived CA to PEM: %w", err)
|
||||
}
|
||||
|
||||
x.psk = psk
|
||||
x.ca = ca
|
||||
x.caCertPEM = pem.Cert
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (x *builder) buildCert(domain string) error {
|
||||
if x.domain == domain {
|
||||
return nil
|
||||
}
|
||||
|
||||
certPEM, err := x.ca.NewServerCert([]string{domain})
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate cert: %w", err)
|
||||
}
|
||||
|
||||
cert, err := certPEM.TLS()
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse TLS cert: %w", err)
|
||||
}
|
||||
|
||||
x.domain = domain
|
||||
x.certs = []tls.Certificate{cert}
|
||||
|
||||
return nil
|
||||
}
|
57
pkg/derivecert/config/builder_test.go
Normal file
57
pkg/derivecert/config/builder_test.go
Normal file
|
@ -0,0 +1,57 @@
|
|||
package config_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||
dcfg "github.com/pomerium/pomerium/pkg/derivecert/config"
|
||||
)
|
||||
|
||||
func TestBuild(t *testing.T) {
|
||||
build := dcfg.NewBuilder()
|
||||
|
||||
key := cryptutil.NewBase64Key()
|
||||
|
||||
cfgA := config.Config{Options: &config.Options{SharedKey: key}}
|
||||
t.Run("no domain requested", func(t *testing.T) {
|
||||
require.NoError(t, build(&cfgA))
|
||||
assert.Empty(t, cfgA.DerivedCAPEM)
|
||||
assert.Empty(t, cfgA.DerivedCertificates)
|
||||
})
|
||||
|
||||
cfgA.Options.DeriveInternalDomainCert = proto.String("example.com")
|
||||
t.Run("generate server cert", func(t *testing.T) {
|
||||
require.NoError(t, build(&cfgA))
|
||||
assert.NotEmpty(t, cfgA.DerivedCAPEM)
|
||||
assert.Len(t, cfgA.DerivedCertificates, 1)
|
||||
})
|
||||
|
||||
cfgB := config.Config{Options: &config.Options{
|
||||
SharedKey: key,
|
||||
DeriveInternalDomainCert: proto.String("example.com"),
|
||||
}}
|
||||
t.Run("caching", func(t *testing.T) {
|
||||
require.NoError(t, build(&cfgB))
|
||||
assert.Equal(t, cfgA.DerivedCAPEM, cfgB.DerivedCAPEM)
|
||||
assert.Equal(t, cfgA.DerivedCertificates[0].Certificate, cfgB.DerivedCertificates[0].Certificate)
|
||||
})
|
||||
|
||||
t.Run("no domain requested after run", func(t *testing.T) {
|
||||
cfg := config.Config{Options: &config.Options{SharedKey: key}}
|
||||
require.NoError(t, build(&cfg))
|
||||
assert.Empty(t, cfg.DerivedCAPEM)
|
||||
assert.Empty(t, cfg.DerivedCertificates)
|
||||
})
|
||||
|
||||
cfgB.Options.DeriveInternalDomainCert = proto.String("example2.com")
|
||||
t.Run("ca caching", func(t *testing.T) {
|
||||
require.NoError(t, build(&cfgB))
|
||||
assert.Equal(t, cfgA.DerivedCAPEM, cfgB.DerivedCAPEM)
|
||||
assert.NotEqual(t, cfgA.DerivedCertificates[0].Certificate, cfgB.DerivedCertificates[0].Certificate)
|
||||
})
|
||||
}
|
Loading…
Add table
Reference in a new issue