Core-Zero Import (#5288)

* initial core-zero import implementation

* Update /config/import openapi description and use PUT instead of POST

* update import ui tests

* Add 413 as a possible response for /config/import

* Options/Settings type conversion tests and related bugfixes

* Fixes for proto type conversion and tests

* Update core-zero import client

* Update core-zero import client

* Update import api and environment detection

* update go.mod

* remove old testdata

* Remove usage of deleted setting after merge

* remove extra newline from --version output
This commit is contained in:
Joe Kralicky 2024-10-09 18:51:56 -04:00 committed by GitHub
parent 5b4fe8969d
commit 0e13248685
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 3193 additions and 700 deletions

View file

@ -3,11 +3,12 @@ package main
import (
"context"
"errors"
"flag"
"fmt"
"os"
"strings"
"github.com/rs/zerolog"
"github.com/spf13/cobra"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
@ -19,43 +20,65 @@ import (
"github.com/pomerium/pomerium/pkg/envoy/files"
)
var (
versionFlag = flag.Bool("version", false, "prints the version")
configFile = flag.String("config", "", "Specify configuration file location")
)
func main() {
flag.Parse()
if *versionFlag {
fmt.Println("pomerium:", version.FullVersion())
fmt.Println("envoy:", files.FullVersion())
return
convertOldStyleFlags()
var configFile string
root := &cobra.Command{
Use: "pomerium",
Version: fmt.Sprintf("pomerium: %s\nenvoy: %s", version.FullVersion(), files.FullVersion()),
SilenceUsage: true,
}
root.AddCommand(zero_cmd.BuildRootCmd())
root.PersistentFlags().StringVar(&configFile, "config", "", "Specify configuration file location")
ctx := context.Background()
log.SetLevel(zerolog.InfoLevel)
runFn := run
if zero_cmd.IsManagedMode(*configFile) {
runFn = func(ctx context.Context) error { return zero_cmd.Run(ctx, *configFile) }
if zero_cmd.IsManagedMode(configFile) {
runFn = zero_cmd.Run
}
root.RunE = func(_ *cobra.Command, _ []string) error {
defer log.Info(ctx).Msg("cmd/pomerium: exiting")
return runFn(ctx, configFile)
}
if err := runFn(ctx); err != nil && !errors.Is(err, context.Canceled) {
if err := root.ExecuteContext(ctx); err != nil {
log.Fatal().Err(err).Msg("cmd/pomerium")
}
log.Info(ctx).Msg("cmd/pomerium: exiting")
}
func run(ctx context.Context) error {
func run(ctx context.Context, configFile string) error {
ctx = log.WithContext(ctx, func(c zerolog.Context) zerolog.Context {
return c.Str("config_file_source", *configFile).Bool("bootstrap", true)
return c.Str("config_file_source", configFile).Bool("bootstrap", true)
})
var src config.Source
src, err := config.NewFileOrEnvironmentSource(*configFile, files.FullVersion())
src, err := config.NewFileOrEnvironmentSource(configFile, files.FullVersion())
if err != nil {
return err
}
return pomerium.Run(ctx, src)
}
// Converts the "-config" and "-version" single-dash style flags to the
// equivalent "--config" and "--version" flags compatible with cobra. These
// are the only two flags that existed previously, so we don't need to check
// for any others.
func convertOldStyleFlags() {
for i, arg := range os.Args {
var found bool
if arg == "-config" || strings.HasPrefix(arg, "-config=") {
found = true
fmt.Fprintln(os.Stderr, "Warning: syntax '-config' is deprecated, use '--config' instead")
} else if arg == "-version" {
found = true
// don't log a warning here, since it could interfere with tools that
// parse the -version output
}
if found {
os.Args[i] = "-" + arg
}
}
}

View file

@ -4,6 +4,7 @@ import (
"encoding/base64"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/crypt"
)
// A PublicKeyEncryptionKeyOptions represents options for a public key encryption key.
@ -24,3 +25,17 @@ func (o *Options) GetAuditKey() (*cryptutil.PublicKeyEncryptionKey, error) {
}
return cryptutil.NewPublicKeyEncryptionKeyWithID(o.AuditKey.ID, raw)
}
func (o *PublicKeyEncryptionKeyOptions) ToProto() *crypt.PublicKeyEncryptionKey {
if o == nil {
return nil
}
decoded, err := base64.StdEncoding.DecodeString(o.Data)
if err != nil {
return nil
}
return &crypt.PublicKeyEncryptionKey{
Id: o.ID,
Data: decoded,
}
}

View file

@ -190,6 +190,87 @@ func (s *DownstreamMTLSSettings) applySettingsProto(
set(&s.CA, p.Ca)
set(&s.CRL, p.Crl)
s.Enforcement = mtlsEnforcementFromProtoEnum(ctx, p.Enforcement)
s.MatchSubjectAltNames = make([]SANMatcher, 0, len(p.MatchSubjectAltNames))
for _, san := range p.MatchSubjectAltNames {
var sanType SANType
switch san.GetSanType() {
case config.SANMatcher_DNS:
sanType = SANTypeDNS
case config.SANMatcher_EMAIL:
sanType = SANTypeEmail
case config.SANMatcher_IP_ADDRESS:
sanType = SANTypeIPAddress
case config.SANMatcher_URI:
sanType = SANTypeURI
case config.SANMatcher_USER_PRINCIPAL_NAME:
sanType = SANTypeUserPrincipalName
}
s.MatchSubjectAltNames = append(s.MatchSubjectAltNames, SANMatcher{
Type: sanType,
Pattern: san.GetPattern(),
})
}
s.MaxVerifyDepth = p.MaxVerifyDepth
}
func (s *DownstreamMTLSSettings) ToProto() *config.DownstreamMtlsSettings {
if s == nil {
return nil
}
var settings config.DownstreamMtlsSettings
var hasAnyFields bool
if ca, err := s.GetCA(); err == nil && len(ca) > 0 {
hasAnyFields = true
caStr := base64.StdEncoding.EncodeToString(ca)
settings.Ca = &caStr
}
if crl, err := s.GetCRL(); err == nil && len(crl) > 0 {
hasAnyFields = true
crlStr := base64.StdEncoding.EncodeToString(crl)
settings.Crl = &crlStr
}
if s.Enforcement != "" {
hasAnyFields = true
switch s.Enforcement {
case MTLSEnforcementPolicy:
settings.Enforcement = config.MtlsEnforcementMode_POLICY.Enum()
case MTLSEnforcementPolicyWithDefaultDeny:
settings.Enforcement = config.MtlsEnforcementMode_POLICY_WITH_DEFAULT_DENY.Enum()
case MTLSEnforcementRejectConnection:
settings.Enforcement = config.MtlsEnforcementMode_REJECT_CONNECTION.Enum()
default:
settings.Enforcement = config.MtlsEnforcementMode_UNKNOWN.Enum()
}
}
for _, san := range s.MatchSubjectAltNames {
hasAnyFields = true
var sanType config.SANMatcher_SANType
switch san.Type {
case SANTypeDNS:
sanType = config.SANMatcher_DNS
case SANTypeEmail:
sanType = config.SANMatcher_EMAIL
case SANTypeIPAddress:
sanType = config.SANMatcher_IP_ADDRESS
case SANTypeURI:
sanType = config.SANMatcher_URI
case SANTypeUserPrincipalName:
sanType = config.SANMatcher_USER_PRINCIPAL_NAME
default:
sanType = config.SANMatcher_SAN_TYPE_UNSPECIFIED
}
settings.MatchSubjectAltNames = append(settings.MatchSubjectAltNames, &config.SANMatcher{
SanType: sanType,
Pattern: san.Pattern,
})
}
settings.MaxVerifyDepth = s.MaxVerifyDepth
hasAnyFields = hasAnyFields || s.MaxVerifyDepth != nil
if !hasAnyFields {
return nil
}
return &settings
}
func mtlsEnforcementFromProtoEnum(

View file

@ -40,6 +40,7 @@ import (
"github.com/pomerium/pomerium/pkg/hpke"
"github.com/pomerium/pomerium/pkg/identity/oauth"
"github.com/pomerium/pomerium/pkg/identity/oauth/apple"
"github.com/pomerium/pomerium/pkg/policy/parser"
)
// DisableHeaderKey is the key used to check whether to disable setting header
@ -1544,6 +1545,11 @@ func (o *Options) ApplySettings(ctx context.Context, certsIndex *cryptutil.Certi
set(&o.AutocertOptions.TrustedCA, settings.AutocertTrustedCa)
set(&o.SkipXffAppend, settings.SkipXffAppend)
set(&o.XffNumTrustedHops, settings.XffNumTrustedHops)
set(&o.EnvoyAdminAccessLogPath, settings.EnvoyAdminAccessLogPath)
set(&o.EnvoyAdminProfilePath, settings.EnvoyAdminProfilePath)
set(&o.EnvoyAdminAddress, settings.EnvoyAdminAddress)
set(&o.EnvoyBindConfigSourceAddress, settings.EnvoyBindConfigSourceAddress)
o.EnvoyBindConfigFreebind = null.BoolFromPtr(settings.EnvoyBindConfigFreebind)
setSlice(&o.ProgrammaticRedirectDomainWhitelist, settings.ProgrammaticRedirectDomainWhitelist)
setAuditKey(&o.AuditKey, settings.AuditKey)
setCodecType(&o.CodecType, settings.CodecType)
@ -1554,6 +1560,250 @@ func (o *Options) ApplySettings(ctx context.Context, certsIndex *cryptutil.Certi
})
}
func (o *Options) ToProto() *config.Config {
var settings config.Settings
copySrcToOptionalDest(&settings.InstallationId, &o.InstallationID)
copySrcToOptionalDest(&settings.LogLevel, (*string)(&o.LogLevel))
settings.AccessLogFields = toStringList(o.AccessLogFields)
settings.AuthorizeLogFields = toStringList(o.AuthorizeLogFields)
copySrcToOptionalDest(&settings.ProxyLogLevel, (*string)(&o.ProxyLogLevel))
copySrcToOptionalDest(&settings.SharedSecret, valueOrFromFileBase64(o.SharedKey, o.SharedSecretFile))
copySrcToOptionalDest(&settings.Services, &o.Services)
copySrcToOptionalDest(&settings.Address, &o.Addr)
copySrcToOptionalDest(&settings.InsecureServer, &o.InsecureServer)
copySrcToOptionalDest(&settings.DnsLookupFamily, &o.DNSLookupFamily)
settings.Certificates = getCertificates(o)
copySrcToOptionalDest(&settings.HttpRedirectAddr, &o.HTTPRedirectAddr)
copyOptionalDuration(&settings.TimeoutRead, o.ReadTimeout)
copyOptionalDuration(&settings.TimeoutWrite, o.WriteTimeout)
copyOptionalDuration(&settings.TimeoutIdle, o.IdleTimeout)
copySrcToOptionalDest(&settings.AuthenticateServiceUrl, &o.AuthenticateURLString)
copySrcToOptionalDest(&settings.AuthenticateInternalServiceUrl, &o.AuthenticateInternalURLString)
copySrcToOptionalDest(&settings.SignoutRedirectUrl, &o.SignOutRedirectURLString)
copySrcToOptionalDest(&settings.AuthenticateCallbackPath, &o.AuthenticateCallbackPath)
copySrcToOptionalDest(&settings.CookieName, &o.CookieName)
copySrcToOptionalDest(&settings.CookieSecret, valueOrFromFileBase64(o.CookieSecret, o.CookieSecretFile))
copySrcToOptionalDest(&settings.CookieDomain, &o.CookieDomain)
copySrcToOptionalDest(&settings.CookieHttpOnly, &o.CookieHTTPOnly)
copyOptionalDuration(&settings.CookieExpire, o.CookieExpire)
copySrcToOptionalDest(&settings.CookieSameSite, &o.CookieSameSite)
copySrcToOptionalDest(&settings.IdpClientId, &o.ClientID)
copySrcToOptionalDest(&settings.IdpClientSecret, valueOrFromFileBase64(o.ClientSecret, o.ClientSecretFile))
copySrcToOptionalDest(&settings.IdpProvider, &o.Provider)
copySrcToOptionalDest(&settings.IdpProviderUrl, &o.ProviderURL)
settings.Scopes = o.Scopes
settings.RequestParams = o.RequestParams
settings.AuthorizeServiceUrls = o.AuthorizeURLStrings
copySrcToOptionalDest(&settings.AuthorizeInternalServiceUrl, &o.AuthorizeInternalURLString)
copySrcToOptionalDest(&settings.OverrideCertificateName, &o.OverrideCertificateName)
copySrcToOptionalDest(&settings.CertificateAuthority, valueOrFromFileBase64(o.CA, o.CAFile))
settings.DeriveTls = o.DeriveInternalDomainCert
copySrcToOptionalDest(&settings.SigningKey, valueOrFromFileBase64(o.SigningKey, o.SigningKeyFile))
settings.SetResponseHeaders = o.SetResponseHeaders
settings.JwtClaimsHeaders = o.JWTClaimsHeaders
copyOptionalDuration(&settings.DefaultUpstreamTimeout, o.DefaultUpstreamTimeout)
copySrcToOptionalDest(&settings.MetricsAddress, &o.MetricsAddr)
copySrcToOptionalDest(&settings.MetricsBasicAuth, &o.MetricsBasicAuth)
settings.MetricsCertificate = toCertificateOrFromFile(o.MetricsCertificate, o.MetricsCertificateKey, o.MetricsCertificateFile, o.MetricsCertificateKeyFile)
copySrcToOptionalDest(&settings.MetricsClientCa, valueOrFromFileBase64(o.MetricsClientCA, o.MetricsClientCAFile))
copySrcToOptionalDest(&settings.TracingProvider, &o.TracingProvider)
copySrcToOptionalDest(&settings.TracingSampleRate, &o.TracingSampleRate)
copySrcToOptionalDest(&settings.TracingDatadogAddress, &o.TracingDatadogAddress)
copySrcToOptionalDest(&settings.TracingJaegerCollectorEndpoint, &o.TracingJaegerCollectorEndpoint)
copySrcToOptionalDest(&settings.TracingJaegerAgentEndpoint, &o.TracingJaegerAgentEndpoint)
copySrcToOptionalDest(&settings.TracingZipkinEndpoint, &o.ZipkinEndpoint)
copySrcToOptionalDest(&settings.GrpcAddress, &o.GRPCAddr)
settings.GrpcInsecure = o.GRPCInsecure
copyOptionalDuration(&settings.GrpcClientTimeout, o.GRPCClientTimeout)
settings.DatabrokerServiceUrls = o.DataBrokerURLStrings
copySrcToOptionalDest(&settings.DatabrokerInternalServiceUrl, &o.DataBrokerInternalURLString)
copySrcToOptionalDest(&settings.DatabrokerStorageType, &o.DataBrokerStorageType)
copySrcToOptionalDest(&settings.DatabrokerStorageConnectionString, valueOrFromFileRaw(o.DataBrokerStorageConnectionString, o.DataBrokerStorageConnectionStringFile))
settings.DownstreamMtls = o.DownstreamMTLS.ToProto()
copySrcToOptionalDest(&settings.GoogleCloudServerlessAuthenticationServiceAccount, &o.GoogleCloudServerlessAuthenticationServiceAccount)
copySrcToOptionalDest(&settings.UseProxyProtocol, &o.UseProxyProtocol)
copySrcToOptionalDest(&settings.Autocert, &o.AutocertOptions.Enable)
copySrcToOptionalDest(&settings.AutocertCa, &o.AutocertOptions.CA)
copySrcToOptionalDest(&settings.AutocertEmail, &o.AutocertOptions.Email)
copySrcToOptionalDest(&settings.AutocertEabKeyId, &o.AutocertOptions.EABKeyID)
copySrcToOptionalDest(&settings.AutocertEabMacKey, &o.AutocertOptions.EABMACKey)
copySrcToOptionalDest(&settings.AutocertDir, &o.AutocertOptions.Folder)
copySrcToOptionalDest(&settings.AutocertTrustedCa, &o.AutocertOptions.TrustedCA)
copySrcToOptionalDest(&settings.AutocertUseStaging, &o.AutocertOptions.UseStaging)
copySrcToOptionalDest(&settings.AutocertMustStaple, &o.AutocertOptions.MustStaple)
copySrcToOptionalDest(&settings.SkipXffAppend, &o.SkipXffAppend)
copySrcToOptionalDest(&settings.XffNumTrustedHops, &o.XffNumTrustedHops)
copySrcToOptionalDest(&settings.EnvoyAdminAccessLogPath, &o.EnvoyAdminAccessLogPath)
copySrcToOptionalDest(&settings.EnvoyAdminProfilePath, &o.EnvoyAdminProfilePath)
copySrcToOptionalDest(&settings.EnvoyAdminAddress, &o.EnvoyAdminAddress)
copySrcToOptionalDest(&settings.EnvoyBindConfigSourceAddress, &o.EnvoyBindConfigSourceAddress)
settings.EnvoyBindConfigFreebind = o.EnvoyBindConfigFreebind.Ptr()
settings.ProgrammaticRedirectDomainWhitelist = o.ProgrammaticRedirectDomainWhitelist
settings.AuditKey = o.AuditKey.ToProto()
if o.CodecType != "" {
codecType := o.CodecType.ToEnvoy()
settings.CodecType = &codecType
}
settings.PassIdentityHeaders = o.PassIdentityHeaders
if o.BrandingOptions != nil {
primaryColor := o.BrandingOptions.GetPrimaryColor()
secondaryColor := o.BrandingOptions.GetSecondaryColor()
darkmodePrimaryColor := o.BrandingOptions.GetDarkmodePrimaryColor()
darkmodeSecondaryColor := o.BrandingOptions.GetDarkmodeSecondaryColor()
logoURL := o.BrandingOptions.GetLogoUrl()
faviconURL := o.BrandingOptions.GetFaviconUrl()
errorMessageFirstParagraph := o.BrandingOptions.GetErrorMessageFirstParagraph()
copySrcToOptionalDest(&settings.PrimaryColor, &primaryColor)
copySrcToOptionalDest(&settings.SecondaryColor, &secondaryColor)
copySrcToOptionalDest(&settings.DarkmodePrimaryColor, &darkmodePrimaryColor)
copySrcToOptionalDest(&settings.DarkmodeSecondaryColor, &darkmodeSecondaryColor)
copySrcToOptionalDest(&settings.LogoUrl, &logoURL)
copySrcToOptionalDest(&settings.FaviconUrl, &faviconURL)
copySrcToOptionalDest(&settings.ErrorMessageFirstParagraph, &errorMessageFirstParagraph)
}
copyMap(&settings.RuntimeFlags, o.RuntimeFlags, func(k RuntimeFlag, v bool) (string, bool) {
return string(k), v
})
routes := make([]*config.Route, 0, o.NumPolicies())
for p := range o.GetAllPolicies() {
routepb, err := p.ToProto()
if err != nil {
continue
}
ppl := p.ToPPL()
pplIsEmpty := true
for _, rule := range ppl.Rules {
if rule.Action == parser.ActionAllow &&
len(rule.And) > 0 ||
len(rule.Nor) > 0 ||
len(rule.Not) > 0 ||
len(rule.Or) > 0 {
pplIsEmpty = false
break
}
}
if !pplIsEmpty {
raw, err := ppl.MarshalJSON()
if err != nil {
continue
}
routepb.PplPolicies = append(routepb.PplPolicies, &config.PPLPolicy{
Raw: raw,
})
}
routes = append(routes, routepb)
}
return &config.Config{
Settings: &settings,
Routes: routes,
}
}
func copySrcToOptionalDest[T comparable](dst **T, src *T) {
var zero T
if *src == zero {
*dst = nil
} else {
if *dst == nil {
*dst = src
} else {
**dst = *src
}
}
}
func toStringList[T ~string](s []T) *config.Settings_StringList {
if len(s) == 0 {
return nil
}
strings := make([]string, len(s))
for i, v := range s {
strings[i] = string(v)
}
return &config.Settings_StringList{Values: strings}
}
func toCertificateOrFromFile(
cert string, key string,
certFile string, keyFile string,
) *config.Settings_Certificate {
var out config.Settings_Certificate
if cert != "" {
out.CertBytes, _ = base64.StdEncoding.DecodeString(cert)
} else if certFile != "" {
b, err := os.ReadFile(certFile)
if err == nil {
out.CertBytes = b
}
}
if key != "" {
out.KeyBytes, _ = base64.StdEncoding.DecodeString(key)
} else if keyFile != "" {
b, err := os.ReadFile(keyFile)
if err == nil {
out.KeyBytes = b
}
}
if out.CertBytes == nil && out.KeyBytes == nil {
return nil
}
return &out
}
func getCertificates(o *Options) []*config.Settings_Certificate {
certs, err := o.GetCertificates()
if err != nil {
return nil
}
out := make([]*config.Settings_Certificate, len(certs))
for i, crt := range certs {
certBytes, keyBytes, err := cryptutil.EncodeCertificate(&crt)
if err != nil {
return nil
}
out[i] = &config.Settings_Certificate{
CertBytes: certBytes,
KeyBytes: keyBytes,
}
}
return out
}
func copyOptionalDuration(dst **durationpb.Duration, src time.Duration) {
if src == 0 {
*dst = nil
} else {
*dst = durationpb.New(src)
}
}
func valueOrFromFileRaw(value string, valueFile string) *string {
if value != "" {
return &value
}
if valueFile == "" {
return &valueFile
}
data, _ := os.ReadFile(valueFile)
dataStr := string(data)
return &dataStr
}
func valueOrFromFileBase64(value string, valueFile string) *string {
if value != "" {
return &value
}
if valueFile == "" {
return &valueFile
}
data, _ := os.ReadFile(valueFile)
encoded := base64.StdEncoding.EncodeToString(data)
return &encoded
}
func dataDir() string {
homeDir, _ := os.UserHomeDir()
if homeDir == "" {

View file

@ -2,32 +2,45 @@ package config
import (
"context"
"crypto"
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/pem"
"fmt"
"hash/fnv"
"math/big"
mathrand "math/rand/v2"
"net/http"
"net/url"
"os"
"path/filepath"
"reflect"
"slices"
"strings"
"sync"
"testing"
"time"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/pomerium/csrf"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/identity/oauth/apple"
"github.com/pomerium/protoutil/protorand"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"github.com/pomerium/csrf"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/identity/oauth/apple"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/fieldmaskpb"
)
var cmpOptIgnoreUnexported = cmpopts.IgnoreUnexported(Options{}, Policy{})
@ -932,8 +945,8 @@ func TestOptions_ApplySettings(t *testing.T) {
xc1, _ := x509.ParseCertificate(cert1.Certificate[0])
certsIndex.Add(xc1)
settings := &config.Settings{
Certificates: []*config.Settings_Certificate{
settings := &configpb.Settings{
Certificates: []*configpb.Settings_Certificate{
{CertBytes: encodeCert(cert2)},
{CertBytes: encodeCert(cert3)},
},
@ -944,13 +957,24 @@ func TestOptions_ApplySettings(t *testing.T) {
t.Run("pass_identity_headers", func(t *testing.T) {
options := NewDefaultOptions()
options.ApplySettings(ctx, nil, &config.Settings{
options.ApplySettings(ctx, nil, &configpb.Settings{
PassIdentityHeaders: proto.Bool(true),
})
assert.Equal(t, proto.Bool(true), options.PassIdentityHeaders)
})
}
func TestXXX(t *testing.T) {
dir, _ := os.MkdirTemp("", "asdf")
t.Log(dir)
for i := 1; i <= 100; i++ {
crt, _ := cryptutil.GenerateCertificate(nil, fmt.Sprintf("route%d.localhost.pomerium.io", i))
crtBytes, keyBytes, _ := cryptutil.EncodeCertificate(crt)
os.WriteFile(fmt.Sprintf("%s/%d.crt", dir, i), crtBytes, 0o644)
os.WriteFile(fmt.Sprintf("%s/%d.key", dir, i), keyBytes, 0o600)
}
}
func TestOptions_GetSetResponseHeaders(t *testing.T) {
t.Run("lax", func(t *testing.T) {
options := NewDefaultOptions()
@ -1364,8 +1388,341 @@ func encodeCert(cert *tls.Certificate) []byte {
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Certificate[0]})
}
func mustParseWeightedURLs(t *testing.T, urls ...string) []WeightedURL {
wu, err := ParseWeightedUrls(urls...)
require.NoError(t, err)
return wu
func TestRoute_FromToProto(t *testing.T) {
routeGen := protorand.New[*configpb.Route]()
routeGen.MaxCollectionElements = 2
routeGen.UseGoDurationLimits = true
routeGen.ExcludeMask(&fieldmaskpb.FieldMask{
Paths: []string{
"from", "to", "load_balancing_weights", "redirect", "response", // set below
"ppl_policies", "name", // no equivalent field
"envoy_opts",
},
})
redirectGen := protorand.New[*configpb.RouteRedirect]()
responseGen := protorand.New[*configpb.RouteDirectResponse]()
randomDomain := func() string {
numSegments := mathrand.IntN(5) + 1
segments := make([]string, numSegments)
for i := range segments {
b := make([]rune, mathrand.IntN(10)+10)
for j := range b {
b[j] = rune(mathrand.IntN(26) + 'a')
}
segments[i] = string(b)
}
return strings.Join(segments, ".")
}
newCompleteRoute := func() *configpb.Route {
pb, err := routeGen.Gen()
require.NoError(t, err)
pb.From = "https://" + randomDomain()
// EnvoyOpts is set to an empty non-nil message during conversion, if nil
pb.EnvoyOpts = &envoy_config_cluster_v3.Cluster{}
switch mathrand.IntN(3) {
case 0:
pb.To = make([]string, mathrand.IntN(3)+1)
for i := range pb.To {
pb.To[i] = "https://" + randomDomain()
}
pb.LoadBalancingWeights = make([]uint32, len(pb.To))
for i := range pb.LoadBalancingWeights {
pb.LoadBalancingWeights[i] = mathrand.Uint32N(10000) + 1
}
case 1:
pb.Redirect, err = redirectGen.Gen()
require.NoError(t, err)
case 2:
pb.Response, err = responseGen.Gen()
require.NoError(t, err)
}
return pb
}
t.Run("Round Trip", func(t *testing.T) {
for range 100 {
route := newCompleteRoute()
policy, err := NewPolicyFromProto(route)
require.NoError(t, err)
route2, err := policy.ToProto()
require.NoError(t, err)
route2.Name = ""
testutil.AssertProtoEqual(t, route, route2)
}
})
t.Run("Multiple routes", func(t *testing.T) {
for range 100 {
route1 := newCompleteRoute()
route2 := newCompleteRoute()
{
// create a new policy every time, since reusing the target will mutate
// the underlying route
policy1, err := NewPolicyFromProto(route1)
require.NoError(t, err)
target, err := policy1.ToProto()
require.NoError(t, err)
target.Name = ""
testutil.AssertProtoEqual(t, route1, target)
}
{
policy2, err := NewPolicyFromProto(route2)
require.NoError(t, err)
target, err := policy2.ToProto()
require.NoError(t, err)
target.Name = ""
testutil.AssertProtoEqual(t, route2, target)
}
{
policy1, err := NewPolicyFromProto(route1)
require.NoError(t, err)
target, err := policy1.ToProto()
require.NoError(t, err)
target.Name = ""
testutil.AssertProtoEqual(t, route1, target)
}
{
policy2, err := NewPolicyFromProto(route2)
require.NoError(t, err)
target, err := policy2.ToProto()
require.NoError(t, err)
target.Name = ""
testutil.AssertProtoEqual(t, route2, target)
}
}
})
}
func TestOptions_FromToProto(t *testing.T) {
generate := func(ratio float64) *configpb.Settings {
t.Helper()
gen := protorand.New[*configpb.Settings]()
gen.MaxCollectionElements = 2
gen.MaxDepth = 3
gen.UseGoDurationLimits = true
gen.ExcludeMask(&fieldmaskpb.FieldMask{
Paths: []string{
"tls_custom_ca_file",
"tls_client_cert_file",
"tls_client_key_file",
"tls_downstream_client_ca_file",
},
})
settings, err := gen.GenPartial(ratio)
require.NoError(t, err)
unsetFalseOptionalBoolFields(settings)
fixZeroValuedEnums(settings)
generateCertificates(t, settings)
return settings
}
t.Run("all fields", func(t *testing.T) {
t.Parallel()
for range 100 {
settings := generate(1)
var options Options
options.ApplySettings(context.Background(), nil, settings)
settings2 := options.ToProto()
testutil.AssertProtoEqual(t, settings, settings2.Settings)
}
})
t.Run("some fields", func(t *testing.T) {
t.Parallel()
for range 100 {
settings := generate(mathrand.Float64())
var options Options
options.ApplySettings(context.Background(), nil, settings)
settings2 := options.ToProto()
testutil.AssertProtoEqual(t, settings, settings2.Settings)
}
})
}
// unset any optional bool fields with a value of false, to match
func unsetFalseOptionalBoolFields(msg proto.Message) {
msg.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
if fd.Cardinality() == protoreflect.Optional && fd.Kind() == protoreflect.BoolKind {
if v.IsValid() && !v.Bool() {
msg.ProtoReflect().Clear(fd)
}
}
return true
})
}
func fixZeroValuedEnums(msg *configpb.Settings) {
if msg.DownstreamMtls != nil && msg.DownstreamMtls.Enforcement != nil {
// there is no "unknown" equivalent, so if the value is randomly set to
// unknown it would be a lossy conversion
if *msg.DownstreamMtls.Enforcement == configpb.MtlsEnforcementMode_UNKNOWN {
msg.DownstreamMtls.Enforcement = nil
// if this was the only present field in the message, don't leave it empty
if proto.Size(msg.DownstreamMtls) == 0 {
msg.DownstreamMtls = nil
}
}
}
}
func generateCertificates(t testing.TB, msg *configpb.Settings) {
if msg.AutocertCa != nil {
*msg.AutocertCa, _ = generateRandomCA(t, *msg.AutocertCa)
}
if msg.DownstreamMtls != nil {
var caKey string
if msg.DownstreamMtls.Ca != nil {
*msg.DownstreamMtls.Ca, caKey = generateRandomCA(t, *msg.DownstreamMtls.Ca)
}
if msg.DownstreamMtls.Crl != nil {
if caKey != "" {
*msg.DownstreamMtls.Crl = generateCRL(t, *msg.DownstreamMtls.Crl, *msg.DownstreamMtls.Ca, caKey)
} else {
randCa, randKey := generateRandomCA(t, *msg.DownstreamMtls.Crl+"_temp_ca")
*msg.DownstreamMtls.Crl = generateCRL(t, *msg.DownstreamMtls.Crl, randCa, randKey)
}
}
}
genCertInPlace := func(cert *configpb.Settings_Certificate, b64 bool) {
cert.Id = "" // no equivalent field
switch {
case len(cert.CertBytes) > 0 && len(cert.KeyBytes) > 0:
crt, key := generateRandomCert(t, string(cert.CertBytes)+string(cert.KeyBytes), b64)
cert.CertBytes = []byte(crt)
cert.KeyBytes = []byte(key)
case len(cert.CertBytes) > 0 && len(cert.KeyBytes) == 0:
crt, _ := generateRandomCert(t, string(cert.CertBytes), b64)
cert.CertBytes = []byte(crt)
case len(cert.CertBytes) == 0 && len(cert.KeyBytes) > 0:
// invalid, but convert anyway
crt, _ := generateRandomCert(t, string(cert.KeyBytes), b64)
cert.KeyBytes = []byte(crt)
}
}
for i, cert := range msg.Certificates {
genCertInPlace(cert, false)
if cert.CertBytes == nil && cert.KeyBytes == nil {
msg.Certificates = slices.Delete(msg.Certificates, i, i+1)
}
}
if msg.MetricsCertificate != nil {
genCertInPlace(msg.MetricsCertificate, false)
if msg.MetricsCertificate.CertBytes == nil && msg.MetricsCertificate.KeyBytes == nil {
msg.MetricsCertificate = nil
}
}
}
func generateRandomCA(t testing.TB, randomInput string) (string, string) {
seed := sha256.Sum256([]byte(randomInput))
priv := ed25519.NewKeyFromSeed(seed[:])
h := fnv.New128()
h.Write([]byte(randomInput))
sum := h.Sum(nil)
var sn big.Int
sn.SetBytes(sum)
now := time.Now()
tmpl := &x509.Certificate{
IsCA: true,
SerialNumber: &sn,
Subject: pkix.Name{CommonName: randomInput},
Issuer: pkix.Name{CommonName: randomInput},
NotBefore: now,
NotAfter: now.Add(12 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCRLSign,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
BasicConstraintsValid: true,
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: der,
})), base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: must(x509.MarshalPKCS8PrivateKey(priv)),
}))
}
func generateCRL(t testing.TB, randomInput string, issuerCrt, issuerKey string) string {
h := fnv.New128()
h.Write([]byte(randomInput))
sum := h.Sum(nil)
var sn big.Int
sn.SetBytes(sum)
issuer, err := cryptutil.CertificateFromBase64(issuerCrt, issuerKey)
require.NoError(t, err)
b, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{
Number: big.NewInt(0x2000),
RevokedCertificates: []pkix.RevokedCertificate{
{
SerialNumber: &sn,
RevocationTime: time.Now(),
},
},
}, issuer.Leaf, issuer.PrivateKey.(crypto.Signer))
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(b)
}
func generateRandomCert(t testing.TB, randomInput string, b64 bool) (string, string) {
seed := sha256.Sum256([]byte(randomInput))
priv := ed25519.NewKeyFromSeed(seed[:])
h := fnv.New128()
h.Write([]byte(randomInput))
sum := h.Sum(nil)
var sn big.Int
sn.SetBytes(sum)
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: &sn,
Subject: pkix.Name{
CommonName: randomInput,
},
Issuer: pkix.Name{
CommonName: randomInput,
},
NotBefore: now,
NotAfter: now.Add(12 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
},
BasicConstraintsValid: true,
}
certDer, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
require.NoError(t, err)
crtPem := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDer,
})
keyPem := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: must(x509.MarshalPKCS8PrivateKey(priv)),
})
if b64 {
return base64.StdEncoding.EncodeToString(crtPem), base64.StdEncoding.EncodeToString(keyPem)
}
return string(crtPem), string(keyPem)
}
func must[T any](t T, err error) T {
if err != nil {
panic(err)
}
return t
}

View file

@ -256,6 +256,7 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
TLSServerName: pb.GetTlsServerName(),
TLSDownstreamServerName: pb.GetTlsDownstreamServerName(),
TLSUpstreamServerName: pb.GetTlsUpstreamServerName(),
TLSUpstreamAllowRenegotiation: pb.GetTlsUpstreamAllowRenegotiation(),
TLSCustomCA: pb.GetTlsCustomCa(),
TLSCustomCAFile: pb.GetTlsCustomCaFile(),
TLSClientCert: pb.GetTlsClientCert(),
@ -296,12 +297,20 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
Body: pb.Response.GetBody(),
}
} else {
to, err := ParseWeightedUrls(pb.GetTo()...)
if err != nil {
return nil, err
p.To = make(WeightedURLs, len(pb.To))
for i, u := range pb.To {
u, err := urlutil.ParseAndValidateURL(u)
if err != nil {
return nil, err
}
w := WeightedURL{
URL: *u,
}
if len(pb.LoadBalancingWeights) == len(pb.To) {
w.LbWeight = pb.LoadBalancingWeights[i]
}
p.To[i] = w
}
p.To = to
}
p.EnvoyOpts = pb.EnvoyOpts
@ -333,7 +342,7 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
Remediation: sp.GetRemediation(),
})
}
return p, p.Validate()
return p, nil
}
// ToProto converts the policy to a protobuf type.
@ -356,12 +365,15 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
AllowedUsers: sp.AllowedUsers,
AllowedDomains: sp.AllowedDomains,
AllowedIdpClaims: sp.AllowedIDPClaims.ToPB(),
Explanation: sp.Explanation,
Remediation: sp.Remediation,
Rego: sp.Rego,
})
}
pb := &configpb.Route{
Name: fmt.Sprint(p.RouteID()),
Id: p.ID,
From: p.From,
AllowedUsers: p.AllowedUsers,
AllowedDomains: p.AllowedDomains,
@ -372,6 +384,7 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
PrefixRewrite: p.PrefixRewrite,
RegexRewritePattern: p.RegexRewritePattern,
RegexRewriteSubstitution: p.RegexRewriteSubstitution,
RegexPriorityOrder: p.RegexPriorityOrder,
CorsAllowPreflight: p.CORSAllowPreflight,
AllowPublicUnauthenticatedAccess: p.AllowPublicUnauthenticatedAccess,
AllowAnyAuthenticatedUser: p.AllowAnyAuthenticatedUser,
@ -391,13 +404,29 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
TlsClientKeyFile: p.TLSClientKeyFile,
TlsDownstreamClientCa: p.TLSDownstreamClientCA,
TlsDownstreamClientCaFile: p.TLSDownstreamClientCAFile,
TlsUpstreamAllowRenegotiation: p.TLSUpstreamAllowRenegotiation,
SetRequestHeaders: p.SetRequestHeaders,
RemoveRequestHeaders: p.RemoveRequestHeaders,
PreserveHostHeader: p.PreserveHostHeader,
PassIdentityHeaders: p.PassIdentityHeaders,
KubernetesServiceAccountToken: p.KubernetesServiceAccountToken,
Policies: sps,
SetResponseHeaders: p.SetResponseHeaders,
EnableGoogleCloudServerlessAuthentication: p.EnableGoogleCloudServerlessAuthentication,
Policies: sps,
EnvoyOpts: p.EnvoyOpts,
SetResponseHeaders: p.SetResponseHeaders,
ShowErrorDetails: p.ShowErrorDetails,
}
if p.HostPathRegexRewritePattern != "" {
pb.HostPathRegexRewritePattern = proto.String(p.HostPathRegexRewritePattern)
}
if p.HostPathRegexRewriteSubstitution != "" {
pb.HostPathRegexRewriteSubstitution = proto.String(p.HostPathRegexRewriteSubstitution)
}
if p.HostRewrite != "" {
pb.HostRewrite = proto.String(p.HostRewrite)
}
if p.HostRewriteHeader != "" {
pb.HostRewriteHeader = proto.String(p.HostRewriteHeader)
}
if p.IDPClientID != "" {
pb.IdpClientId = proto.String(p.IDPClientID)
@ -512,10 +541,11 @@ func (p *Policy) Validate() error {
return fmt.Errorf("config: couldn't decode custom ca: %w", err)
}
} else if p.TLSCustomCAFile != "" {
_, err := os.Stat(p.TLSCustomCAFile)
ca, err := os.ReadFile(p.TLSCustomCAFile)
if err != nil {
return fmt.Errorf("config: couldn't load client ca file: %w", err)
}
p.TLSCustomCA = base64.StdEncoding.EncodeToString(ca)
}
const clientCADeprecationMsg = "config: %s is deprecated, see https://www.pomerium.com/docs/" +

