mirror of
https://github.com/pomerium/pomerium.git
synced 2025-04-28 09:56:31 +02:00
1797 lines
57 KiB
Go
1797 lines
57 KiB
Go
package config
|
|
|
|
import (
|
|
"context"
|
|
"crypto"
|
|
"crypto/ed25519"
|
|
"crypto/rand"
|
|
"crypto/sha256"
|
|
"crypto/tls"
|
|
"crypto/x509"
|
|
"crypto/x509/pkix"
|
|
"encoding/base64"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"hash/fnv"
|
|
"math/big"
|
|
mathrand "math/rand/v2"
|
|
"net/http"
|
|
"net/url"
|
|
"os"
|
|
"path/filepath"
|
|
"reflect"
|
|
"slices"
|
|
"strings"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
|
|
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
|
|
"github.com/google/go-cmp/cmp"
|
|
"github.com/google/go-cmp/cmp/cmpopts"
|
|
"github.com/spf13/viper"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
"google.golang.org/protobuf/proto"
|
|
"google.golang.org/protobuf/reflect/protoreflect"
|
|
"google.golang.org/protobuf/types/known/fieldmaskpb"
|
|
|
|
"github.com/pomerium/csrf"
|
|
"github.com/pomerium/pomerium/internal/testutil"
|
|
"github.com/pomerium/pomerium/pkg/cryptutil"
|
|
configpb "github.com/pomerium/pomerium/pkg/grpc/config"
|
|
"github.com/pomerium/pomerium/pkg/identity/oauth/apple"
|
|
"github.com/pomerium/protoutil/protorand"
|
|
)
|
|
|
|
var cmpOptIgnoreUnexported = cmpopts.IgnoreUnexported(Options{}, Policy{})
|
|
|
|
func Test_Validate(t *testing.T) {
|
|
t.Parallel()
|
|
testOptions := func() *Options {
|
|
o := NewDefaultOptions()
|
|
|
|
o.SharedKey = "test"
|
|
o.Services = "all"
|
|
o.CertFile = "./testdata/example-cert.pem"
|
|
o.KeyFile = "./testdata/example-key.pem"
|
|
return o
|
|
}
|
|
good := testOptions()
|
|
badServices := testOptions()
|
|
badServices.Services = "blue"
|
|
badSecret := testOptions()
|
|
badSecret.SharedKey = ""
|
|
badSecret.Services = "authenticate"
|
|
badSecretAllServices := testOptions()
|
|
badSecretAllServices.SharedKey = ""
|
|
|
|
badPolicyFile := testOptions()
|
|
badPolicyFile.PolicyFile = "file"
|
|
invalidStorageType := testOptions()
|
|
invalidStorageType.DataBrokerStorageType = "foo"
|
|
missingStorageDSN := testOptions()
|
|
missingStorageDSN.DataBrokerStorageType = "postgres"
|
|
badSignoutRedirectURL := testOptions()
|
|
badSignoutRedirectURL.SignOutRedirectURLString = "--"
|
|
badCookieSettings := testOptions()
|
|
badCookieSettings.CookieSameSite = "none"
|
|
|
|
tests := []struct {
|
|
name string
|
|
testOpts *Options
|
|
wantErr bool
|
|
}{
|
|
{"good default with no env settings", good, false},
|
|
{"invalid service type", badServices, true},
|
|
{"missing shared secret", badSecret, true},
|
|
{"missing shared secret but all service", badSecretAllServices, false},
|
|
{"policy file specified", badPolicyFile, true},
|
|
{"invalid databroker storage type", invalidStorageType, true},
|
|
{"missing databroker storage dsn", missingStorageDSN, true},
|
|
{"invalid signout redirect url", badSignoutRedirectURL, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := tt.testOpts.Validate()
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_bindEnvs(t *testing.T) {
|
|
o := new(Options)
|
|
o.viper = viper.New()
|
|
v := viper.New()
|
|
os.Clearenv()
|
|
defer os.Unsetenv("POMERIUM_DEBUG")
|
|
defer os.Unsetenv("POLICY")
|
|
defer os.Unsetenv("HEADERS")
|
|
t.Setenv("POMERIUM_DEBUG", "true")
|
|
t.Setenv("POLICY", "LSBmcm9tOiBodHRwczovL2h0dHBiaW4ubG9jYWxob3N0LnBvbWVyaXVtLmlvCiAgdG86IAogICAgLSBodHRwOi8vbG9jYWxob3N0OjgwODEsMQo=")
|
|
t.Setenv("HEADERS", `{"X-Custom-1":"foo", "X-Custom-2":"bar"}`)
|
|
err := bindEnvs(v)
|
|
if err != nil {
|
|
t.Fatalf("failed to bind options to env vars: %s", err)
|
|
}
|
|
err = v.Unmarshal(o, ViperPolicyHooks)
|
|
if err != nil {
|
|
t.Errorf("Could not unmarshal %#v: %s", o, err)
|
|
}
|
|
o.viper = v
|
|
if !o.Debug {
|
|
t.Errorf("Failed to load POMERIUM_DEBUG from environment")
|
|
}
|
|
if len(o.Policies) != 1 {
|
|
t.Error("failed to bind POLICY env")
|
|
}
|
|
if o.Services != "" {
|
|
t.Errorf("Somehow got SERVICES from environment without configuring it")
|
|
}
|
|
if o.HeadersEnv != `{"X-Custom-1":"foo", "X-Custom-2":"bar"}` {
|
|
t.Errorf("Failed to bind headers env var to HeadersEnv")
|
|
}
|
|
}
|
|
|
|
type Foo struct {
|
|
FieldOne Bar `mapstructure:"field_one"`
|
|
FieldTwo string `mapstructure:"field_two"`
|
|
}
|
|
type Bar struct {
|
|
Baz int `mapstructure:"baz"`
|
|
Quux string `mapstructure:"quux"`
|
|
}
|
|
|
|
func Test_bindEnvsRecursive(t *testing.T) {
|
|
v := viper.New()
|
|
_, err := bindEnvsRecursive(reflect.TypeOf(Foo{}), v, "", "")
|
|
require.NoError(t, err)
|
|
|
|
t.Setenv("FIELD_ONE_BAZ", "123")
|
|
t.Setenv("FIELD_ONE_QUUX", "hello")
|
|
t.Setenv("FIELD_TWO", "world")
|
|
|
|
var foo Foo
|
|
v.Unmarshal(&foo)
|
|
assert.Equal(t, Foo{
|
|
FieldOne: Bar{
|
|
Baz: 123,
|
|
Quux: "hello",
|
|
},
|
|
FieldTwo: "world",
|
|
}, foo)
|
|
}
|
|
|
|
func Test_bindEnvsRecursive_Override(t *testing.T) {
|
|
v := viper.New()
|
|
v.SetConfigType("yaml")
|
|
v.ReadConfig(strings.NewReader(`
|
|
field_one:
|
|
baz: 10
|
|
quux: abc
|
|
field_two: hello
|
|
`))
|
|
|
|
// Baseline: values populated from config file.
|
|
var foo1 Foo
|
|
v.Unmarshal(&foo1)
|
|
assert.Equal(t, Foo{
|
|
FieldOne: Bar{
|
|
Baz: 10,
|
|
Quux: "abc",
|
|
},
|
|
FieldTwo: "hello",
|
|
}, foo1)
|
|
|
|
_, err := bindEnvsRecursive(reflect.TypeOf(Foo{}), v, "", "")
|
|
require.NoError(t, err)
|
|
|
|
// Environment variables should selectively override config file keys.
|
|
t.Setenv("FIELD_ONE_QUUX", "def")
|
|
var foo2 Foo
|
|
v.Unmarshal(&foo2)
|
|
assert.Equal(t, Foo{
|
|
FieldOne: Bar{
|
|
Baz: 10,
|
|
Quux: "def",
|
|
},
|
|
FieldTwo: "hello",
|
|
}, foo2)
|
|
|
|
t.Setenv("FIELD_TWO", "world")
|
|
var foo3 Foo
|
|
v.Unmarshal(&foo3)
|
|
assert.Equal(t, Foo{
|
|
FieldOne: Bar{
|
|
Baz: 10,
|
|
Quux: "def",
|
|
},
|
|
FieldTwo: "world",
|
|
}, foo3)
|
|
}
|
|
|
|
func Test_parseHeaders(t *testing.T) {
|
|
// t.Parallel()
|
|
tests := []struct {
|
|
name string
|
|
want map[string]string
|
|
envHeaders string
|
|
viperHeaders any
|
|
wantErr bool
|
|
}{
|
|
{
|
|
"good env",
|
|
map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"},
|
|
`{"X-Custom-1":"foo", "X-Custom-2":"bar"}`,
|
|
map[string]string{"X": "foo"},
|
|
false,
|
|
},
|
|
{
|
|
"good env not_json",
|
|
map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"},
|
|
`X-Custom-1:foo,X-Custom-2:bar`,
|
|
map[string]string{"X": "foo"},
|
|
false,
|
|
},
|
|
{
|
|
"bad env",
|
|
map[string]string{},
|
|
"xyyyy",
|
|
map[string]string{"X": "foo"},
|
|
true,
|
|
},
|
|
{
|
|
"bad env not_json",
|
|
map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"},
|
|
`X-Custom-1:foo,X-Custom-2bar`,
|
|
map[string]string{"X": "foo"},
|
|
true,
|
|
},
|
|
{
|
|
"bad viper",
|
|
map[string]string{},
|
|
"",
|
|
"notaheaderstruct",
|
|
true,
|
|
},
|
|
{
|
|
"good viper",
|
|
map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"},
|
|
"",
|
|
map[string]string{"X-Custom-1": "foo", "X-Custom-2": "bar"},
|
|
false,
|
|
},
|
|
{
|
|
"new field name",
|
|
map[string]string{"X-Custom-1": "foo"},
|
|
"",
|
|
map[string]string{"X-Custom-1": "foo"},
|
|
false,
|
|
},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
var (
|
|
o *Options
|
|
mu sync.Mutex
|
|
)
|
|
mu.Lock()
|
|
defer mu.Unlock()
|
|
o = NewDefaultOptions()
|
|
o.viperSet("set_response_headers", tt.viperHeaders)
|
|
o.viperSet("HeadersEnv", tt.envHeaders)
|
|
o.HeadersEnv = tt.envHeaders
|
|
err := o.parseHeaders(context.Background())
|
|
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("Error condition unexpected: err=%s", err)
|
|
}
|
|
|
|
if !tt.wantErr && !cmp.Equal(tt.want, o.SetResponseHeaders) {
|
|
t.Errorf("Did get expected headers: %s", cmp.Diff(tt.want, o.SetResponseHeaders))
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_parsePolicyFile(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
opts := []cmp.Option{
|
|
cmpopts.IgnoreFields(Policy{}, "EnvoyOpts"),
|
|
cmpOptIgnoreUnexported,
|
|
}
|
|
|
|
source := "https://pomerium.io"
|
|
|
|
to, err := ParseWeightedURL("https://httpbin.org")
|
|
require.NoError(t, err)
|
|
|
|
tests := []struct {
|
|
name string
|
|
policyBytes []byte
|
|
want []Policy
|
|
wantErr bool
|
|
}{
|
|
{
|
|
"simple json",
|
|
[]byte(fmt.Sprintf(`{"policy":[{"from": "%s","to":"%s"}]}`, source, to.URL.String())),
|
|
[]Policy{{
|
|
From: source,
|
|
To: []WeightedURL{*to},
|
|
}},
|
|
false,
|
|
},
|
|
{"bad from", []byte(`{"policy":[{"from": "%","to":"httpbin.org"}]}`), nil, true},
|
|
{"bad to", []byte(`{"policy":[{"from": "pomerium.io","to":"%"}]}`), nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
tempFile, _ := os.CreateTemp(t.TempDir(), "*.json")
|
|
defer tempFile.Close()
|
|
defer os.Remove(tempFile.Name())
|
|
tempFile.Write(tt.policyBytes)
|
|
var o Options
|
|
o.viper = viper.New()
|
|
o.viper.SetConfigFile(tempFile.Name())
|
|
if err := o.viper.ReadInConfig(); err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
err := o.parsePolicy()
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("parsePolicyEnv() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if err == nil {
|
|
if diff := cmp.Diff(o.Policies, tt.want, opts...); diff != "" {
|
|
t.Errorf("parsePolicyEnv() = diff:%s", diff)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_decodeSANMatcher(t *testing.T) {
|
|
// Verify that config file parsing will decode the SANMatcher type.
|
|
const yaml = `
|
|
downstream_mtls:
|
|
match_subject_alt_names:
|
|
- dns: 'example-1\..*'
|
|
- dns: '.*\.example-2'
|
|
`
|
|
cfg := filepath.Join(t.TempDir(), "config.yaml")
|
|
err := os.WriteFile(cfg, []byte(yaml), 0o644)
|
|
require.NoError(t, err)
|
|
|
|
o, err := optionsFromViper(cfg)
|
|
require.NoError(t, err)
|
|
|
|
assert.Equal(t, []SANMatcher{
|
|
{Type: SANTypeDNS, Pattern: `example-1\..*`},
|
|
{Type: SANTypeDNS, Pattern: `.*\.example-2`},
|
|
}, o.DownstreamMTLS.MatchSubjectAltNames)
|
|
}
|
|
|
|
func Test_Checksum(t *testing.T) {
|
|
o := NewDefaultOptions()
|
|
|
|
oldChecksum := o.Checksum()
|
|
o.SharedKey = "changemeplease"
|
|
newChecksum := o.Checksum()
|
|
|
|
if newChecksum == oldChecksum {
|
|
t.Errorf("Checksum() failed to update old = %d, new = %d", oldChecksum, newChecksum)
|
|
}
|
|
|
|
if newChecksum == 0 || oldChecksum == 0 {
|
|
t.Error("Checksum() not returning data")
|
|
}
|
|
|
|
if o.Checksum() != newChecksum {
|
|
t.Error("Checksum() inconsistent")
|
|
}
|
|
}
|
|
|
|
func TestOptionsFromViper(t *testing.T) {
|
|
opts := []cmp.Option{
|
|
cmpopts.IgnoreFields(Options{}, "CookieSecret", "GRPCInsecure", "GRPCAddr", "DataBrokerURLString", "DataBrokerURLStrings", "AuthorizeURLString", "AuthorizeURLStrings", "DefaultUpstreamTimeout", "CookieExpire", "Services", "Addr", "LogLevel", "KeyFile", "CertFile", "SharedKey", "ReadTimeout", "IdleTimeout", "GRPCClientTimeout", "ProgrammaticRedirectDomainWhitelist", "RuntimeFlags"),
|
|
cmpopts.IgnoreFields(Policy{}, "EnvoyOpts"),
|
|
cmpOptIgnoreUnexported,
|
|
}
|
|
|
|
tests := []struct {
|
|
name string
|
|
configBytes []byte
|
|
want *Options
|
|
wantErr bool
|
|
}{
|
|
{
|
|
"good",
|
|
[]byte(`{"autocert_dir":"","insecure_server":true,"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
|
|
&Options{
|
|
Policies: []Policy{{From: "https://from.example", To: mustParseWeightedURLs(t, "https://to.example")}},
|
|
CookieName: "_pomerium",
|
|
InsecureServer: true,
|
|
CookieHTTPOnly: true,
|
|
AuthenticateCallbackPath: "/oauth2/callback",
|
|
DataBrokerStorageType: "memory",
|
|
EnvoyAdminAccessLogPath: os.DevNull,
|
|
EnvoyAdminProfilePath: os.DevNull,
|
|
},
|
|
false,
|
|
},
|
|
{
|
|
"good disable header",
|
|
[]byte(`{"autocert_dir":"","insecure_server":true,"set_response_headers": {"disable":"true"},"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
|
|
&Options{
|
|
Policies: []Policy{{From: "https://from.example", To: mustParseWeightedURLs(t, "https://to.example")}},
|
|
CookieName: "_pomerium",
|
|
AuthenticateCallbackPath: "/oauth2/callback",
|
|
CookieHTTPOnly: true,
|
|
InsecureServer: true,
|
|
SetResponseHeaders: map[string]string{"disable": "true"},
|
|
DataBrokerStorageType: "memory",
|
|
EnvoyAdminAccessLogPath: os.DevNull,
|
|
EnvoyAdminProfilePath: os.DevNull,
|
|
},
|
|
false,
|
|
},
|
|
{
|
|
"good disable header",
|
|
[]byte(`{"autocert_dir":"","insecure_server":true,"set_response_headers": {"disable":"true"},"policy":[{"from": "https://from.example","to":"https://to.example"}]}`),
|
|
&Options{
|
|
Policies: []Policy{{From: "https://from.example", To: mustParseWeightedURLs(t, "https://to.example")}},
|
|
CookieName: "_pomerium",
|
|
AuthenticateCallbackPath: "/oauth2/callback",
|
|
CookieHTTPOnly: true,
|
|
InsecureServer: true,
|
|
SetResponseHeaders: map[string]string{"disable": "true"},
|
|
DataBrokerStorageType: "memory",
|
|
EnvoyAdminAccessLogPath: os.DevNull,
|
|
EnvoyAdminProfilePath: os.DevNull,
|
|
},
|
|
false,
|
|
},
|
|
{"bad url", []byte(`{"policy":[{"from": "https://","to":"https://to.example"}]}`), nil, true},
|
|
{"bad policy", []byte(`{"policy":[{"allow_public_unauthenticated_access": "dog","to":"https://to.example"}]}`), nil, true},
|
|
{"bad file", []byte(`{''''}`), nil, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
tempFile, _ := os.CreateTemp(t.TempDir(), "*.json")
|
|
defer tempFile.Close()
|
|
defer os.Remove(tempFile.Name())
|
|
tempFile.Write(tt.configBytes)
|
|
got, err := optionsFromViper(tempFile.Name())
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("optionsFromViper() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
if diff := cmp.Diff(got, tt.want, opts...); diff != "" {
|
|
t.Errorf("newOptionsFromConfig() = %s", diff)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_NewOptionsFromConfigEnvVar(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
envKeyPairs map[string]string
|
|
wantErr bool
|
|
}{
|
|
{"good", map[string]string{"INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
|
|
{"bad no shared secret", map[string]string{"INSECURE_SERVER": "true", "SERVICES": "authenticate"}, true},
|
|
{"good no shared secret in all mode", map[string]string{"INSECURE_SERVER": "true"}, false},
|
|
{"bad header", map[string]string{"HEADERS": "x;y;z", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
|
|
{"bad authenticate url", map[string]string{"AUTHENTICATE_SERVICE_URL": "authenticate.example", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
|
|
{"bad authorize url", map[string]string{"AUTHORIZE_SERVICE_URL": "authorize.example", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
|
|
{"bad cert base64", map[string]string{"CERTIFICATE": "bad cert", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
|
|
{"bad cert key base64", map[string]string{"CERTIFICATE_KEY": "bad cert", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
|
|
{"no certs no insecure mode set", map[string]string{"SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
|
|
{"good disable headers ", map[string]string{"HEADERS": "disable:true", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
|
|
{"bad whitespace in secret", map[string]string{"INSECURE_SERVER": "true", "SERVICES": "authenticate", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=\n"}, true},
|
|
{"same addr and grpc addr", map[string]string{"SERVICES": "databroker", "ADDRESS": "0", "GRPC_ADDRESS": "0", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
|
|
{"bad cert files", map[string]string{"INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM=", "CERTIFICATES": "./test-data/example-cert.pem"}, true},
|
|
{"good cert file", map[string]string{"CERTIFICATE_FILE": "./testdata/example-cert.pem", "CERTIFICATE_KEY_FILE": "./testdata/example-key.pem", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
|
|
{"bad cert file", map[string]string{"CERTIFICATE_FILE": "./testdata/example-cert-bad.pem", "CERTIFICATE_KEY_FILE": "./testdata/example-key-bad.pem", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
|
|
{"good client ca file", map[string]string{"DOWNSTREAM_MTLS_CA_FILE": "./testdata/ca.pem", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, false},
|
|
{"bad client ca file", map[string]string{"DOWNSTREAM_MTLS_CA_FILE": "./testdata/bad-ca.pem", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
|
|
{"bad client ca b64", map[string]string{"DOWNSTREAM_MTLS_CA": "bad cert", "INSECURE_SERVER": "true", "SHARED_SECRET": "YixWi1MYh77NMECGGIJQevoonYtVF+ZPRkQZrrmeRqM="}, true},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
for k, v := range tt.envKeyPairs {
|
|
t.Setenv(k, v)
|
|
}
|
|
_, err := newOptionsFromConfig("")
|
|
if (err != nil) != tt.wantErr {
|
|
t.Errorf("newOptionsFromConfig() error = %v, wantErr %v", err, tt.wantErr)
|
|
return
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func Test_AutoCertOptionsFromEnvVar(t *testing.T) {
|
|
type test struct {
|
|
envs map[string]string
|
|
expected AutocertOptions
|
|
wantErr bool
|
|
cleanup func()
|
|
}
|
|
|
|
tests := map[string]func(t *testing.T) test{
|
|
"ok/simple": func(_ *testing.T) test {
|
|
envs := map[string]string{
|
|
"AUTOCERT": "true",
|
|
"AUTOCERT_DIR": "/test",
|
|
"AUTOCERT_MUST_STAPLE": "true",
|
|
|
|
"INSECURE_SERVER": "true",
|
|
}
|
|
return test{
|
|
envs: envs,
|
|
expected: AutocertOptions{
|
|
Enable: true,
|
|
Folder: "/test",
|
|
MustStaple: true,
|
|
},
|
|
wantErr: false,
|
|
}
|
|
},
|
|
"ok/custom-ca": func(t *testing.T) test {
|
|
certPEM, err := newCACertPEM()
|
|
require.NoError(t, err)
|
|
envs := map[string]string{
|
|
"AUTOCERT": "true",
|
|
"AUTOCERT_CA": "test-ca.example.com/directory",
|
|
"AUTOCERT_EMAIL": "test@example.com",
|
|
"AUTOCERT_EAB_KEY_ID": "keyID",
|
|
"AUTOCERT_EAB_MAC_KEY": "fake-key",
|
|
"AUTOCERT_TRUSTED_CA": base64.StdEncoding.EncodeToString(certPEM),
|
|
"AUTOCERT_DIR": "/test",
|
|
"AUTOCERT_MUST_STAPLE": "true",
|
|
|
|
"INSECURE_SERVER": "true",
|
|
}
|
|
return test{
|
|
envs: envs,
|
|
wantErr: false,
|
|
expected: AutocertOptions{
|
|
Enable: true,
|
|
CA: "test-ca.example.com/directory",
|
|
Email: "test@example.com",
|
|
EABKeyID: "keyID",
|
|
EABMACKey: "fake-key",
|
|
TrustedCA: base64.StdEncoding.EncodeToString(certPEM),
|
|
Folder: "/test",
|
|
MustStaple: true,
|
|
},
|
|
}
|
|
},
|
|
"ok/custom-ca-file": func(t *testing.T) test {
|
|
certPEM, err := newCACertPEM()
|
|
require.NoError(t, err)
|
|
f, err := os.CreateTemp(t.TempDir(), "pomerium-test-ca")
|
|
require.NoError(t, err)
|
|
n, err := f.Write(certPEM)
|
|
require.NoError(t, err)
|
|
require.Equal(t, len(certPEM), n)
|
|
envs := map[string]string{
|
|
"AUTOCERT": "true",
|
|
"AUTOCERT_CA": "test-ca.example.com/directory",
|
|
"AUTOCERT_EMAIL": "test@example.com",
|
|
"AUTOCERT_EAB_KEY_ID": "keyID",
|
|
"AUTOCERT_EAB_MAC_KEY": "fake-key",
|
|
"AUTOCERT_TRUSTED_CA_FILE": f.Name(),
|
|
"AUTOCERT_DIR": "/test",
|
|
"AUTOCERT_MUST_STAPLE": "true",
|
|
|
|
"INSECURE_SERVER": "true",
|
|
}
|
|
return test{
|
|
envs: envs,
|
|
wantErr: false,
|
|
expected: AutocertOptions{
|
|
Enable: true,
|
|
CA: "test-ca.example.com/directory",
|
|
Email: "test@example.com",
|
|
EABKeyID: "keyID",
|
|
EABMACKey: "fake-key",
|
|
TrustedCAFile: f.Name(),
|
|
Folder: "/test",
|
|
MustStaple: true,
|
|
},
|
|
cleanup: func() { os.Remove(f.Name()) },
|
|
}
|
|
},
|
|
}
|
|
|
|
for name, run := range tests {
|
|
tc := run(t)
|
|
t.Run(name, func(t *testing.T) {
|
|
for k, v := range tc.envs {
|
|
t.Setenv(k, v)
|
|
}
|
|
o, err := newOptionsFromConfig("")
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
if !cmp.Equal(tc.expected, o.AutocertOptions) {
|
|
t.Errorf("AutoCertOptionsFromEnvVar() diff = %s", cmp.Diff(tc.expected, o.AutocertOptions))
|
|
}
|
|
if tc.cleanup != nil {
|
|
tc.cleanup()
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestHTTPRedirectAddressStripQuotes(t *testing.T) {
|
|
o := NewDefaultOptions()
|
|
o.InsecureServer = true
|
|
o.HTTPRedirectAddr = `":80"`
|
|
assert.NoError(t, o.Validate())
|
|
assert.Equal(t, ":80", o.HTTPRedirectAddr)
|
|
}
|
|
|
|
func TestCertificatesArrayParsing(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
testCertFileRef := "./testdata/example-cert.pem"
|
|
testKeyFileRef := "./testdata/example-key.pem"
|
|
|
|
tests := []struct {
|
|
name string
|
|
certificateFiles []certificateFilePair
|
|
wantErr bool
|
|
}{
|
|
{"Handles file reference as params", []certificateFilePair{{KeyFile: testKeyFileRef, CertFile: testCertFileRef}}, false},
|
|
{"Returns an error otherwise", []certificateFilePair{{KeyFile: "abc", CertFile: "abc"}}, true},
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
o := NewDefaultOptions()
|
|
o.CertificateFiles = tt.certificateFiles
|
|
err := o.Validate()
|
|
|
|
if err != nil && tt.wantErr == false {
|
|
t.Fatal(err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestCompareByteSliceSlice(t *testing.T) {
|
|
type Bytes = [][]byte
|
|
|
|
tests := []struct {
|
|
expect int
|
|
a Bytes
|
|
b Bytes
|
|
}{
|
|
{
|
|
0,
|
|
Bytes{
|
|
{0, 1, 2, 3},
|
|
},
|
|
Bytes{
|
|
{0, 1, 2, 3},
|
|
},
|
|
},
|
|
{
|
|
-1,
|
|
Bytes{
|
|
{0, 1, 2, 3},
|
|
},
|
|
Bytes{
|
|
{0, 1, 2, 4},
|
|
},
|
|
},
|
|
{
|
|
1,
|
|
Bytes{
|
|
{0, 1, 2, 4},
|
|
},
|
|
Bytes{
|
|
{0, 1, 2, 3},
|
|
},
|
|
},
|
|
{
|
|
-1,
|
|
Bytes{
|
|
{0, 1, 2, 3},
|
|
},
|
|
Bytes{
|
|
{0, 1, 2, 3},
|
|
{4, 5, 6, 7},
|
|
},
|
|
},
|
|
{
|
|
1,
|
|
Bytes{
|
|
{0, 1, 2, 3},
|
|
{4, 5, 6, 7},
|
|
},
|
|
Bytes{
|
|
{0, 1, 2, 3},
|
|
},
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
actual := compareByteSliceSlice(tt.a, tt.b)
|
|
if tt.expect != actual {
|
|
t.Errorf("expected compare(%v, %v) to be %v but got %v",
|
|
tt.a, tt.b, tt.expect, actual)
|
|
}
|
|
}
|
|
}
|
|
|
|
func TestHasAnyDownstreamMTLSClientCA(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
cases := []struct {
|
|
label string
|
|
opts *Options
|
|
expected bool
|
|
}{
|
|
{"zero", &Options{}, false},
|
|
{"default", NewDefaultOptions(), false},
|
|
{"no client CAs", &Options{
|
|
Policies: []Policy{
|
|
{From: "https://example.com/one"},
|
|
{From: "https://example.com/two"},
|
|
{From: "https://example.com/three"},
|
|
},
|
|
}, false},
|
|
{"global client CA only", &Options{
|
|
DownstreamMTLS: DownstreamMTLSSettings{CA: "ZmFrZSBDQQ=="},
|
|
Policies: []Policy{
|
|
{From: "https://example.com/one"},
|
|
{From: "https://example.com/two"},
|
|
{From: "https://example.com/three"},
|
|
},
|
|
}, true},
|
|
{"per-route CA only", &Options{
|
|
Policies: []Policy{
|
|
{From: "https://example.com/one"},
|
|
{
|
|
From: "https://example.com/two",
|
|
TLSDownstreamClientCA: "ZmFrZSBDQQ==",
|
|
},
|
|
{From: "https://example.com/three"},
|
|
},
|
|
}, true},
|
|
{"both global and per-route client CAs", &Options{
|
|
DownstreamMTLS: DownstreamMTLSSettings{CA: "ZmFrZSBDQQ=="},
|
|
Policies: []Policy{
|
|
{From: "https://example.com/one"},
|
|
{
|
|
From: "https://example.com/two",
|
|
TLSDownstreamClientCA: "ZmFrZSBDQQ==",
|
|
},
|
|
{From: "https://example.com/three"},
|
|
},
|
|
}, true},
|
|
}
|
|
for i := range cases {
|
|
c := &cases[i]
|
|
t.Run(c.label, func(t *testing.T) {
|
|
actual := c.opts.HasAnyDownstreamMTLSClientCA()
|
|
assert.Equal(t, c.expected, actual)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOptions_DefaultURL(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
firstURL := func(f func() ([]*url.URL, error)) func() (*url.URL, error) {
|
|
return func() (*url.URL, error) {
|
|
urls, err := f()
|
|
if err != nil {
|
|
return nil, err
|
|
} else if len(urls) == 0 {
|
|
return nil, fmt.Errorf("no url defined")
|
|
}
|
|
return urls[0], nil
|
|
}
|
|
}
|
|
|
|
defaultOptions := &Options{}
|
|
opts := &Options{
|
|
AuthenticateURLString: "https://authenticate.example.com",
|
|
AuthorizeURLString: "https://authorize.example.com",
|
|
DataBrokerURLString: "https://databroker.example.com",
|
|
}
|
|
tests := []struct {
|
|
name string
|
|
f func() (*url.URL, error)
|
|
expectedURLStr string
|
|
}{
|
|
{"default authenticate url", defaultOptions.GetAuthenticateURL, "https://authenticate.pomerium.app"},
|
|
{"good authenticate url", opts.GetAuthenticateURL, "https://authenticate.example.com"},
|
|
{"good authorize url", firstURL(opts.GetAuthorizeURLs), "https://authorize.example.com"},
|
|
{"good databroker url", firstURL(opts.GetDataBrokerURLs), "https://databroker.example.com"},
|
|
}
|
|
|
|
for _, tc := range tests {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
u, err := tc.f()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, tc.expectedURLStr, u.String())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOptions_UseStatelessAuthenticateFlow(t *testing.T) {
|
|
t.Run("enabled by default", func(t *testing.T) {
|
|
options := &Options{}
|
|
assert.True(t, options.UseStatelessAuthenticateFlow())
|
|
})
|
|
t.Run("enabled explicitly", func(t *testing.T) {
|
|
options := &Options{AuthenticateURLString: "https://authenticate.pomerium.app"}
|
|
assert.True(t, options.UseStatelessAuthenticateFlow())
|
|
})
|
|
t.Run("disabled", func(t *testing.T) {
|
|
options := &Options{AuthenticateURLString: "https://authenticate.example.com"}
|
|
assert.False(t, options.UseStatelessAuthenticateFlow())
|
|
})
|
|
t.Run("force enabled", func(t *testing.T) {
|
|
options := &Options{AuthenticateURLString: "https://authenticate.example.com"}
|
|
t.Setenv("DEBUG_FORCE_AUTHENTICATE_FLOW", "stateless")
|
|
assert.True(t, options.UseStatelessAuthenticateFlow())
|
|
})
|
|
t.Run("force disabled", func(t *testing.T) {
|
|
options := &Options{}
|
|
t.Setenv("DEBUG_FORCE_AUTHENTICATE_FLOW", "stateful")
|
|
assert.False(t, options.UseStatelessAuthenticateFlow())
|
|
})
|
|
}
|
|
|
|
func TestOptions_GetOauthOptions(t *testing.T) {
|
|
opts := &Options{AuthenticateURLString: "https://authenticate.example.com"}
|
|
oauthOptions, err := opts.GetOauthOptions()
|
|
require.NoError(t, err)
|
|
|
|
// Test that oauth redirect url hostname must point to authenticate url hostname.
|
|
u, err := opts.GetAuthenticateURL()
|
|
require.NoError(t, err)
|
|
assert.Equal(t, u.Hostname(), oauthOptions.RedirectURL.Hostname())
|
|
}
|
|
|
|
func TestOptions_GetAllRouteableGRPCHosts(t *testing.T) {
|
|
opts := &Options{
|
|
AuthenticateURLString: "https://authenticate.example.com",
|
|
AuthorizeURLString: "https://authorize.example.com",
|
|
DataBrokerURLString: "https://databroker.example.com",
|
|
Services: "all",
|
|
}
|
|
hosts, err := opts.GetAllRouteableGRPCHosts()
|
|
assert.NoError(t, err)
|
|
|
|
assert.Equal(t, []string{
|
|
"authorize.example.com",
|
|
"authorize.example.com:443",
|
|
"databroker.example.com",
|
|
"databroker.example.com:443",
|
|
}, hosts)
|
|
}
|
|
|
|
func TestOptions_GetAllRouteableHTTPHosts(t *testing.T) {
|
|
to := WeightedURLs{{URL: url.URL{Scheme: "https", Host: "to.example.com"}}}
|
|
p1 := Policy{From: "https://from1.example.com", To: to}
|
|
assert.NoError(t, p1.Validate())
|
|
p2 := Policy{From: "https://from2.example.com", To: to}
|
|
assert.NoError(t, p2.Validate())
|
|
p3 := Policy{From: "https://from3.example.com", TLSDownstreamServerName: "from.example.com", To: to}
|
|
assert.NoError(t, p3.Validate())
|
|
p4 := Policy{From: "https://from4.example.com", MCP: &MCP{}, To: to}
|
|
assert.NoError(t, p4.Validate())
|
|
|
|
opts := &Options{
|
|
AuthenticateURLString: "https://authenticate.example.com",
|
|
AuthorizeURLString: "https://authorize.example.com",
|
|
DataBrokerURLString: "https://databroker.example.com",
|
|
Policies: []Policy{p1, p2, p3, p4},
|
|
Services: "all",
|
|
}
|
|
hosts, mcpHosts, err := opts.GetAllRouteableHTTPHosts()
|
|
assert.NoError(t, err)
|
|
assert.Empty(t, cmp.Diff(mcpHosts, map[string]bool{"from4.example.com:443": true, "from4.example.com": true}))
|
|
|
|
assert.Equal(t, []string{
|
|
"authenticate.example.com",
|
|
"authenticate.example.com:443",
|
|
"from.example.com",
|
|
"from.example.com:443",
|
|
"from1.example.com",
|
|
"from1.example.com:443",
|
|
"from2.example.com",
|
|
"from2.example.com:443",
|
|
"from3.example.com",
|
|
"from3.example.com:443",
|
|
"from4.example.com",
|
|
"from4.example.com:443",
|
|
}, hosts)
|
|
}
|
|
|
|
func TestOptions_ApplySettings(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second)
|
|
defer clearTimeout()
|
|
|
|
t.Run("certificates", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
cert1, err := cryptutil.GenerateCertificate(nil, "example.com")
|
|
require.NoError(t, err)
|
|
cert1path := filepath.Join(t.TempDir(), "example.com.pem")
|
|
err = os.WriteFile(cert1path, cert1.Certificate[0], 0o600)
|
|
require.NoError(t, err)
|
|
options.CertificateFiles = append(options.CertificateFiles, certificateFilePair{
|
|
CertFile: cert1path,
|
|
})
|
|
cert2, err := cryptutil.GenerateCertificate(nil, "example.com")
|
|
require.NoError(t, err)
|
|
cert3, err := cryptutil.GenerateCertificate(nil, "not.example.com")
|
|
require.NoError(t, err)
|
|
|
|
certsIndex := cryptutil.NewCertificatesIndex()
|
|
xc1, _ := x509.ParseCertificate(cert1.Certificate[0])
|
|
certsIndex.Add(xc1)
|
|
|
|
settings := &configpb.Settings{
|
|
Certificates: []*configpb.Settings_Certificate{
|
|
{CertBytes: encodeCert(cert2)},
|
|
{CertBytes: encodeCert(cert3)},
|
|
},
|
|
}
|
|
options.ApplySettings(ctx, certsIndex, settings)
|
|
assert.Len(t, options.CertificateData, 1, "should prevent adding duplicate certificates")
|
|
})
|
|
|
|
t.Run("pass_identity_headers", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{
|
|
PassIdentityHeaders: proto.Bool(true),
|
|
})
|
|
assert.Equal(t, proto.Bool(true), options.PassIdentityHeaders)
|
|
})
|
|
|
|
t.Run("branding", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{
|
|
PrimaryColor: proto.String("#FFFFFF"),
|
|
})
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{})
|
|
assert.Equal(t, "#FFFFFF", options.BrandingOptions.GetPrimaryColor())
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{
|
|
PrimaryColor: proto.String("#333333"),
|
|
})
|
|
assert.Equal(t, "#333333", options.BrandingOptions.GetPrimaryColor())
|
|
})
|
|
|
|
t.Run("jwt_groups_filter", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{
|
|
JwtGroupsFilter: []string{"foo", "bar", "baz"},
|
|
})
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{})
|
|
assert.Equal(t, NewJWTGroupsFilter([]string{"foo", "bar", "baz"}), options.JWTGroupsFilter)
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{
|
|
JwtGroupsFilter: []string{"quux", "zulu"},
|
|
})
|
|
assert.Equal(t, NewJWTGroupsFilter([]string{"quux", "zulu"}), options.JWTGroupsFilter)
|
|
})
|
|
|
|
t.Run("jwt_issuer_format", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
assert.Equal(t, JWTIssuerFormatUnset, options.JWTIssuerFormat)
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{
|
|
JwtIssuerFormat: configpb.IssuerFormat_IssuerURI.Enum(),
|
|
})
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{})
|
|
assert.Equal(t, JWTIssuerFormatURI, options.JWTIssuerFormat)
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{
|
|
JwtIssuerFormat: configpb.IssuerFormat_IssuerHostOnly.Enum(),
|
|
})
|
|
assert.Equal(t, JWTIssuerFormatHostOnly, options.JWTIssuerFormat)
|
|
})
|
|
|
|
t.Run("bearer_token_format", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
options := NewDefaultOptions()
|
|
assert.Nil(t, options.BearerTokenFormat)
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{
|
|
BearerTokenFormat: configpb.BearerTokenFormat_BEARER_TOKEN_FORMAT_DEFAULT.Enum(),
|
|
})
|
|
assert.Equal(t, ptr(BearerTokenFormatDefault), options.BearerTokenFormat)
|
|
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{})
|
|
assert.Equal(t, ptr(BearerTokenFormatDefault), options.BearerTokenFormat, "should preserve existing bearer token format")
|
|
})
|
|
|
|
t.Run("idp_access_token_allowed_audiences", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
options := NewDefaultOptions()
|
|
assert.Nil(t, options.IDPAccessTokenAllowedAudiences)
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{
|
|
IdpAccessTokenAllowedAudiences: &configpb.Settings_StringList{Values: []string{"x", "y", "z"}},
|
|
})
|
|
assert.Equal(t, ptr([]string{"x", "y", "z"}), options.IDPAccessTokenAllowedAudiences)
|
|
options.ApplySettings(ctx, nil, &configpb.Settings{})
|
|
assert.Equal(t, ptr([]string{"x", "y", "z"}), options.IDPAccessTokenAllowedAudiences,
|
|
"should preserve idp access token allowed audiences")
|
|
})
|
|
}
|
|
|
|
func TestOptions_GetSetResponseHeaders(t *testing.T) {
|
|
t.Run("lax", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
assert.Equal(t, map[string]string{
|
|
"X-Frame-Options": "SAMEORIGIN",
|
|
"X-XSS-Protection": "1; mode=block",
|
|
}, options.GetSetResponseHeaders())
|
|
})
|
|
t.Run("strict", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
options.Cert = "CERT"
|
|
assert.Equal(t, map[string]string{
|
|
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
|
"X-Frame-Options": "SAMEORIGIN",
|
|
"X-XSS-Protection": "1; mode=block",
|
|
}, options.GetSetResponseHeaders())
|
|
})
|
|
t.Run("autocert-staging", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
options.Cert = "CERT"
|
|
options.AutocertOptions.UseStaging = true
|
|
assert.Equal(t, map[string]string{
|
|
"X-Frame-Options": "SAMEORIGIN",
|
|
"X-XSS-Protection": "1; mode=block",
|
|
}, options.GetSetResponseHeaders())
|
|
})
|
|
t.Run("disable", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
options.SetResponseHeaders = map[string]string{DisableHeaderKey: "1", "x-other": "xyz"}
|
|
assert.Equal(t, map[string]string{}, options.GetSetResponseHeaders())
|
|
})
|
|
t.Run("empty", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
options.SetResponseHeaders = map[string]string{}
|
|
assert.Equal(t, map[string]string{}, options.GetSetResponseHeaders())
|
|
})
|
|
t.Run("no partial defaults", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
options.Cert = "CERT"
|
|
options.SetResponseHeaders = map[string]string{"X-Frame-Options": "DENY"}
|
|
assert.Equal(t, map[string]string{"X-Frame-Options": "DENY"},
|
|
options.GetSetResponseHeaders())
|
|
})
|
|
}
|
|
|
|
func TestOptions_GetSetResponseHeadersForPolicy(t *testing.T) {
|
|
t.Run("disable but set in policy", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
options.SetResponseHeaders = map[string]string{DisableHeaderKey: "1"}
|
|
policy := &Policy{
|
|
SetResponseHeaders: map[string]string{"x": "y"},
|
|
}
|
|
assert.Equal(t, map[string]string{"x": "y"}, options.GetSetResponseHeadersForPolicy(policy))
|
|
})
|
|
t.Run("global defaults plus policy", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
options.Cert = "CERT"
|
|
policy := &Policy{
|
|
SetResponseHeaders: map[string]string{"Route": "xyz"},
|
|
}
|
|
assert.Equal(t, map[string]string{
|
|
"Route": "xyz",
|
|
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
|
"X-Frame-Options": "SAMEORIGIN",
|
|
"X-XSS-Protection": "1; mode=block",
|
|
}, options.GetSetResponseHeadersForPolicy(policy))
|
|
})
|
|
t.Run("global defaults partial override", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
options.Cert = "CERT"
|
|
policy := &Policy{
|
|
SetResponseHeaders: map[string]string{"X-Frame-Options": "DENY"},
|
|
}
|
|
assert.Equal(t, map[string]string{
|
|
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
|
"X-Frame-Options": "DENY",
|
|
"X-XSS-Protection": "1; mode=block",
|
|
}, options.GetSetResponseHeadersForPolicy(policy))
|
|
})
|
|
t.Run("multiple policies", func(t *testing.T) {
|
|
options := NewDefaultOptions()
|
|
options.SetResponseHeaders = map[string]string{"global": "foo"}
|
|
p1 := &Policy{
|
|
SetResponseHeaders: map[string]string{"route-1": "bar"},
|
|
}
|
|
p2 := &Policy{
|
|
SetResponseHeaders: map[string]string{"route-2": "baz"},
|
|
}
|
|
assert.Equal(t, map[string]string{
|
|
"global": "foo",
|
|
"route-1": "bar",
|
|
}, options.GetSetResponseHeadersForPolicy(p1))
|
|
assert.Equal(t, map[string]string{
|
|
"global": "foo",
|
|
"route-2": "baz",
|
|
}, options.GetSetResponseHeadersForPolicy(p2))
|
|
assert.Equal(t, map[string]string{"global": "foo"}, options.GetSetResponseHeaders())
|
|
})
|
|
}
|
|
|
|
func TestOptions_GetSharedKey(t *testing.T) {
|
|
t.Run("default", func(t *testing.T) {
|
|
o := NewDefaultOptions()
|
|
bs, err := o.GetSharedKey()
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, randomSharedKey, base64.StdEncoding.EncodeToString(bs))
|
|
})
|
|
t.Run("missing", func(t *testing.T) {
|
|
o := NewDefaultOptions()
|
|
o.Services = ServiceProxy
|
|
_, err := o.GetSharedKey()
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestOptions_GetSigningKey(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
for _, tc := range []struct {
|
|
name string
|
|
input string
|
|
output []byte
|
|
err error
|
|
}{
|
|
{"missing", "", []byte{}, nil},
|
|
{"pem", `
|
|
-----BEGIN EC PRIVATE KEY-----
|
|
MHQCAQEEIGGh6FlBe8yy9dRJgm+35lj3naGFtDODOf6leCW1bRGwoAcGBSuBBAAK
|
|
oUQDQgAE7UlKcFatc9m3GinCrhhT2oRQZ/bEwS98iEUXr0DR8GdxH3e4fhnicsNB
|
|
jHOCur7NYTgf5VaPJwIqLGBmTwM0ew==
|
|
-----END EC PRIVATE KEY-----
|
|
|
|
-----BEGIN EC PRIVATE KEY-----
|
|
MHQCAQEEIBo4wSjkFqQrzf2APNnPol8EDZzkhpcMSaEWXg8iOkbOoAcGBSuBBAAK
|
|
oUQDQgAEr+bGqssRv8RxPV2jJbDpMw81AVXr5+Q2pIF4u6xD9r56lst8uHYThPsw
|
|
ypaqswFIkSzQSW8awdWJ5d+1DEJRUQ==
|
|
-----END EC PRIVATE KEY-----
|
|
`, []byte{
|
|
0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
|
|
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x4d,
|
|
0x48, 0x51, 0x43, 0x41, 0x51, 0x45, 0x45, 0x49, 0x47, 0x47, 0x68, 0x36, 0x46, 0x6c, 0x42, 0x65,
|
|
0x38, 0x79, 0x79, 0x39, 0x64, 0x52, 0x4a, 0x67, 0x6d, 0x2b, 0x33, 0x35, 0x6c, 0x6a, 0x33, 0x6e,
|
|
0x61, 0x47, 0x46, 0x74, 0x44, 0x4f, 0x44, 0x4f, 0x66, 0x36, 0x6c, 0x65, 0x43, 0x57, 0x31, 0x62,
|
|
0x52, 0x47, 0x77, 0x6f, 0x41, 0x63, 0x47, 0x42, 0x53, 0x75, 0x42, 0x42, 0x41, 0x41, 0x4b, 0x0a,
|
|
0x6f, 0x55, 0x51, 0x44, 0x51, 0x67, 0x41, 0x45, 0x37, 0x55, 0x6c, 0x4b, 0x63, 0x46, 0x61, 0x74,
|
|
0x63, 0x39, 0x6d, 0x33, 0x47, 0x69, 0x6e, 0x43, 0x72, 0x68, 0x68, 0x54, 0x32, 0x6f, 0x52, 0x51,
|
|
0x5a, 0x2f, 0x62, 0x45, 0x77, 0x53, 0x39, 0x38, 0x69, 0x45, 0x55, 0x58, 0x72, 0x30, 0x44, 0x52,
|
|
0x38, 0x47, 0x64, 0x78, 0x48, 0x33, 0x65, 0x34, 0x66, 0x68, 0x6e, 0x69, 0x63, 0x73, 0x4e, 0x42,
|
|
0x0a, 0x6a, 0x48, 0x4f, 0x43, 0x75, 0x72, 0x37, 0x4e, 0x59, 0x54, 0x67, 0x66, 0x35, 0x56, 0x61,
|
|
0x50, 0x4a, 0x77, 0x49, 0x71, 0x4c, 0x47, 0x42, 0x6d, 0x54, 0x77, 0x4d, 0x30, 0x65, 0x77, 0x3d,
|
|
0x3d, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
|
|
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a,
|
|
0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
|
|
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x4d,
|
|
0x48, 0x51, 0x43, 0x41, 0x51, 0x45, 0x45, 0x49, 0x42, 0x6f, 0x34, 0x77, 0x53, 0x6a, 0x6b, 0x46,
|
|
0x71, 0x51, 0x72, 0x7a, 0x66, 0x32, 0x41, 0x50, 0x4e, 0x6e, 0x50, 0x6f, 0x6c, 0x38, 0x45, 0x44,
|
|
0x5a, 0x7a, 0x6b, 0x68, 0x70, 0x63, 0x4d, 0x53, 0x61, 0x45, 0x57, 0x58, 0x67, 0x38, 0x69, 0x4f,
|
|
0x6b, 0x62, 0x4f, 0x6f, 0x41, 0x63, 0x47, 0x42, 0x53, 0x75, 0x42, 0x42, 0x41, 0x41, 0x4b, 0x0a,
|
|
0x6f, 0x55, 0x51, 0x44, 0x51, 0x67, 0x41, 0x45, 0x72, 0x2b, 0x62, 0x47, 0x71, 0x73, 0x73, 0x52,
|
|
0x76, 0x38, 0x52, 0x78, 0x50, 0x56, 0x32, 0x6a, 0x4a, 0x62, 0x44, 0x70, 0x4d, 0x77, 0x38, 0x31,
|
|
0x41, 0x56, 0x58, 0x72, 0x35, 0x2b, 0x51, 0x32, 0x70, 0x49, 0x46, 0x34, 0x75, 0x36, 0x78, 0x44,
|
|
0x39, 0x72, 0x35, 0x36, 0x6c, 0x73, 0x74, 0x38, 0x75, 0x48, 0x59, 0x54, 0x68, 0x50, 0x73, 0x77,
|
|
0x0a, 0x79, 0x70, 0x61, 0x71, 0x73, 0x77, 0x46, 0x49, 0x6b, 0x53, 0x7a, 0x51, 0x53, 0x57, 0x38,
|
|
0x61, 0x77, 0x64, 0x57, 0x4a, 0x35, 0x64, 0x2b, 0x31, 0x44, 0x45, 0x4a, 0x52, 0x55, 0x51, 0x3d,
|
|
0x3d, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
|
|
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d,
|
|
}, nil},
|
|
{"base64", `
|
|
LS0tLS1CRUdJTiBFQyBQUklWQVRFIEtFWS0tLS0tCk1IUUNBUUVFSUdHaDZGbEJlOHl5OWRSSmdtKzM1bGozbmFHRnRET0RPZjZsZUNXMWJSR3dvQWNHQlN1QkJBQUsKb1VRRFFnQUU3VWxLY0ZhdGM5bTNHaW5DcmhoVDJvUlFaL2JFd1M5OGlFVVhyMERSOEdkeEgzZTRmaG5pY3NOQgpqSE9DdXI3TllUZ2Y1VmFQSndJcUxHQm1Ud00wZXc9PQotLS0tLUVORCBFQyBQUklWQVRFIEtFWS0tLS0tCgotLS0tLUJFR0lOIEVDIFBSSVZBVEUgS0VZLS0tLS0KTUhRQ0FRRUVJQm80d1Nqa0ZxUXJ6ZjJBUE5uUG9sOEVEWnpraHBjTVNhRVdYZzhpT2tiT29BY0dCU3VCQkFBSwpvVVFEUWdBRXIrYkdxc3NSdjhSeFBWMmpKYkRwTXc4MUFWWHI1K1EycElGNHU2eEQ5cjU2bHN0OHVIWVRoUHN3CnlwYXFzd0ZJa1N6UVNXOGF3ZFdKNWQrMURFSlJVUT09Ci0tLS0tRU5EIEVDIFBSSVZBVEUgS0VZLS0tLS0=
|
|
`, []byte{
|
|
0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
|
|
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x4d,
|
|
0x48, 0x51, 0x43, 0x41, 0x51, 0x45, 0x45, 0x49, 0x47, 0x47, 0x68, 0x36, 0x46, 0x6c, 0x42, 0x65,
|
|
0x38, 0x79, 0x79, 0x39, 0x64, 0x52, 0x4a, 0x67, 0x6d, 0x2b, 0x33, 0x35, 0x6c, 0x6a, 0x33, 0x6e,
|
|
0x61, 0x47, 0x46, 0x74, 0x44, 0x4f, 0x44, 0x4f, 0x66, 0x36, 0x6c, 0x65, 0x43, 0x57, 0x31, 0x62,
|
|
0x52, 0x47, 0x77, 0x6f, 0x41, 0x63, 0x47, 0x42, 0x53, 0x75, 0x42, 0x42, 0x41, 0x41, 0x4b, 0x0a,
|
|
0x6f, 0x55, 0x51, 0x44, 0x51, 0x67, 0x41, 0x45, 0x37, 0x55, 0x6c, 0x4b, 0x63, 0x46, 0x61, 0x74,
|
|
0x63, 0x39, 0x6d, 0x33, 0x47, 0x69, 0x6e, 0x43, 0x72, 0x68, 0x68, 0x54, 0x32, 0x6f, 0x52, 0x51,
|
|
0x5a, 0x2f, 0x62, 0x45, 0x77, 0x53, 0x39, 0x38, 0x69, 0x45, 0x55, 0x58, 0x72, 0x30, 0x44, 0x52,
|
|
0x38, 0x47, 0x64, 0x78, 0x48, 0x33, 0x65, 0x34, 0x66, 0x68, 0x6e, 0x69, 0x63, 0x73, 0x4e, 0x42,
|
|
0x0a, 0x6a, 0x48, 0x4f, 0x43, 0x75, 0x72, 0x37, 0x4e, 0x59, 0x54, 0x67, 0x66, 0x35, 0x56, 0x61,
|
|
0x50, 0x4a, 0x77, 0x49, 0x71, 0x4c, 0x47, 0x42, 0x6d, 0x54, 0x77, 0x4d, 0x30, 0x65, 0x77, 0x3d,
|
|
0x3d, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
|
|
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x0a,
|
|
0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x42, 0x45, 0x47, 0x49, 0x4e, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
|
|
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x0a, 0x4d,
|
|
0x48, 0x51, 0x43, 0x41, 0x51, 0x45, 0x45, 0x49, 0x42, 0x6f, 0x34, 0x77, 0x53, 0x6a, 0x6b, 0x46,
|
|
0x71, 0x51, 0x72, 0x7a, 0x66, 0x32, 0x41, 0x50, 0x4e, 0x6e, 0x50, 0x6f, 0x6c, 0x38, 0x45, 0x44,
|
|
0x5a, 0x7a, 0x6b, 0x68, 0x70, 0x63, 0x4d, 0x53, 0x61, 0x45, 0x57, 0x58, 0x67, 0x38, 0x69, 0x4f,
|
|
0x6b, 0x62, 0x4f, 0x6f, 0x41, 0x63, 0x47, 0x42, 0x53, 0x75, 0x42, 0x42, 0x41, 0x41, 0x4b, 0x0a,
|
|
0x6f, 0x55, 0x51, 0x44, 0x51, 0x67, 0x41, 0x45, 0x72, 0x2b, 0x62, 0x47, 0x71, 0x73, 0x73, 0x52,
|
|
0x76, 0x38, 0x52, 0x78, 0x50, 0x56, 0x32, 0x6a, 0x4a, 0x62, 0x44, 0x70, 0x4d, 0x77, 0x38, 0x31,
|
|
0x41, 0x56, 0x58, 0x72, 0x35, 0x2b, 0x51, 0x32, 0x70, 0x49, 0x46, 0x34, 0x75, 0x36, 0x78, 0x44,
|
|
0x39, 0x72, 0x35, 0x36, 0x6c, 0x73, 0x74, 0x38, 0x75, 0x48, 0x59, 0x54, 0x68, 0x50, 0x73, 0x77,
|
|
0x0a, 0x79, 0x70, 0x61, 0x71, 0x73, 0x77, 0x46, 0x49, 0x6b, 0x53, 0x7a, 0x51, 0x53, 0x57, 0x38,
|
|
0x61, 0x77, 0x64, 0x57, 0x4a, 0x35, 0x64, 0x2b, 0x31, 0x44, 0x45, 0x4a, 0x52, 0x55, 0x51, 0x3d,
|
|
0x3d, 0x0a, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d, 0x45, 0x4e, 0x44, 0x20, 0x45, 0x43, 0x20, 0x50, 0x52,
|
|
0x49, 0x56, 0x41, 0x54, 0x45, 0x20, 0x4b, 0x45, 0x59, 0x2d, 0x2d, 0x2d, 0x2d, 0x2d,
|
|
}, nil},
|
|
} {
|
|
t.Run(tc.name, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
output, err := (&Options{SigningKey: tc.input}).GetSigningKey()
|
|
assert.Equal(t, tc.err, err)
|
|
assert.Equal(t, tc.output, output)
|
|
|
|
dir := t.TempDir()
|
|
err = os.WriteFile(filepath.Join(dir, "cert"), []byte(tc.input), 0o0666)
|
|
assert.NoError(t, err)
|
|
|
|
output, err = (&Options{SigningKeyFile: filepath.Join(dir, "cert")}).GetSigningKey()
|
|
assert.Equal(t, tc.err, err)
|
|
assert.Equal(t, tc.output, output)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOptions_GetCookieSecret(t *testing.T) {
|
|
t.Run("default", func(t *testing.T) {
|
|
o := NewDefaultOptions()
|
|
bs, err := o.GetCookieSecret()
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, randomSharedKey, base64.StdEncoding.EncodeToString(bs))
|
|
})
|
|
t.Run("missing", func(t *testing.T) {
|
|
o := NewDefaultOptions()
|
|
o.Services = ServiceProxy
|
|
_, err := o.GetCookieSecret()
|
|
assert.Error(t, err)
|
|
})
|
|
}
|
|
|
|
func TestOptions_GetCookieSameSite(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
for _, tc := range []struct {
|
|
input string
|
|
expected http.SameSite
|
|
}{
|
|
{"", http.SameSiteDefaultMode},
|
|
{"Lax", http.SameSiteLaxMode},
|
|
{"lax", http.SameSiteLaxMode},
|
|
{"Strict", http.SameSiteStrictMode},
|
|
{"strict", http.SameSiteStrictMode},
|
|
{"None", http.SameSiteNoneMode},
|
|
{"none", http.SameSiteNoneMode},
|
|
{"UnKnOwN", http.SameSiteDefaultMode},
|
|
} {
|
|
t.Run(tc.input, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
o := NewDefaultOptions()
|
|
o.CookieSameSite = tc.input
|
|
assert.Equal(t, tc.expected, o.GetCookieSameSite())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOptions_GetCSRFSameSite(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
for _, tc := range []struct {
|
|
cookieSameSite string
|
|
provider string
|
|
expected csrf.SameSiteMode
|
|
}{
|
|
{"", "", csrf.SameSiteDefaultMode},
|
|
{"Lax", "", csrf.SameSiteLaxMode},
|
|
{"lax", "", csrf.SameSiteLaxMode},
|
|
{"Strict", "", csrf.SameSiteStrictMode},
|
|
{"strict", "", csrf.SameSiteStrictMode},
|
|
{"None", "", csrf.SameSiteNoneMode},
|
|
{"none", "", csrf.SameSiteNoneMode},
|
|
{"UnKnOwN", "", csrf.SameSiteDefaultMode},
|
|
{"", apple.Name, csrf.SameSiteNoneMode},
|
|
} {
|
|
t.Run(tc.cookieSameSite, func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
o := NewDefaultOptions()
|
|
o.CookieSameSite = tc.cookieSameSite
|
|
o.Provider = tc.provider
|
|
assert.Equal(t, tc.expected, o.GetCSRFSameSite())
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOptions_RequestParams(t *testing.T) {
|
|
cases := []struct {
|
|
label string
|
|
config string
|
|
expected map[string]string
|
|
}{
|
|
{"not present", "", nil},
|
|
{"explicitly empty", "idp_request_params: {}", map[string]string{}},
|
|
}
|
|
cfg := filepath.Join(t.TempDir(), "config.yaml")
|
|
for i := range cases {
|
|
c := &cases[i]
|
|
t.Run(c.label, func(t *testing.T) {
|
|
err := os.WriteFile(cfg, []byte(c.config), 0o644)
|
|
require.NoError(t, err)
|
|
o, err := newOptionsFromConfig(cfg)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, c.expected, o.RequestParams)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOptions_RequestParamsFromEnv(t *testing.T) {
|
|
t.Setenv("IDP_REQUEST_PARAMS", `{"x":"y"}`)
|
|
|
|
options, err := newOptionsFromConfig("")
|
|
if assert.NoError(t, err) {
|
|
assert.Equal(t, map[string]string{"x": "y"}, options.RequestParams)
|
|
}
|
|
}
|
|
|
|
func TestOptions_RuntimeFlags(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
extra := DefaultRuntimeFlags()
|
|
extra["another"] = true
|
|
|
|
cases := []struct {
|
|
label string
|
|
config string
|
|
expected RuntimeFlags
|
|
}{
|
|
{"not present", "", DefaultRuntimeFlags()},
|
|
{"explicitly empty", `{"runtime_flags": {}}`, DefaultRuntimeFlags()},
|
|
{"all", `{"runtime_flags":{"another":true}}`, extra},
|
|
}
|
|
cfg := filepath.Join(t.TempDir(), "config.yaml")
|
|
for _, c := range cases {
|
|
t.Run(c.label, func(t *testing.T) {
|
|
err := os.WriteFile(cfg, []byte(c.config), 0o644)
|
|
require.NoError(t, err)
|
|
o, err := newOptionsFromConfig(cfg)
|
|
require.NoError(t, err)
|
|
assert.Equal(t, c.expected, o.RuntimeFlags)
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestOptions_GetDataBrokerStorageConnectionString(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
t.Run("validate", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
o := NewDefaultOptions()
|
|
o.Services = "databroker"
|
|
o.DataBrokerStorageType = "postgres"
|
|
o.SharedKey = cryptutil.NewBase64Key()
|
|
|
|
assert.ErrorContains(t, o.Validate(), "missing databroker storage backend dsn",
|
|
"should validate DSN")
|
|
|
|
o.DataBrokerStorageConnectionString = "DSN"
|
|
assert.NoError(t, o.Validate(),
|
|
"should have no error when the dsn is set")
|
|
|
|
o.DataBrokerStorageConnectionString = ""
|
|
o.DataBrokerStorageConnectionStringFile = "DSN_FILE"
|
|
assert.NoError(t, o.Validate(),
|
|
"should have no error when the dsn file is set")
|
|
})
|
|
t.Run("literal", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
o := NewDefaultOptions()
|
|
o.DataBrokerStorageConnectionString = "DSN"
|
|
|
|
dsn, err := o.GetDataBrokerStorageConnectionString()
|
|
assert.NoError(t, err)
|
|
assert.Equal(t, "DSN", dsn)
|
|
})
|
|
t.Run("file", func(t *testing.T) {
|
|
t.Parallel()
|
|
|
|
dir := t.TempDir()
|
|
fp := filepath.Join(dir, "DSN_FILE")
|
|
|
|
o := NewDefaultOptions()
|
|
o.DataBrokerStorageConnectionStringFile = fp
|
|
o.DataBrokerStorageConnectionString = "IGNORED"
|
|
|
|
dsn, err := o.GetDataBrokerStorageConnectionString()
|
|
assert.Error(t, err,
|
|
"should return an error when the file doesn't exist")
|
|
assert.Empty(t, dsn)
|
|
|
|
os.WriteFile(fp, []byte(`
|
|
DSN
|
|
`), 0o644)
|
|
|
|
dsn, err = o.GetDataBrokerStorageConnectionString()
|
|
assert.NoError(t, err,
|
|
"should not return an error when the file exists")
|
|
assert.Equal(t, "DSN", dsn,
|
|
"should return the trimmed contents of the file")
|
|
})
|
|
}
|
|
|
|
func encodeCert(cert *tls.Certificate) []byte {
|
|
return pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert.Certificate[0]})
|
|
}
|
|
|
|
func TestRoute_FromToProto(t *testing.T) {
|
|
routeGen := protorand.New[*configpb.Route]()
|
|
routeGen.MaxCollectionElements = 2
|
|
routeGen.UseGoDurationLimits = true
|
|
routeGen.ExcludeMask(&fieldmaskpb.FieldMask{
|
|
Paths: []string{
|
|
"from", "to", "load_balancing_weights", "redirect", "response", // set below
|
|
"ppl_policies", "name", // no equivalent field
|
|
"envoy_opts",
|
|
},
|
|
})
|
|
redirectGen := protorand.New[*configpb.RouteRedirect]()
|
|
responseGen := protorand.New[*configpb.RouteDirectResponse]()
|
|
|
|
randomDomain := func() string {
|
|
numSegments := mathrand.IntN(5) + 1
|
|
segments := make([]string, numSegments)
|
|
for i := range segments {
|
|
b := make([]rune, mathrand.IntN(10)+10)
|
|
for j := range b {
|
|
b[j] = rune(mathrand.IntN(26) + 'a')
|
|
}
|
|
segments[i] = string(b)
|
|
}
|
|
return strings.Join(segments, ".")
|
|
}
|
|
|
|
newCompleteRoute := func() *configpb.Route {
|
|
pb, err := routeGen.Gen()
|
|
|
|
require.NoError(t, err)
|
|
pb.From = "https://" + randomDomain()
|
|
// EnvoyOpts is set to an empty non-nil message during conversion, if nil
|
|
pb.EnvoyOpts = &envoy_config_cluster_v3.Cluster{}
|
|
// JWT groups filter order is not significant. Upon conversion back to
|
|
// a protobuf the JWT groups will be sorted.
|
|
slices.Sort(pb.JwtGroupsFilter)
|
|
|
|
switch mathrand.IntN(3) {
|
|
case 0:
|
|
pb.To = make([]string, mathrand.IntN(3)+1)
|
|
for i := range pb.To {
|
|
pb.To[i] = "https://" + randomDomain()
|
|
}
|
|
pb.LoadBalancingWeights = make([]uint32, len(pb.To))
|
|
for i := range pb.LoadBalancingWeights {
|
|
pb.LoadBalancingWeights[i] = mathrand.Uint32N(10000) + 1
|
|
}
|
|
case 1:
|
|
pb.Redirect, err = redirectGen.Gen()
|
|
require.NoError(t, err)
|
|
case 2:
|
|
pb.Response, err = responseGen.Gen()
|
|
require.NoError(t, err)
|
|
}
|
|
return pb
|
|
}
|
|
|
|
t.Run("Round Trip", func(t *testing.T) {
|
|
for range 100 {
|
|
route := newCompleteRoute()
|
|
|
|
policy, err := NewPolicyFromProto(route)
|
|
require.NoError(t, err)
|
|
|
|
route2, err := policy.ToProto()
|
|
require.NoError(t, err)
|
|
route2.Name = ""
|
|
|
|
testutil.AssertProtoEqual(t, route, route2)
|
|
}
|
|
})
|
|
|
|
t.Run("Multiple routes", func(t *testing.T) {
|
|
for range 100 {
|
|
route1 := newCompleteRoute()
|
|
route2 := newCompleteRoute()
|
|
|
|
{
|
|
// create a new policy every time, since reusing the target will mutate
|
|
// the underlying route
|
|
policy1, err := NewPolicyFromProto(route1)
|
|
require.NoError(t, err)
|
|
target, err := policy1.ToProto()
|
|
require.NoError(t, err)
|
|
target.Name = ""
|
|
testutil.AssertProtoEqual(t, route1, target)
|
|
}
|
|
{
|
|
policy2, err := NewPolicyFromProto(route2)
|
|
require.NoError(t, err)
|
|
target, err := policy2.ToProto()
|
|
require.NoError(t, err)
|
|
target.Name = ""
|
|
testutil.AssertProtoEqual(t, route2, target)
|
|
}
|
|
{
|
|
policy1, err := NewPolicyFromProto(route1)
|
|
require.NoError(t, err)
|
|
target, err := policy1.ToProto()
|
|
require.NoError(t, err)
|
|
target.Name = ""
|
|
testutil.AssertProtoEqual(t, route1, target)
|
|
}
|
|
{
|
|
policy2, err := NewPolicyFromProto(route2)
|
|
require.NoError(t, err)
|
|
target, err := policy2.ToProto()
|
|
require.NoError(t, err)
|
|
target.Name = ""
|
|
testutil.AssertProtoEqual(t, route2, target)
|
|
}
|
|
}
|
|
})
|
|
}
|
|
|
|
func TestOptions_FromToProto(t *testing.T) {
|
|
generate := func(ratio float64) *configpb.Settings {
|
|
t.Helper()
|
|
gen := protorand.New[*configpb.Settings]()
|
|
gen.MaxCollectionElements = 2
|
|
gen.MaxDepth = 3
|
|
gen.UseGoDurationLimits = true
|
|
gen.ExcludeMask(&fieldmaskpb.FieldMask{
|
|
Paths: []string{
|
|
"tls_custom_ca_file",
|
|
"tls_client_cert_file",
|
|
"tls_client_key_file",
|
|
"tls_downstream_client_ca_file",
|
|
},
|
|
})
|
|
|
|
settings, err := gen.GenPartial(ratio)
|
|
require.NoError(t, err)
|
|
unsetFalseOptionalBoolFields(settings)
|
|
fixZeroValuedEnums(settings)
|
|
generateCertificates(t, settings)
|
|
// JWT groups filter order is not significant. Upon conversion back to
|
|
// a protobuf the JWT groups will be sorted.
|
|
slices.Sort(settings.JwtGroupsFilter)
|
|
|
|
return settings
|
|
}
|
|
|
|
t.Run("all fields", func(t *testing.T) {
|
|
t.Parallel()
|
|
for range 100 {
|
|
settings := generate(1)
|
|
var options Options
|
|
options.ApplySettings(context.Background(), nil, settings)
|
|
settings2 := options.ToProto()
|
|
testutil.AssertProtoEqual(t, settings, settings2.Settings)
|
|
}
|
|
})
|
|
|
|
t.Run("some fields", func(t *testing.T) {
|
|
t.Parallel()
|
|
for range 100 {
|
|
settings := generate(mathrand.Float64())
|
|
var options Options
|
|
options.ApplySettings(context.Background(), nil, settings)
|
|
settings2 := options.ToProto()
|
|
testutil.AssertProtoEqual(t, settings, settings2.Settings)
|
|
}
|
|
})
|
|
}
|
|
|
|
// unset any optional bool fields with a value of false, to match
|
|
func unsetFalseOptionalBoolFields(msg proto.Message) {
|
|
msg.ProtoReflect().Range(func(fd protoreflect.FieldDescriptor, v protoreflect.Value) bool {
|
|
if fd.Cardinality() == protoreflect.Optional && fd.Kind() == protoreflect.BoolKind {
|
|
if v.IsValid() && !v.Bool() {
|
|
msg.ProtoReflect().Clear(fd)
|
|
}
|
|
}
|
|
return true
|
|
})
|
|
}
|
|
|
|
func fixZeroValuedEnums(msg *configpb.Settings) {
|
|
if msg.DownstreamMtls != nil && msg.DownstreamMtls.Enforcement != nil {
|
|
// there is no "unknown" equivalent, so if the value is randomly set to
|
|
// unknown it would be a lossy conversion
|
|
if *msg.DownstreamMtls.Enforcement == configpb.MtlsEnforcementMode_UNKNOWN {
|
|
msg.DownstreamMtls.Enforcement = nil
|
|
// if this was the only present field in the message, don't leave it empty
|
|
if proto.Size(msg.DownstreamMtls) == 0 {
|
|
msg.DownstreamMtls = nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func generateCertificates(t testing.TB, msg *configpb.Settings) {
|
|
if msg.AutocertCa != nil {
|
|
*msg.AutocertCa, _ = generateRandomCA(t, *msg.AutocertCa)
|
|
}
|
|
if msg.DownstreamMtls != nil {
|
|
var caKey string
|
|
if msg.DownstreamMtls.Ca != nil {
|
|
*msg.DownstreamMtls.Ca, caKey = generateRandomCA(t, *msg.DownstreamMtls.Ca)
|
|
}
|
|
if msg.DownstreamMtls.Crl != nil {
|
|
if caKey != "" {
|
|
*msg.DownstreamMtls.Crl = generateCRL(t, *msg.DownstreamMtls.Crl, *msg.DownstreamMtls.Ca, caKey)
|
|
} else {
|
|
randCa, randKey := generateRandomCA(t, *msg.DownstreamMtls.Crl+"_temp_ca")
|
|
*msg.DownstreamMtls.Crl = generateCRL(t, *msg.DownstreamMtls.Crl, randCa, randKey)
|
|
}
|
|
}
|
|
}
|
|
genCertInPlace := func(cert *configpb.Settings_Certificate, b64 bool) {
|
|
cert.Id = "" // no equivalent field
|
|
switch {
|
|
case len(cert.CertBytes) > 0 && len(cert.KeyBytes) > 0:
|
|
crt, key := generateRandomCert(t, string(cert.CertBytes)+string(cert.KeyBytes), b64)
|
|
cert.CertBytes = []byte(crt)
|
|
cert.KeyBytes = []byte(key)
|
|
case len(cert.CertBytes) > 0 && len(cert.KeyBytes) == 0:
|
|
crt, _ := generateRandomCert(t, string(cert.CertBytes), b64)
|
|
cert.CertBytes = []byte(crt)
|
|
case len(cert.CertBytes) == 0 && len(cert.KeyBytes) > 0:
|
|
// invalid, but convert anyway
|
|
crt, _ := generateRandomCert(t, string(cert.KeyBytes), b64)
|
|
cert.KeyBytes = []byte(crt)
|
|
}
|
|
}
|
|
for i, cert := range msg.Certificates {
|
|
genCertInPlace(cert, false)
|
|
if cert.CertBytes == nil && cert.KeyBytes == nil {
|
|
msg.Certificates = slices.Delete(msg.Certificates, i, i+1)
|
|
}
|
|
}
|
|
if msg.MetricsCertificate != nil {
|
|
genCertInPlace(msg.MetricsCertificate, false)
|
|
if msg.MetricsCertificate.CertBytes == nil && msg.MetricsCertificate.KeyBytes == nil {
|
|
msg.MetricsCertificate = nil
|
|
}
|
|
}
|
|
}
|
|
|
|
func generateRandomCA(t testing.TB, randomInput string) (string, string) {
|
|
seed := sha256.Sum256([]byte(randomInput))
|
|
priv := ed25519.NewKeyFromSeed(seed[:])
|
|
h := fnv.New128()
|
|
h.Write([]byte(randomInput))
|
|
sum := h.Sum(nil)
|
|
var sn big.Int
|
|
sn.SetBytes(sum)
|
|
|
|
now := time.Now()
|
|
tmpl := &x509.Certificate{
|
|
IsCA: true,
|
|
SerialNumber: &sn,
|
|
Subject: pkix.Name{CommonName: randomInput},
|
|
Issuer: pkix.Name{CommonName: randomInput},
|
|
NotBefore: now,
|
|
NotAfter: now.Add(12 * time.Hour),
|
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCRLSign,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
|
x509.ExtKeyUsageClientAuth,
|
|
x509.ExtKeyUsageServerAuth,
|
|
},
|
|
BasicConstraintsValid: true,
|
|
}
|
|
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
|
|
require.NoError(t, err)
|
|
return base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: der,
|
|
})), base64.StdEncoding.EncodeToString(pem.EncodeToMemory(&pem.Block{
|
|
Type: "PRIVATE KEY",
|
|
Bytes: must(x509.MarshalPKCS8PrivateKey(priv)),
|
|
}))
|
|
}
|
|
|
|
func generateCRL(t testing.TB, randomInput string, issuerCrt, issuerKey string) string {
|
|
h := fnv.New128()
|
|
h.Write([]byte(randomInput))
|
|
sum := h.Sum(nil)
|
|
var sn big.Int
|
|
sn.SetBytes(sum)
|
|
issuer, err := cryptutil.CertificateFromBase64(issuerCrt, issuerKey)
|
|
require.NoError(t, err)
|
|
b, err := x509.CreateRevocationList(rand.Reader, &x509.RevocationList{
|
|
Number: big.NewInt(0x2000),
|
|
RevokedCertificates: []pkix.RevokedCertificate{
|
|
{
|
|
SerialNumber: &sn,
|
|
RevocationTime: time.Now(),
|
|
},
|
|
},
|
|
}, issuer.Leaf, issuer.PrivateKey.(crypto.Signer))
|
|
require.NoError(t, err)
|
|
return base64.StdEncoding.EncodeToString(b)
|
|
}
|
|
|
|
func generateRandomCert(t testing.TB, randomInput string, b64 bool) (string, string) {
|
|
seed := sha256.Sum256([]byte(randomInput))
|
|
priv := ed25519.NewKeyFromSeed(seed[:])
|
|
h := fnv.New128()
|
|
h.Write([]byte(randomInput))
|
|
sum := h.Sum(nil)
|
|
var sn big.Int
|
|
sn.SetBytes(sum)
|
|
now := time.Now()
|
|
tmpl := &x509.Certificate{
|
|
SerialNumber: &sn,
|
|
Subject: pkix.Name{
|
|
CommonName: randomInput,
|
|
},
|
|
Issuer: pkix.Name{
|
|
CommonName: randomInput,
|
|
},
|
|
NotBefore: now,
|
|
NotAfter: now.Add(12 * time.Hour),
|
|
KeyUsage: x509.KeyUsageDigitalSignature,
|
|
ExtKeyUsage: []x509.ExtKeyUsage{
|
|
x509.ExtKeyUsageClientAuth,
|
|
},
|
|
BasicConstraintsValid: true,
|
|
}
|
|
certDer, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, priv.Public(), priv)
|
|
require.NoError(t, err)
|
|
crtPem := pem.EncodeToMemory(&pem.Block{
|
|
Type: "CERTIFICATE",
|
|
Bytes: certDer,
|
|
})
|
|
keyPem := pem.EncodeToMemory(&pem.Block{
|
|
Type: "PRIVATE KEY",
|
|
Bytes: must(x509.MarshalPKCS8PrivateKey(priv)),
|
|
})
|
|
if b64 {
|
|
return base64.StdEncoding.EncodeToString(crtPem), base64.StdEncoding.EncodeToString(keyPem)
|
|
}
|
|
return string(crtPem), string(keyPem)
|
|
}
|
|
|
|
func must[T any](t T, err error) T {
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
return t
|
|
}
|
|
|
|
func ptr[T any](v T) *T {
|
|
return &v
|
|
}
|