mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-29 10:26:29 +02:00
auto tls (#3856)
This commit is contained in:
parent
78fc4853db
commit
488bcd6f72
12 changed files with 447 additions and 67 deletions
|
@ -1,8 +1,11 @@
|
||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"encoding/base64"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/fileutil"
|
||||||
"github.com/pomerium/pomerium/internal/hashutil"
|
"github.com/pomerium/pomerium/internal/hashutil"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
"github.com/pomerium/pomerium/internal/telemetry/metrics"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
@ -17,6 +20,12 @@ type Config struct {
|
||||||
AutoCertificates []tls.Certificate
|
AutoCertificates []tls.Certificate
|
||||||
EnvoyVersion string
|
EnvoyVersion string
|
||||||
|
|
||||||
|
// DerivedCertificates are TLS certificates derived from the shared secret
|
||||||
|
DerivedCertificates []tls.Certificate
|
||||||
|
// DerivedCAPEM is a PEM-encoded certificate authority
|
||||||
|
// derived from the shared secret
|
||||||
|
DerivedCAPEM []byte
|
||||||
|
|
||||||
// GRPCPort is the port the gRPC server is running on.
|
// GRPCPort is the port the gRPC server is running on.
|
||||||
GRPCPort string
|
GRPCPort string
|
||||||
// HTTPPort is the port the HTTP server is running on.
|
// HTTPPort is the port the HTTP server is running on.
|
||||||
|
@ -57,9 +66,39 @@ func (cfg *Config) Clone() *Config {
|
||||||
ACMETLSALPNPort: cfg.ACMETLSALPNPort,
|
ACMETLSALPNPort: cfg.ACMETLSALPNPort,
|
||||||
|
|
||||||
MetricsScrapeEndpoints: endpoints,
|
MetricsScrapeEndpoints: endpoints,
|
||||||
|
|
||||||
|
DerivedCertificates: cfg.DerivedCertificates,
|
||||||
|
DerivedCAPEM: cfg.DerivedCAPEM,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AllCertificateAuthoritiesPEM returns all CAs as PEM bundle bytes
|
||||||
|
func (cfg *Config) AllCertificateAuthoritiesPEM() ([]byte, error) {
|
||||||
|
var combined bytes.Buffer
|
||||||
|
if cfg.Options.CA != "" {
|
||||||
|
bs, err := base64.StdEncoding.DecodeString(cfg.Options.CA)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, _ = combined.Write(bs)
|
||||||
|
_, _ = combined.WriteRune('\n')
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Options.CAFile != "" {
|
||||||
|
if err := fileutil.CopyFileUpTo(&combined, cfg.Options.CAFile, 5<<20); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
_, _ = combined.WriteRune('\n')
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.DerivedCAPEM != nil {
|
||||||
|
_, _ = combined.Write(cfg.DerivedCAPEM)
|
||||||
|
_, _ = combined.WriteRune('\n')
|
||||||
|
}
|
||||||
|
|
||||||
|
return combined.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
// AllCertificates returns all the certificates in the config.
|
// AllCertificates returns all the certificates in the config.
|
||||||
func (cfg *Config) AllCertificates() ([]tls.Certificate, error) {
|
func (cfg *Config) AllCertificates() ([]tls.Certificate, error) {
|
||||||
optionCertificates, err := cfg.Options.GetCertificates()
|
optionCertificates, err := cfg.Options.GetCertificates()
|
||||||
|
@ -70,6 +109,7 @@ func (cfg *Config) AllCertificates() ([]tls.Certificate, error) {
|
||||||
var certs []tls.Certificate
|
var certs []tls.Certificate
|
||||||
certs = append(certs, optionCertificates...)
|
certs = append(certs, optionCertificates...)
|
||||||
certs = append(certs, cfg.AutoCertificates...)
|
certs = append(certs, cfg.AutoCertificates...)
|
||||||
|
certs = append(certs, cfg.DerivedCertificates...)
|
||||||
return certs, nil
|
return certs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -51,22 +51,22 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
controlGRPC, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-grpc", grpcURLs, upstreamProtocolHTTP2)
|
controlGRPC, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-grpc", grpcURLs, upstreamProtocolHTTP2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
|
controlHTTP, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
controlMetrics, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto)
|
controlMetrics, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
authorizeCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
|
authorizeCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -75,7 +75,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
||||||
authorizeCluster.OutlierDetection = grpcOutlierDetection()
|
authorizeCluster.OutlierDetection = grpcOutlierDetection()
|
||||||
}
|
}
|
||||||
|
|
||||||
databrokerCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
|
databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -113,7 +113,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
||||||
policy.EnvoyOpts = newDefaultEnvoyClusterConfig()
|
policy.EnvoyOpts = newDefaultEnvoyClusterConfig()
|
||||||
}
|
}
|
||||||
if len(policy.To) > 0 {
|
if len(policy.To) > 0 {
|
||||||
cluster, err := b.buildPolicyCluster(ctx, cfg.Options, &policy)
|
cluster, err := b.buildPolicyCluster(ctx, cfg, &policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("policy #%d: %w", i, err)
|
return nil, fmt.Errorf("policy #%d: %w", i, err)
|
||||||
}
|
}
|
||||||
|
@ -131,16 +131,16 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
|
||||||
|
|
||||||
func (b *Builder) buildInternalCluster(
|
func (b *Builder) buildInternalCluster(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
options *config.Options,
|
cfg *config.Config,
|
||||||
name string,
|
name string,
|
||||||
dsts []*url.URL,
|
dsts []*url.URL,
|
||||||
upstreamProtocol upstreamProtocolConfig,
|
upstreamProtocol upstreamProtocolConfig,
|
||||||
) (*envoy_config_cluster_v3.Cluster, error) {
|
) (*envoy_config_cluster_v3.Cluster, error) {
|
||||||
cluster := newDefaultEnvoyClusterConfig()
|
cluster := newDefaultEnvoyClusterConfig()
|
||||||
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
|
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily)
|
||||||
var endpoints []Endpoint
|
var endpoints []Endpoint
|
||||||
for _, dst := range dsts {
|
for _, dst := range dsts {
|
||||||
ts, err := b.buildInternalTransportSocket(ctx, options, dst)
|
ts, err := b.buildInternalTransportSocket(ctx, cfg, dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -153,10 +153,12 @@ func (b *Builder) buildInternalCluster(
|
||||||
return cluster, nil
|
return cluster, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
|
func (b *Builder) buildPolicyCluster(ctx context.Context, cfg *config.Config, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
|
||||||
cluster := new(envoy_config_cluster_v3.Cluster)
|
cluster := new(envoy_config_cluster_v3.Cluster)
|
||||||
proto.Merge(cluster, policy.EnvoyOpts)
|
proto.Merge(cluster, policy.EnvoyOpts)
|
||||||
|
|
||||||
|
options := cfg.Options
|
||||||
|
|
||||||
if options.EnvoyBindConfigFreebind.IsSet() || options.EnvoyBindConfigSourceAddress != "" {
|
if options.EnvoyBindConfigFreebind.IsSet() || options.EnvoyBindConfigSourceAddress != "" {
|
||||||
cluster.UpstreamBindConfig = new(envoy_config_core_v3.BindConfig)
|
cluster.UpstreamBindConfig = new(envoy_config_core_v3.BindConfig)
|
||||||
if options.EnvoyBindConfigFreebind.IsSet() {
|
if options.EnvoyBindConfigFreebind.IsSet() {
|
||||||
|
@ -183,7 +185,7 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
|
||||||
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
|
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
|
||||||
|
|
||||||
name := getClusterID(policy)
|
name := getClusterID(policy)
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, options, policy)
|
endpoints, err := b.buildPolicyEndpoints(ctx, cfg, policy)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -205,12 +207,12 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
|
||||||
|
|
||||||
func (b *Builder) buildPolicyEndpoints(
|
func (b *Builder) buildPolicyEndpoints(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
options *config.Options,
|
cfg *config.Config,
|
||||||
policy *config.Policy,
|
policy *config.Policy,
|
||||||
) ([]Endpoint, error) {
|
) ([]Endpoint, error) {
|
||||||
var endpoints []Endpoint
|
var endpoints []Endpoint
|
||||||
for _, dst := range policy.To {
|
for _, dst := range policy.To {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, options, policy, dst.URL)
|
ts, err := b.buildPolicyTransportSocket(ctx, cfg, policy, dst.URL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -221,7 +223,7 @@ func (b *Builder) buildPolicyEndpoints(
|
||||||
|
|
||||||
func (b *Builder) buildInternalTransportSocket(
|
func (b *Builder) buildInternalTransportSocket(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
options *config.Options,
|
cfg *config.Config,
|
||||||
endpoint *url.URL,
|
endpoint *url.URL,
|
||||||
) (*envoy_config_core_v3.TransportSocket, error) {
|
) (*envoy_config_core_v3.TransportSocket, error) {
|
||||||
if endpoint.Scheme != "https" {
|
if endpoint.Scheme != "https" {
|
||||||
|
@ -230,10 +232,10 @@ func (b *Builder) buildInternalTransportSocket(
|
||||||
|
|
||||||
validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{
|
validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{
|
||||||
MatchTypedSubjectAltNames: []*envoy_extensions_transport_sockets_tls_v3.SubjectAltNameMatcher{
|
MatchTypedSubjectAltNames: []*envoy_extensions_transport_sockets_tls_v3.SubjectAltNameMatcher{
|
||||||
b.buildSubjectAltNameMatcher(endpoint, options.OverrideCertificateName),
|
b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile)
|
bs, err := getCombinedCertificateAuthority(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||||
} else {
|
} else {
|
||||||
|
@ -246,7 +248,7 @@ func (b *Builder) buildInternalTransportSocket(
|
||||||
ValidationContext: validationContext,
|
ValidationContext: validationContext,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
Sni: b.buildSubjectNameIndication(endpoint, options.OverrideCertificateName),
|
Sni: b.buildSubjectNameIndication(endpoint, cfg.Options.OverrideCertificateName),
|
||||||
}
|
}
|
||||||
tlsConfig := marshalAny(tlsContext)
|
tlsConfig := marshalAny(tlsContext)
|
||||||
return &envoy_config_core_v3.TransportSocket{
|
return &envoy_config_core_v3.TransportSocket{
|
||||||
|
@ -259,7 +261,7 @@ func (b *Builder) buildInternalTransportSocket(
|
||||||
|
|
||||||
func (b *Builder) buildPolicyTransportSocket(
|
func (b *Builder) buildPolicyTransportSocket(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
options *config.Options,
|
cfg *config.Config,
|
||||||
policy *config.Policy,
|
policy *config.Policy,
|
||||||
dst url.URL,
|
dst url.URL,
|
||||||
) (*envoy_config_core_v3.TransportSocket, error) {
|
) (*envoy_config_core_v3.TransportSocket, error) {
|
||||||
|
@ -269,7 +271,7 @@ func (b *Builder) buildPolicyTransportSocket(
|
||||||
|
|
||||||
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
|
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
|
||||||
|
|
||||||
vc, err := b.buildPolicyValidationContext(ctx, options, policy, dst)
|
vc, err := b.buildPolicyValidationContext(ctx, cfg, policy, dst)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -331,7 +333,7 @@ func (b *Builder) buildPolicyTransportSocket(
|
||||||
|
|
||||||
func (b *Builder) buildPolicyValidationContext(
|
func (b *Builder) buildPolicyValidationContext(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
options *config.Options,
|
cfg *config.Config,
|
||||||
policy *config.Policy,
|
policy *config.Policy,
|
||||||
dst url.URL,
|
dst url.URL,
|
||||||
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
|
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
|
||||||
|
@ -356,7 +358,7 @@ func (b *Builder) buildPolicyValidationContext(
|
||||||
}
|
}
|
||||||
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
|
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
|
||||||
} else {
|
} else {
|
||||||
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile)
|
bs, err := getCombinedCertificateAuthority(cfg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -26,25 +26,25 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-32484c314b584447463735303142374c31414145374650305a525539554938594d524855353757313942494d473847535231.pem")
|
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-32484c314b584447463735303142374c31414145374650305a525539554938594d524855353757313942494d473847535231.pem")
|
||||||
|
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||||
rootCABytes, _ := getCombinedCertificateAuthority("", "")
|
rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}})
|
||||||
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
||||||
|
|
||||||
o1 := config.NewDefaultOptions()
|
o1 := config.NewDefaultOptions()
|
||||||
o2 := config.NewDefaultOptions()
|
o2 := config.NewDefaultOptions()
|
||||||
o2.CA = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0})
|
o2.CA = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0})
|
||||||
|
|
||||||
combinedCABytes, _ := getCombinedCertificateAuthority(o2.CA, "")
|
combinedCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{CA: o2.CA}})
|
||||||
combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename()
|
combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename()
|
||||||
|
|
||||||
t.Run("insecure", func(t *testing.T) {
|
t.Run("insecure", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "http://example.com"),
|
To: mustParseWeightedURLs(t, "http://example.com"),
|
||||||
}, *mustParseURL(t, "http://example.com"))
|
}, *mustParseURL(t, "http://example.com"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Nil(t, ts)
|
assert.Nil(t, ts)
|
||||||
})
|
})
|
||||||
t.Run("host as sni", func(t *testing.T) {
|
t.Run("host as sni", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -97,7 +97,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
`, ts)
|
`, ts)
|
||||||
})
|
})
|
||||||
t.Run("tls_server_name as sni", func(t *testing.T) {
|
t.Run("tls_server_name as sni", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
TLSServerName: "use-this-name.example.com",
|
TLSServerName: "use-this-name.example.com",
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
|
@ -151,7 +151,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
`, ts)
|
`, ts)
|
||||||
})
|
})
|
||||||
t.Run("tls_upstream_server_name as sni", func(t *testing.T) {
|
t.Run("tls_upstream_server_name as sni", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
TLSUpstreamServerName: "use-this-name.example.com",
|
TLSUpstreamServerName: "use-this-name.example.com",
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
|
@ -205,7 +205,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
`, ts)
|
`, ts)
|
||||||
})
|
})
|
||||||
t.Run("tls_skip_verify", func(t *testing.T) {
|
t.Run("tls_skip_verify", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
TLSSkipVerify: true,
|
TLSSkipVerify: true,
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
|
@ -260,7 +260,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
`, ts)
|
`, ts)
|
||||||
})
|
})
|
||||||
t.Run("custom ca", func(t *testing.T) {
|
t.Run("custom ca", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
|
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
|
@ -314,7 +314,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
`, ts)
|
`, ts)
|
||||||
})
|
})
|
||||||
t.Run("options custom ca", func(t *testing.T) {
|
t.Run("options custom ca", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o2, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o2}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -368,7 +368,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
})
|
})
|
||||||
t.Run("client certificate", func(t *testing.T) {
|
t.Run("client certificate", func(t *testing.T) {
|
||||||
clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey)
|
clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey)
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
ClientCertificate: clientCert,
|
ClientCertificate: clientCert,
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
|
@ -430,7 +430,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
`, ts)
|
`, ts)
|
||||||
})
|
})
|
||||||
t.Run("allow renegotiation", func(t *testing.T) {
|
t.Run("allow renegotiation", func(t *testing.T) {
|
||||||
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
|
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "https://example.com"),
|
To: mustParseWeightedURLs(t, "https://example.com"),
|
||||||
TLSUpstreamAllowRenegotiation: true,
|
TLSUpstreamAllowRenegotiation: true,
|
||||||
}, *mustParseURL(t, "https://example.com"))
|
}, *mustParseURL(t, "https://example.com"))
|
||||||
|
@ -489,11 +489,11 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
|
||||||
func Test_buildCluster(t *testing.T) {
|
func Test_buildCluster(t *testing.T) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||||
rootCABytes, _ := getCombinedCertificateAuthority("", "")
|
rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}})
|
||||||
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
|
||||||
o1 := config.NewDefaultOptions()
|
o1 := config.NewDefaultOptions()
|
||||||
t.Run("insecure", func(t *testing.T) {
|
t.Run("insecure", func(t *testing.T) {
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"),
|
To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"),
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -550,7 +550,7 @@ func Test_buildCluster(t *testing.T) {
|
||||||
`, cluster)
|
`, cluster)
|
||||||
})
|
})
|
||||||
t.Run("secure", func(t *testing.T) {
|
t.Run("secure", func(t *testing.T) {
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t,
|
To: mustParseWeightedURLs(t,
|
||||||
"https://example.com",
|
"https://example.com",
|
||||||
"https://example.com",
|
"https://example.com",
|
||||||
|
@ -718,7 +718,7 @@ func Test_buildCluster(t *testing.T) {
|
||||||
`, cluster)
|
`, cluster)
|
||||||
})
|
})
|
||||||
t.Run("ip addresses", func(t *testing.T) {
|
t.Run("ip addresses", func(t *testing.T) {
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"),
|
To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"),
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -773,7 +773,7 @@ func Test_buildCluster(t *testing.T) {
|
||||||
`, cluster)
|
`, cluster)
|
||||||
})
|
})
|
||||||
t.Run("weights", func(t *testing.T) {
|
t.Run("weights", func(t *testing.T) {
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"),
|
To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"),
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -830,7 +830,7 @@ func Test_buildCluster(t *testing.T) {
|
||||||
`, cluster)
|
`, cluster)
|
||||||
})
|
})
|
||||||
t.Run("localhost", func(t *testing.T) {
|
t.Run("localhost", func(t *testing.T) {
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "http://localhost"),
|
To: mustParseWeightedURLs(t, "http://localhost"),
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -876,7 +876,7 @@ func Test_buildCluster(t *testing.T) {
|
||||||
`, cluster)
|
`, cluster)
|
||||||
})
|
})
|
||||||
t.Run("outlier", func(t *testing.T) {
|
t.Run("outlier", func(t *testing.T) {
|
||||||
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
|
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
|
||||||
To: mustParseWeightedURLs(t, "http://example.com"),
|
To: mustParseWeightedURLs(t, "http://example.com"),
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
@ -959,7 +959,7 @@ func Test_bindConfig(t *testing.T) {
|
||||||
|
|
||||||
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
|
||||||
t.Run("no bind config", func(t *testing.T) {
|
t.Run("no bind config", func(t *testing.T) {
|
||||||
cluster, err := b.buildPolicyCluster(ctx, &config.Options{}, &config.Policy{
|
cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{}}, &config.Policy{
|
||||||
From: "https://from.example.com",
|
From: "https://from.example.com",
|
||||||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||||
})
|
})
|
||||||
|
@ -967,9 +967,9 @@ func Test_bindConfig(t *testing.T) {
|
||||||
assert.Nil(t, cluster.UpstreamBindConfig)
|
assert.Nil(t, cluster.UpstreamBindConfig)
|
||||||
})
|
})
|
||||||
t.Run("freebind", func(t *testing.T) {
|
t.Run("freebind", func(t *testing.T) {
|
||||||
cluster, err := b.buildPolicyCluster(ctx, &config.Options{
|
cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{
|
||||||
EnvoyBindConfigFreebind: null.BoolFrom(true),
|
EnvoyBindConfigFreebind: null.BoolFrom(true),
|
||||||
}, &config.Policy{
|
}}, &config.Policy{
|
||||||
From: "https://from.example.com",
|
From: "https://from.example.com",
|
||||||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||||
})
|
})
|
||||||
|
@ -985,9 +985,9 @@ func Test_bindConfig(t *testing.T) {
|
||||||
`, cluster.UpstreamBindConfig)
|
`, cluster.UpstreamBindConfig)
|
||||||
})
|
})
|
||||||
t.Run("source address", func(t *testing.T) {
|
t.Run("source address", func(t *testing.T) {
|
||||||
cluster, err := b.buildPolicyCluster(ctx, &config.Options{
|
cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{
|
||||||
EnvoyBindConfigSourceAddress: "192.168.0.1",
|
EnvoyBindConfigSourceAddress: "192.168.0.1",
|
||||||
}, &config.Policy{
|
}}, &config.Policy{
|
||||||
From: "https://from.example.com",
|
From: "https://from.example.com",
|
||||||
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
To: mustParseWeightedURLs(t, "https://to.example.com"),
|
||||||
})
|
})
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/base64"
|
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
@ -32,6 +31,7 @@ import (
|
||||||
"google.golang.org/protobuf/types/known/wrapperspb"
|
"google.golang.org/protobuf/types/known/wrapperspb"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/fileutil"
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/pkg/cryptutil"
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
)
|
)
|
||||||
|
@ -211,36 +211,25 @@ func getRootCertificateAuthority() (string, error) {
|
||||||
return rootCABundle.value, nil
|
return rootCABundle.value, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getCombinedCertificateAuthority(customCA, customCAFile string) ([]byte, error) {
|
func getCombinedCertificateAuthority(cfg *config.Config) ([]byte, error) {
|
||||||
rootFile, err := getRootCertificateAuthority()
|
rootFile, err := getRootCertificateAuthority()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
combined, err := os.ReadFile(rootFile)
|
var buf bytes.Buffer
|
||||||
if err != nil {
|
if err := fileutil.CopyFileUpTo(&buf, rootFile, 5<<20); err != nil {
|
||||||
return nil, fmt.Errorf("error reading root certificates: %w", err)
|
return nil, fmt.Errorf("error reading root certificates: %w", err)
|
||||||
}
|
}
|
||||||
|
buf.WriteRune('\n')
|
||||||
|
|
||||||
if customCA != "" {
|
all, err := cfg.AllCertificateAuthoritiesPEM()
|
||||||
bs, err := base64.StdEncoding.DecodeString(customCA)
|
if err != nil {
|
||||||
if err != nil {
|
return nil, fmt.Errorf("get all CA: %w", err)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
combined = append(combined, '\n')
|
|
||||||
combined = append(combined, bs...)
|
|
||||||
}
|
}
|
||||||
|
buf.Write(all)
|
||||||
|
|
||||||
if customCAFile != "" {
|
return buf.Bytes(), nil
|
||||||
bs, err := os.ReadFile(customCAFile)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
combined = append(combined, '\n')
|
|
||||||
combined = append(combined, bs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return combined, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func marshalAny(msg proto.Message) *anypb.Any {
|
func marshalAny(msg proto.Message) *anypb.Any {
|
||||||
|
|
68
config/layered.go
Normal file
68
config/layered.go
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// LayeredSource is an abstraction for a ConfigSource that depends on an underlying config,
|
||||||
|
// and uses a builder to build the relevant part of the configuration
|
||||||
|
type LayeredSource struct {
|
||||||
|
mx sync.Mutex
|
||||||
|
|
||||||
|
cfg *Config
|
||||||
|
underlying Source
|
||||||
|
builder func(*Config) error
|
||||||
|
|
||||||
|
ChangeDispatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
_ = Source(&LayeredSource{})
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewLayeredSource creates a new config source that is watching the underlying source for changes
|
||||||
|
func NewLayeredSource(ctx context.Context, underlying Source, builder func(*Config) error) (*LayeredSource, error) {
|
||||||
|
cfg := underlying.GetConfig().Clone()
|
||||||
|
src := LayeredSource{
|
||||||
|
cfg: cfg,
|
||||||
|
underlying: underlying,
|
||||||
|
builder: builder,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := builder(cfg); err != nil {
|
||||||
|
return nil, fmt.Errorf("build initial config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
underlying.OnConfigChange(ctx, src.onUnderlyingConfigChange)
|
||||||
|
|
||||||
|
return &src, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *LayeredSource) onUnderlyingConfigChange(ctx context.Context, next *Config) {
|
||||||
|
cfg := src.rebuild(ctx, next)
|
||||||
|
src.Trigger(ctx, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (src *LayeredSource) rebuild(ctx context.Context, next *Config) *Config {
|
||||||
|
src.mx.Lock()
|
||||||
|
defer src.mx.Unlock()
|
||||||
|
|
||||||
|
cfg := next.Clone()
|
||||||
|
if err := src.builder(cfg); err != nil {
|
||||||
|
log.Error(ctx).Err(err).Msg("building config")
|
||||||
|
cfg = next
|
||||||
|
}
|
||||||
|
src.cfg = cfg
|
||||||
|
return cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConfig returns currently stored config
|
||||||
|
func (src *LayeredSource) GetConfig() *Config {
|
||||||
|
src.mx.Lock()
|
||||||
|
defer src.mx.Unlock()
|
||||||
|
return src.cfg
|
||||||
|
}
|
44
config/layered_test.go
Normal file
44
config/layered_test.go
Normal file
|
@ -0,0 +1,44 @@
|
||||||
|
package config_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLayeredConfig(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
|
||||||
|
t.Run("error on initial build", func(t *testing.T) {
|
||||||
|
underlying := config.NewStaticSource(&config.Config{})
|
||||||
|
_, err := config.NewLayeredSource(ctx, underlying, func(c *config.Config) error {
|
||||||
|
return errors.New("error")
|
||||||
|
})
|
||||||
|
require.Error(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("propagate new config on error", func(t *testing.T) {
|
||||||
|
underlying := config.NewStaticSource(&config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("a.com")}})
|
||||||
|
layered, err := config.NewLayeredSource(ctx, underlying, func(c *config.Config) error {
|
||||||
|
if c.Options.GetDeriveInternalDomain() == "b.com" {
|
||||||
|
return errors.New("reject update")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var dst *config.Config
|
||||||
|
layered.OnConfigChange(ctx, func(ctx context.Context, c *config.Config) {
|
||||||
|
dst = c
|
||||||
|
})
|
||||||
|
|
||||||
|
underlying.SetConfig(ctx, &config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("b.com")}})
|
||||||
|
assert.Equal(t, "b.com", dst.Options.GetDeriveInternalDomain())
|
||||||
|
})
|
||||||
|
}
|
|
@ -162,6 +162,10 @@ type Options struct {
|
||||||
CA string `mapstructure:"certificate_authority" yaml:"certificate_authority,omitempty"`
|
CA string `mapstructure:"certificate_authority" yaml:"certificate_authority,omitempty"`
|
||||||
CAFile string `mapstructure:"certificate_authority_file" yaml:"certificate_authority_file,omitempty"`
|
CAFile string `mapstructure:"certificate_authority_file" yaml:"certificate_authority_file,omitempty"`
|
||||||
|
|
||||||
|
// DeriveInternalDomainCert is an option that would derive certificate authority
|
||||||
|
// and domain certificates from the shared key and use them for internal communication
|
||||||
|
DeriveInternalDomainCert *string `mapstructure:"derive_tls" yaml:"derive_tls,omitempty"`
|
||||||
|
|
||||||
// SigningKey is the private key used to add a JWT-signature to upstream requests.
|
// SigningKey is the private key used to add a JWT-signature to upstream requests.
|
||||||
// https://www.pomerium.com/docs/topics/getting-users-identity.html
|
// https://www.pomerium.com/docs/topics/getting-users-identity.html
|
||||||
SigningKey string `mapstructure:"signing_key" yaml:"signing_key,omitempty"`
|
SigningKey string `mapstructure:"signing_key" yaml:"signing_key,omitempty"`
|
||||||
|
@ -728,6 +732,14 @@ func (o *Options) Validate() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetDeriveInternalDomain returns an optional internal domain name to use for gRPC endpoint
|
||||||
|
func (o *Options) GetDeriveInternalDomain() string {
|
||||||
|
if o.DeriveInternalDomainCert == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return *o.DeriveInternalDomainCert
|
||||||
|
}
|
||||||
|
|
||||||
// GetAuthenticateURL returns the AuthenticateURL in the options or 127.0.0.1.
|
// GetAuthenticateURL returns the AuthenticateURL in the options or 127.0.0.1.
|
||||||
func (o *Options) GetAuthenticateURL() (*url.URL, error) {
|
func (o *Options) GetAuthenticateURL() (*url.URL, error) {
|
||||||
rawurl := o.AuthenticateURLString
|
rawurl := o.AuthenticateURLString
|
||||||
|
|
|
@ -3,7 +3,10 @@
|
||||||
package fileutil
|
package fileutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"errors"
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -46,3 +49,37 @@ func Getwd() string {
|
||||||
}
|
}
|
||||||
return p
|
return p
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ReadFileUpTo reads file up to given size
|
||||||
|
// it returns an error if file is larger than allowed maximum
|
||||||
|
func ReadFileUpTo(fname string, maxSize int64) ([]byte, error) {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := CopyFileUpTo(&buf, fname, maxSize); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CopyFileUpTo copies content of the file up to maxBytes
|
||||||
|
// it returns an error if file is larger than allowed maximum
|
||||||
|
func CopyFileUpTo(dst io.Writer, fname string, maxBytes int64) error {
|
||||||
|
fd, err := os.Open(fname)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("open %s: %w", fname, err)
|
||||||
|
}
|
||||||
|
defer func() { _ = fd.Close() }()
|
||||||
|
|
||||||
|
fi, err := fd.Stat()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stat %s: %w", fname, err)
|
||||||
|
}
|
||||||
|
if fi.Size() > maxBytes {
|
||||||
|
return fmt.Errorf("file %s size %d > max %d", fname, fi.Size(), maxBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := io.Copy(dst, fd); err != nil {
|
||||||
|
return fmt.Errorf("read %s: %w", fname, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -1,8 +1,15 @@
|
||||||
package fileutil
|
package fileutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestIsReadableFile(t *testing.T) {
|
func TestIsReadableFile(t *testing.T) {
|
||||||
|
@ -45,3 +52,29 @@ func TestGetwd(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestReadFileUpTo(t *testing.T) {
|
||||||
|
d := t.TempDir()
|
||||||
|
input := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}
|
||||||
|
fname := path.Join(d, "test")
|
||||||
|
require.NoError(t, os.WriteFile(fname, input, 0600))
|
||||||
|
|
||||||
|
for _, tc := range []struct {
|
||||||
|
size int
|
||||||
|
expectError bool
|
||||||
|
}{
|
||||||
|
{len(input) - 1, true},
|
||||||
|
{len(input), false},
|
||||||
|
{len(input) + 1, false},
|
||||||
|
} {
|
||||||
|
t.Run(fmt.Sprint(tc), func(t *testing.T) {
|
||||||
|
out, err := ReadFileUpTo(fname, int64(tc.size))
|
||||||
|
if tc.expectError {
|
||||||
|
require.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, bytes.Equal(input, out))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -23,6 +23,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/log"
|
"github.com/pomerium/pomerium/internal/log"
|
||||||
"github.com/pomerium/pomerium/internal/registry"
|
"github.com/pomerium/pomerium/internal/registry"
|
||||||
"github.com/pomerium/pomerium/internal/version"
|
"github.com/pomerium/pomerium/internal/version"
|
||||||
|
derivecert_config "github.com/pomerium/pomerium/pkg/derivecert/config"
|
||||||
"github.com/pomerium/pomerium/pkg/envoy"
|
"github.com/pomerium/pomerium/pkg/envoy"
|
||||||
"github.com/pomerium/pomerium/pkg/envoy/files"
|
"github.com/pomerium/pomerium/pkg/envoy/files"
|
||||||
"github.com/pomerium/pomerium/proxy"
|
"github.com/pomerium/pomerium/proxy"
|
||||||
|
@ -35,6 +36,10 @@ func Run(ctx context.Context, src config.Source) error {
|
||||||
Str("version", version.FullVersion()).
|
Str("version", version.FullVersion()).
|
||||||
Msg("cmd/pomerium")
|
Msg("cmd/pomerium")
|
||||||
|
|
||||||
|
src, err := config.NewLayeredSource(ctx, src, derivecert_config.NewBuilder())
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
src = databroker.NewConfigSource(ctx, src)
|
src = databroker.NewConfigSource(ctx, src)
|
||||||
logMgr := config.NewLogManager(ctx, src)
|
logMgr := config.NewLogManager(ctx, src)
|
||||||
defer logMgr.Close()
|
defer logMgr.Close()
|
||||||
|
@ -42,7 +47,7 @@ func Run(ctx context.Context, src config.Source) error {
|
||||||
// trigger changes when underlying files are changed
|
// trigger changes when underlying files are changed
|
||||||
src = config.NewFileWatcherSource(src)
|
src = config.NewFileWatcherSource(src)
|
||||||
|
|
||||||
src, err := autocert.New(src)
|
src, err = autocert.New(src)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -57,8 +62,10 @@ func Run(ctx context.Context, src config.Source) error {
|
||||||
|
|
||||||
eventsMgr := events.New()
|
eventsMgr := events.New()
|
||||||
|
|
||||||
|
cfg := src.GetConfig()
|
||||||
|
|
||||||
// setup the control plane
|
// setup the control plane
|
||||||
controlPlane, err := controlplane.NewServer(src.GetConfig(), metricsMgr, eventsMgr)
|
controlPlane, err := controlplane.NewServer(cfg, metricsMgr, eventsMgr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error creating control plane: %w", err)
|
return fmt.Errorf("error creating control plane: %w", err)
|
||||||
}
|
}
|
||||||
|
|
91
pkg/derivecert/config/builder.go
Normal file
91
pkg/derivecert/config/builder.go
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
// Package config implements derived certs in the Pomerium Configuration
|
||||||
|
package config
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/pkg/derivecert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type builder struct {
|
||||||
|
psk []byte
|
||||||
|
ca *derivecert.CA
|
||||||
|
caCertPEM []byte
|
||||||
|
|
||||||
|
domain string
|
||||||
|
certs []tls.Certificate
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewBuilder returns a new derived certs config builder with caching
|
||||||
|
func NewBuilder() func(*config.Config) error {
|
||||||
|
return new(builder).Build
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *builder) Build(cfg *config.Config) error {
|
||||||
|
if cfg.Options.DeriveInternalDomainCert == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
psk, err := cfg.Options.GetSharedKey()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("shared key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = x.buildCA(psk); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = x.buildCert(*cfg.Options.DeriveInternalDomainCert); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg.DerivedCAPEM = x.caCertPEM
|
||||||
|
cfg.DerivedCertificates = x.certs
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *builder) buildCA(psk []byte) error {
|
||||||
|
if bytes.Equal(x.psk, psk) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ca, err := derivecert.NewCA(psk)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("building certificate authority from shared key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
pem, err := ca.PEM()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("encode derived CA to PEM: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
x.psk = psk
|
||||||
|
x.ca = ca
|
||||||
|
x.caCertPEM = pem.Cert
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *builder) buildCert(domain string) error {
|
||||||
|
if x.domain == domain {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
certPEM, err := x.ca.NewServerCert([]string{domain})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate cert: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
cert, err := certPEM.TLS()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse TLS cert: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
x.domain = domain
|
||||||
|
x.certs = []tls.Certificate{cert}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
57
pkg/derivecert/config/builder_test.go
Normal file
57
pkg/derivecert/config/builder_test.go
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
package config_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"google.golang.org/protobuf/proto"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
||||||
|
dcfg "github.com/pomerium/pomerium/pkg/derivecert/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuild(t *testing.T) {
|
||||||
|
build := dcfg.NewBuilder()
|
||||||
|
|
||||||
|
key := cryptutil.NewBase64Key()
|
||||||
|
|
||||||
|
cfgA := config.Config{Options: &config.Options{SharedKey: key}}
|
||||||
|
t.Run("no domain requested", func(t *testing.T) {
|
||||||
|
require.NoError(t, build(&cfgA))
|
||||||
|
assert.Empty(t, cfgA.DerivedCAPEM)
|
||||||
|
assert.Empty(t, cfgA.DerivedCertificates)
|
||||||
|
})
|
||||||
|
|
||||||
|
cfgA.Options.DeriveInternalDomainCert = proto.String("example.com")
|
||||||
|
t.Run("generate server cert", func(t *testing.T) {
|
||||||
|
require.NoError(t, build(&cfgA))
|
||||||
|
assert.NotEmpty(t, cfgA.DerivedCAPEM)
|
||||||
|
assert.Len(t, cfgA.DerivedCertificates, 1)
|
||||||
|
})
|
||||||
|
|
||||||
|
cfgB := config.Config{Options: &config.Options{
|
||||||
|
SharedKey: key,
|
||||||
|
DeriveInternalDomainCert: proto.String("example.com"),
|
||||||
|
}}
|
||||||
|
t.Run("caching", func(t *testing.T) {
|
||||||
|
require.NoError(t, build(&cfgB))
|
||||||
|
assert.Equal(t, cfgA.DerivedCAPEM, cfgB.DerivedCAPEM)
|
||||||
|
assert.Equal(t, cfgA.DerivedCertificates[0].Certificate, cfgB.DerivedCertificates[0].Certificate)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no domain requested after run", func(t *testing.T) {
|
||||||
|
cfg := config.Config{Options: &config.Options{SharedKey: key}}
|
||||||
|
require.NoError(t, build(&cfg))
|
||||||
|
assert.Empty(t, cfg.DerivedCAPEM)
|
||||||
|
assert.Empty(t, cfg.DerivedCertificates)
|
||||||
|
})
|
||||||
|
|
||||||
|
cfgB.Options.DeriveInternalDomainCert = proto.String("example2.com")
|
||||||
|
t.Run("ca caching", func(t *testing.T) {
|
||||||
|
require.NoError(t, build(&cfgB))
|
||||||
|
assert.Equal(t, cfgA.DerivedCAPEM, cfgB.DerivedCAPEM)
|
||||||
|
assert.NotEqual(t, cfgA.DerivedCertificates[0].Certificate, cfgB.DerivedCertificates[0].Certificate)
|
||||||
|
})
|
||||||
|
}
|
Loading…
Add table
Reference in a new issue