From e448909042bdda5f85cdf528bd10fbab0857dbcd Mon Sep 17 00:00:00 2001 From: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com> Date: Thu, 17 Aug 2023 08:13:57 -0700 Subject: [PATCH] authorize: remove incorrect "valid-client-certificate" reason (#4470) Fix the logic around when to add the default invalid_client_certificate rule: this should only be added if mTLS is enabled and the enforcement mode is not set to "policy". Add a unit test for this logic. --- authorize/authorize.go | 2 +- authorize/authorize_test.go | 53 +++++++++++++++++++++++++++++++++++ config/options.go | 20 +++++++++++++ config/options_test.go | 56 +++++++++++++++++++++++++++++++++++++ 4 files changed, 130 insertions(+), 1 deletion(-) diff --git a/authorize/authorize.go b/authorize/authorize.go index 1e028d4c3..c694ab5f9 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -117,7 +117,7 @@ func newPolicyEvaluator(opts *config.Options, store *store.Store) (*evaluator.Ev // It is important to add an invalid_client_certificate rule even when the // mTLS enforcement behavior is set to reject connections at the listener // level, because of the per-route TLSDownstreamClientCA setting. - addDefaultClientCertificateRule := + addDefaultClientCertificateRule := opts.HasAnyDownstreamMTLSClientCA() && opts.DownstreamMTLS.GetEnforcement() != config.MTLSEnforcementPolicy clientCertConstraints, err := evaluator.ClientCertConstraintsFromConfig(&opts.DownstreamMTLS) diff --git a/authorize/authorize_test.go b/authorize/authorize_test.go index d2dbf3114..350ae5040 100644 --- a/authorize/authorize_test.go +++ b/authorize/authorize_test.go @@ -7,7 +7,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/pomerium/pomerium/authorize/evaluator" + "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/policy/criteria" ) func TestNew(t *testing.T) { @@ -138,3 +141,53 @@ func testPolicies(t *testing.T) []config.Policy { } return policies } + +func TestNewPolicyEvaluator_addDefaultClientCertificateRule(t *testing.T) { + t.Parallel() + + resultFalse := evaluator.NewRuleResult(false) // no mention of client certificates + resultClientCertificateRequired := evaluator.NewRuleResult(true, + criteria.ReasonClientCertificateRequired) + + cases := []struct { + label string + opts *config.Options + expected evaluator.RuleResult + }{ + {"zero", &config.Options{}, resultFalse}, + {"default", config.NewDefaultOptions(), resultFalse}, + {"client CA, default enforcement", &config.Options{ + DownstreamMTLS: config.DownstreamMTLSSettings{CA: "ZmFrZSBDQQ=="}, + }, resultClientCertificateRequired}, + {"client CA, reject connection", &config.Options{ + DownstreamMTLS: config.DownstreamMTLSSettings{ + CA: "ZmFrZSBDQQ==", + Enforcement: config.MTLSEnforcementRejectConnection, + }, + }, resultClientCertificateRequired}, + {"client CA, policy", &config.Options{ + DownstreamMTLS: config.DownstreamMTLSSettings{ + CA: "ZmFrZSBDQQ==", + Enforcement: config.MTLSEnforcementPolicy, + }, + }, resultFalse}, + } + for i := range cases { + c := &cases[i] + t.Run(c.label, func(t *testing.T) { + store := store.New() + c.opts.Policies = []config.Policy{{ + To: mustParseWeightedURLs(t, "http://example.com"), + }} + e, err := newPolicyEvaluator(c.opts, store) + require.NoError(t, err) + + r, err := e.Evaluate(context.Background(), &evaluator.Request{ + Policy: &c.opts.Policies[0], + HTTP: evaluator.RequestHTTP{}, + }) + require.NoError(t, err) + assert.Equal(t, c.expected, r.Deny) + }) + } +} diff --git a/config/options.go b/config/options.go index 17249dacc..1f2838890 100644 --- a/config/options.go +++ b/config/options.go @@ -976,6 +976,26 @@ func (o *Options) GetMetricsBasicAuth() (username, password string, ok bool) { return string(bs[:idx]), string(bs[idx+1:]), true } +// HasAnyDownstreamMTLSClientCA returns true if there is a global downstream +// client CA or there are any per-route downstream client CAs. +func (o *Options) HasAnyDownstreamMTLSClientCA() bool { + // All the CA settings should already have been validated. + ca, _ := o.DownstreamMTLS.GetCA() + if len(ca) > 0 { + return true + } + allPolicies := o.GetAllPolicies() + for i := range allPolicies { + // We don't need to check TLSDownstreamClientCAFile here because + // Policy.Validate() will populate TLSDownstreamClientCA when + // TLSDownstreamClientCAFile is set. + if allPolicies[i].TLSDownstreamClientCA != "" { + return true + } + } + return false +} + // GetDataBrokerCertificate gets the optional databroker certificate. This method will return nil if no certificate is // specified. func (o *Options) GetDataBrokerCertificate() (*tls.Certificate, error) { diff --git a/config/options_test.go b/config/options_test.go index aa5ee9357..b817ee22c 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -759,6 +759,62 @@ func TestDeprecatedClientCAOptions(t *testing.T) { }) } +func TestHasAnyDownstreamMTLSClientCA(t *testing.T) { + t.Parallel() + + cases := []struct { + label string + opts *Options + expected bool + }{ + {"zero", &Options{}, false}, + {"default", NewDefaultOptions(), false}, + {"no client CAs", &Options{ + Policies: []Policy{ + {From: "https://example.com/one"}, + {From: "https://example.com/two"}, + {From: "https://example.com/three"}, + }, + }, false}, + {"global client CA only", &Options{ + DownstreamMTLS: DownstreamMTLSSettings{CA: "ZmFrZSBDQQ=="}, + Policies: []Policy{ + {From: "https://example.com/one"}, + {From: "https://example.com/two"}, + {From: "https://example.com/three"}, + }, + }, true}, + {"per-route CA only", &Options{ + Policies: []Policy{ + {From: "https://example.com/one"}, + { + From: "https://example.com/two", + TLSDownstreamClientCA: "ZmFrZSBDQQ==", + }, + {From: "https://example.com/three"}, + }, + }, true}, + {"both global and per-route client CAs", &Options{ + DownstreamMTLS: DownstreamMTLSSettings{CA: "ZmFrZSBDQQ=="}, + Policies: []Policy{ + {From: "https://example.com/one"}, + { + From: "https://example.com/two", + TLSDownstreamClientCA: "ZmFrZSBDQQ==", + }, + {From: "https://example.com/three"}, + }, + }, true}, + } + for i := range cases { + c := &cases[i] + t.Run(c.label, func(t *testing.T) { + actual := c.opts.HasAnyDownstreamMTLSClientCA() + assert.Equal(t, c.expected, actual) + }) + } +} + func TestOptions_DefaultURL(t *testing.T) { t.Parallel()