This commit is contained in:
Denis Mishin 2023-01-05 16:35:58 -05:00 committed by GitHub
parent 78fc4853db
commit 488bcd6f72
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 447 additions and 67 deletions

View file

@ -1,8 +1,11 @@
package config
import (
"bytes"
"crypto/tls"
"encoding/base64"
"github.com/pomerium/pomerium/internal/fileutil"
"github.com/pomerium/pomerium/internal/hashutil"
"github.com/pomerium/pomerium/internal/telemetry/metrics"
"github.com/pomerium/pomerium/pkg/cryptutil"
@ -17,6 +20,12 @@ type Config struct {
AutoCertificates []tls.Certificate
EnvoyVersion string
// DerivedCertificates are TLS certificates derived from the shared secret
DerivedCertificates []tls.Certificate
// DerivedCAPEM is a PEM-encoded certificate authority
// derived from the shared secret
DerivedCAPEM []byte
// GRPCPort is the port the gRPC server is running on.
GRPCPort string
// HTTPPort is the port the HTTP server is running on.
@ -57,9 +66,39 @@ func (cfg *Config) Clone() *Config {
ACMETLSALPNPort: cfg.ACMETLSALPNPort,
MetricsScrapeEndpoints: endpoints,
DerivedCertificates: cfg.DerivedCertificates,
DerivedCAPEM: cfg.DerivedCAPEM,
}
}
// AllCertificateAuthoritiesPEM returns all CAs as PEM bundle bytes
func (cfg *Config) AllCertificateAuthoritiesPEM() ([]byte, error) {
var combined bytes.Buffer
if cfg.Options.CA != "" {
bs, err := base64.StdEncoding.DecodeString(cfg.Options.CA)
if err != nil {
return nil, err
}
_, _ = combined.Write(bs)
_, _ = combined.WriteRune('\n')
}
if cfg.Options.CAFile != "" {
if err := fileutil.CopyFileUpTo(&combined, cfg.Options.CAFile, 5<<20); err != nil {
return nil, err
}
_, _ = combined.WriteRune('\n')
}
if cfg.DerivedCAPEM != nil {
_, _ = combined.Write(cfg.DerivedCAPEM)
_, _ = combined.WriteRune('\n')
}
return combined.Bytes(), nil
}
// AllCertificates returns all the certificates in the config.
func (cfg *Config) AllCertificates() ([]tls.Certificate, error) {
optionCertificates, err := cfg.Options.GetCertificates()
@ -70,6 +109,7 @@ func (cfg *Config) AllCertificates() ([]tls.Certificate, error) {
var certs []tls.Certificate
certs = append(certs, optionCertificates...)
certs = append(certs, cfg.AutoCertificates...)
certs = append(certs, cfg.DerivedCertificates...)
return certs, nil
}

View file

@ -51,22 +51,22 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
}
}
controlGRPC, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-grpc", grpcURLs, upstreamProtocolHTTP2)
controlGRPC, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-grpc", grpcURLs, upstreamProtocolHTTP2)
if err != nil {
return nil, err
}
controlHTTP, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
controlHTTP, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-http", []*url.URL{httpURL}, upstreamProtocolAuto)
if err != nil {
return nil, err
}
controlMetrics, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto)
controlMetrics, err := b.buildInternalCluster(ctx, cfg, "pomerium-control-plane-metrics", []*url.URL{metricsURL}, upstreamProtocolAuto)
if err != nil {
return nil, err
}
authorizeCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
authorizeCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-authorize", authorizeURLs, upstreamProtocolHTTP2)
if err != nil {
return nil, err
}
@ -75,7 +75,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
authorizeCluster.OutlierDetection = grpcOutlierDetection()
}
databrokerCluster, err := b.buildInternalCluster(ctx, cfg.Options, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
databrokerCluster, err := b.buildInternalCluster(ctx, cfg, "pomerium-databroker", databrokerURLs, upstreamProtocolHTTP2)
if err != nil {
return nil, err
}
@ -113,7 +113,7 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
policy.EnvoyOpts = newDefaultEnvoyClusterConfig()
}
if len(policy.To) > 0 {
cluster, err := b.buildPolicyCluster(ctx, cfg.Options, &policy)
cluster, err := b.buildPolicyCluster(ctx, cfg, &policy)
if err != nil {
return nil, fmt.Errorf("policy #%d: %w", i, err)
}
@ -131,16 +131,16 @@ func (b *Builder) BuildClusters(ctx context.Context, cfg *config.Config) ([]*env
func (b *Builder) buildInternalCluster(
ctx context.Context,
options *config.Options,
cfg *config.Config,
name string,
dsts []*url.URL,
upstreamProtocol upstreamProtocolConfig,
) (*envoy_config_cluster_v3.Cluster, error) {
cluster := newDefaultEnvoyClusterConfig()
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(options.DNSLookupFamily)
cluster.DnsLookupFamily = config.GetEnvoyDNSLookupFamily(cfg.Options.DNSLookupFamily)
var endpoints []Endpoint
for _, dst := range dsts {
ts, err := b.buildInternalTransportSocket(ctx, options, dst)
ts, err := b.buildInternalTransportSocket(ctx, cfg, dst)
if err != nil {
return nil, err
}
@ -153,10 +153,12 @@ func (b *Builder) buildInternalCluster(
return cluster, nil
}
func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Options, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
func (b *Builder) buildPolicyCluster(ctx context.Context, cfg *config.Config, policy *config.Policy) (*envoy_config_cluster_v3.Cluster, error) {
cluster := new(envoy_config_cluster_v3.Cluster)
proto.Merge(cluster, policy.EnvoyOpts)
options := cfg.Options
if options.EnvoyBindConfigFreebind.IsSet() || options.EnvoyBindConfigSourceAddress != "" {
cluster.UpstreamBindConfig = new(envoy_config_core_v3.BindConfig)
if options.EnvoyBindConfigFreebind.IsSet() {
@ -183,7 +185,7 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
name := getClusterID(policy)
endpoints, err := b.buildPolicyEndpoints(ctx, options, policy)
endpoints, err := b.buildPolicyEndpoints(ctx, cfg, policy)
if err != nil {
return nil, err
}
@ -205,12 +207,12 @@ func (b *Builder) buildPolicyCluster(ctx context.Context, options *config.Option
func (b *Builder) buildPolicyEndpoints(
ctx context.Context,
options *config.Options,
cfg *config.Config,
policy *config.Policy,
) ([]Endpoint, error) {
var endpoints []Endpoint
for _, dst := range policy.To {
ts, err := b.buildPolicyTransportSocket(ctx, options, policy, dst.URL)
ts, err := b.buildPolicyTransportSocket(ctx, cfg, policy, dst.URL)
if err != nil {
return nil, err
}
@ -221,7 +223,7 @@ func (b *Builder) buildPolicyEndpoints(
func (b *Builder) buildInternalTransportSocket(
ctx context.Context,
options *config.Options,
cfg *config.Config,
endpoint *url.URL,
) (*envoy_config_core_v3.TransportSocket, error) {
if endpoint.Scheme != "https" {
@ -230,10 +232,10 @@ func (b *Builder) buildInternalTransportSocket(
validationContext := &envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext{
MatchTypedSubjectAltNames: []*envoy_extensions_transport_sockets_tls_v3.SubjectAltNameMatcher{
b.buildSubjectAltNameMatcher(endpoint, options.OverrideCertificateName),
b.buildSubjectAltNameMatcher(endpoint, cfg.Options.OverrideCertificateName),
},
}
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile)
bs, err := getCombinedCertificateAuthority(cfg)
if err != nil {
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else {
@ -246,7 +248,7 @@ func (b *Builder) buildInternalTransportSocket(
ValidationContext: validationContext,
},
},
Sni: b.buildSubjectNameIndication(endpoint, options.OverrideCertificateName),
Sni: b.buildSubjectNameIndication(endpoint, cfg.Options.OverrideCertificateName),
}
tlsConfig := marshalAny(tlsContext)
return &envoy_config_core_v3.TransportSocket{
@ -259,7 +261,7 @@ func (b *Builder) buildInternalTransportSocket(
func (b *Builder) buildPolicyTransportSocket(
ctx context.Context,
options *config.Options,
cfg *config.Config,
policy *config.Policy,
dst url.URL,
) (*envoy_config_core_v3.TransportSocket, error) {
@ -269,7 +271,7 @@ func (b *Builder) buildPolicyTransportSocket(
upstreamProtocol := getUpstreamProtocolForPolicy(ctx, policy)
vc, err := b.buildPolicyValidationContext(ctx, options, policy, dst)
vc, err := b.buildPolicyValidationContext(ctx, cfg, policy, dst)
if err != nil {
return nil, err
}
@ -331,7 +333,7 @@ func (b *Builder) buildPolicyTransportSocket(
func (b *Builder) buildPolicyValidationContext(
ctx context.Context,
options *config.Options,
cfg *config.Config,
policy *config.Policy,
dst url.URL,
) (*envoy_extensions_transport_sockets_tls_v3.CertificateValidationContext, error) {
@ -356,7 +358,7 @@ func (b *Builder) buildPolicyValidationContext(
}
validationContext.TrustedCa = b.filemgr.BytesDataSource("custom-ca.pem", bs)
} else {
bs, err := getCombinedCertificateAuthority(options.CA, options.CAFile)
bs, err := getCombinedCertificateAuthority(cfg)
if err != nil {
log.Error(ctx).Err(err).Msg("unable to enable certificate verification because no root CAs were found")
} else {

View file

@ -26,25 +26,25 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
customCA := filepath.Join(cacheDir, "pomerium", "envoy", "files", "custom-ca-32484c314b584447463735303142374c31414145374650305a525539554938594d524855353757313942494d473847535231.pem")
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
rootCABytes, _ := getCombinedCertificateAuthority("", "")
rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}})
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
o1 := config.NewDefaultOptions()
o2 := config.NewDefaultOptions()
o2.CA = base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0})
combinedCABytes, _ := getCombinedCertificateAuthority(o2.CA, "")
combinedCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{CA: o2.CA}})
combinedCA := b.filemgr.BytesDataSource("ca.pem", combinedCABytes).GetFilename()
t.Run("insecure", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com"),
}, *mustParseURL(t, "http://example.com"))
require.NoError(t, err)
assert.Nil(t, ts)
})
t.Run("host as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
}, *mustParseURL(t, "https://example.com"))
require.NoError(t, err)
@ -97,7 +97,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("tls_server_name as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
TLSServerName: "use-this-name.example.com",
}, *mustParseURL(t, "https://example.com"))
@ -151,7 +151,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("tls_upstream_server_name as sni", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
TLSUpstreamServerName: "use-this-name.example.com",
}, *mustParseURL(t, "https://example.com"))
@ -205,7 +205,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("tls_skip_verify", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
TLSSkipVerify: true,
}, *mustParseURL(t, "https://example.com"))
@ -260,7 +260,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("custom ca", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
TLSCustomCA: base64.StdEncoding.EncodeToString([]byte{0, 0, 0, 0}),
}, *mustParseURL(t, "https://example.com"))
@ -314,7 +314,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("options custom ca", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o2, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o2}, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
}, *mustParseURL(t, "https://example.com"))
require.NoError(t, err)
@ -368,7 +368,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
})
t.Run("client certificate", func(t *testing.T) {
clientCert, _ := cryptutil.CertificateFromBase64(aExampleComCert, aExampleComKey)
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
ClientCertificate: clientCert,
}, *mustParseURL(t, "https://example.com"))
@ -430,7 +430,7 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
`, ts)
})
t.Run("allow renegotiation", func(t *testing.T) {
ts, err := b.buildPolicyTransportSocket(ctx, o1, &config.Policy{
ts, err := b.buildPolicyTransportSocket(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t, "https://example.com"),
TLSUpstreamAllowRenegotiation: true,
}, *mustParseURL(t, "https://example.com"))
@ -489,11 +489,11 @@ func Test_buildPolicyTransportSocket(t *testing.T) {
func Test_buildCluster(t *testing.T) {
ctx := context.Background()
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
rootCABytes, _ := getCombinedCertificateAuthority("", "")
rootCABytes, _ := getCombinedCertificateAuthority(&config.Config{Options: &config.Options{}})
rootCA := b.filemgr.BytesDataSource("ca.pem", rootCABytes).GetFilename()
o1 := config.NewDefaultOptions()
t.Run("insecure", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com", "http://1.2.3.4"),
})
require.NoError(t, err)
@ -550,7 +550,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("secure", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t,
"https://example.com",
"https://example.com",
@ -718,7 +718,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("ip addresses", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t, "http://127.0.0.1", "http://127.0.0.2"),
})
require.NoError(t, err)
@ -773,7 +773,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("weights", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t, "http://127.0.0.1:8080,1", "http://127.0.0.2,2"),
})
require.NoError(t, err)
@ -830,7 +830,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("localhost", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t, "http://localhost"),
})
require.NoError(t, err)
@ -876,7 +876,7 @@ func Test_buildCluster(t *testing.T) {
`, cluster)
})
t.Run("outlier", func(t *testing.T) {
endpoints, err := b.buildPolicyEndpoints(ctx, o1, &config.Policy{
endpoints, err := b.buildPolicyEndpoints(ctx, &config.Config{Options: o1}, &config.Policy{
To: mustParseWeightedURLs(t, "http://example.com"),
})
require.NoError(t, err)
@ -959,7 +959,7 @@ func Test_bindConfig(t *testing.T) {
b := New("local-grpc", "local-http", "local-metrics", filemgr.NewManager(), nil)
t.Run("no bind config", func(t *testing.T) {
cluster, err := b.buildPolicyCluster(ctx, &config.Options{}, &config.Policy{
cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{}}, &config.Policy{
From: "https://from.example.com",
To: mustParseWeightedURLs(t, "https://to.example.com"),
})
@ -967,9 +967,9 @@ func Test_bindConfig(t *testing.T) {
assert.Nil(t, cluster.UpstreamBindConfig)
})
t.Run("freebind", func(t *testing.T) {
cluster, err := b.buildPolicyCluster(ctx, &config.Options{
cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{
EnvoyBindConfigFreebind: null.BoolFrom(true),
}, &config.Policy{
}}, &config.Policy{
From: "https://from.example.com",
To: mustParseWeightedURLs(t, "https://to.example.com"),
})
@ -985,9 +985,9 @@ func Test_bindConfig(t *testing.T) {
`, cluster.UpstreamBindConfig)
})
t.Run("source address", func(t *testing.T) {
cluster, err := b.buildPolicyCluster(ctx, &config.Options{
cluster, err := b.buildPolicyCluster(ctx, &config.Config{Options: &config.Options{
EnvoyBindConfigSourceAddress: "192.168.0.1",
}, &config.Policy{
}}, &config.Policy{
From: "https://from.example.com",
To: mustParseWeightedURLs(t, "https://to.example.com"),
})

View file

@ -6,7 +6,6 @@ import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
@ -32,6 +31,7 @@ import (
"google.golang.org/protobuf/types/known/wrapperspb"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/fileutil"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
)
@ -211,36 +211,25 @@ func getRootCertificateAuthority() (string, error) {
return rootCABundle.value, nil
}
func getCombinedCertificateAuthority(customCA, customCAFile string) ([]byte, error) {
func getCombinedCertificateAuthority(cfg *config.Config) ([]byte, error) {
rootFile, err := getRootCertificateAuthority()
if err != nil {
return nil, err
}
combined, err := os.ReadFile(rootFile)
if err != nil {
var buf bytes.Buffer
if err := fileutil.CopyFileUpTo(&buf, rootFile, 5<<20); err != nil {
return nil, fmt.Errorf("error reading root certificates: %w", err)
}
buf.WriteRune('\n')
if customCA != "" {
bs, err := base64.StdEncoding.DecodeString(customCA)
if err != nil {
return nil, err
}
combined = append(combined, '\n')
combined = append(combined, bs...)
all, err := cfg.AllCertificateAuthoritiesPEM()
if err != nil {
return nil, fmt.Errorf("get all CA: %w", err)
}
buf.Write(all)
if customCAFile != "" {
bs, err := os.ReadFile(customCAFile)
if err != nil {
return nil, err
}
combined = append(combined, '\n')
combined = append(combined, bs...)
}
return combined, nil
return buf.Bytes(), nil
}
func marshalAny(msg proto.Message) *anypb.Any {

68
config/layered.go Normal file
View file

@ -0,0 +1,68 @@
package config
import (
"context"
"fmt"
"sync"
"github.com/pomerium/pomerium/internal/log"
)
// LayeredSource is an abstraction for a ConfigSource that depends on an underlying config,
// and uses a builder to build the relevant part of the configuration
type LayeredSource struct {
mx sync.Mutex
cfg *Config
underlying Source
builder func(*Config) error
ChangeDispatcher
}
var (
_ = Source(&LayeredSource{})
)
// NewLayeredSource creates a new config source that is watching the underlying source for changes
func NewLayeredSource(ctx context.Context, underlying Source, builder func(*Config) error) (*LayeredSource, error) {
cfg := underlying.GetConfig().Clone()
src := LayeredSource{
cfg: cfg,
underlying: underlying,
builder: builder,
}
if err := builder(cfg); err != nil {
return nil, fmt.Errorf("build initial config: %w", err)
}
underlying.OnConfigChange(ctx, src.onUnderlyingConfigChange)
return &src, nil
}
func (src *LayeredSource) onUnderlyingConfigChange(ctx context.Context, next *Config) {
cfg := src.rebuild(ctx, next)
src.Trigger(ctx, cfg)
}
func (src *LayeredSource) rebuild(ctx context.Context, next *Config) *Config {
src.mx.Lock()
defer src.mx.Unlock()
cfg := next.Clone()
if err := src.builder(cfg); err != nil {
log.Error(ctx).Err(err).Msg("building config")
cfg = next
}
src.cfg = cfg
return cfg
}
// GetConfig returns currently stored config
func (src *LayeredSource) GetConfig() *Config {
src.mx.Lock()
defer src.mx.Unlock()
return src.cfg
}

44
config/layered_test.go Normal file
View file

@ -0,0 +1,44 @@
package config_test
import (
"context"
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/config"
)
func TestLayeredConfig(t *testing.T) {
ctx := context.Background()
t.Run("error on initial build", func(t *testing.T) {
underlying := config.NewStaticSource(&config.Config{})
_, err := config.NewLayeredSource(ctx, underlying, func(c *config.Config) error {
return errors.New("error")
})
require.Error(t, err)
})
t.Run("propagate new config on error", func(t *testing.T) {
underlying := config.NewStaticSource(&config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("a.com")}})
layered, err := config.NewLayeredSource(ctx, underlying, func(c *config.Config) error {
if c.Options.GetDeriveInternalDomain() == "b.com" {
return errors.New("reject update")
}
return nil
})
require.NoError(t, err)
var dst *config.Config
layered.OnConfigChange(ctx, func(ctx context.Context, c *config.Config) {
dst = c
})
underlying.SetConfig(ctx, &config.Config{Options: &config.Options{DeriveInternalDomainCert: proto.String("b.com")}})
assert.Equal(t, "b.com", dst.Options.GetDeriveInternalDomain())
})
}

View file

@ -162,6 +162,10 @@ type Options struct {
CA string `mapstructure:"certificate_authority" yaml:"certificate_authority,omitempty"`
CAFile string `mapstructure:"certificate_authority_file" yaml:"certificate_authority_file,omitempty"`
// DeriveInternalDomainCert is an option that would derive certificate authority
// and domain certificates from the shared key and use them for internal communication
DeriveInternalDomainCert *string `mapstructure:"derive_tls" yaml:"derive_tls,omitempty"`
// SigningKey is the private key used to add a JWT-signature to upstream requests.
// https://www.pomerium.com/docs/topics/getting-users-identity.html
SigningKey string `mapstructure:"signing_key" yaml:"signing_key,omitempty"`
@ -728,6 +732,14 @@ func (o *Options) Validate() error {
return nil
}
// GetDeriveInternalDomain returns an optional internal domain name to use for gRPC endpoint
func (o *Options) GetDeriveInternalDomain() string {
if o.DeriveInternalDomainCert == nil {
return ""
}
return *o.DeriveInternalDomainCert
}
// GetAuthenticateURL returns the AuthenticateURL in the options or 127.0.0.1.
func (o *Options) GetAuthenticateURL() (*url.URL, error) {
rawurl := o.AuthenticateURLString

View file

@ -3,7 +3,10 @@
package fileutil
import (
"bytes"
"errors"
"fmt"
"io"
"os"
)
@ -46,3 +49,37 @@ func Getwd() string {
}
return p
}
// ReadFileUpTo reads file up to given size
// it returns an error if file is larger than allowed maximum
func ReadFileUpTo(fname string, maxSize int64) ([]byte, error) {
var buf bytes.Buffer
if err := CopyFileUpTo(&buf, fname, maxSize); err != nil {
return nil, err
}
return buf.Bytes(), nil
}
// CopyFileUpTo copies content of the file up to maxBytes
// it returns an error if file is larger than allowed maximum
func CopyFileUpTo(dst io.Writer, fname string, maxBytes int64) error {
fd, err := os.Open(fname)
if err != nil {
return fmt.Errorf("open %s: %w", fname, err)
}
defer func() { _ = fd.Close() }()
fi, err := fd.Stat()
if err != nil {
return fmt.Errorf("stat %s: %w", fname, err)
}
if fi.Size() > maxBytes {
return fmt.Errorf("file %s size %d > max %d", fname, fi.Size(), maxBytes)
}
if _, err := io.Copy(dst, fd); err != nil {
return fmt.Errorf("read %s: %w", fname, err)
}
return nil
}

View file

@ -1,8 +1,15 @@
package fileutil
import (
"bytes"
"fmt"
"os"
"path"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestIsReadableFile(t *testing.T) {
@ -45,3 +52,29 @@ func TestGetwd(t *testing.T) {
})
}
}
func TestReadFileUpTo(t *testing.T) {
d := t.TempDir()
input := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}
fname := path.Join(d, "test")
require.NoError(t, os.WriteFile(fname, input, 0600))
for _, tc := range []struct {
size int
expectError bool
}{
{len(input) - 1, true},
{len(input), false},
{len(input) + 1, false},
} {
t.Run(fmt.Sprint(tc), func(t *testing.T) {
out, err := ReadFileUpTo(fname, int64(tc.size))
if tc.expectError {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.True(t, bytes.Equal(input, out))
})
}
}

View file

@ -23,6 +23,7 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/registry"
"github.com/pomerium/pomerium/internal/version"
derivecert_config "github.com/pomerium/pomerium/pkg/derivecert/config"
"github.com/pomerium/pomerium/pkg/envoy"
"github.com/pomerium/pomerium/pkg/envoy/files"
"github.com/pomerium/pomerium/proxy"
@ -35,6 +36,10 @@ func Run(ctx context.Context, src config.Source) error {
Str("version", version.FullVersion()).
Msg("cmd/pomerium")
src, err := config.NewLayeredSource(ctx, src, derivecert_config.NewBuilder())
if err != nil {
return err
}
src = databroker.NewConfigSource(ctx, src)
logMgr := config.NewLogManager(ctx, src)
defer logMgr.Close()
@ -42,7 +47,7 @@ func Run(ctx context.Context, src config.Source) error {
// trigger changes when underlying files are changed
src = config.NewFileWatcherSource(src)
src, err := autocert.New(src)
src, err = autocert.New(src)
if err != nil {
return err
}
@ -57,8 +62,10 @@ func Run(ctx context.Context, src config.Source) error {
eventsMgr := events.New()
cfg := src.GetConfig()
// setup the control plane
controlPlane, err := controlplane.NewServer(src.GetConfig(), metricsMgr, eventsMgr)
controlPlane, err := controlplane.NewServer(cfg, metricsMgr, eventsMgr)
if err != nil {
return fmt.Errorf("error creating control plane: %w", err)
}

View file

@ -0,0 +1,91 @@
// Package config implements derived certs in the Pomerium Configuration
package config
import (
"bytes"
"crypto/tls"
"fmt"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/derivecert"
)
type builder struct {
psk []byte
ca *derivecert.CA
caCertPEM []byte
domain string
certs []tls.Certificate
}
// NewBuilder returns a new derived certs config builder with caching
func NewBuilder() func(*config.Config) error {
return new(builder).Build
}
func (x *builder) Build(cfg *config.Config) error {
if cfg.Options.DeriveInternalDomainCert == nil {
return nil
}
psk, err := cfg.Options.GetSharedKey()
if err != nil {
return fmt.Errorf("shared key: %w", err)
}
if err = x.buildCA(psk); err != nil {
return err
}
if err = x.buildCert(*cfg.Options.DeriveInternalDomainCert); err != nil {
return err
}
cfg.DerivedCAPEM = x.caCertPEM
cfg.DerivedCertificates = x.certs
return nil
}
func (x *builder) buildCA(psk []byte) error {
if bytes.Equal(x.psk, psk) {
return nil
}
ca, err := derivecert.NewCA(psk)
if err != nil {
return fmt.Errorf("building certificate authority from shared key: %w", err)
}
pem, err := ca.PEM()
if err != nil {
return fmt.Errorf("encode derived CA to PEM: %w", err)
}
x.psk = psk
x.ca = ca
x.caCertPEM = pem.Cert
return nil
}
func (x *builder) buildCert(domain string) error {
if x.domain == domain {
return nil
}
certPEM, err := x.ca.NewServerCert([]string{domain})
if err != nil {
return fmt.Errorf("generate cert: %w", err)
}
cert, err := certPEM.TLS()
if err != nil {
return fmt.Errorf("parse TLS cert: %w", err)
}
x.domain = domain
x.certs = []tls.Certificate{cert}
return nil
}

View file

@ -0,0 +1,57 @@
package config_test
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/pkg/cryptutil"
dcfg "github.com/pomerium/pomerium/pkg/derivecert/config"
)
func TestBuild(t *testing.T) {
build := dcfg.NewBuilder()
key := cryptutil.NewBase64Key()
cfgA := config.Config{Options: &config.Options{SharedKey: key}}
t.Run("no domain requested", func(t *testing.T) {
require.NoError(t, build(&cfgA))
assert.Empty(t, cfgA.DerivedCAPEM)
assert.Empty(t, cfgA.DerivedCertificates)
})
cfgA.Options.DeriveInternalDomainCert = proto.String("example.com")
t.Run("generate server cert", func(t *testing.T) {
require.NoError(t, build(&cfgA))
assert.NotEmpty(t, cfgA.DerivedCAPEM)
assert.Len(t, cfgA.DerivedCertificates, 1)
})
cfgB := config.Config{Options: &config.Options{
SharedKey: key,
DeriveInternalDomainCert: proto.String("example.com"),
}}
t.Run("caching", func(t *testing.T) {
require.NoError(t, build(&cfgB))
assert.Equal(t, cfgA.DerivedCAPEM, cfgB.DerivedCAPEM)
assert.Equal(t, cfgA.DerivedCertificates[0].Certificate, cfgB.DerivedCertificates[0].Certificate)
})
t.Run("no domain requested after run", func(t *testing.T) {
cfg := config.Config{Options: &config.Options{SharedKey: key}}
require.NoError(t, build(&cfg))
assert.Empty(t, cfg.DerivedCAPEM)
assert.Empty(t, cfg.DerivedCertificates)
})
cfgB.Options.DeriveInternalDomainCert = proto.String("example2.com")
t.Run("ca caching", func(t *testing.T) {
require.NoError(t, build(&cfgB))
assert.Equal(t, cfgA.DerivedCAPEM, cfgB.DerivedCAPEM)
assert.NotEqual(t, cfgA.DerivedCertificates[0].Certificate, cfgB.DerivedCertificates[0].Certificate)
})
}