pomerium/config/options_test.go
Kenneth Jenkins e1eca4e97c
config: fix jwt_issuer_format conversion (#5524)
Remove the previous conversion logic in NewPolicyFromProto() for the 
jwt_issuer_format field. This would prevent the new "unset" state from
working correctly. Add a unit test to verify that all three values
(unset, "hostOnly" and "uri") will successfully round trip to the proto
format and back again.

Also add a test case for the Options.ApplySettings() method to verify 
that an unset jwt_issuer_format will not overwrite the existing value
(if any) in the settings.
2025-03-12 16:13:16 -07:00

1764 lines
56 KiB
Go

package config
import (
"context"
"crypto"
"crypto/ed25519"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/base64"
"encoding/pem"
"fmt"
"hash/fnv"
"math/big"
mathrand "math/rand/v2"
"net/http"
"net/url"
"os"
"path/filepath"
"reflect"
"slices"
"strings"
"sync"
"testing"
"time"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/spf13/viper"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/reflect/protoreflect"
"google.golang.org/protobuf/types/known/fieldmaskpb"
"github.com/pomerium/csrf"
"github.com/pomerium/pomerium/internal/testutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
"github.com/pomerium/pomerium/pkg/identity/oauth/apple"
"github.com/pomerium/protoutil/protorand"
)
var cmpOptIgnoreUnexported = cmpopts.IgnoreUnexported(Options{}, Policy{})
func Test_Validate(t *testing.T) {
t.Parallel()
testOptions := func() *Options {
o := NewDefaultOptions()
o.SharedKey = "test"
o.Services = "all"
o.CertFile = "./testdata/example-cert.pem"
o.KeyFile = "./testdata/example-key.pem"
return o
}
good := testOptions()
badServices := testOptions()
badServices.Services = "blue"
badSecret := testOptions()
badSecret.SharedKey = ""
badSecret.Services = "authenticate"
badSecretAllServices := testOptions()
badSecretAllServices.SharedKey = ""
badPolicyFile := testOptions()
badPolicyFile.PolicyFile = "file"
invalidStorageType := testOptions()
invalidStorageType.DataBrokerStorageType = "foo"
missingStorageDSN := testOptions()
missingStorageDSN.DataBrokerStorageType = "postgres"
badSignoutRedirectURL := testOptions()
badSignoutRedirectURL.SignOutRedirectURLString = "--"
badCookieSettings := testOptions()
badCookieSettings.CookieSameSite = "none"
tests := []struct {
name string
testOpts *Options
wantErr bool
}{
{"good default with no env settings", good, false},
{"invalid service type", badServices, true},
{"missing shared secret", badSecret, true},
{"missing shared secret but all service", badSecretAllServices, false},
{"policy file specified", badPolicyFile, true},
{"invalid databroker storage type", invalidStorageType, true},
{"missing databroker storage dsn", missingStorageDSN, true},
{"invalid signout redirect url", badSignoutRedirectURL, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.testOpts.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_bindEnvs(t *testing.T) {
o := new(Options)
o.viper = viper.New()
v := viper.New()
os.Clearenv()
defer os.Unsetenv("POMERIUM_DEBUG")
defer os.Unsetenv("POLICY")
defer os.Unsetenv("HEADERS")
t.Setenv("POMERIUM_DEBUG", "true")
t.Setenv("POLICY", "LSBmcm9tOiBodHRwczovL2h0dHBiaW4ubG9jYWxob3N0LnBvbWVyaXVtLmlvCiAgdG86IAogICAgLSBodHRwOi8vbG9jYWxob3N0OjgwODEsMQo=")
t.Setenv("HEADERS", `{"X-Custom-1":"foo", "X-Custom-2":"bar"}`)
err := bindEnvs(v)
if err != nil {
t.Fatalf("failed to bind options to env vars: %s", err)
}
err = v.Unmarshal(o, ViperPolicyHooks)
if err != nil {
t.Errorf("Could not unmarshal %#v: %s", o, err)
}
o.viper = v
if !o.Debug {
t.Errorf("Failed to load POMERIUM_DEBUG from environment")
}
if len(o.Policies) != 1 {
t.Error("failed to bind POLICY env")
}
if o.Services != "" {
t.Errorf("Somehow got SERVICES from environment without configuring it")
}
if o.HeadersEnv != `{"X-Custom-1":"foo", "X-Custom-2":"bar"}` {
t.Errorf("Failed to bind headers env var to HeadersEnv")
}
}
type Foo struct {
FieldOne Bar `mapstructure:"field_one"`
FieldTwo string `mapstructure:"field_two"`
}
type Bar struct {
Baz int `mapstructure:"baz"`
Quux string `mapstructure:"quux"`
}
func Test_bindEnvsRecursive(t *testing.T) {
v := viper.New()
_, err := bindEnvsRecursive(reflect.TypeOf(Foo{}), v, "", "")
require.NoError(t, err)
t.Setenv("FIELD_ONE_BAZ", "123")
t.Setenv("FIELD_ONE_QUUX", "hello")
t.Setenv("FIELD_TWO", "world")
var foo Foo
v.Unmarshal(&foo)
assert.Equal(t, Foo{
FieldOne: Bar{
Baz: 123,
Quux: "hello",
},
FieldTwo: "world",
}, foo)
}
func Test_bindEnvsRecursive_Override(t *testing.T) {
v := viper.New()
v.SetConfigType("yaml")
v.ReadConfig(strings.NewReader(`
field_one:
baz: 10
quux: abc
field_two: hello
`))
// Baseline: values populated from config file.
var foo1 Foo
v.Unmarshal(&foo1)
assert.Equal(t, Foo{
FieldOne: Bar{
Baz: 10,
Quux: "abc",
},
FieldTwo: "hello",
}, foo1)
_, err := bindEnvsRecursive(reflect.TypeOf(Foo{}), v, "", "")
require.NoError(t, err)
// Environment variables should selectively override config file keys.
t.Setenv("FIELD_ONE_QUUX", "def")
var foo2 Foo
v.Unmarshal(&foo2)
assert.Equal(t, Foo{
FieldOne: Bar{
Baz: 10,
Quux: "def",
},
FieldTwo: "hello",
}, foo2)
t.Setenv("FIELD_TWO", "world")
var foo3 Foo
v.Unmarshal(&foo3)
assert.Equal(t, Foo{
FieldOne: Bar{
Baz: 10,
Quux: "def",
},
FieldTwo: "world",
}, foo3)
}
func Test_parseHeaders(t *testing.T) {
// t.Parallel()
tests := []struct {
name string
want map[string]string
envHeaders string
viperHeaders any
wantErr bool
}{
{
"good env",
map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"},
`{"X-Custom-1":"foo", "X-Custom-2":"bar"}`,
map[string]string{"X": "foo"},
false,
},
{
"good env not_json",
map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"},
`X-Custom-1:foo,X-Custom-2:bar`,
map[string]string{"X": "foo"},
false,
},
{
"bad env",
map[string]string{},
"xyyyy",
map[string]string{"X": "foo"},
true,
},
{
"bad env not_json",
map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"},
`X-Custom-1:foo,X-Custom-2bar`,
map[string]string{"X": "foo"},
true,
},
{
"bad viper",
map[string]string{},
"",
"notaheaderstruct",
true,
},
{
"good viper",
map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"},
"",
map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"},
false,
},
{
"new field name",
map[string]string{"X-Custom-1": "foo"},
"",
map[string]string{"X-Custom-1": "foo"},
false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var (
o *Options
mu sync.Mutex
)
mu.Lock()
defer mu.Unlock()
o = NewDefaultOptions()
o.viperSet("set_response_headers", tt.viperHeaders)
o.viperSet("HeadersEnv", tt.envHeaders)
o.HeadersEnv = tt.envHeaders
err := o.parseHeaders(context.Background())
if (err != nil) != tt.wantErr {
t.Errorf("Error condition unexpected: err=%s", err)
}
if !tt.wantErr && !cmp.Equal(tt.want, o.SetResponseHeaders) {
t.Errorf("Did get expected headers: %s", cmp.Diff(tt.want, o.SetResponseHeaders))
}
})
}
}
func Test_parsePolicyFile(t *testing.T) {
t.Parallel()
opts := []cmp.Option{
cmpopts.IgnoreFields(Policy{}, "EnvoyOpts"),
cmpOptIgnoreUnexported,
}
source := "https://pomerium.io"
to, err := ParseWeightedURL("https://httpbin.org")
require.NoError(t, err)
tests := []struct {
name string
policyBytes []byte
want []Policy
wantErr bool
}{
{
"simple json",
[]byte(fmt.Sprintf(`{"policy":[{"from": "%s","to":"%s"}]}`, source, to.URL.String())),
[]Policy{{
From: source,
To: []WeightedURL{*to},
}},
false,
},
{"bad from", []byte(`{"policy":[{"from": "%","to":"httpbin.org"}]}`), nil, true},
{"bad to", []byte(`{"policy":[{"from": "pomerium.io","to":"%"}]}`), nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempFile, _ := os.CreateTemp("", "*.json")
defer tempFile.Close()
defer os.Remove(tempFile.Name())
tempFile.Write(tt.policyBytes)
var o Options
o.viper = viper.New()
o.viper.SetConfigFile(tempFile.Name())
if err := o.viper.ReadInConfig(); err != nil {
t.Fatal(err)
}
err := o.parsePolicy()
if (err != nil) != tt.wantErr {
t.Errorf("parsePolicyEnv() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err == nil {
if diff := cmp.Diff(o.Policies, tt.want, opts...); diff != "" {
t.Errorf("parsePolicyEnv() = diff:%s", diff)
}
}
})
}
}
func Test_decodeSANMatcher(t *testing.T) {
// Verify that config file parsing will decode the SANMatcher type.
const yaml = `
downstream_mtls:
match_subject_alt_names:
- dns: 'example-1\..*'
- dns: '.*\.example-2'
`
cfg := filepath.Join(t.TempDir(), "config.yaml")
err := os.WriteFile(cfg, []byte(yaml), 0o644)
require.NoError(t, err)
o, err := optionsFromViper(cfg)
require.NoError(t, err)
assert.Equal(t, []SANMatcher{
{Type: SANTypeDNS, Pattern: `example-1\..*`},
{Type: SANTypeDNS, Pattern: `.*\.example-2`},
}, o.DownstreamMTLS.MatchSubjectAltNames)
}
func Test_Checksum(t *testing.T) {
o := NewDefaultOptions()
oldChecksum := o.Checksum()
o.SharedKey = "changemeplease"
newChecksum := o.Checksum()
if newChecksum == oldChecksum {
t.Errorf("Checksum() failed to update old = %d, new = %d", oldChecksum, newChecksum)
}
if newChecksum == 0 || oldChecksum == 0 {
t.Error("Checksum() not returning data")
}
if o.Checksum() != newChecksum {
t.Error("Checksum() inconsistent")
}
}
func TestOptionsFromViper(t *testing.T) {
opts := []cmp.Option{
cmpopts.IgnoreFields(Options{}, "CookieSecret", "GRPCInsecure", "GRPCAddr", "DataBrokerURLString", "DataBrokerURLStrings", "AuthorizeURLString", "AuthorizeURLStrings", "DefaultUpstreamTimeout", "CookieExpire", "Services", "Addr", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "IdleTimeout", "GRPCClientTimeout", "ProgrammaticRedirectDomainWhitelist", "RuntimeFlags"),
cmpopts.IgnoreFields(Policy{}, "EnvoyOpts"),
cmpOptIgnoreUnexported,
}
tests := []struct {
name string
configBytes []byte
want *Options
wantErr bool
}{
{
"good",
[]byte(`{"autocert_dir":"","insecure_server":true,"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
&Options{
Policies: []Policy{{From: "https://from.example", To: mustParseWeightedURLs(t, "https://to.example")}},
CookieName: "_pomerium",
InsecureServer: true,
CookieHTTPOnly: true,
AuthenticateCallbackPath: "/oauth2/callback",
DataBrokerStorageType: "memory",
EnvoyAdminAccessLogPath: os.DevNull,
EnvoyAdminProfilePath: os.DevNull,
},
false,
},
{
"good disable header",
[]byte(`{"autocert_dir":"","insecure_server":true,"set_response_headers": {"disable":"true"},"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
&Options{
Policies: []Policy{{From: "https://from.example", To: mustParseWeightedURLs(t, "https://to.example")}},
CookieName: "_pomerium",
AuthenticateCallbackPath: "/oauth2/callback",
CookieHTTPOnly: true,
InsecureServer: true,
SetResponseHeaders: map[string]string{"disable": "true"},
DataBrokerStorageType: "memory",
EnvoyAdminAccessLogPath: os.DevNull,
EnvoyAdminProfilePath: os.DevNull,
},
false,
},
{
"good disable header",
[]byte(`{"autocert_dir":"","insecure_server":true,"set_response_headers": {"disable":"true"},"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
&Options{
Policies: []Policy{{From: "https://from.example", To: mustParseWeightedURLs(t, "https://to.example")}},
CookieName: "_pomerium",
AuthenticateCallbackPath: "/oauth2/callback",
CookieHTTPOnly: true,
InsecureServer: true,
SetResponseHeaders: map[string]string{"disable": "true"},
DataBrokerStorageType: "memory",
EnvoyAdminAccessLogPath: os.DevNull,
EnvoyAdminProfilePath: os.DevNull,
},
false,
},
{"bad url", []byte(`{"policy":[{"from": "https://","to":"https://to.example"}]}`), nil, true},
{"bad policy", []byte(`{"policy":[{"allow_public_unauthenticated_access": "dog","to":"https://to.example"}]}`), nil, true},
{"bad file", []byte(`{''''}`), nil, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tempFile, _ := os.CreateTemp("", "*.json")
defer tempFile.Close()
defer os.Remove(tempFile.Name())
tempFile.Write(tt.configBytes)
got, err := optionsFromViper(tempFile.Name())
if (err != nil) != tt.wantErr {
t.Errorf("optionsFromViper() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(got, tt.want, opts...); diff != "" {
t.Errorf("newOptionsFromConfig() = %s", diff)
}
})
}
}
func Test_NewOptionsFromConfigEnvVar(t *testing.T) {
tests := []struct {
name string
envKeyPairs map[string]string
wantErr bool
}{
{"good", map[string]string{"INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
{"bad no shared secret", map[string]string{"INSECURE_SERVER": "true", "SERVICES": "authenticate"}, true},
{"good no shared secret in all mode", map[string]string{"INSECURE_SERVER": "true"}, false},
{"bad header", map[string]string{"HEADERS": "x;y;z", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
{"bad authenticate url", map[string]string{"AUTHENTICATE_SERVICE_URL": "authenticate.example", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
{"bad authorize url", map[string]string{"AUTHORIZE_SERVICE_URL": "authorize.example", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
{"bad cert base64", map[string]string{"CERTIFICATE": "bad cert", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
{"bad cert key base64", map[string]string{"CERTIFICATE_KEY": "bad cert", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
{"no certs no insecure mode set", map[string]string{"SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
{"good disable headers ", map[string]string{"HEADERS": "disable:true", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
{"bad whitespace in secret", map[string]string{"INSECURE_SERVER": "true", "SERVICES": "authenticate", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=\n"}, true},
{"same addr and grpc addr", map[string]string{"SERVICES": "databroker", "ADDRESS": "0", "GRPC_ADDRESS": "0", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
{"bad cert files", map[string]string{"INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", "CERTIFICATES": "./test-data/example-cert.pem"}, true},
{"good cert file", map[string]string{"CERTIFICATE_FILE": "./testdata/example-cert.pem", "CERTIFICATE_KEY_FILE": "./testdata/example-key.pem", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
{"bad cert file", map[string]string{"CERTIFICATE_FILE": "./testdata/example-cert-bad.pem", "CERTIFICATE_KEY_FILE": "./testdata/example-key-bad.pem", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
{"good client ca file", map[string]string{"DOWNSTREAM_MTLS_CA_FILE": "./testdata/ca.pem", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
{"bad client ca file", map[string]string{"DOWNSTREAM_MTLS_CA_FILE": "./testdata/bad-ca.pem", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
{"bad client ca b64", map[string]string{"DOWNSTREAM_MTLS_CA": "bad cert", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
for k, v := range tt.envKeyPairs {
os.Setenv(k, v)
defer os.Unsetenv(k)
}
_, err := newOptionsFromConfig("")
if (err != nil) != tt.wantErr {
t.Errorf("newOptionsFromConfig() error = %v, wantErr %v", err, tt.wantErr)
return
}
})
}
}
func Test_AutoCertOptionsFromEnvVar(t *testing.T) {
type test struct {
envs map[string]string
expected AutocertOptions
wantErr bool
cleanup func()
}
tests := map[string]func(t *testing.T) test{
"ok/simple": func(_ *testing.T) test {
envs := map[string]string{
"AUTOCERT": "true",
"AUTOCERT_DIR": "/test",
"AUTOCERT_MUST_STAPLE": "true",
"INSECURE_SERVER": "true",
}
return test{
envs: envs,
expected: AutocertOptions{
Enable: true,
Folder: "/test",
MustStaple: true,
},
wantErr: false,
}
},
"ok/custom-ca": func(t *testing.T) test {
certPEM, err := newCACertPEM()
require.NoError(t, err)
envs := map[string]string{
"AUTOCERT": "true",
"AUTOCERT_CA": "test-ca.example.com/directory",
"AUTOCERT_EMAIL": "test@example.com",
"AUTOCERT_EAB_KEY_ID": "keyID",
"AUTOCERT_EAB_MAC_KEY": "fake-key",
"AUTOCERT_TRUSTED_CA": base64.StdEncoding.EncodeToString(certPEM),
"AUTOCERT_DIR": "/test",
"AUTOCERT_MUST_STAPLE": "true",
"INSECURE_SERVER": "true",
}
return test{
envs: envs,
wantErr: false,
expected: AutocertOptions{
Enable: true,
CA: "test-ca.example.com/directory",
Email: "test@example.com",
EABKeyID: "keyID",
EABMACKey: "fake-key",
TrustedCA: base64.StdEncoding.EncodeToString(certPEM),
Folder: "/test",
MustStaple: true,
},
}
},
"ok/custom-ca-file": func(t *testing.T) test {
certPEM, err := newCACertPEM()
require.NoError(t, err)
f, err := os.CreateTemp("", "pomerium-test-ca")
require.NoError(t, err)
n, err := f.Write(certPEM)
require.NoError(t, err)
require.Equal(t, len(certPEM), n)
envs := map[string]string{
"AUTOCERT": "true",
"AUTOCERT_CA": "test-ca.example.com/directory",
"AUTOCERT_EMAIL": "test@example.com",
"AUTOCERT_EAB_KEY_ID": "keyID",
"AUTOCERT_EAB_MAC_KEY": "fake-key",
"AUTOCERT_TRUSTED_CA_FILE": f.Name(),
"AUTOCERT_DIR": "/test",
"AUTOCERT_MUST_STAPLE": "true",
"INSECURE_SERVER": "true",
}
return test{
envs: envs,
wantErr: false,
expected: AutocertOptions{
Enable: true,
CA: "test-ca.example.com/directory",
Email: "test@example.com",
EABKeyID: "keyID",
EABMACKey: "fake-key",
TrustedCAFile: f.Name(),
Folder: "/test",
MustStaple: true,
},
cleanup: func() { os.Remove(f.Name()) },
}
},
}
for name, run := range tests {
tc := run(t)
t.Run(name, func(t *testing.T) {
for k, v := range tc.envs {
os.Setenv(k, v)
defer os.Unsetenv(k)
}
o, err := newOptionsFromConfig("")
if err != nil {
t.Fatal(err)
}
if !cmp.Equal(tc.expected, o.AutocertOptions) {
t.Errorf("AutoCertOptionsFromEnvVar() diff = %s", cmp.Diff(tc.expected, o.AutocertOptions))
}
if tc.cleanup != nil {
tc.cleanup()
}
})
}
}
func TestHTTPRedirectAddressStripQuotes(t *testing.T) {
o := NewDefaultOptions()
o.InsecureServer = true
o.HTTPRedirectAddr = `":80"`
assert.NoError(t, o.Validate())
assert.Equal(t, ":80", o.HTTPRedirectAddr)
}
func TestCertificatesArrayParsing(t *testing.T) {
t.Parallel()
testCertFileRef := "./testdata/example-cert.pem"
testKeyFileRef := "./testdata/example-key.pem"
tests := []struct {
name string
certificateFiles []certificateFilePair
wantErr bool
}{
{"Handles file reference as params", []certificateFilePair{{KeyFile: testKeyFileRef, CertFile: testCertFileRef}}, false},
{"Returns an error otherwise", []certificateFilePair{{KeyFile: "abc", CertFile: "abc"}}, true},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
o := NewDefaultOptions()
o.CertificateFiles = tt.certificateFiles
err := o.Validate()
if err != nil && tt.wantErr == false {
t.Fatal(err)
}
})
}
}
func TestCompareByteSliceSlice(t *testing.T) {
type Bytes = [][]byte
tests := []struct {
expect int
a Bytes
b Bytes
}{
{
0,
Bytes{
{0, 1, 2, 3},
},
Bytes{
{0, 1, 2, 3},
},
},
{
-1,
Bytes{
{0, 1, 2, 3},
},
Bytes{
{0, 1, 2, 4},
},
},
{
1,
Bytes{
{0, 1, 2, 4},
},
Bytes{
{0, 1, 2, 3},
},
},
{
-1,
Bytes{
{0, 1, 2, 3},
},
Bytes{
{0, 1, 2, 3},
{4, 5, 6, 7},
},
},
{
1,
Bytes{
{0, 1, 2, 3},
{4, 5, 6, 7},
},
Bytes{
{0, 1, 2, 3},
},
},
}
for _, tt := range tests {
actual := compareByteSliceSlice(tt.a, tt.b)
if tt.expect != actual {
t.Errorf("expected compare(%v, %v) to be %v but got %v",
tt.a, tt.b, tt.expect, actual)
}
}
}
func TestHasAnyDownstreamMTLSClientCA(t *testing.T) {
t.Parallel()
cases := []struct {
label string
opts *Options
expected bool
}{
{"zero", &Options{}, false},
{"default", NewDefaultOptions(), false},
{"no client CAs", &Options{
Policies: []Policy{
{From: "https://example.com/one"},
{From: "https://example.com/two"},
{From: "https://example.com/three"},
},
}, false},
{"global client CA only", &Options{
DownstreamMTLS: DownstreamMTLSSettings{CA: "ZmFrZSBDQQ=="},
Policies: []Policy{
{From: "https://example.com/one"},
{From: "https://example.com/two"},
{From: "https://example.com/three"},
},
}, true},
{"per-route CA only", &Options{
Policies: []Policy{
{From: "https://example.com/one"},
{
From: "https://example.com/two",
TLSDownstreamClientCA: "ZmFrZSBDQQ==",
},
{From: "https://example.com/three"},
},
}, true},
{"both global and per-route client CAs", &Options{
DownstreamMTLS: DownstreamMTLSSettings{CA: "ZmFrZSBDQQ=="},
Policies: []Policy{
{From: "https://example.com/one"},
{
From: "https://example.com/two",
TLSDownstreamClientCA: "ZmFrZSBDQQ==",
},
{From: "https://example.com/three"},
},
}, true},
}
for i := range cases {
c := &cases[i]
t.Run(c.label, func(t *testing.T) {
actual := c.opts.HasAnyDownstreamMTLSClientCA()
assert.Equal(t, c.expected, actual)
})
}
}
func TestOptions_DefaultURL(t *testing.T) {
t.Parallel()
firstURL := func(f func() ([]*url.URL, error)) func() (*url.URL, error) {
return func() (*url.URL, error) {
urls, err := f()
if err != nil {
return nil, err
} else if len(urls) == 0 {
return nil, fmt.Errorf("no url defined")
}
return urls[0], nil
}
}
defaultOptions := &Options{}
opts := &Options{
AuthenticateURLString: "https://authenticate.example.com",
AuthorizeURLString: "https://authorize.example.com",
DataBrokerURLString: "https://databroker.example.com",
}
tests := []struct {
name string
f func() (*url.URL, error)
expectedURLStr string
}{
{"default authenticate url", defaultOptions.GetAuthenticateURL, "https://authenticate.pomerium.app"},
{"good authenticate url", opts.GetAuthenticateURL, "https://authenticate.example.com"},
{"good authorize url", firstURL(opts.GetAuthorizeURLs), "https://authorize.example.com"},
{"good databroker url", firstURL(opts.GetDataBrokerURLs), "https://databroker.example.com"},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
u, err := tc.f()
require.NoError(t, err)
assert.Equal(t, tc.expectedURLStr, u.String())
})
}
}
func TestOptions_UseStatelessAuthenticateFlow(t *testing.T) {
t.Run("enabled by default", func(t *testing.T) {
options := &Options{}
assert.True(t, options.UseStatelessAuthenticateFlow())
})
t.Run("enabled explicitly", func(t *testing.T) {
options := &Options{AuthenticateURLString: "https://authenticate.pomerium.app"}
assert.True(t, options.UseStatelessAuthenticateFlow())
})
t.Run("disabled", func(t *testing.T) {
options := &Options{AuthenticateURLString: "https://authenticate.example.com"}
assert.False(t, options.UseStatelessAuthenticateFlow())
})
t.Run("force enabled", func(t *testing.T) {
options := &Options{AuthenticateURLString: "https://authenticate.example.com"}
t.Setenv("DEBUG_FORCE_AUTHENTICATE_FLOW", "stateless")
assert.True(t, options.UseStatelessAuthenticateFlow())
})
t.Run("force disabled", func(t *testing.T) {
options := &Options{}
t.Setenv("DEBUG_FORCE_AUTHENTICATE_FLOW", "stateful")
assert.False(t, options.UseStatelessAuthenticateFlow())
})
}
func TestOptions_GetOauthOptions(t *testing.T) {
opts := &Options{AuthenticateURLString: "https://authenticate.example.com"}
oauthOptions, err := opts.GetOauthOptions()
require.NoError(t, err)
// Test that oauth redirect url hostname must point to authenticate url hostname.
u, err := opts.GetAuthenticateURL()
require.NoError(t, err)
assert.Equal(t, u.Hostname(), oauthOptions.RedirectURL.Hostname())
}
func TestOptions_GetAllRouteableGRPCHosts(t *testing.T) {
opts := &Options{
AuthenticateURLString: "https://authenticate.example.com",
AuthorizeURLString: "https://authorize.example.com",
DataBrokerURLString: "https://databroker.example.com",
Services: "all",
}
hosts, err := opts.GetAllRouteableGRPCHosts()
assert.NoError(t, err)
assert.Equal(t, []string{
"authorize.example.com",
"authorize.example.com:443",
"databroker.example.com",
"databroker.example.com:443",
}, hosts)
}
func TestOptions_GetAllRouteableHTTPHosts(t *testing.T) {
p1 := Policy{From: "https://from1.example.com"}
p1.Validate()
p2 := Policy{From: "https://from2.example.com"}
p2.Validate()
p3 := Policy{From: "https://from3.example.com", TLSDownstreamServerName: "from.example.com"}
p3.Validate()
opts := &Options{
AuthenticateURLString: "https://authenticate.example.com",
AuthorizeURLString: "https://authorize.example.com",
DataBrokerURLString: "https://databroker.example.com",
Policies: []Policy{p1, p2, p3},
Services: "all",
}
hosts, err := opts.GetAllRouteableHTTPHosts()
assert.NoError(t, err)
assert.Equal(t, []string{
"authenticate.example.com",
"authenticate.example.com:443",
"from.example.com",
"from.example.com:443",
"from1.example.com",
"from1.example.com:443",
"from2.example.com",
"from2.example.com:443",
"from3.example.com",
"from3.example.com:443",
}, hosts)
}
func TestOptions_ApplySettings(t *testing.T) {
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second)
defer clearTimeout()
t.Run("certificates", func(t *testing.T) {
options := NewDefaultOptions()
cert1, err := cryptutil.GenerateCertificate(nil, "example.com")
require.NoError(t, err)
cert1path := filepath.Join(t.TempDir(), "example.com.pem")
err = os.WriteFile(cert1path, cert1.Certificate[0], 0o600)
require.NoError(t, err)
options.CertificateFiles = append(options.CertificateFiles, certificateFilePair{
CertFile: cert1path,
})
cert2, err := cryptutil.GenerateCertificate(nil, "example.com")
require.NoError(t, err)
cert3, err := cryptutil.GenerateCertificate(nil, "not.example.com")
require.NoError(t, err)
certsIndex := cryptutil.NewCertificatesIndex()
xc1, _ := x509.ParseCertificate(cert1.Certificate[0])
certsIndex.Add(xc1)
settings := &configpb.Settings{
Certificates: []*configpb.Settings_Certificate{
{CertBytes: encodeCert(cert2)},
{CertBytes: encodeCert(cert3)},
},
}
options.ApplySettings(ctx, certsIndex, settings)
assert.Len(t, options.CertificateData, 1, "should prevent adding duplicate certificates")
})
t.Run("pass_identity_headers", func(t *testing.T) {
options := NewDefaultOptions()
options.ApplySettings(ctx, nil, &configpb.Settings{
PassIdentityHeaders: proto.Bool(true),
})
assert.Equal(t, proto.Bool(true), options.PassIdentityHeaders)
})
t.Run("branding", func(t *testing.T) {
options := NewDefaultOptions()
options.ApplySettings(ctx, nil, &configpb.Settings{
PrimaryColor: proto.String("#FFFFFF"),
})
options.ApplySettings(ctx, nil, &configpb.Settings{})
assert.Equal(t, "#FFFFFF", options.BrandingOptions.GetPrimaryColor())
options.ApplySettings(ctx, nil, &configpb.Settings{
PrimaryColor: proto.String("#333333"),
})
assert.Equal(t, "#333333", options.BrandingOptions.GetPrimaryColor())
})
t.Run("jwt_groups_filter", func(t *testing.T) {
options := NewDefaultOptions()
options.ApplySettings(ctx, nil, &configpb.Settings{
JwtGroupsFilter: []string{"foo", "bar", "baz"},
})
options.ApplySettings(ctx, nil, &configpb.Settings{})
assert.Equal(t, NewJWTGroupsFilter([]string{"foo", "bar", "baz"}), options.JWTGroupsFilter)
options.ApplySettings(ctx, nil, &configpb.Settings{
JwtGroupsFilter: []string{"quux", "zulu"},
})
assert.Equal(t, NewJWTGroupsFilter([]string{"quux", "zulu"}), options.JWTGroupsFilter)
})
t.Run("jwt_issuer_format", func(t *testing.T) {
options := NewDefaultOptions()
assert.Equal(t, JWTIssuerFormatUnset, options.JWTIssuerFormat)
options.ApplySettings(ctx, nil, &configpb.Settings{
JwtIssuerFormat: configpb.IssuerFormat_IssuerURI.Enum(),
})
options.ApplySettings(ctx, nil, &configpb.Settings{})
assert.Equal(t, JWTIssuerFormatURI, options.JWTIssuerFormat)
options.ApplySettings(ctx, nil, &configpb.Settings{
JwtIssuerFormat: configpb.IssuerFormat_IssuerHostOnly.Enum(),
})
assert.Equal(t, JWTIssuerFormatHostOnly, options.JWTIssuerFormat)
})
}
func TestOptions_GetSetResponseHeaders(t *testing.T) {
t.Run("lax", func(t *testing.T) {
options := NewDefaultOptions()
assert.Equal(t, map[string]string{
"X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block",
}, options.GetSetResponseHeaders())
})
t.Run("strict", func(t *testing.T) {
options := NewDefaultOptions()
options.Cert = "CERT"
assert.Equal(t, map[string]string{
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
"X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block",
}, options.GetSetResponseHeaders())
})
t.Run("autocert-staging", func(t *testing.T) {
options := NewDefaultOptions()
options.Cert = "CERT"
options.AutocertOptions.UseStaging = true
assert.Equal(t, map[string]string{
"X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block",
}, options.GetSetResponseHeaders())
})
t.Run("disable", func(t *testing.T) {
options := NewDefaultOptions()
options.SetResponseHeaders = map[string]string{DisableHeaderKey: "1", "x-other": "xyz"}
assert.Equal(t, map[string]string{}, options.GetSetResponseHeaders())
})
t.Run("empty", func(t *testing.T) {
options := NewDefaultOptions()
options.SetResponseHeaders = map[string]string{}
assert.Equal(t, map[string]string{}, options.GetSetResponseHeaders())
})
t.Run("no partial defaults", func(t *testing.T) {
options := NewDefaultOptions()
options.Cert = "CERT"
options.SetResponseHeaders = map[string]string{"X-Frame-Options": "DENY"}
assert.Equal(t, map[string]string{"X-Frame-Options": "DENY"},
options.GetSetResponseHeaders())
})
}
func TestOptions_GetSetResponseHeadersForPolicy(t *testing.T) {
t.Run("disable but set in policy", func(t *testing.T) {
options := NewDefaultOptions()
options.SetResponseHeaders = map[string]string{DisableHeaderKey: "1"}
policy := &Policy{
SetResponseHeaders: map[string]string{"x": "y"},
}
assert.Equal(t, map[string]string{"x": "y"}, options.GetSetResponseHeadersForPolicy(policy))
})
t.Run("global defaults plus policy", func(t *testing.T) {
options := NewDefaultOptions()
options.Cert = "CERT"
policy := &Policy{
SetResponseHeaders: map[string]string{"Route": "xyz"},
}
assert.Equal(t, map[string]string{
"Route": "xyz",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
"X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block",
}, options.GetSetResponseHeadersForPolicy(policy))
})
t.Run("global defaults partial override", func(t *testing.T) {
options := NewDefaultOptions()
options.Cert = "CERT"
policy := &Policy{
SetResponseHeaders: map[string]string{"X-Frame-Options": "DENY"},
}
assert.Equal(t, map[string]string{
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
"X-Frame-Options": "DENY",
"X-XSS-Protection": "1; mode=block",
}, options.GetSetResponseHeadersForPolicy(policy))
})
t.Run("multiple policies", func(t *testing.T) {
options := NewDefaultOptions()
options.SetResponseHeaders = map[string]string{"global": "foo"}
p1 := &Policy{
SetResponseHeaders: map[string]string{"route-1": "bar"},
}
p2 := &Policy{
SetResponseHeaders: map[string]string{"route-2": "baz"},
}
assert.Equal(t, map[string]string{
"global": "foo",
"route-1": "bar",
}, options.GetSetResponseHeadersForPolicy(p1))
assert.Equal(t, map[string]string{
"global": "foo",
"route-2": "baz",
}, options.GetSetResponseHeadersForPolicy(p2))
assert.Equal(t, map[string]string{"global": "foo"}, options.GetSetResponseHeaders())
})
}
func TestOptions_GetSharedKey(t *testing.T) {
t.Run("default", func(t *testing.T) {
o := NewDefaultOptions()
bs, err := o.GetSharedKey()
assert.NoError(t, err)
assert.Equal(t, randomSharedKey, base64.StdEncoding.EncodeToString(bs))
})
t.Run("missing", func(t *testing.T) {
o := NewDefaultOptions()
o.Services = ServiceProxy
_, err := o.GetSharedKey()
assert.Error(t, err)
})
}
func TestOptions_GetSigningKey(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
name string
input string
output []byte
err error
}{
{"missing", "", []byte{}, nil},
{"pem", `
-----BEGIN EC PRIVATE KEY-----
MHQCAQEEIGGh6FlBe8yy9dRJgm+35lj3naGFtDODOf6leCW1bRGwoAcGBSuBBAAK
oUQDQgAE7UlKcFatc9m3GinCrhhT2oRQZ/bEwS98iEUXr0DR8GdxH3e4fhnicsNB
jHOCur7NYTgf5VaPJwIqLGBmTwM0ew==
-----END EC PRIVATE KEY-----
-----BEGIN EC PRIVATE KEY-----
MHQCAQEEIBo4wSjkFqQrzf2APNnPol8EDZzkhpcMSaEWXg8iOkbOoAcGBSuBBAAK
oUQDQgAEr+bGqssRv8RxPV2jJbDpMw81AVXr5+Q2pIF4u6xD9r56lst8uHYThPsw
ypaqswFIkSzQSW8awdWJ5d+1DEJRUQ==
-----END EC PRIVATE KEY-----
`, []byte{
0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x4d,
0x48, 0x51, 0x43, 0x41, 0x51, 0x45, 0x45, 0x49, 0x47, 0x47, 0x68, 0x36, 0x46, 0x6c, 0x42, 0x65,
0x38, 0x79, 0x79, 0x39, 0x64, 0x52, 0x4a, 0x67, 0x6d, 0x2b, 0x33, 0x35, 0x6c, 0x6a, 0x33, 0x6e,
0x61, 0x47, 0x46, 0x74, 0x44, 0x4f, 0x44, 0x4f, 0x66, 0x36, 0x6c, 0x65, 0x43, 0x57, 0x31, 0x62,
0x52, 0x47, 0x77, 0x6f, 0x41, 0x63, 0x47, 0x42, 0x53, 0x75, 0x42, 0x42, 0x41, 0x41, 0x4b, 0x0a,
0x6f, 0x55, 0x51, 0x44, 0x51, 0x67, 0x41, 0x45, 0x37, 0x55, 0x6c, 0x4b, 0x63, 0x46, 0x61, 0x74,
0x63, 0x39, 0x6d, 0x33, 0x47, 0x69, 0x6e, 0x43, 0x72, 0x68, 0x68, 0x54, 0x32, 0x6f, 0x52, 0x51,
0x5a, 0x2f, 0x62, 0x45, 0x77, 0x53, 0x39, 0x38, 0x69, 0x45, 0x55, 0x58, 0x72, 0x30, 0x44, 0x52,
0x38, 0x47, 0x64, 0x78, 0x48, 0x33, 0x65, 0x34, 0x66, 0x68, 0x6e, 0x69, 0x63, 0x73, 0x4e, 0x42,
0x0a, 0x6a, 0x48, 0x4f, 0x43, 0x75, 0x72, 0x37, 0x4e, 0x59, 0x54, 0x67, 0x66, 0x35, 0x56, 0x61,
0x50, 0x4a, 0x77, 0x49, 0x71, 0x4c, 0x47, 0x42, 0x6d, 0x54, 0x77, 0x4d, 0x30, 0x65, 0x77, 0x3d,
0x3d, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a,
0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x4d,
0x48, 0x51, 0x43, 0x41, 0x51, 0x45, 0x45, 0x49, 0x42, 0x6f, 0x34, 0x77, 0x53, 0x6a, 0x6b, 0x46,
0x71, 0x51, 0x72, 0x7a, 0x66, 0x32, 0x41, 0x50, 0x4e, 0x6e, 0x50, 0x6f, 0x6c, 0x38, 0x45, 0x44,
0x5a, 0x7a, 0x6b, 0x68, 0x70, 0x63, 0x4d, 0x53, 0x61, 0x45, 0x57, 0x58, 0x67, 0x38, 0x69, 0x4f,
0x6b, 0x62, 0x4f, 0x6f, 0x41, 0x63, 0x47, 0x42, 0x53, 0x75, 0x42, 0x42, 0x41, 0x41, 0x4b, 0x0a,
0x6f, 0x55, 0x51, 0x44, 0x51, 0x67, 0x41, 0x45, 0x72, 0x2b, 0x62, 0x47, 0x71, 0x73, 0x73, 0x52,
0x76, 0x38, 0x52, 0x78, 0x50, 0x56, 0x32, 0x6a, 0x4a, 0x62, 0x44, 0x70, 0x4d, 0x77, 0x38, 0x31,
0x41, 0x56, 0x58, 0x72, 0x35, 0x2b, 0x51, 0x32, 0x70, 0x49, 0x46, 0x34, 0x75, 0x36, 0x78, 0x44,
0x39, 0x72, 0x35, 0x36, 0x6c, 0x73, 0x74, 0x38, 0x75, 0x48, 0x59, 0x54, 0x68, 0x50, 0x73, 0x77,
0x0a, 0x79, 0x70, 0x61, 0x71, 0x73, 0x77, 0x46, 0x49, 0x6b, 0x53, 0x7a, 0x51, 0x53, 0x57, 0x38,
0x61, 0x77, 0x64, 0x57, 0x4a, 0x35, 0x64, 0x2b, 0x31, 0x44, 0x45, 0x4a, 0x52, 0x55, 0x51, 0x3d,
0x3d, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d,
}, nil},
{"base64", `
LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IUUNBUUVFSUdHaDZGbEJlOHl5OWRSSmdtKzM1bGozbmFHRnRET0RPZjZsZUNXMWJSR3dvQWNHQlN1QkJBQUsKb1VRRFFnQUU3VWxLY0ZhdGM5bTNHaW5DcmhoVDJvUlFaL2JFd1M5OGlFVVhyMERSOEdkeEgzZTRmaG5pY3NOQgpqSE9DdXI3TllUZ2Y1VmFQSndJcUxHQm1Ud00wZXc9PQotLS0tLUVORCBFQyBQUklWQVRFIEtFWS0tLS0tCgotLS0tLUJFR0lOIEVDIFBSSVZBVEUgS0VZLS0tLS0KTUhRQ0FRRUVJQm80d1Nqa0ZxUXJ6ZjJBUE5uUG9sOEVEWnpraHBjTVNhRVdYZzhpT2tiT29BY0dCU3VCQkFBSwpvVVFEUWdBRXIrYkdxc3NSdjhSeFBWMmpKYkRwTXc4MUFWWHI1K1EycElGNHU2eEQ5cjU2bHN0OHVIWVRoUHN3CnlwYXFzd0ZJa1N6UVNXOGF3ZFdKNWQrMURFSlJVUT09Ci0tLS0tRU5EIEVDIFBSSVZBVEUgS0VZLS0tLS0=
`, []byte{
0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x4d,
0x48, 0x51, 0x43, 0x41, 0x51, 0x45, 0x45, 0x49, 0x47, 0x47, 0x68, 0x36, 0x46, 0x6c, 0x42, 0x65,
0x38, 0x79, 0x79, 0x39, 0x64, 0x52, 0x4a, 0x67, 0x6d, 0x2b, 0x33, 0x35, 0x6c, 0x6a, 0x33, 0x6e,
0x61, 0x47, 0x46, 0x74, 0x44, 0x4f, 0x44, 0x4f, 0x66, 0x36, 0x6c, 0x65, 0x43, 0x57, 0x31, 0x62,
0x52, 0x47, 0x77, 0x6f, 0x41, 0x63, 0x47, 0x42, 0x53, 0x75, 0x42, 0x42, 0x41, 0x41, 0x4b, 0x0a,
0x6f, 0x55, 0x51, 0x44, 0x51, 0x67, 0x41, 0x45, 0x37, 0x55, 0x6c, 0x4b, 0x63, 0x46, 0x61, 0x74,
0x63, 0x39, 0x6d, 0x33, 0x47, 0x69, 0x6e, 0x43, 0x72, 0x68, 0x68, 0x54, 0x32, 0x6f, 0x52, 0x51,
0x5a, 0x2f, 0x62, 0x45, 0x77, 0x53, 0x39, 0x38, 0x69, 0x45, 0x55, 0x58, 0x72, 0x30, 0x44, 0x52,
0x38, 0x47, 0x64, 0x78, 0x48, 0x33, 0x65, 0x34, 0x66, 0x68, 0x6e, 0x69, 0x63, 0x73, 0x4e, 0x42,
0x0a, 0x6a, 0x48, 0x4f, 0x43, 0x75, 0x72, 0x37, 0x4e, 0x59, 0x54, 0x67, 0x66, 0x35, 0x56, 0x61,
0x50, 0x4a, 0x77, 0x49, 0x71, 0x4c, 0x47, 0x42, 0x6d, 0x54, 0x77, 0x4d, 0x30, 0x65, 0x77, 0x3d,
0x3d, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a,
0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x4d,
0x48, 0x51, 0x43, 0x41, 0x51, 0x45, 0x45, 0x49, 0x42, 0x6f, 0x34, 0x77, 0x53, 0x6a, 0x6b, 0x46,
0x71, 0x51, 0x72, 0x7a, 0x66, 0x32, 0x41, 0x50, 0x4e, 0x6e, 0x50, 0x6f, 0x6c, 0x38, 0x45, 0x44,
0x5a, 0x7a, 0x6b, 0x68, 0x70, 0x63, 0x4d, 0x53, 0x61, 0x45, 0x57, 0x58, 0x67, 0x38, 0x69, 0x4f,
0x6b, 0x62, 0x4f, 0x6f, 0x41, 0x63, 0x47, 0x42, 0x53, 0x75, 0x42, 0x42, 0x41, 0x41, 0x4b, 0x0a,
0x6f, 0x55, 0x51, 0x44, 0x51, 0x67, 0x41, 0x45, 0x72, 0x2b, 0x62, 0x47, 0x71, 0x73, 0x73, 0x52,
0x76, 0x38, 0x52, 0x78, 0x50, 0x56, 0x32, 0x6a, 0x4a, 0x62, 0x44, 0x70, 0x4d, 0x77, 0x38, 0x31,
0x41, 0x56, 0x58, 0x72, 0x35, 0x2b, 0x51, 0x32, 0x70, 0x49, 0x46, 0x34, 0x75, 0x36, 0x78, 0x44,
0x39, 0x72, 0x35, 0x36, 0x6c, 0x73, 0x74, 0x38, 0x75, 0x48, 0x59, 0x54, 0x68, 0x50, 0x73, 0x77,
0x0a, 0x79, 0x70, 0x61, 0x71, 0x73, 0x77, 0x46, 0x49, 0x6b, 0x53, 0x7a, 0x51, 0x53, 0x57, 0x38,
0x61, 0x77, 0x64, 0x57, 0x4a, 0x35, 0x64, 0x2b, 0x31, 0x44, 0x45, 0x4a, 0x52, 0x55, 0x51, 0x3d,
0x3d, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d,
}, nil},
} {
tc := tc
t.Run(tc.name, func(t *testing.T) {
t.Parallel()
output, err := (&Options{SigningKey: tc.input}).GetSigningKey()
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.output, output)
dir := t.TempDir()
err = os.WriteFile(filepath.Join(dir, "cert"), []byte(tc.input), 0o0666)
assert.NoError(t, err)
output, err = (&Options{SigningKeyFile: filepath.Join(dir, "cert")}).GetSigningKey()
assert.Equal(t, tc.err, err)
assert.Equal(t, tc.output, output)
})
}
}
func TestOptions_GetCookieSecret(t *testing.T) {
t.Run("default", func(t *testing.T) {
o := NewDefaultOptions()
bs, err := o.GetCookieSecret()
assert.NoError(t, err)
assert.Equal(t, randomSharedKey, base64.StdEncoding.EncodeToString(bs))
})
t.Run("missing", func(t *testing.T) {
o := NewDefaultOptions()
o.Services = ServiceProxy
_, err := o.GetCookieSecret()
assert.Error(t, err)
})
}
func TestOptions_GetCookieSameSite(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
input string
expected http.SameSite
}{
{"", http.SameSiteDefaultMode},
{"Lax", http.SameSiteLaxMode},
{"lax", http.SameSiteLaxMode},
{"Strict", http.SameSiteStrictMode},
{"strict", http.SameSiteStrictMode},
{"None", http.SameSiteNoneMode},
{"none", http.SameSiteNoneMode},
{"UnKnOwN", http.SameSiteDefaultMode},
} {
tc := tc
t.Run(tc.input, func(t *testing.T) {
t.Parallel()
o := NewDefaultOptions()
o.CookieSameSite = tc.input
assert.Equal(t, tc.expected, o.GetCookieSameSite())
})
}
}
func TestOptions_GetCSRFSameSite(t *testing.T) {
t.Parallel()
for _, tc := range []struct {
cookieSameSite string
provider string
expected csrf.SameSiteMode
}{
{"", "", csrf.SameSiteDefaultMode},
{"Lax", "", csrf.SameSiteLaxMode},
{"lax", "", csrf.SameSiteLaxMode},
{"Strict", "", csrf.SameSiteStrictMode},
{"strict", "", csrf.SameSiteStrictMode},
{"None", "", csrf.SameSiteNoneMode},
{"none", "", csrf.SameSiteNoneMode},
{"UnKnOwN", "", csrf.SameSiteDefaultMode},
{"", apple.Name, csrf.SameSiteNoneMode},
} {
tc := tc
t.Run(tc.cookieSameSite, func(t *testing.T) {
t.Parallel()
o := NewDefaultOptions()
o.CookieSameSite = tc.cookieSameSite
o.Provider = tc.provider
assert.Equal(t, tc.expected, o.GetCSRFSameSite())
})
}
}
func TestOptions_RequestParams(t *testing.T) {
cases := []struct {
label string
config string
expected map[string]string
}{
{"not present", "", nil},
{"explicitly empty", "idp_request_params: {}", map[string]string{}},
}
cfg := filepath.Join(t.TempDir(), "config.yaml")
for i := range cases {
c := &cases[i]
t.Run(c.label, func(t *testing.T) {
err := os.WriteFile(cfg, []byte(c.config), 0o644)
require.NoError(t, err)
o, err := newOptionsFromConfig(cfg)
require.NoError(t, err)
assert.Equal(t, c.expected, o.RequestParams)
})
}
}
func TestOptions_RequestParamsFromEnv(t *testing.T) {
t.Setenv("IDP_REQUEST_PARAMS", `{"x":"y"}`)
options, err := newOptionsFromConfig("")
if assert.NoError(t, err) {
assert.Equal(t, map[string]string{"x": "y"}, options.RequestParams)
}
}
func TestOptions_RuntimeFlags(t *testing.T) {
t.Parallel()
extra := DefaultRuntimeFlags()
extra["another"] = true
cases := []struct {
label string
config string
expected RuntimeFlags
}{
{"not present", "", DefaultRuntimeFlags()},
{"explicitly empty", `{"runtime_flags": {}}`, DefaultRuntimeFlags()},
{"all", `{"runtime_flags":{"another":true}}`, extra},
}
cfg := filepath.Join(t.TempDir(), "config.yaml")
for _, c := range cases {
t.Run(c.label, func(t *testing.T) {
err := os.WriteFile(cfg, []byte(c.config), 0o644)
require.NoError(t, err)
o, err := newOptionsFromConfig(cfg)
require.NoError(t, err)
assert.Equal(t, c.expected, o.RuntimeFlags)
})
}
}
func TestOptions_GetDataBrokerStorageConnectionString(t *testing.T) {
t.Parallel()
t.Run("validate", func(t *testing.T) {
t.Parallel()
o := NewDefaultOptions()
o.Services = "databroker"
o.DataBrokerStorageType = "postgres"
o.SharedKey = cryptutil.NewBase64Key()
assert.ErrorContains(t, o.Validate(), "missing databroker storage backend dsn",
"should validate DSN")
o.DataBrokerStorageConnectionString = "DSN"
assert.NoError(t, o.Validate(),
"should have no error when the dsn is set")
o.DataBrokerStorageConnectionString = ""
o.DataBrokerStorageConnectionStringFile = "DSN_FILE"
assert.NoError(t, o.Validate(),
"should have no error when the dsn file is set")
})
t.Run("literal", func(t *testing.T) {
t.Parallel()
o := NewDefaultOptions()
o.DataBrokerStorageConnectionString = "DSN"
dsn, err := o.GetDataBrokerStorageConnectionString()
assert.NoError(t, err)
assert.Equal(t, "DSN", dsn)
})
t.Run("file", func(t *testing.T) {
t.Parallel()
dir := t.TempDir()
fp := filepath.Join(dir, "DSN_FILE")
o := NewDefaultOptions()
o.DataBrokerStorageConnectionStringFile = fp
o.DataBrokerStorageConnectionString = "IGNORED"
dsn, err := o.GetDataBrokerStorageConnectionString()
assert.Error(t, err,
"should return an error when the file doesn't exist")
assert.Empty(t, dsn)
os.WriteFile(fp, []byte(`
DSN
`), 0o644)
dsn, err = o.GetDataBrokerStorageConnectionString()
assert.NoError(t, err,
"should not return an error when the file exists")
assert.Equal(t, "DSN", dsn,
"should return the trimmed contents of the file")
})
}
func encodeCert(cert *tls.Certificate) []byte {
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Certificate[0]})
}
func TestRoute_FromToProto(t *testing.T) {
routeGen := protorand.New[*configpb.Route]()
routeGen.MaxCollectionElements = 2
routeGen.UseGoDurationLimits = true
routeGen.ExcludeMask(&fieldmaskpb.FieldMask{
Paths: []string{
"from", "to", "load_balancing_weights", "redirect", "response", // set below
"ppl_policies", "name", // no equivalent field
"envoy_opts",
},
})
redirectGen := protorand.New[*configpb.RouteRedirect]()
responseGen := protorand.New[*configpb.RouteDirectResponse]()
randomDomain := func() string {
numSegments := mathrand.IntN(5) + 1
segments := make([]string, numSegments)
for i := range segments {
b := make([]rune, mathrand.IntN(10)+10)
for j := range b {
b[j] = rune(mathrand.IntN(26) + 'a')
}
segments[i] = string(b)
}
return strings.Join(segments, ".")
}
newCompleteRoute := func() *configpb.Route {
pb, err := routeGen.Gen()
require.NoError(t, err)
pb.From = "https://" + randomDomain()
// EnvoyOpts is set to an empty non-nil message during conversion, if nil
pb.EnvoyOpts = &envoy_config_cluster_v3.Cluster{}
// JWT groups filter order is not significant. Upon conversion back to
// a protobuf the JWT groups will be sorted.
slices.Sort(pb.JwtGroupsFilter)
switch mathrand.IntN(3) {
case 0:
pb.To = make([]string, mathrand.IntN(3)+1)
for i := range pb.To {
pb.To[i] = "https://" + randomDomain()
}
pb.LoadBalancingWeights = make([]uint32, len(pb.To))
for i := range pb.LoadBalancingWeights {
pb.LoadBalancingWeights[i] = mathrand.Uint32N(10000) + 1
}
case 1:
pb.Redirect, err = redirectGen.Gen()
require.NoError(t, err)
case 2:
pb.Response, err = responseGen.Gen()
require.NoError(t, err)
}
return pb
}
t.Run("Round Trip", func(t *testing.T) {
for range 100 {
route := newCompleteRoute()
policy, err := NewPolicyFromProto(route)
require.NoError(t, err)
route2, err := policy.ToProto()
require.NoError(t, err)
route2.Name = ""
testutil.AssertProtoEqual(t, route, route2)
}
})
t.Run("Multiple routes", func(t *testing.T) {
for range 100 {
route1 := newCompleteRoute()
route2 := newCompleteRoute()
{
// create a new policy every time, since reusing the target will mutate
// the underlying route
policy1, err := NewPolicyFromProto(route1)
require.NoError(t, err)
target, err := policy1.ToProto()
require.NoError(t, err)
target.Name = ""
testutil.AssertProtoEqual(t, route1, target)
}
{
policy2, err := NewPolicyFromProto(route2)
require.NoError(t, err)
target, err := policy2.ToProto()
require.NoError(t, err)
target.Name = ""
testutil.AssertProtoEqual(t, route2, target)
}
{
policy1, err := NewPolicyFromProto(route1)
require.NoError(t, err)
target, err := policy1.ToProto()
require.NoError(t, err)
target.Name = ""
testutil.AssertProtoEqual(t, route1, target)
}
{
policy2, err := NewPolicyFromProto(route2)
require.NoError(t, err)
target, err := policy2.ToProto()
require.NoError(t, err)
target.Name = ""
testutil.AssertProtoEqual(t, route2, target)
}
}
})
}
func TestOptions_FromToProto(t *testing.T) {
generate := func(ratio float64) *configpb.Settings {
t.Helper()
gen := protorand.New[*configpb.Settings]()
gen.MaxCollectionElements = 2
gen.MaxDepth = 3
gen.UseGoDurationLimits = true
gen.ExcludeMask(&fieldmaskpb.FieldMask{
Paths: []string{
"tls_custom_ca_file",
"tls_client_cert_file",
"tls_client_key_file",
"tls_downstream_client_ca_file",
},
})
settings, err := gen.GenPartial(ratio)
require.NoError(t, err)
unsetFalseOptionalBoolFields(settings)
fixZeroValuedEnums(settings)
generateCertificates(t, settings)
// JWT groups filter order is not significant. Upon conversion back to
// a protobuf the JWT groups will be sorted.
slices.Sort(settings.JwtGroupsFilter)
return settings
}
t.Run("all fields", func(t *testing.T) {
t.Parallel()
for range 100 {
settings := generate(1)
var options Options
options.ApplySettings(context.Background(), nil, settings)
settings2 := options.ToProto()
testutil.AssertProtoEqual(t, settings, settings2.Settings)
}
})
t.Run("some fields", func(t *testing.T) {
t.Parallel()
for range 100 {
settings := generate(mathrand.Float64())
var options Options
options.ApplySettings(context.Background(), nil, settings)
settings2 := options.ToProto()
testutil.AssertProtoEqual(t, settings, settings2.Settings)
}
})
}
// unset any optional bool fields with a value of false, to match
func unsetFalseOptionalBoolFields(msg proto.Message) {
msg.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
if fd.Cardinality() == protoreflect.Optional && fd.Kind() == protoreflect.BoolKind {
if v.IsValid() && !v.Bool() {
msg.ProtoReflect().Clear(fd)
}
}
return true
})
}
func fixZeroValuedEnums(msg *configpb.Settings) {
if msg.DownstreamMtls != nil && msg.DownstreamMtls.Enforcement != nil {
// there is no "unknown" equivalent, so if the value is randomly set to
// unknown it would be a lossy conversion
if *msg.DownstreamMtls.Enforcement == configpb.MtlsEnforcementMode_UNKNOWN {
msg.DownstreamMtls.Enforcement = nil
// if this was the only present field in the message, don't leave it empty
if proto.Size(msg.DownstreamMtls) == 0 {
msg.DownstreamMtls = nil
}
}
}
}
func generateCertificates(t testing.TB, msg *configpb.Settings) {
if msg.AutocertCa != nil {
*msg.AutocertCa, _ = generateRandomCA(t, *msg.AutocertCa)
}
if msg.DownstreamMtls != nil {
var caKey string
if msg.DownstreamMtls.Ca != nil {
*msg.DownstreamMtls.Ca, caKey = generateRandomCA(t, *msg.DownstreamMtls.Ca)
}
if msg.DownstreamMtls.Crl != nil {
if caKey != "" {
*msg.DownstreamMtls.Crl = generateCRL(t, *msg.DownstreamMtls.Crl, *msg.DownstreamMtls.Ca, caKey)
} else {
randCa, randKey := generateRandomCA(t, *msg.DownstreamMtls.Crl+"_temp_ca")
*msg.DownstreamMtls.Crl = generateCRL(t, *msg.DownstreamMtls.Crl, randCa, randKey)
}
}
}
genCertInPlace := func(cert *configpb.Settings_Certificate, b64 bool) {
cert.Id = "" // no equivalent field
switch {
case len(cert.CertBytes) > 0 && len(cert.KeyBytes) > 0:
crt, key := generateRandomCert(t, string(cert.CertBytes)+string(cert.KeyBytes), b64)
cert.CertBytes = []byte(crt)
cert.KeyBytes = []byte(key)
case len(cert.CertBytes) > 0 && len(cert.KeyBytes) == 0:
crt, _ := generateRandomCert(t, string(cert.CertBytes), b64)
cert.CertBytes = []byte(crt)
case len(cert.CertBytes) == 0 && len(cert.KeyBytes) > 0:
// invalid, but convert anyway
crt, _ := generateRandomCert(t, string(cert.KeyBytes), b64)
cert.KeyBytes = []byte(crt)
}
}
for i, cert := range msg.Certificates {
genCertInPlace(cert, false)
if cert.CertBytes == nil && cert.KeyBytes == nil {
msg.Certificates = slices.Delete(msg.Certificates, i, i+1)
}
}
if msg.MetricsCertificate != nil {
genCertInPlace(msg.MetricsCertificate, false)
if msg.MetricsCertificate.CertBytes == nil && msg.MetricsCertificate.KeyBytes == nil {
msg.MetricsCertificate = nil
}
}
}
func generateRandomCA(t testing.TB, randomInput string) (string, string) {
seed := sha256.Sum256([]byte(randomInput))
priv := ed25519.NewKeyFromSeed(seed[:])
h := fnv.New128()
h.Write([]byte(randomInput))
sum := h.Sum(nil)
var sn big.Int
sn.SetBytes(sum)
now := time.Now()
tmpl := &x509.Certificate{
IsCA: true,
SerialNumber: &sn,
Subject: pkix.Name{CommonName: randomInput},
Issuer: pkix.Name{CommonName: randomInput},
NotBefore: now,
NotAfter: now.Add(12 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCRLSign,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
x509.ExtKeyUsageServerAuth,
},
BasicConstraintsValid: true,
}
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: der,
})), base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: must(x509.MarshalPKCS8PrivateKey(priv)),
}))
}
func generateCRL(t testing.TB, randomInput string, issuerCrt, issuerKey string) string {
h := fnv.New128()
h.Write([]byte(randomInput))
sum := h.Sum(nil)
var sn big.Int
sn.SetBytes(sum)
issuer, err := cryptutil.CertificateFromBase64(issuerCrt, issuerKey)
require.NoError(t, err)
b, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{
Number: big.NewInt(0x2000),
RevokedCertificates: []pkix.RevokedCertificate{
{
SerialNumber: &sn,
RevocationTime: time.Now(),
},
},
}, issuer.Leaf, issuer.PrivateKey.(crypto.Signer))
require.NoError(t, err)
return base64.StdEncoding.EncodeToString(b)
}
func generateRandomCert(t testing.TB, randomInput string, b64 bool) (string, string) {
seed := sha256.Sum256([]byte(randomInput))
priv := ed25519.NewKeyFromSeed(seed[:])
h := fnv.New128()
h.Write([]byte(randomInput))
sum := h.Sum(nil)
var sn big.Int
sn.SetBytes(sum)
now := time.Now()
tmpl := &x509.Certificate{
SerialNumber: &sn,
Subject: pkix.Name{
CommonName: randomInput,
},
Issuer: pkix.Name{
CommonName: randomInput,
},
NotBefore: now,
NotAfter: now.Add(12 * time.Hour),
KeyUsage: x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageClientAuth,
},
BasicConstraintsValid: true,
}
certDer, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
require.NoError(t, err)
crtPem := pem.EncodeToMemory(&pem.Block{
Type: "CERTIFICATE",
Bytes: certDer,
})
keyPem := pem.EncodeToMemory(&pem.Block{
Type: "PRIVATE KEY",
Bytes: must(x509.MarshalPKCS8PrivateKey(priv)),
})
if b64 {
return base64.StdEncoding.EncodeToString(crtPem), base64.StdEncoding.EncodeToString(keyPem)
}
return string(crtPem), string(keyPem)
}
func must[T any](t T, err error) T {
if err != nil {
panic(err)
}
return t
}