View file

@ -369,3 +369,9 @@ func TestPolicy_IsTCPUpstream(t *testing.T) {
}
assert.False(t, p3.IsTCPUpstream())
}
func mustParseWeightedURLs(t testing.TB, urls ...string) []WeightedURL {
wu, err := ParseWeightedUrls(urls...)
require.NoError(t, err)
return wu
}

7
go.mod
View file

@ -51,6 +51,7 @@ require (
github.com/peterbourgon/ff/v3 v3.4.0
github.com/pomerium/csrf v1.7.0
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524
github.com/pomerium/protoutil v0.0.0-20240813175624-47b7ac43ff46
github.com/pomerium/webauthn v0.0.0-20240603205124-0428df511172
github.com/prometheus/client_golang v1.20.4
github.com/prometheus/client_model v0.6.1
@ -59,6 +60,7 @@ require (
github.com/rs/cors v1.11.1
github.com/rs/zerolog v1.33.0
github.com/shirou/gopsutil/v3 v3.24.5
github.com/spf13/cobra v1.8.1
github.com/spf13/viper v1.19.0
github.com/stretchr/testify v1.9.0
github.com/tniswong/go.rfcx v0.0.0-20181019234604-07783c52761f
@ -159,10 +161,12 @@ require (
github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect
github.com/hashicorp/errwrap v1.1.0 // indirect
github.com/hashicorp/hcl v1.0.0 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect
github.com/jackc/puddle/v2 v2.2.2 // indirect
github.com/klauspost/cpuid/v2 v2.2.8 // indirect
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240624235931-330eb762e74c // indirect
github.com/libdns/libdns v0.2.2 // indirect
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae // indirect
github.com/magiconair/properties v1.8.7 // indirect
@ -175,7 +179,6 @@ require (
github.com/morikuni/aec v1.0.0 // indirect
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect
github.com/onsi/ginkgo v1.16.5 // indirect
github.com/onsi/gomega v1.30.0 // indirect
github.com/opencontainers/go-digest v1.0.0 // indirect
github.com/opencontainers/image-spec v1.1.0 // indirect
github.com/opencontainers/runc v1.1.14 // indirect
@ -196,6 +199,7 @@ require (
github.com/spf13/afero v1.11.0 // indirect
github.com/spf13/cast v1.6.0 // indirect
github.com/spf13/pflag v1.0.5 // indirect
github.com/sryoya/protorand v0.0.0-20240429201223-e7440656b2a4 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/subosito/gotenv v1.6.0 // indirect
github.com/tchap/go-patricia/v2 v2.3.1 // indirect
@ -211,6 +215,7 @@ require (
github.com/yusufpapurcu/wmi v1.2.4 // indirect
github.com/zeebo/assert v1.3.1 // indirect
github.com/zeebo/blake3 v0.2.3 // indirect
github.com/zeebo/xxh3 v1.0.2 // indirect
go.opentelemetry.io/contrib/instrumentation/google.golang.org/grpc/otelgrpc v0.54.0 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.54.0 // indirect
go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.28.0 // indirect

25
go.sum
View file

@ -171,6 +171,7 @@ github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3
github.com/coreos/go-oidc/v3 v3.11.0 h1:Ia3MxdwpSw702YW0xgfmP1GVCMA9aEFWu12XUZ3/OtI=
github.com/coreos/go-oidc/v3 v3.11.0/go.mod h1:gE3LgjOgFoHi9a4ce4/tJczr0Ai2/BoDhf0r5lltWI0=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o=
github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY=
github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
@ -257,6 +258,9 @@ github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpv
github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0/go.mod h1:fyg7847qk6SyHyPtNmDHnmrv/HOrqktSC+C9fM+CJOE=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/go-task/slim-sprig/v3 v3.0.0 h1:sUs3vkvUymDpBKi3qH1YSqBQk9+9D/8M2mN1vB6EwHI=
github.com/go-task/slim-sprig/v3 v3.0.0/go.mod h1:W848ghGpv3Qj3dhTPRyJypKRiqCdHZiAzKg9hl15HA8=
github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y=
github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8=
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
@ -342,6 +346,8 @@ github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hf
github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM=
github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6 h1:k7nVchz72niMH6YLQNvHSdIE7iqsQxK1P41mySCvssg=
github.com/google/pprof v0.0.0-20240424215950-a892ee059fd6/go.mod h1:kf6iHlnVGwgKolg33glAes7Yg/8iWP8ukqeldJSO7jw=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/s2a-go v0.1.8 h1:zZDs9gcbt9ZPLV0ndSyQk6Kacx2g/X+SKYovpnz3SMM=
github.com/google/s2a-go v0.1.8/go.mod h1:6iNWHTpQ+nfNRN5E00MSdfDwVesa8hhS32PhPO8deJA=
@ -385,6 +391,8 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4=
github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
@ -425,6 +433,8 @@ github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY=
github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE=
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240624235931-330eb762e74c h1:TRkEV8M5PhQU55WI49FKTszEIpFlwZ1wfxcACCRT7SE=
github.com/kralicky/go-adaptive-radix-tree v0.0.0-20240624235931-330eb762e74c/go.mod h1:oJwexVSshEat0E3evyKOH6QzN8GFWrhLvEoh8GiJzss=
github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc=
github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
@ -484,11 +494,13 @@ github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+W
github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk=
github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE=
github.com/onsi/ginkgo v1.16.5/go.mod h1:+E8gABHa3K6zRBolWtd+ROzc/U5bkGt0FwiG042wbpU=
github.com/onsi/ginkgo/v2 v2.19.1 h1:QXgq3Z8Crl5EL1WBAC98A5sEBHARrAJNzAmMxzLcRF0=
github.com/onsi/ginkgo/v2 v2.19.1/go.mod h1:O3DtEWQkPa/F7fBMgmZQKKsluAy8pd3rEQdrjkPb9zA=
github.com/onsi/gomega v1.4.3/go.mod h1:ex+gbHU/CVuBBDIJjb2X0qEXbFg53c61hWP/1CpauHY=
github.com/onsi/gomega v1.7.1/go.mod h1:XdKZgCCFLUoM/7CFJVPcG8C1xQ1AJ0vpAezJrB7JYyY=
github.com/onsi/gomega v1.10.1/go.mod h1:iN09h71vgCQne3DLsj+A5owkum+a2tYe+TOCB1ybHNo=
github.com/onsi/gomega v1.30.0 h1:hvMK7xYz4D3HapigLTeGdId/NcfQx1VHMJc60ew99+8=
github.com/onsi/gomega v1.30.0/go.mod h1:9sxs+SwGrKI0+PWe4Fxa9tFQQBG5xSsSbMXOI8PPpoQ=
github.com/onsi/gomega v1.34.1 h1:EUMJIKUjM8sKjYbtxQI9A4z2o+rruxnzNvpknOXie6k=
github.com/onsi/gomega v1.34.1/go.mod h1:kU1QgUvBDLXBJq618Xvm2LUX6rSAfRaFRTcdOeDLwwY=
github.com/open-policy-agent/opa v0.69.0 h1:s2igLw2Z6IvGWGuXSfugWkVultDMsM9pXiDuMp7ckWw=
github.com/open-policy-agent/opa v0.69.0/go.mod h1:+qyXJGkpEJ6kpB1kGo8JSwHtVXbTdsGdQYPWWNYNj+4=
github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U=
@ -528,6 +540,8 @@ github.com/pomerium/csrf v1.7.0 h1:Qp4t6oyEod3svQtKfJZs589mdUTWKVf7q0PgCKYCshY=
github.com/pomerium/csrf v1.7.0/go.mod h1:hAPZV47mEj2T9xFs+ysbum4l7SF1IdrryYaY6PdoIqw=
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524 h1:3YQY1sb54tEEbr0L73rjHkpLB0IB6qh3zl1+XQbMLis=
github.com/pomerium/datasource v0.18.2-0.20221108160055-c6134b5ed524/go.mod h1:7fGbUYJnU8RcxZJvUvhukOIBv1G7LWDAHMfDxAf5+Y0=
github.com/pomerium/protoutil v0.0.0-20240813175624-47b7ac43ff46 h1:NRTg8JOXCxcIA1lAgD74iYud0rbshbWOB3Ou4+Huil8=
github.com/pomerium/protoutil v0.0.0-20240813175624-47b7ac43ff46/go.mod h1:QqZmx6ZgPxz18va7kqoT4t/0yJtP7YFIDiT/W2n2fZ4=
github.com/pomerium/webauthn v0.0.0-20240603205124-0428df511172 h1:TqoPqRgXSHpn+tEJq6H72iCS5pv66j3rPprThUEZg0E=
github.com/pomerium/webauthn v0.0.0-20240603205124-0428df511172/go.mod h1:kBQ45E9LluzW7FP1Scn3esaiS2WVbvNRLMOTHareZNQ=
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU=
@ -580,6 +594,7 @@ github.com/rs/xid v1.6.0 h1:fV591PaemRlL6JfRxGDEPl69wICngIQ3shQtzfy2gxU=
github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0=
github.com/rs/zerolog v1.33.0 h1:1cU2KZkvPxNyfgEmhHAz/1A9Bz+llsdYzklWFzgp0r8=
github.com/rs/zerolog v1.33.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ=
github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4=
github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE=
@ -603,11 +618,15 @@ github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8=
github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY=
github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0=
github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo=
github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM=
github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y=
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
github.com/spf13/viper v1.19.0 h1:RWq5SEjt8o25SROyN3z2OrDB9l7RPd3lwTWU8EcEdcI=
github.com/spf13/viper v1.19.0/go.mod h1:GQUN9bilAbhU/jgc1bKs99f/suXKeUMct8Adx5+Ntkg=
github.com/spkg/bom v0.0.0-20160624110644-59b7046e48ad/go.mod h1:qLr4V1qq6nMqFKkMo8ZTx3f+BZEkzsRUY10Xsm2mwU0=
github.com/sryoya/protorand v0.0.0-20240429201223-e7440656b2a4 h1:/jKH9ivHOUkahZs3zPfJfOmkXDFB6OdsHZ4W8gyDb/c=
github.com/sryoya/protorand v0.0.0-20240429201223-e7440656b2a4/go.mod h1:9a23nlv6vzBeVlQq6JQCjljZ6sfzsB6aha1m5Ly1W2Y=
github.com/streadway/amqp v0.0.0-20190404075320-75d898a42a94/go.mod h1:AZpEONHx3DKn8O/DFsRAY58/XVQiIPMTMB1SddzLXVw=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@ -675,6 +694,8 @@ github.com/zeebo/blake3 v0.2.3 h1:TFoLXsjeXqRNFxSbk35Dk4YtszE/MQQGK10BH4ptoTg=
github.com/zeebo/blake3 v0.2.3/go.mod h1:mjJjZpnsyIVtVgTOSpJ9vmRE4wgDeyt2HU3qXvvKCaQ=
github.com/zeebo/pcg v1.0.1 h1:lyqfGeWiv4ahac6ttHs+I5hwtH/+1mrhlCtVNQM2kHo=
github.com/zeebo/pcg v1.0.1/go.mod h1:09F0S9iiKrwn9rlI5yjLkmrug154/YRW6KnnXVDM/l4=
github.com/zeebo/xxh3 v1.0.2 h1:xZmwmqxHZA8AI603jOQ0tMqmBr9lPeFwGg6d+xy9DC0=
github.com/zeebo/xxh3 v1.0.2/go.mod h1:5NWz9Sef7zIDm2JHfFlcQvNekmcEl9ekUZQQKCYaDcA=
go.opencensus.io v0.21.0/go.mod h1:mSImk1erAIZhrmZN+AvHh14ztQfjbGwt4TtuofqLduU=
go.opencensus.io v0.22.0/go.mod h1:+kGneAE2xo2IficOXnaByMWTGM9T73dGwxeWcUqIpI8=
go.opencensus.io v0.22.2/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw=

View file

@ -2,18 +2,23 @@
package zero
import (
"bytes"
"context"
"fmt"
"io"
"time"
"github.com/klauspost/compress/zstd"
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"google.golang.org/protobuf/proto"
"github.com/pomerium/pomerium/internal/zero/apierror"
connect_mux "github.com/pomerium/pomerium/internal/zero/connect-mux"
"github.com/pomerium/pomerium/internal/zero/grpcconn"
token_api "github.com/pomerium/pomerium/internal/zero/token"
"github.com/pomerium/pomerium/pkg/fanout"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
cluster_api "github.com/pomerium/pomerium/pkg/zero/cluster"
connect_api "github.com/pomerium/pomerium/pkg/zero/connect"
)
@ -116,6 +121,30 @@ func (api *API) GetClusterResourceBundles(ctx context.Context) (*cluster_api.Get
)
}
func (api *API) ImportConfig(ctx context.Context, cfg *configpb.Config, params *cluster_api.ImportConfigurationParams) (*cluster_api.ImportResponse, error) {
data, err := proto.Marshal(cfg)
if err != nil {
return nil, err
}
var compressedData bytes.Buffer
w, err := zstd.NewWriter(&compressedData, zstd.WithEncoderLevel(zstd.SpeedBestCompression))
if err != nil {
panic(fmt.Sprintf("bug: %v", err))
}
_, err = io.Copy(w, bytes.NewReader(data))
if err != nil {
return nil, err
}
if err := w.Close(); err != nil {
return nil, err
}
return apierror.CheckResponse(api.cluster.ImportConfigurationWithBodyWithResponse(ctx,
params,
"application/octet-stream",
&compressedData,
))
}
func (api *API) GetTelemetryConn() *grpc.ClientConn {
return api.telemetryConn
}

View file

@ -0,0 +1,176 @@
package cmd
import (
"bytes"
"encoding/json"
"fmt"
"os"
"path"
"strconv"
"strings"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/envoy/files"
"github.com/pomerium/pomerium/pkg/zero/cluster"
"github.com/pomerium/pomerium/pkg/zero/importutil"
"github.com/rs/zerolog"
"github.com/spf13/cobra"
)
func BuildImportCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "import",
Short: "Import an existing configuration to a Zero cluster",
RunE: func(cmd *cobra.Command, _ []string) error {
configFlag := cmd.InheritedFlags().Lookup("config")
var configFile string
if configFlag != nil {
configFile = configFlag.Value.String()
}
envInfo := findEnvironmentInfo()
if configFile == "" {
configFile = envInfo.ConfigArg
}
if configFile == "" {
return fmt.Errorf("no config file provided")
}
log.SetLevel(zerolog.ErrorLevel)
src, err := config.NewFileOrEnvironmentSource(configFile, files.FullVersion())
if err != nil {
return err
}
cfg := src.GetConfig()
client := zeroClientFromContext(cmd.Context())
converted := cfg.Options.ToProto()
for i, name := range importutil.GenerateRouteNames(converted.Routes) {
converted.Routes[i].Name = name
}
var params cluster.ImportConfigurationParams
if data, err := json.Marshal(envInfo); err == nil {
hints := make(map[string]string)
if err := json.Unmarshal(data, &hints); err == nil {
pairs := []string{}
for key, value := range hints {
pairs = append(pairs, fmt.Sprintf("%s=%s", key, value))
}
if len(pairs) > 0 {
params.XImportHints = &pairs
}
}
}
resp, err := client.ImportConfig(cmd.Context(), converted, &params)
if err != nil {
return fmt.Errorf("error importing config: %w", err)
}
if resp.Warnings != nil {
for _, warn := range *resp.Warnings {
cmd.Printf("warning: %s\n", warn)
}
}
if resp.Messages != nil {
for _, msg := range *resp.Messages {
cmd.Printf("✔ %s\n", msg)
}
}
cmd.Println("\nImport successful, return to your browser to continue setup.")
return nil
},
}
return cmd
}
type environmentInfo struct {
SystemType string `json:"systemType,omitempty"`
Hostname string `json:"hostname,omitempty"`
KubernetesNamespace string `json:"kubernetesNamespace,omitempty"`
Argv0 string `json:"argv0,omitempty"`
ConfigArg string `json:"configArg,omitempty"`
}
func findEnvironmentInfo() environmentInfo {
var info environmentInfo
if isKubernetes() {
info.SystemType = "kubernetes"
// search for downward api environment variables to see if we can determine
// the current namespace (adds '-n <namespace>' to the command given in the
// zero ui for users to copy/paste)
for _, env := range []string{
"POMERIUM_NAMESPACE", // the name we use in our official manifests
"POD_NAMESPACE", // very common alternative name
} {
if v, ok := os.LookupEnv(env); ok {
info.KubernetesNamespace = v
break
}
}
} else if isDocker() {
info.SystemType = "docker"
} else {
info.SystemType = "linux"
info.Argv0 = os.Args[0]
return info
}
info.Hostname, _ = os.Hostname()
pid, ok := findPomeriumPid()
if !ok {
return info
}
cmdline, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
if err != nil {
return info
}
args := bytes.Split(cmdline, []byte{0})
if len(args) > 0 {
info.Argv0 = string(args[0])
}
for i, arg := range args {
if strings.Contains(string(arg), "-config") {
if strings.Contains(string(arg), "-config=") {
info.ConfigArg = strings.Split(string(arg), "=")[1]
} else if len(args) > i+1 {
info.ConfigArg = string(args[i+1])
}
}
}
return info
}
func isKubernetes() bool {
return os.Getenv("KUBERNETES_SERVICE_HOST") != "" && os.Getenv("KUBERNETES_SERVICE_PORT") != ""
}
func isDocker() bool {
_, err := os.Stat("/.dockerenv")
return err == nil
}
func findPomeriumPid() (int, bool) {
pid1Argv0 := getProcessArgv0(1)
if path.Base(pid1Argv0) == "pomerium" {
return 1, true
}
pidList, err := os.ReadFile("/proc/1/task/1/children")
if err != nil {
return 0, false
}
for _, pidStr := range strings.Fields(string(pidList)) {
pid, _ := strconv.Atoi(pidStr)
if path.Base(getProcessArgv0(pid)) == "pomerium" {
return pid, true
}
}
return 0, false
}
func getProcessArgv0(pid int) string {
cmdline, err := os.ReadFile(fmt.Sprintf("/proc/%d/cmdline", pid))
if err != nil {
return ""
}
argv0, _, _ := bytes.Cut(cmdline, []byte{0})
return string(argv0)
}

View file

@ -0,0 +1,71 @@
package cmd
import (
"context"
"errors"
zero "github.com/pomerium/pomerium/internal/zero/api"
"github.com/spf13/cobra"
)
type zeroClientContextKeyType struct{}
var zeroClientContextKey zeroClientContextKeyType
func zeroClientFromContext(ctx context.Context) *zero.API {
return ctx.Value(zeroClientContextKey).(*zero.API)
}
func BuildRootCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "zero",
Short: "Interact with the Pomerium Zero cloud service",
PersistentPreRunE: func(cmd *cobra.Command, _ []string) error {
configFlag := cmd.InheritedFlags().Lookup("config")
var configFile string
if configFlag != nil {
configFile = configFlag.Value.String()
}
if err := setupLogger(); err != nil {
return err
}
var token string
if tokenFlag := cmd.InheritedFlags().Lookup("token"); tokenFlag != nil && tokenFlag.Changed {
token = tokenFlag.Value.String()
} else {
token = getToken(configFile)
}
if token == "" {
return errors.New("no token provided")
}
var clusterAPIEndpoint string
if endpointFlag := cmd.InheritedFlags().Lookup("cluster-api-endpoint"); endpointFlag != nil && endpointFlag.Changed {
clusterAPIEndpoint = endpointFlag.Value.String()
} else {
clusterAPIEndpoint = getClusterAPIEndpoint()
}
client, err := zero.NewAPI(cmd.Context(),
zero.WithAPIToken(token),
zero.WithClusterAPIEndpoint(clusterAPIEndpoint),
zero.WithConnectAPIEndpoint(getConnectAPIEndpoint()),
zero.WithOTELEndpoint(getOTELAPIEndpoint()),
)
if err != nil {
return err
}
cmd.SetContext(context.WithValue(cmd.Context(), zeroClientContextKey, client))
return nil
},
}
cmd.AddCommand(BuildImportCmd())
cmd.PersistentFlags().String("config", "", "Specify configuration file location")
cmd.PersistentFlags().String("token", "", "Pomerium Zero Token (default: $POMERIUM_ZERO_TOKEN)")
cmd.PersistentFlags().String("cluster-api-endpoint", "", "Pomerium Zero Cluster API Endpoint (default: $CLUSTER_API_ENDPOINT)")
cmd.PersistentFlags().Lookup("cluster-api-endpoint").Hidden = true
return cmd
}

File diff suppressed because it is too large Load diff

View file

@ -108,6 +108,7 @@ message Route {
envoy.config.cluster.v3.Cluster envoy_opts = 36;
repeated Policy policies = 27;
repeated PPLPolicy ppl_policies = 63;
string id = 28;
optional string host_rewrite = 50;
@ -120,6 +121,10 @@ message Route {
bool show_error_details = 59;
}
message PPLPolicy {
bytes raw = 1;
}
message Policy {
string id = 1;
string name = 2;
@ -227,7 +232,7 @@ message Settings {
optional string envoy_admin_profile_path = 109;
optional string envoy_admin_address = 110;
optional string envoy_bind_config_source_address = 111;
optional string envoy_bind_config_freebind = 112;
optional bool envoy_bind_config_freebind = 112;
repeated string programmatic_redirect_domain_whitelist = 68;
optional envoy.extensions.filters.network.http_connection_manager.v3
.HttpConnectionManager.CodecType codec_type = 73;
@ -247,6 +252,8 @@ message DownstreamMtlsSettings {
optional string ca = 1;
optional string crl = 2;
optional MtlsEnforcementMode enforcement = 3;
repeated SANMatcher match_subject_alt_names = 4;
optional uint32 max_verify_depth = 5;
}
enum MtlsEnforcementMode {
@ -255,3 +262,16 @@ enum MtlsEnforcementMode {
POLICY_WITH_DEFAULT_DENY = 2;
REJECT_CONNECTION = 3;
}
message SANMatcher {
enum SANType {
SAN_TYPE_UNSPECIFIED = 0;
EMAIL = 1;
DNS = 2;
URI = 3;
IP_ADDRESS = 4;
USER_PRINCIPAL_NAME = 5;
}
SANType san_type = 1;
string pattern = 2;
}

View file

@ -103,6 +103,9 @@ type ClientInterface interface {
ReportClusterResourceBundleStatus(ctx context.Context, bundleId BundleId, body ReportClusterResourceBundleStatusJSONRequestBody, reqEditors ...RequestEditorFn) (*http.Response, error)
// ImportConfigurationWithBody request with any body
ImportConfigurationWithBody(ctx context.Context, params *ImportConfigurationParams, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error)
// ExchangeClusterIdentityTokenWithBody request with any body
ExchangeClusterIdentityTokenWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error)
@ -174,6 +177,18 @@ func (c *Client) ReportClusterResourceBundleStatus(ctx context.Context, bundleId
return c.Client.Do(req)
}
func (c *Client) ImportConfigurationWithBody(ctx context.Context, params *ImportConfigurationParams, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) {
req, err := NewImportConfigurationRequestWithBody(c.Server, params, contentType, body)
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
if err := c.applyEditors(ctx, req, reqEditors); err != nil {
return nil, err
}
return c.Client.Do(req)
}
func (c *Client) ExchangeClusterIdentityTokenWithBody(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*http.Response, error) {
req, err := NewExchangeClusterIdentityTokenRequestWithBody(c.Server, contentType, body)
if err != nil {
@ -357,6 +372,50 @@ func NewReportClusterResourceBundleStatusRequestWithBody(server string, bundleId
return req, nil
}
// NewImportConfigurationRequestWithBody generates requests for ImportConfiguration with any type of body
func NewImportConfigurationRequestWithBody(server string, params *ImportConfigurationParams, contentType string, body io.Reader) (*http.Request, error) {
var err error
serverURL, err := url.Parse(server)
if err != nil {
return nil, err
}
operationPath := fmt.Sprintf("/config/import")
if operationPath[0] == '/' {
operationPath = "." + operationPath
}
queryURL, err := serverURL.Parse(operationPath)
if err != nil {
return nil, err
}
req, err := http.NewRequest("PUT", queryURL.String(), body)
if err != nil {
return nil, err
}
req.Header.Add("Content-Type", contentType)
if params != nil {
if params.XImportHints != nil {
var headerParam0 string
headerParam0, err = runtime.StyleParamWithLocation("simple", true, "X-Import-Hints", runtime.ParamLocationHeader, *params.XImportHints)
if err != nil {
return nil, err
}
req.Header.Set("X-Import-Hints", headerParam0)
}
}
return req, nil
}
// NewExchangeClusterIdentityTokenRequest calls the generic ExchangeClusterIdentityToken builder with application/json body
func NewExchangeClusterIdentityTokenRequest(server string, body ExchangeClusterIdentityTokenJSONRequestBody) (*http.Request, error) {
var bodyReader io.Reader
@ -494,6 +553,9 @@ type ClientWithResponsesInterface interface {
ReportClusterResourceBundleStatusWithResponse(ctx context.Context, bundleId BundleId, body ReportClusterResourceBundleStatusJSONRequestBody, reqEditors ...RequestEditorFn) (*ReportClusterResourceBundleStatusResp, error)
// ImportConfigurationWithBodyWithResponse request with any body
ImportConfigurationWithBodyWithResponse(ctx context.Context, params *ImportConfigurationParams, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*ImportConfigurationResp, error)
// ExchangeClusterIdentityTokenWithBodyWithResponse request with any body
ExchangeClusterIdentityTokenWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*ExchangeClusterIdentityTokenResp, error)
@ -601,6 +663,32 @@ func (r ReportClusterResourceBundleStatusResp) StatusCode() int {
return 0
}
type ImportConfigurationResp struct {
Body []byte
HTTPResponse *http.Response
JSON200 *ImportResponse
JSON400 *ErrorResponse
JSON403 *ErrorResponse
JSON413 *ErrorResponse
JSON500 *ErrorResponse
}
// Status returns HTTPResponse.Status
func (r ImportConfigurationResp) Status() string {
if r.HTTPResponse != nil {
return r.HTTPResponse.Status
}
return http.StatusText(0)
}
// StatusCode returns HTTPResponse.StatusCode
func (r ImportConfigurationResp) StatusCode() int {
if r.HTTPResponse != nil {
return r.HTTPResponse.StatusCode
}
return 0
}
type ExchangeClusterIdentityTokenResp struct {
Body []byte
HTTPResponse *http.Response
@ -692,6 +780,15 @@ func (c *ClientWithResponses) ReportClusterResourceBundleStatusWithResponse(ctx
return ParseReportClusterResourceBundleStatusResp(rsp)
}
// ImportConfigurationWithBodyWithResponse request with arbitrary body returning *ImportConfigurationResp
func (c *ClientWithResponses) ImportConfigurationWithBodyWithResponse(ctx context.Context, params *ImportConfigurationParams, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*ImportConfigurationResp, error) {
rsp, err := c.ImportConfigurationWithBody(ctx, params, contentType, body, reqEditors...)
if err != nil {
return nil, err
}
return ParseImportConfigurationResp(rsp)
}
// ExchangeClusterIdentityTokenWithBodyWithResponse request with arbitrary body returning *ExchangeClusterIdentityTokenResp
func (c *ClientWithResponses) ExchangeClusterIdentityTokenWithBodyWithResponse(ctx context.Context, contentType string, body io.Reader, reqEditors ...RequestEditorFn) (*ExchangeClusterIdentityTokenResp, error) {
rsp, err := c.ExchangeClusterIdentityTokenWithBody(ctx, contentType, body, reqEditors...)
@ -886,6 +983,60 @@ func ParseReportClusterResourceBundleStatusResp(rsp *http.Response) (*ReportClus
return response, nil
}
// ParseImportConfigurationResp parses an HTTP response from a ImportConfigurationWithResponse call
func ParseImportConfigurationResp(rsp *http.Response) (*ImportConfigurationResp, error) {
bodyBytes, err := io.ReadAll(rsp.Body)
defer func() { _ = rsp.Body.Close() }()
if err != nil {
return nil, err
}
response := &ImportConfigurationResp{
Body: bodyBytes,
HTTPResponse: rsp,
}
switch {
case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 200:
var dest ImportResponse
if err := json.Unmarshal(bodyBytes, &dest); err != nil {
return nil, err
}
response.JSON200 = &dest
case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 400:
var dest ErrorResponse
if err := json.Unmarshal(bodyBytes, &dest); err != nil {
return nil, err
}
response.JSON400 = &dest
case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 403:
var dest ErrorResponse
if err := json.Unmarshal(bodyBytes, &dest); err != nil {
return nil, err
}
response.JSON403 = &dest
case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 413:
var dest ErrorResponse
if err := json.Unmarshal(bodyBytes, &dest); err != nil {
return nil, err
}
response.JSON413 = &dest
case strings.Contains(rsp.Header.Get("Content-Type"), "json") && rsp.StatusCode == 500:
var dest ErrorResponse
if err := json.Unmarshal(bodyBytes, &dest); err != nil {
return nil, err
}
response.JSON500 = &dest
}
return response, nil
}
// ParseExchangeClusterIdentityTokenResp parses an HTTP response from a ExchangeClusterIdentityTokenWithResponse call
func ParseExchangeClusterIdentityTokenResp(rsp *http.Response) (*ExchangeClusterIdentityTokenResp, error) {
bodyBytes, err := io.ReadAll(rsp.Body)
@ -1073,6 +1224,39 @@ func (r *ReportClusterResourceBundleStatusResp) GetValue() *EmptyResponse {
return &EmptyResponse{}
}
// GetHTTPResponse implements apierror.APIResponse
func (r *ImportConfigurationResp) GetHTTPResponse() *http.Response {
return r.HTTPResponse
}
// GetValue implements apierror.APIResponse
func (r *ImportConfigurationResp) GetValue() *ImportResponse {
return r.JSON200
}
// GetBadRequestError implements apierror.APIResponse
func (r *ImportConfigurationResp) GetBadRequestError() (string, bool) {
if r.JSON400 == nil {
return "", false
}
return r.JSON400.Error, true
}
func (r *ImportConfigurationResp) GetForbiddenError() (string, bool) {
if r.JSON403 == nil {
return "", false
}
return r.JSON403.Error, true
}
// GetInternalServerError implements apierror.APIResponse
func (r *ImportConfigurationResp) GetInternalServerError() (string, bool) {
if r.JSON500 == nil {
return "", false
}
return r.JSON500.Error, true
}
// GetHTTPResponse implements apierror.APIResponse
func (r *ExchangeClusterIdentityTokenResp) GetHTTPResponse() *http.Response {
return r.HTTPResponse

View file

@ -13,4 +13,5 @@ var (
_ apierror.APIResponse[GetBundlesResponse] = (*GetClusterResourceBundlesResp)(nil)
_ apierror.APIResponse[DownloadBundleResponse] = (*DownloadClusterResourceBundleResp)(nil)
_ apierror.APIResponse[EmptyResponse] = (*ReportClusterResourceBundleStatusResp)(nil)
_ apierror.APIResponse[ImportResponse] = (*ImportConfigurationResp)(nil)
)

View file

@ -100,6 +100,12 @@ type GetBundlesResponse struct {
Bundles []Bundle `json:"bundles"`
}
// ImportResponse defines model for ImportResponse.
type ImportResponse struct {
Messages *[]string `json:"messages,omitempty"`
Warnings *[]string `json:"warnings,omitempty"`
}
// ReportUsageRequest defines model for ReportUsageRequest.
type ReportUsageRequest struct {
Users []ReportUsageUser `json:"users"`
@ -115,6 +121,11 @@ type ReportUsageUser struct {
// BundleId defines model for bundleId.
type BundleId = string
// ImportConfigurationParams defines parameters for ImportConfiguration.
type ImportConfigurationParams struct {
XImportHints *[]string `json:"X-Import-Hints,omitempty"`
}
// ReportClusterResourceBundleStatusJSONRequestBody defines body for ReportClusterResourceBundleStatus for application/json ContentType.
type ReportClusterResourceBundleStatusJSONRequestBody = BundleStatus

View file

@ -148,7 +148,6 @@ paths:
application/json:
schema:
$ref: "#/components/schemas/ErrorResponse"
/reportUsage:
post:
description: Report usage for the cluster
@ -176,6 +175,64 @@ paths:
schema:
$ref: "#/components/schemas/ErrorResponse"
/config/import:
put:
description: |
Apply the raw configuration directly to the cluster, replacing any
existing user-defined routes, policies, and certificates.
Only available before a Pomerium instance has connected to the cluster
for the first time.
operationId: importConfiguration
tags: [cluster]
parameters:
- in: header
name: X-Import-Hints
schema:
type: array
items:
type: string
style: simple
explode: true
requestBody:
required: true
content:
application/octet-stream:
schema:
type: string
contentMediaType: application/octet-stream
contentEncoding: zstd
description: type.googleapis.com/pomerium.config.Config
responses:
"200":
description: OK
content:
application/json:
schema:
$ref: "#/components/schemas/ImportResponse"
"400":
description: Bad Request
content:
application/json:
schema:
$ref: "#/components/schemas/ErrorResponse"
"403":
description: Forbidden
content:
application/json:
schema:
$ref: "#/components/schemas/ErrorResponse"
"413":
description: Content Too Large
content:
application/json:
schema:
$ref: "#/components/schemas/ErrorResponse"
"500":
description: Internal Server Error
content:
application/json:
schema:
$ref: "#/components/schemas/ErrorResponse"
components:
parameters:
bundleId:
@ -279,6 +336,17 @@ components:
description: Error message
required:
- error
ImportResponse:
type: object
properties:
messages:
type: array
items:
type: string
warnings:
type: array
items:
type: string
ExchangeTokenRequest:
type: object
properties:

View file

@ -7,6 +7,7 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"github.com/go-chi/chi/v5"
@ -29,6 +30,9 @@ type ServerInterface interface {
// (POST /bundles/{bundleId}/status)
ReportClusterResourceBundleStatus(w http.ResponseWriter, r *http.Request, bundleId BundleId)
// (PUT /config/import)
ImportConfiguration(w http.ResponseWriter, r *http.Request, params ImportConfigurationParams)
// (POST /exchangeToken)
ExchangeClusterIdentityToken(w http.ResponseWriter, r *http.Request)
@ -60,6 +64,11 @@ func (_ Unimplemented) ReportClusterResourceBundleStatus(w http.ResponseWriter,
w.WriteHeader(http.StatusNotImplemented)
}
// (PUT /config/import)
func (_ Unimplemented) ImportConfiguration(w http.ResponseWriter, r *http.Request, params ImportConfigurationParams) {
w.WriteHeader(http.StatusNotImplemented)
}
// (POST /exchangeToken)
func (_ Unimplemented) ExchangeClusterIdentityToken(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusNotImplemented)
@ -169,6 +178,49 @@ func (siw *ServerInterfaceWrapper) ReportClusterResourceBundleStatus(w http.Resp
handler.ServeHTTP(w, r.WithContext(ctx))
}
// ImportConfiguration operation middleware
func (siw *ServerInterfaceWrapper) ImportConfiguration(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
var err error
ctx = context.WithValue(ctx, BearerAuthScopes, []string{})
// Parameter object where we will unmarshal all parameters from the context
var params ImportConfigurationParams
headers := r.Header
// ------------- Optional header parameter "X-Import-Hints" -------------
if valueList, found := headers[http.CanonicalHeaderKey("X-Import-Hints")]; found {
var XImportHints []string
n := len(valueList)
if n != 1 {
siw.ErrorHandlerFunc(w, r, &TooManyValuesForParamError{ParamName: "X-Import-Hints", Count: n})
return
}
err = runtime.BindStyledParameterWithOptions("simple", "X-Import-Hints", valueList[0], &XImportHints, runtime.BindStyledParameterOptions{ParamLocation: runtime.ParamLocationHeader, Explode: true, Required: false})
if err != nil {
siw.ErrorHandlerFunc(w, r, &InvalidParamFormatError{ParamName: "X-Import-Hints", Err: err})
return
}
params.XImportHints = &XImportHints
}
handler := http.Handler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
siw.Handler.ImportConfiguration(w, r, params)
}))
for i := len(siw.HandlerMiddlewares) - 1; i >= 0; i-- {
handler = siw.HandlerMiddlewares[i](handler)
}
handler.ServeHTTP(w, r.WithContext(ctx))
}
// ExchangeClusterIdentityToken operation middleware
func (siw *ServerInterfaceWrapper) ExchangeClusterIdentityToken(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
@ -326,6 +378,9 @@ func HandlerWithOptions(si ServerInterface, options ChiServerOptions) http.Handl
r.Group(func(r chi.Router) {
r.Post(options.BaseURL+"/bundles/{bundleId}/status", wrapper.ReportClusterResourceBundleStatus)
})
r.Group(func(r chi.Router) {
r.Put(options.BaseURL+"/config/import", wrapper.ImportConfiguration)
})
r.Group(func(r chi.Router) {
r.Post(options.BaseURL+"/exchangeToken", wrapper.ExchangeClusterIdentityToken)
})
@ -483,6 +538,60 @@ func (response ReportClusterResourceBundleStatus500JSONResponse) VisitReportClus
return json.NewEncoder(w).Encode(response)
}
type ImportConfigurationRequestObject struct {
Params ImportConfigurationParams
Body io.Reader
}
type ImportConfigurationResponseObject interface {
VisitImportConfigurationResponse(w http.ResponseWriter) error
}
type ImportConfiguration200JSONResponse ImportResponse
func (response ImportConfiguration200JSONResponse) VisitImportConfigurationResponse(w http.ResponseWriter) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(200)
return json.NewEncoder(w).Encode(response)
}
type ImportConfiguration400JSONResponse ErrorResponse
func (response ImportConfiguration400JSONResponse) VisitImportConfigurationResponse(w http.ResponseWriter) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(400)
return json.NewEncoder(w).Encode(response)
}
type ImportConfiguration403JSONResponse ErrorResponse
func (response ImportConfiguration403JSONResponse) VisitImportConfigurationResponse(w http.ResponseWriter) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(403)
return json.NewEncoder(w).Encode(response)
}
type ImportConfiguration413JSONResponse ErrorResponse
func (response ImportConfiguration413JSONResponse) VisitImportConfigurationResponse(w http.ResponseWriter) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(413)
return json.NewEncoder(w).Encode(response)
}
type ImportConfiguration500JSONResponse ErrorResponse
func (response ImportConfiguration500JSONResponse) VisitImportConfigurationResponse(w http.ResponseWriter) error {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(500)
return json.NewEncoder(w).Encode(response)
}
type ExchangeClusterIdentityTokenRequestObject struct {
Body *ExchangeClusterIdentityTokenJSONRequestBody
}
@ -567,6 +676,9 @@ type StrictServerInterface interface {
// (POST /bundles/{bundleId}/status)
ReportClusterResourceBundleStatus(ctx context.Context, request ReportClusterResourceBundleStatusRequestObject) (ReportClusterResourceBundleStatusResponseObject, error)
// (PUT /config/import)
ImportConfiguration(ctx context.Context, request ImportConfigurationRequestObject) (ImportConfigurationResponseObject, error)
// (POST /exchangeToken)
ExchangeClusterIdentityToken(ctx context.Context, request ExchangeClusterIdentityTokenRequestObject) (ExchangeClusterIdentityTokenResponseObject, error)
@ -710,6 +822,34 @@ func (sh *strictHandler) ReportClusterResourceBundleStatus(w http.ResponseWriter
}
}
// ImportConfiguration operation middleware
func (sh *strictHandler) ImportConfiguration(w http.ResponseWriter, r *http.Request, params ImportConfigurationParams) {
var request ImportConfigurationRequestObject
request.Params = params
request.Body = r.Body
handler := func(ctx context.Context, w http.ResponseWriter, r *http.Request, request interface{}) (interface{}, error) {
return sh.ssi.ImportConfiguration(ctx, request.(ImportConfigurationRequestObject))
}
for _, middleware := range sh.middlewares {
handler = middleware(handler, "ImportConfiguration")
}
response, err := handler(r.Context(), w, r, request)
if err != nil {
sh.options.ResponseErrorHandlerFunc(w, r, err)
} else if validResponse, ok := response.(ImportConfigurationResponseObject); ok {
if err := validResponse.VisitImportConfigurationResponse(w); err != nil {
sh.options.ResponseErrorHandlerFunc(w, r, err)
}
} else if response != nil {
sh.options.ResponseErrorHandlerFunc(w, r, fmt.Errorf("unexpected response type: %T", response))
}
}
// ExchangeClusterIdentityToken operation middleware
func (sh *strictHandler) ExchangeClusterIdentityToken(w http.ResponseWriter, r *http.Request) {
var request ExchangeClusterIdentityTokenRequestObject

View file

@ -0,0 +1,368 @@
package importutil
import (
"crypto/x509"
"fmt"
"iter"
"net/url"
"regexp"
"slices"
"strconv"
"strings"
"github.com/cespare/xxhash/v2"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
)
func GenerateCertName(cert *x509.Certificate) *string {
var out string
if cert.IsCA {
if cert.Subject.CommonName != "" {
out = cert.Subject.CommonName
} else {
out = cert.Subject.String()
}
} else {
if cert.Subject.CommonName != "" {
out = cert.Subject.CommonName
} else if len(cert.DNSNames) > 0 {
out = pickDNSName(cert.DNSNames)
} else {
out = "leaf"
}
}
if strings.Contains(out, "-") {
out = strings.ReplaceAll(out, " ", "_")
} else {
out = strings.ReplaceAll(out, " ", "-")
}
suffix := fmt.Sprintf("@%d", cert.NotBefore.Unix())
if !strings.Contains(out, suffix) {
out += suffix
}
return &out
}
func pickDNSName(names []string) string {
if len(names) == 1 {
return names[0]
}
// prefer wildcard names
for _, name := range names {
if strings.HasPrefix(name, "*.") {
return name
}
}
return names[0]
}
func GenerateRouteNames(routes []*configpb.Route) []string {
out := make([]string, len(routes))
prefixes := make([][]string, len(routes))
indexes := map[*configpb.Route]int{}
trie := newDomainTrie()
for i, route := range routes {
trie.Insert(route)
indexes[route] = i
}
trie.Compact()
trie.Walk(func(parents []string, node *domainTreeNode) {
for subdomain, child := range node.children {
for route, name := range differentiateRoutes(subdomain, child.routes) {
idx := indexes[route]
out[idx] = name
prefixes[idx] = parents
}
}
})
seen := map[string]int{}
for idx, name := range out {
prevIdx, ok := seen[name]
if !ok {
out[idx] = name
seen[name] = idx
continue
}
delete(seen, name)
var b strings.Builder
b.WriteString(name)
var prevNameB strings.Builder
prevNameB.WriteString(out[prevIdx])
var nameB strings.Builder
nameB.WriteString(name)
minLen := min(len(prefixes[prevIdx]), len(prefixes[idx]))
maxLen := max(len(prefixes[prevIdx]), len(prefixes[idx]))
for j := range maxLen {
if j >= minLen {
if j < len(prefixes[prevIdx]) {
prevNameB.WriteRune('-')
prevNameB.WriteString(strings.ReplaceAll(prefixes[prevIdx][j], ".", "-"))
} else {
nameB.WriteRune('-')
nameB.WriteString(strings.ReplaceAll(prefixes[idx][j], ".", "-"))
}
continue
}
prevPrefix, prefix := trimCommonSubdomains(prefixes[prevIdx][j], prefixes[idx][j])
if prevPrefix != prefix {
prevNameB.WriteRune('-')
prevNameB.WriteString(prevPrefix)
nameB.WriteRune('-')
nameB.WriteString(prefix)
}
}
out[prevIdx] = prevNameB.String()
out[idx] = nameB.String()
seen[out[prevIdx]] = prevIdx
seen[out[idx]] = idx
}
for i, name := range out {
if name == "" {
out[i] = fmt.Sprintf("route-%d", i)
}
}
return out
}
func trimCommonSubdomains(a, b string) (string, string) {
aParts := strings.Split(a, ".")
bParts := strings.Split(b, ".")
for len(aParts) > 1 && len(bParts) > 1 && aParts[0] == bParts[0] {
aParts = aParts[1:]
bParts = bParts[1:]
}
for len(aParts) > 1 && len(bParts) > 1 && aParts[len(aParts)-1] == bParts[len(bParts)-1] {
aParts = aParts[:len(aParts)-1]
bParts = bParts[:len(bParts)-1]
}
return strings.Join(aParts, "-"), strings.Join(bParts, "-")
}
func differentiateRoutes(subdomain string, routes []*configpb.Route) iter.Seq2[*configpb.Route, string] {
return func(yield func(*configpb.Route, string) bool) {
if len(routes) == 1 {
yield(routes[0], subdomain)
return
}
names := map[string][]*configpb.Route{}
replacer := strings.NewReplacer(
" ", "_",
"/", "-",
"*", "",
)
simplePathName := func(pathOrPrefix string) string {
if p, err := url.PathUnescape(pathOrPrefix); err == nil {
pathOrPrefix = strings.ToLower(p)
}
return replacer.Replace(strings.Trim(pathOrPrefix, "/ "))
}
genericRegexCounter := 0
regexName := func(regex string) string {
if path, pattern, ok := commonRegexPattern(regex); ok {
name := simplePathName(path)
if name == "" && pattern != "" {
return "re-any"
}
return fmt.Sprintf("re-%s-prefix", name)
}
genericRegexCounter++
return fmt.Sprintf("re-%d", genericRegexCounter)
}
var prefixCount, pathCount int
for _, route := range routes {
// each route will have the same domain, but a unique prefix/path/regex.
var name string
switch {
case route.Prefix != "":
name = simplePathName(route.Prefix)
prefixCount++
case route.Path != "":
name = simplePathName(route.Path)
pathCount++
case route.Regex != "":
name = regexName(route.Regex)
}
names[name] = append(names[name], route)
}
nameCounts := map[uint64]int{}
for name, routes := range names {
if len(routes) == 1 {
var b strings.Builder
b.WriteString(subdomain)
if name != "" {
b.WriteRune('-')
b.WriteString(name)
}
if !yield(routes[0], b.String()) {
return
}
} else {
// assign a "-prefix" or "-path" suffix to routes with the same name
// but different configurations
prefixSuffix := "-prefix"
pathSuffix := "-path"
switch {
case prefixCount == 1 && pathCount == 1:
pathSuffix = ""
case prefixCount > 1 && pathCount == 1:
prefixSuffix = ""
case prefixCount == 1 && pathCount > 1:
pathSuffix = ""
case prefixCount == 0:
pathSuffix = ""
case pathCount == 0:
prefixSuffix = ""
}
var b strings.Builder
for _, route := range routes {
b.Reset()
b.WriteString(subdomain)
if name != "" {
b.WriteRune('-')
b.WriteString(name)
}
if route.Prefix != "" {
b.WriteString(prefixSuffix)
} else if route.Path != "" {
b.WriteString(pathSuffix)
}
sum := xxhash.Sum64String(b.String())
nameCounts[sum]++
if c := nameCounts[sum]; c > 1 {
b.WriteString(" (")
b.WriteString(strconv.Itoa(c))
b.WriteString(")")
}
if !yield(route, b.String()) {
return
}
}
}
}
}
}
type domainTreeNode struct {
parent *domainTreeNode
children map[string]*domainTreeNode
routes []*configpb.Route
}
func (n *domainTreeNode) insert(key string, route *configpb.Route) *domainTreeNode {
if existing, ok := n.children[key]; ok {
if route != nil {
existing.routes = append(existing.routes, route)
}
return existing
}
node := &domainTreeNode{
parent: n,
children: map[string]*domainTreeNode{},
}
if route != nil {
node.routes = append(node.routes, route)
}
n.children[key] = node
return node
}
type domainTrie struct {
root *domainTreeNode
}
func newDomainTrie() *domainTrie {
t := &domainTrie{
root: &domainTreeNode{
children: map[string]*domainTreeNode{},
},
}
return t
}
type walkFn = func(parents []string, node *domainTreeNode)
func (t *domainTrie) Walk(fn walkFn) {
t.root.walk(nil, fn)
}
func (n *domainTreeNode) walk(prefix []string, fn walkFn) {
for key, child := range n.children {
fn(append(prefix, key), child)
child.walk(append(prefix, key), fn)
}
}
func (t *domainTrie) Insert(route *configpb.Route) {
u, _ := url.Parse(route.From)
if u == nil {
// ignore invalid urls, they will be assigned generic fallback names
return
}
parts := strings.Split(u.Hostname(), ".")
slices.Reverse(parts)
cur := t.root
for _, part := range parts[:len(parts)-1] {
cur = cur.insert(part, nil)
}
cur.insert(parts[len(parts)-1], route)
}
func (t *domainTrie) Compact() {
t.root.compact()
}
func (n *domainTreeNode) compact() {
for _, child := range n.children {
child.compact()
}
if n.parent == nil {
return
}
var firstKey string
var firstChild *domainTreeNode
for key, child := range n.children {
firstKey, firstChild = key, child
break
}
// compact intermediate nodes, not leaves
if len(n.children) == 1 && len(firstChild.routes) == 0 {
firstChild.parent = n.parent
for key, child := range n.parent.children {
if child == n {
delete(n.parent.children, key)
n.parent.children[fmt.Sprintf("%s.%s", key, firstKey)] = firstChild
*n = domainTreeNode{}
break
}
}
}
}
// Matches an optional leading slash, then zero or more path segments separated
// by '/' characters, where the final path segment contains one of the following
// commonly used regex patterns used to match path segments:
// - '.*' or '.+'
// - '[^/]*', '[^/]+', '[^\/]*', or '[^\/]+'
// - '\w*' or '\w+'
// - any of the above patterns, enclosed by parentheses
// The first capture group contains the path leading up to the wildcard segment
// and can be empty or have leading/trailing slashes. The second capture group
// contains the wildcard segment with no leading or trailing slashes.
var pathPrefixMatchRegex = regexp.MustCompile(`^(\/?(?:\w+\/)*)(\(?(?:\.\+|\.\*|\[\^\\?\/\][\+\*]|\\w[\+\*])\)?)$`)
func commonRegexPattern(re string) (path string, pattern string, found bool) {
re = strings.TrimSuffix(strings.TrimPrefix(re, "^"), "$")
if match := pathPrefixMatchRegex.FindStringSubmatch(re); match != nil {
return match[1], match[2], true
}
return "", "", false
}

View file

@ -0,0 +1,394 @@
package importutil_test
import (
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"slices"
"testing"
"time"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/zero/importutil"
"github.com/stretchr/testify/assert"
)
func TestGenerateCertName(t *testing.T) {
cases := []struct {
name string
input x509.Certificate
expected string
}{
{
name: "cert with common name",
input: x509.Certificate{
IsCA: true,
Subject: pkix.Name{CommonName: "sample"},
},
expected: "sample",
},
{
name: "cert with common name and other subject fields",
input: x509.Certificate{
IsCA: true,
Subject: pkix.Name{
CommonName: "sample",
Organization: []string{"foo"},
OrganizationalUnit: []string{"bar"},
},
},
expected: "sample",
},
{
name: "common name with spaces",
input: x509.Certificate{
IsCA: true,
Subject: pkix.Name{CommonName: "sample name"},
},
expected: "sample-name",
},
{
name: "common name with special characters",
input: x509.Certificate{
IsCA: true,
Subject: pkix.Name{CommonName: "sample common-name"},
},
expected: "sample_common-name",
},
{
name: "cert with other subject fields but no common name",
input: x509.Certificate{
IsCA: true,
Subject: pkix.Name{
Organization: []string{"foo"},
OrganizationalUnit: []string{"bar"},
},
},
expected: "OU=bar,O=foo",
},
{
name: "leaf cert with common name",
input: x509.Certificate{
IsCA: false,
Subject: pkix.Name{CommonName: "sample"},
},
expected: "sample",
},
{
name: "leaf cert with dns name",
input: x509.Certificate{
IsCA: false,
DNSNames: []string{"example.com"},
},
expected: "example.com",
},
{
name: "leaf cert with dns names",
input: x509.Certificate{
IsCA: false,
DNSNames: []string{"example.com", "*.example.com"},
},
expected: "*.example.com",
},
{
name: "leaf cert with neither common name nor dns names",
input: x509.Certificate{
IsCA: false,
},
expected: "leaf",
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
nbf := time.Now()
tc.input.NotBefore = nbf
tc.expected += fmt.Sprintf("@%d", nbf.Unix())
out := importutil.GenerateCertName(&tc.input)
assert.Equal(t, tc.expected, *out)
})
}
}
func TestGenerateRouteNames(t *testing.T) {
const testExample = "https://test.example.com"
cases := []struct {
name string
input []*configpb.Route
expected []string
}{
{
name: "single domain name",
input: []*configpb.Route{
{From: "https://foo.example.com"},
{From: "https://bar.example.com"},
{From: "https://baz.example.com"},
},
expected: []string{"foo", "bar", "baz"},
},
{
name: "multiple domain names, unique subdomains",
input: []*configpb.Route{
{From: "https://a.domain1.example.com"},
{From: "https://b.domain1.example.com"},
{From: "https://c.domain1.example.com"},
{From: "https://d.domain2.example.com"},
{From: "https://e.domain2.example.com"},
{From: "https://f.domain2.example.com"},
},
expected: []string{"a", "b", "c", "d", "e", "f"},
},
{
name: "multiple domain names, conflicting subdomains",
input: []*configpb.Route{
{From: "https://a.domain1.example.com"},
{From: "https://b.domain1.example.com"},
{From: "https://c.domain1.example.com"},
{From: "https://a.domain2.example.com"},
{From: "https://b.domain2.example.com"},
{From: "https://c.domain2.example.com"},
},
expected: []string{
"a-domain1",
"b-domain1",
"c-domain1",
"a-domain2",
"b-domain2",
"c-domain2",
},
},
{
name: "multiple nested domain names, conflicting subdomains",
input: []*configpb.Route{
{From: "https://a.domain1.domain2.domain3.example.com"},
{From: "https://b.domain1.domain2.domain3.example.com"},
{From: "https://c.domain1.domain2.domain3.example.com"},
{From: "https://a.domain1.domain2.domain4.example.com"},
{From: "https://b.domain1.domain2.domain4.example.com"},
{From: "https://c.domain1.domain2.domain4.example.com"},
{From: "https://a.domain1.domain2.domain5.example.com"},
{From: "https://b.domain2.domain2.domain5.example.com"},
{From: "https://c.domain3.domain2.domain5.example.com"},
{From: "https://a.domain1.domain2.domain6.example.com"},
{From: "https://b.domain2.domain2.domain6.example.com"},
{From: "https://c.domain3.domain2.domain6.example.com"},
},
expected: []string{
"a-domain3",
"b-domain3",
"c-domain3",
"a-domain4",
"b-domain4",
"c-domain4",
"a-domain5",
"b-domain5",
"c-domain5",
"a-domain6",
"b-domain6",
"c-domain6",
},
},
{
name: "conflicting subdomain names nested at different levels",
input: []*configpb.Route{
{From: "https://a.domain1.domain2.example.com"},
{From: "https://a.domain1.example.com"},
{From: "https://a.example.com"},
{From: "https://a.domain3.domain2.example.com"},
{From: "https://a.domain3.example.com"},
},
expected: []string{
"a-domain2-domain1",
"a-domain1",
"a",
"a-domain2-domain3",
"a-domain3",
},
},
{
name: "conflicting subdomain names nested at different levels, unique paths",
input: []*configpb.Route{
{From: "https://a.domain1.domain2.example.com"},
{From: "https://a.domain1.example.com"},
{From: "https://a.example.com"},
},
expected: []string{
"a-domain2-domain1",
"a-domain1",
"a",
},
},
{
name: "same domain, separate prefix options",
input: []*configpb.Route{
{From: testExample, Prefix: "/a"},
{From: testExample, Prefix: "/b"},
{From: testExample, Prefix: "/c"},
},
expected: []string{"test-a", "test-b", "test-c"},
},
{
name: "same domain, mixed prefix/path options",
input: []*configpb.Route{
{From: testExample, Prefix: "/a"},
{From: testExample, Path: "/b"},
{From: testExample, Prefix: "/c"},
{From: testExample, Path: "/d"},
},
expected: []string{"test-a", "test-b", "test-c", "test-d"},
},
{
name: "same domain, name-conflicting prefix/path options (1 prefix/1 path)",
input: []*configpb.Route{
{From: testExample, Prefix: "/a/"},
{From: testExample, Path: "/a"},
},
expected: []string{"test-a-prefix", "test-a"},
},
{
name: "same domain, name-conflicting prefix/path options (more prefixes than paths)",
input: []*configpb.Route{
{From: testExample, Prefix: "/a/"},
{From: testExample, Prefix: "/b/"},
{From: testExample, Prefix: "/c/"},
{From: testExample, Path: "/a"},
},
expected: []string{"test-a", "test-b", "test-c", "test-a-path"},
},
{
name: "same domain, name-conflicting prefix/path options (more paths than prefixes)",
input: []*configpb.Route{
{From: testExample, Path: "/a"},
{From: testExample, Path: "/b"},
{From: testExample, Path: "/c"},
{From: testExample, Prefix: "/a/"},
},
expected: []string{"test-a", "test-b", "test-c", "test-a-prefix"},
},
{
name: "same domain, name-conflicting path options, duplicate names",
input: []*configpb.Route{
{From: testExample, Path: "/a"},
{From: testExample, Path: "/a/"},
},
expected: []string{"test-a", "test-a (2)"},
},
{
name: "same domain, name-conflicting prefix options, duplicate names",
input: []*configpb.Route{
{From: testExample, Prefix: "/a"},
{From: testExample, Prefix: "/a/"},
},
expected: []string{"test-a", "test-a (2)"},
},
{
name: "missing domain name",
input: []*configpb.Route{{From: "https://:1234"}},
expected: []string{"route-0"},
},
{
name: "invalid URL",
input: []*configpb.Route{{From: "https://\x7f"}},
expected: []string{"route-0"},
},
{
name: "regex paths",
input: []*configpb.Route{
{From: testExample, Regex: `/a/(.*)/b`},
{From: testExample, Regex: `/a/(foo|bar)/b`},
{From: testExample, Regex: `/(authorize.*|login|logout)`},
{From: testExample, Regex: `/foo.+=-())(*+=,;:@~!'''-+_/.*`},
{From: testExample, Regex: `/*`},
{From: testExample, Regex: `/other/(.*)`},
{From: testExample, Regex: `/other/.*`},
{From: testExample, Regex: `/other/([^/]+)`},
{From: testExample, Regex: `/other/([^/]*)`},
{From: testExample, Regex: `/other/([^\/]+)`},
{From: testExample, Regex: `/other/([^\/]*)`},
{From: testExample, Regex: `/other/[^/]+`},
{From: testExample, Regex: `/other/[^/]*`},
{From: testExample, Regex: `/other/[^\/]+`},
{From: testExample, Regex: `/other/[^\/]*`},
{From: testExample, Regex: `/foo/bar/baz/.*`},
{From: testExample, Regex: `/.*`},
{From: testExample, Regex: `/.*`},
{From: testExample, Regex: `/(.*)`},
{From: testExample, Regex: `/.+`},
{From: testExample, Regex: `/(.+)`},
{From: testExample, Regex: `/([^/]+)`},
{From: testExample, Regex: `/([^/]*)`},
{From: testExample, Regex: `/([^\/]+)`},
{From: testExample, Regex: `/([^\/]*)`},
{From: testExample, Regex: `/[^/]+`},
{From: testExample, Regex: `/[^/]*`},
{From: testExample, Regex: `/[^\/]+`},
{From: testExample, Regex: `/[^\/]*`},
{From: testExample, Regex: `.+`},
{From: testExample, Regex: `(.+)`},
{From: testExample, Regex: `([^/]+)`},
{From: testExample, Regex: `([^/]*)`},
{From: testExample, Regex: `([^\/]+)`},
{From: testExample, Regex: `([^\/]*)`},
{From: testExample, Regex: `[^/]+`},
{From: testExample, Regex: `[^/]*`},
{From: testExample, Regex: `[^\/]+`},
{From: testExample, Regex: `[^\/]*`},
{From: testExample, Regex: `\w+`},
{From: testExample, Regex: `\w*`},
{From: testExample, Regex: `/\w+`},
{From: testExample, Regex: `/\w*`},
{From: testExample, Regex: `/(\w+)`},
{From: testExample, Regex: `/(\w*)`},
{From: testExample, Regex: `foo/.*`},
{From: testExample, Regex: `/foo/.*`},
{From: testExample, Regex: `/foo/\w+`},
{From: testExample, Regex: `/foo/\w*`},
},
expected: slices.Collect(func(yield func(string) bool) {
yield("test-re-1")
yield("test-re-2")
yield("test-re-3")
yield("test-re-4")
yield("test-re-5")
yield("test-re-other-prefix")
for i := 2; i <= 10; i++ {
yield(fmt.Sprintf("test-re-other-prefix (%d)", i))
}
yield("test-re-foo-bar-baz-prefix")
yield("test-re-any")
for i := 2; i <= 29; i++ {
yield(fmt.Sprintf("test-re-any (%d)", i))
}
yield("test-re-foo-prefix")
yield("test-re-foo-prefix (2)")
yield("test-re-foo-prefix (3)")
yield("test-re-foo-prefix (4)")
}),
},
{
name: "duplicate routes",
input: []*configpb.Route{
{From: "https://route1.localhost.pomerium.io:8443"},
{From: "https://route1.localhost.pomerium.io:8443"},
{From: "https://route2.localhost.pomerium.io:8443"},
{From: "https://route3.localhost.pomerium.io:8443"},
{From: "https://route4.localhost.pomerium.io:8443"},
},
expected: []string{
"route1",
"route1 (2)",
"route2",
"route3",
"route4",
},
},
}
for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, importutil.GenerateRouteNames(tc.input))
})
}
}