mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-25 05:58:16 +02:00
Add client certificate utilities
This commit is contained in:
parent
6ed5752fa5
commit
490a301aa4
10 changed files with 355 additions and 20 deletions
|
@ -21,6 +21,7 @@ import (
|
|||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
"github.com/pomerium/pomerium/internal/testenv/scenarios"
|
||||
"github.com/pomerium/pomerium/internal/testenv/upstreams"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||
|
@ -124,6 +125,46 @@ func TestHTTP(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestClientCert(t *testing.T) {
|
||||
env := testenv.New(t)
|
||||
env.Add(scenarios.DownstreamMTLS(config.MTLSEnforcementRejectConnection))
|
||||
|
||||
up := upstreams.HTTP(nil)
|
||||
up.Handle("/foo", func(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprintln(w, "hello world")
|
||||
})
|
||||
|
||||
clientCert := env.NewClientCert()
|
||||
|
||||
route := up.Route().
|
||||
From(env.SubdomainURL("http")).
|
||||
PPL(fmt.Sprintf(`{"allow":{"and":["client_certificate":{"fingerprint":%q}]}}`, clientCert.Fingerprint()))
|
||||
|
||||
env.AddUpstream(up)
|
||||
env.Start()
|
||||
|
||||
recorder := env.NewLogRecorder()
|
||||
|
||||
resp, err := up.Get(route, upstreams.Path("/foo"), upstreams.ClientCert(clientCert))
|
||||
require.NoError(t, err)
|
||||
|
||||
defer resp.Body.Close()
|
||||
data, err := io.ReadAll(resp.Body)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "hello world\n", string(data))
|
||||
|
||||
recorder.Match([]map[string]any{
|
||||
{
|
||||
"service": "envoy",
|
||||
"path": "/foo",
|
||||
"method": "GET",
|
||||
"message": "http-request",
|
||||
"response-code-details": "via_upstream",
|
||||
"client-certificate": clientCert,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func TestH2C(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.SkipNow()
|
||||
|
|
|
@ -69,10 +69,17 @@ type Server struct {
|
|||
}
|
||||
|
||||
// NewServer creates a new Server. Listener ports are chosen by the OS.
|
||||
func NewServer(ctx context.Context, cfg *config.Config, metricsMgr *config.MetricsManager, eventsMgr *events.Manager) (*Server, error) {
|
||||
func NewServer(
|
||||
ctx context.Context,
|
||||
cfg *config.Config,
|
||||
metricsMgr *config.MetricsManager,
|
||||
eventsMgr *events.Manager,
|
||||
fileMgr *filemgr.Manager,
|
||||
) (*Server, error) {
|
||||
srv := &Server{
|
||||
metricsMgr: metricsMgr,
|
||||
EventsMgr: eventsMgr,
|
||||
filemgr: fileMgr,
|
||||
reproxy: reproxy.New(),
|
||||
haveSetCapacity: map[string]bool{},
|
||||
updateConfig: make(chan *config.Config, 1),
|
||||
|
@ -149,7 +156,6 @@ func NewServer(ctx context.Context, cfg *config.Config, metricsMgr *config.Metri
|
|||
// metrics
|
||||
srv.MetricsRouter.Handle("/metrics", srv.metricsMgr)
|
||||
|
||||
srv.filemgr = filemgr.NewManager()
|
||||
srv.filemgr.ClearCache()
|
||||
|
||||
srv.Builder = envoyconfig.New(
|
||||
|
|
|
@ -12,6 +12,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||
"github.com/pomerium/pomerium/internal/events"
|
||||
"github.com/pomerium/pomerium/pkg/netutil"
|
||||
)
|
||||
|
@ -38,7 +39,7 @@ func TestServerHTTP(t *testing.T) {
|
|||
cfg.Options.SharedKey = "JDNjY2ITDlARvNaQXjc2Djk+GA6xeCy4KiozmZfdbTs="
|
||||
|
||||
src := config.NewStaticSource(cfg)
|
||||
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New())
|
||||
srv, err := NewServer(ctx, cfg, config.NewMetricsManager(ctx, src), events.New(), filemgr.NewManager(filemgr.WithCacheDir(t.TempDir())))
|
||||
require.NoError(t, err)
|
||||
go srv.Run(ctx)
|
||||
|
||||
|
|
|
@ -2,10 +2,20 @@ package testenv
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/sha256"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"crypto/x509/pkix"
|
||||
"encoding/asn1"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
|
@ -13,13 +23,16 @@ import (
|
|||
"strconv"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||
"github.com/pomerium/pomerium/internal/log"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"github.com/pomerium/pomerium/pkg/cmd/pomerium"
|
||||
"github.com/pomerium/pomerium/pkg/health"
|
||||
"github.com/pomerium/pomerium/pkg/netutil"
|
||||
"github.com/pomerium/pomerium/pkg/slices"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
@ -33,14 +46,33 @@ type Environment interface {
|
|||
// top-level logger scoped to this environment. It will be canceled when
|
||||
// Stop() is called, or during test cleanup.
|
||||
Context() context.Context
|
||||
|
||||
Assert() *assert.Assertions
|
||||
Require() *require.Assertions
|
||||
|
||||
// TempDir returns a unique temp directory for this context. Calling this
|
||||
// function multiple times returns the same path.
|
||||
TempDir() string
|
||||
// CACert returns the test environment's root CA certificate and private key.
|
||||
CACert() *tls.Certificate
|
||||
// ServerCAs returns a new [*x509.CertPool] containing the root CA certificate
|
||||
// used to sign the server cert and other test certificates.
|
||||
ServerCAs() *x509.CertPool
|
||||
// ServerCert returns the Pomerium server's certificate.
|
||||
// ServerCert returns the Pomerium server's certificate and private key.
|
||||
ServerCert() *tls.Certificate
|
||||
// NewClientCert generates a new client certificate signed by the root CA
|
||||
// certificate. One or more optional templates can be given, which can be
|
||||
// used to set or override certain parameters when creating a certificate,
|
||||
// including subject, SANs, or extensions. If more than one template is
|
||||
// provided, they will be applied in order from left to right.
|
||||
//
|
||||
// By default (unless overridden in a template), the certificate will have
|
||||
// its Common Name set to the file:line string of the call site. Calls to
|
||||
// NewClientCert() on different lines will have different subjects. If
|
||||
// multiple certs with the same subject are needed, wrap the call to this
|
||||
// function in another helper function, or separate calls with commas on the
|
||||
// same line.
|
||||
NewClientCert(templateOverrides ...*x509.Certificate) *Certificate
|
||||
|
||||
// Add adds the given [Modifier] to the environment. All modifiers will be
|
||||
// invoked upon calling Start() to apply individual modifications to the
|
||||
|
@ -80,8 +112,22 @@ type Environment interface {
|
|||
NewLogRecorder(opts ...LogRecorderOption) *LogRecorder
|
||||
}
|
||||
|
||||
type Certificate tls.Certificate
|
||||
|
||||
func (c *Certificate) Fingerprint() string {
|
||||
sum := sha256.Sum256(c.Leaf.Raw)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func (c *Certificate) SPKIHash() string {
|
||||
sum := sha256.Sum256(c.Leaf.RawSubjectPublicKeyInfo)
|
||||
return base64.StdEncoding.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
type environment struct {
|
||||
t testing.TB
|
||||
assert *assert.Assertions
|
||||
require *require.Assertions
|
||||
tempDir string
|
||||
domain string
|
||||
ports Ports
|
||||
|
@ -125,6 +171,8 @@ func New(t testing.TB) Environment {
|
|||
|
||||
e := &environment{
|
||||
t: t,
|
||||
assert: assert.New(t),
|
||||
require: require.New(t),
|
||||
tempDir: t.TempDir(),
|
||||
ports: Ports{
|
||||
http: values.Deferred[int](),
|
||||
|
@ -141,13 +189,14 @@ func New(t testing.TB) Environment {
|
|||
copyFile := func(src, dstRel string) {
|
||||
data, err := os.ReadFile(src)
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o666))
|
||||
require.NoError(t, os.WriteFile(filepath.Join(e.tempDir, dstRel), data, 0o600))
|
||||
}
|
||||
|
||||
certsToCopy := []string{
|
||||
"trusted.pem",
|
||||
"trusted-key.pem",
|
||||
"ca.pem",
|
||||
"ca-key.pem",
|
||||
}
|
||||
for _, crt := range certsToCopy {
|
||||
copyFile(filepath.Join(workspaceFolder, "integration/tpl/files", crt), filepath.Join("certs/", filepath.Base(crt)))
|
||||
|
@ -174,12 +223,29 @@ func (e *environment) Context() context.Context {
|
|||
return ContextWithEnv(e.ctx, e)
|
||||
}
|
||||
|
||||
func (e *environment) Assert() *assert.Assertions {
|
||||
return e.assert
|
||||
}
|
||||
|
||||
func (e *environment) Require() *require.Assertions {
|
||||
return e.require
|
||||
}
|
||||
|
||||
func (e *environment) SubdomainURL(subdomain string) values.Value[string] {
|
||||
return values.Bind(e.ports.http, func(port int) string {
|
||||
return fmt.Sprintf("https://%s.%s:%d", subdomain, e.domain, port)
|
||||
})
|
||||
}
|
||||
|
||||
func (e *environment) CACert() *tls.Certificate {
|
||||
caCert, err := tls.LoadX509KeyPair(
|
||||
filepath.Join(e.tempDir, "certs", "ca.pem"),
|
||||
filepath.Join(e.tempDir, "certs", "ca-key.pem"),
|
||||
)
|
||||
require.NoError(e.t, err)
|
||||
return &caCert
|
||||
}
|
||||
|
||||
func (e *environment) ServerCAs() *x509.CertPool {
|
||||
pool := x509.NewCertPool()
|
||||
caCert, err := os.ReadFile(filepath.Join(e.tempDir, "certs", "ca.pem"))
|
||||
|
@ -213,20 +279,39 @@ func (e *environment) Start() {
|
|||
require.NoError(e.t, err)
|
||||
port0, _ := strconv.Atoi(ports[0])
|
||||
e.ports.http.Resolve(port0)
|
||||
cfg.Options.AutocertOptions = config.AutocertOptions{Enable: false}
|
||||
cfg.Options.LogLevel = config.LogLevelInfo
|
||||
cfg.Options.ProxyLogLevel = config.LogLevelInfo
|
||||
cfg.Options.Addr = fmt.Sprintf("127.0.0.1:%d", port0)
|
||||
cfg.Options.CertFile = filepath.Join(e.tempDir, "certs", "trusted.pem")
|
||||
cfg.Options.KeyFile = filepath.Join(e.tempDir, "certs", "trusted-key.pem")
|
||||
cfg.Options.AccessLogFields = []log.AccessLogField{
|
||||
log.AccessLogFieldAuthority,
|
||||
log.AccessLogFieldDuration,
|
||||
log.AccessLogFieldForwardedFor,
|
||||
log.AccessLogFieldIP,
|
||||
log.AccessLogFieldMethod,
|
||||
log.AccessLogFieldPath,
|
||||
log.AccessLogFieldQuery,
|
||||
log.AccessLogFieldReferer,
|
||||
log.AccessLogFieldRequestID,
|
||||
log.AccessLogFieldResponseCode,
|
||||
log.AccessLogFieldResponseCodeDetails,
|
||||
log.AccessLogFieldSize,
|
||||
log.AccessLogFieldUpstreamCluster,
|
||||
log.AccessLogFieldUserAgent,
|
||||
log.AccessLogFieldClientCertificate,
|
||||
}
|
||||
cfg.AllocatePorts(*(*[6]string)(ports[1:]))
|
||||
|
||||
e.AddTask(TaskFunc(func(ctx context.Context) error {
|
||||
fileMgr := filemgr.NewManager(filemgr.WithCacheDir(filepath.Join(e.TempDir(), "cache")))
|
||||
src := config.NewStaticSource(cfg)
|
||||
for _, mod := range e.mods {
|
||||
mod.Value.Modify(cfg)
|
||||
require.NoError(e.t, cfg.Options.Validate(), "invoking modifier resulted in an invalid configuration:\nadded by: "+mod.Caller)
|
||||
}
|
||||
return pomerium.Run(e.ctx, src)
|
||||
return pomerium.Run(e.ctx, src, pomerium.WithOverrideFileManager(fileMgr))
|
||||
}))
|
||||
|
||||
for i, task := range e.tasks {
|
||||
|
@ -238,6 +323,64 @@ func (e *environment) Start() {
|
|||
}
|
||||
}
|
||||
|
||||
func (e *environment) NewClientCert(templateOverrides ...*x509.Certificate) *Certificate {
|
||||
caCert := e.CACert()
|
||||
|
||||
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(e.t, err)
|
||||
|
||||
sn, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
|
||||
require.NoError(e.t, err)
|
||||
now := time.Now()
|
||||
tmpl := &x509.Certificate{
|
||||
SerialNumber: sn,
|
||||
Subject: pkix.Name{
|
||||
CommonName: getCaller(),
|
||||
},
|
||||
NotBefore: now,
|
||||
NotAfter: now.Add(12 * time.Hour),
|
||||
KeyUsage: x509.KeyUsageDigitalSignature,
|
||||
ExtKeyUsage: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
BasicConstraintsValid: true,
|
||||
}
|
||||
for _, override := range templateOverrides {
|
||||
tmpl.CRLDistributionPoints = slices.Unique(append(tmpl.CRLDistributionPoints, override.CRLDistributionPoints...))
|
||||
tmpl.DNSNames = slices.Unique(append(tmpl.DNSNames, override.DNSNames...))
|
||||
tmpl.EmailAddresses = slices.Unique(append(tmpl.EmailAddresses, override.EmailAddresses...))
|
||||
tmpl.ExtraExtensions = append(tmpl.ExtraExtensions, override.ExtraExtensions...)
|
||||
tmpl.IPAddresses = slices.UniqueBy(append(tmpl.IPAddresses, override.IPAddresses...), net.IP.String)
|
||||
tmpl.URIs = slices.UniqueBy(append(tmpl.URIs, override.URIs...), (*url.URL).String)
|
||||
tmpl.UnknownExtKeyUsage = slices.UniqueBy(append(tmpl.UnknownExtKeyUsage, override.UnknownExtKeyUsage...), asn1.ObjectIdentifier.String)
|
||||
seq := override.Subject.ToRDNSequence()
|
||||
tmpl.Subject.FillFromRDNSequence(&seq)
|
||||
tmpl.KeyUsage |= override.KeyUsage
|
||||
tmpl.ExtKeyUsage = slices.Unique(append(tmpl.ExtKeyUsage, override.ExtKeyUsage...))
|
||||
}
|
||||
|
||||
clientCertDER, err := x509.CreateCertificate(rand.Reader, tmpl, caCert.Leaf, priv.Public(), caCert.PrivateKey)
|
||||
require.NoError(e.t, err)
|
||||
|
||||
cert, err := x509.ParseCertificate(clientCertDER)
|
||||
require.NoError(e.t, err)
|
||||
|
||||
clientCert := &tls.Certificate{
|
||||
Certificate: [][]byte{cert.Raw, caCert.Leaf.Raw},
|
||||
PrivateKey: priv,
|
||||
Leaf: cert,
|
||||
}
|
||||
|
||||
_, err = clientCert.Leaf.Verify(x509.VerifyOptions{
|
||||
KeyUsages: []x509.ExtKeyUsage{
|
||||
x509.ExtKeyUsageClientAuth,
|
||||
},
|
||||
Roots: e.ServerCAs(),
|
||||
})
|
||||
require.NoError(e.t, err, "bug: generated client cert is not valid")
|
||||
return (*Certificate)(clientCert)
|
||||
}
|
||||
|
||||
func (e *environment) Stop() {
|
||||
e.cleanupOnce.Do(func() {
|
||||
e.cancel(ErrCauseManualStop)
|
||||
|
|
|
@ -4,8 +4,11 @@ import (
|
|||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
@ -115,9 +118,14 @@ func (lr *LogRecorder) Logs() []map[string]any {
|
|||
// Match stops the log recorder (if it is not already stopped), then asserts
|
||||
// that the given expected logs were captured. The expected logs may contain
|
||||
// partial or complete log entries. By default, logs must only match the fields
|
||||
// given, and may contain additional fields that will be ignored. For details,
|
||||
// see [OpenMap] and [ClosedMap]. As a special case, using [json.Number] as the
|
||||
// expected value will convert the actual value to a string before comparison.
|
||||
// given, and may contain additional fields that will be ignored.
|
||||
//
|
||||
// There are several special-case value types that can be used to customize the
|
||||
// matching behavior, and/or simplify some common use cases, as follows:
|
||||
// - [OpenMap] and [ClosedMap] can be used to control matching logic
|
||||
// - [json.Number] will convert the actual value to a string before comparison
|
||||
// - [*tls.Certificate] or [*x509.Certificate] will expand to the fields that
|
||||
// would be logged for this certificate
|
||||
func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
|
||||
lr.collectLogs()
|
||||
var match func(expected, actual map[string]any, open bool) (bool, int)
|
||||
|
@ -132,15 +140,50 @@ func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
|
|||
|
||||
switch actualValue := actualValue.(type) {
|
||||
case map[string]any:
|
||||
switch value := value.(type) {
|
||||
switch expectedValue := value.(type) {
|
||||
case ClosedMap:
|
||||
ok, s := match(value, actualValue, false)
|
||||
ok, s := match(expectedValue, actualValue, false)
|
||||
score += s * 2
|
||||
if !ok {
|
||||
return false, score
|
||||
}
|
||||
case OpenMap:
|
||||
ok, s := match(value, actualValue, true)
|
||||
ok, s := match(expectedValue, actualValue, true)
|
||||
score += s
|
||||
if !ok {
|
||||
return false, score
|
||||
}
|
||||
case *tls.Certificate, *Certificate, *x509.Certificate:
|
||||
var leaf *x509.Certificate
|
||||
switch expectedValue := expectedValue.(type) {
|
||||
case *tls.Certificate:
|
||||
leaf = expectedValue.Leaf
|
||||
case *Certificate:
|
||||
leaf = expectedValue.Leaf
|
||||
case *x509.Certificate:
|
||||
leaf = expectedValue
|
||||
}
|
||||
|
||||
// keep logic consistent with controlplane.populateCertEventDict()
|
||||
expected := map[string]any{}
|
||||
if iss := leaf.Issuer.String(); iss != "" {
|
||||
expected["issuer"] = iss
|
||||
}
|
||||
if sub := leaf.Subject.String(); sub != "" {
|
||||
expected["subject"] = sub
|
||||
}
|
||||
sans := []string{}
|
||||
for _, dnsSAN := range leaf.DNSNames {
|
||||
sans = append(sans, "DNS:"+dnsSAN)
|
||||
}
|
||||
for _, uriSAN := range leaf.URIs {
|
||||
sans = append(sans, "URI:"+uriSAN.String())
|
||||
}
|
||||
if len(sans) > 0 {
|
||||
expected["subjectAltName"] = sans
|
||||
}
|
||||
|
||||
ok, s := match(expected, actualValue, false)
|
||||
score += s
|
||||
if !ok {
|
||||
return false, score
|
||||
|
@ -164,7 +207,23 @@ func (lr *LogRecorder) Match(expectedLogs []map[string]any) {
|
|||
}
|
||||
score++
|
||||
default:
|
||||
panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue))
|
||||
// handle slices
|
||||
if reflect.TypeOf(actualValue).Kind() == reflect.Slice {
|
||||
if reflect.TypeOf(value) != reflect.TypeOf(actualValue) {
|
||||
return false, score
|
||||
}
|
||||
actualSlice := reflect.ValueOf(actualValue)
|
||||
expectedSlice := reflect.ValueOf(value)
|
||||
totalScore := 0
|
||||
for i := range min(actualSlice.Len(), expectedSlice.Len()) {
|
||||
if actualSlice.Index(i).Equal(expectedSlice.Index(i)) {
|
||||
totalScore++
|
||||
}
|
||||
}
|
||||
score += totalScore
|
||||
} else {
|
||||
panic(fmt.Sprintf("test bug: add check for type %T in assertMatchingLogs", actualValue))
|
||||
}
|
||||
}
|
||||
}
|
||||
if !open && len(expected) != len(actual) {
|
||||
|
|
|
@ -2,9 +2,11 @@ package testenv
|
|||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testenv/values"
|
||||
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||
)
|
||||
|
||||
// PolicyRoute is a [Route] implementation suitable for most common use cases
|
||||
|
@ -54,6 +56,20 @@ func (b *PolicyRoute) Policy(edit func(*config.Policy)) Route {
|
|||
return b
|
||||
}
|
||||
|
||||
// PPL implements Route.
|
||||
func (b *PolicyRoute) PPL(ppl string) Route {
|
||||
pplPolicy, err := parser.ParseYAML(strings.NewReader(ppl))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
b.edits = append(b.edits, func(p *config.Policy) {
|
||||
p.Policy = &config.PPLPolicy{
|
||||
Policy: pplPolicy,
|
||||
}
|
||||
})
|
||||
return b
|
||||
}
|
||||
|
||||
// To implements Route.
|
||||
func (b *PolicyRoute) URL() values.Value[string] {
|
||||
return b.from
|
||||
|
|
24
internal/testenv/scenarios/mtls.go
Normal file
24
internal/testenv/scenarios/mtls.go
Normal file
|
@ -0,0 +1,24 @@
|
|||
package scenarios
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/internal/testenv"
|
||||
)
|
||||
|
||||
func DownstreamMTLS(mode config.MTLSEnforcement) testenv.Modifier {
|
||||
return testenv.ModifierFunc(func(ctx context.Context, cfg *config.Config) {
|
||||
env := testenv.EnvFromContext(ctx)
|
||||
block := pem.Block{
|
||||
Type: "CERTIFICATE",
|
||||
Bytes: env.CACert().Leaf.Raw,
|
||||
}
|
||||
cfg.Options.DownstreamMTLS = config.DownstreamMTLSSettings{
|
||||
CA: base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&block)),
|
||||
Enforcement: mode,
|
||||
}
|
||||
})
|
||||
}
|
|
@ -77,10 +77,24 @@ func (m Modifiers) Modify(cfg *config.Config) {
|
|||
}
|
||||
}
|
||||
|
||||
type ModifierFunc func(cfg *config.Config)
|
||||
type modifierFunc struct {
|
||||
fn func(ctx context.Context, cfg *config.Config)
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func (f ModifierFunc) Modify(cfg *config.Config) {
|
||||
f(cfg)
|
||||
// Attach implements Modifier.
|
||||
func (f *modifierFunc) Attach(ctx context.Context) {
|
||||
f.ctx = ctx
|
||||
}
|
||||
|
||||
func (f *modifierFunc) Modify(cfg *config.Config) {
|
||||
f.fn(f.ctx, cfg)
|
||||
}
|
||||
|
||||
var _ Modifier = (*modifierFunc)(nil)
|
||||
|
||||
func ModifierFunc(fn func(ctx context.Context, cfg *config.Config)) Modifier {
|
||||
return &modifierFunc{fn: fn}
|
||||
}
|
||||
|
||||
// Task represents a background task that can be added to an [Environment] to
|
||||
|
@ -116,6 +130,7 @@ type Route interface {
|
|||
URL() values.Value[string]
|
||||
To(toUrl values.Value[string]) Route
|
||||
Policy(edit func(*config.Policy)) Route
|
||||
PPL(ppl string) Route
|
||||
// add more methods here as they become needed
|
||||
}
|
||||
|
||||
|
|
|
@ -77,9 +77,11 @@ func Body(body any) RequestOption {
|
|||
}
|
||||
|
||||
// ClientCert adds a client certificate to the request.
|
||||
func ClientCert(cert tls.Certificate) RequestOption {
|
||||
func ClientCert[T interface {
|
||||
*testenv.Certificate | *tls.Certificate
|
||||
}](cert T) RequestOption {
|
||||
return func(o *RequestOptions) {
|
||||
o.clientCerts = append(o.clientCerts, cert)
|
||||
o.clientCerts = append(o.clientCerts, *(*tls.Certificate)(cert))
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -250,6 +252,7 @@ func (h *httpUpstream) Do(method string, r testenv.Route, opts ...RequestOption)
|
|||
}
|
||||
return retry.NewTerminalError(err)
|
||||
}
|
||||
resp.Body.Close()
|
||||
return nil
|
||||
}, retry.WithMaxInterval(1*time.Second)); err != nil {
|
||||
return nil, err
|
||||
|
|
|
@ -16,6 +16,7 @@ import (
|
|||
"github.com/pomerium/pomerium/authenticate"
|
||||
"github.com/pomerium/pomerium/authorize"
|
||||
"github.com/pomerium/pomerium/config"
|
||||
"github.com/pomerium/pomerium/config/envoyconfig/filemgr"
|
||||
databroker_service "github.com/pomerium/pomerium/databroker"
|
||||
"github.com/pomerium/pomerium/internal/autocert"
|
||||
"github.com/pomerium/pomerium/internal/controlplane"
|
||||
|
@ -30,8 +31,29 @@ import (
|
|||
"github.com/pomerium/pomerium/proxy"
|
||||
)
|
||||
|
||||
type RunOptions struct {
|
||||
fileMgr *filemgr.Manager
|
||||
}
|
||||
|
||||
type RunOption func(*RunOptions)
|
||||
|
||||
func (o *RunOptions) apply(opts ...RunOption) {
|
||||
for _, op := range opts {
|
||||
op(o)
|
||||
}
|
||||
}
|
||||
|
||||
func WithOverrideFileManager(fileMgr *filemgr.Manager) RunOption {
|
||||
return func(o *RunOptions) {
|
||||
o.fileMgr = fileMgr
|
||||
}
|
||||
}
|
||||
|
||||
// Run runs the main pomerium application.
|
||||
func Run(ctx context.Context, src config.Source) error {
|
||||
func Run(ctx context.Context, src config.Source, opts ...RunOption) error {
|
||||
options := RunOptions{}
|
||||
options.apply(opts...)
|
||||
|
||||
_, _ = maxprocs.Set(maxprocs.Logger(func(s string, i ...any) { log.Debug(context.Background()).Msgf(s, i...) }))
|
||||
|
||||
evt := log.Info(ctx).
|
||||
|
@ -68,10 +90,15 @@ func Run(ctx context.Context, src config.Source) error {
|
|||
|
||||
eventsMgr := events.New()
|
||||
|
||||
fileMgr := options.fileMgr
|
||||
if fileMgr == nil {
|
||||
fileMgr = filemgr.NewManager()
|
||||
}
|
||||
|
||||
cfg := src.GetConfig()
|
||||
|
||||
// setup the control plane
|
||||
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr)
|
||||
controlPlane, err := controlplane.NewServer(ctx, cfg, metricsMgr, eventsMgr, fileMgr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating control plane: %w", err)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue