diff --git a/authorize/authorize.go b/authorize/authorize.go index 1e028d4c3..c694ab5f9 100644 --- a/authorize/authorize.go +++ b/authorize/authorize.go @@ -117,7 +117,7 @@ func newPolicyEvaluator(opts *config.Options, store *store.Store) (*evaluator.Ev // It is important to add an invalid_client_certificate rule even when the // mTLS enforcement behavior is set to reject connections at the listener // level, because of the per-route TLSDownstreamClientCA setting. - addDefaultClientCertificateRule := + addDefaultClientCertificateRule := opts.HasAnyDownstreamMTLSClientCA() && opts.DownstreamMTLS.GetEnforcement() != config.MTLSEnforcementPolicy clientCertConstraints, err := evaluator.ClientCertConstraintsFromConfig(&opts.DownstreamMTLS) diff --git a/authorize/authorize_test.go b/authorize/authorize_test.go index d2dbf3114..350ae5040 100644 --- a/authorize/authorize_test.go +++ b/authorize/authorize_test.go @@ -7,7 +7,10 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/pomerium/pomerium/authorize/evaluator" + "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/pkg/policy/criteria" ) func TestNew(t *testing.T) { @@ -138,3 +141,53 @@ func testPolicies(t *testing.T) []config.Policy { } return policies } + +func TestNewPolicyEvaluator_addDefaultClientCertificateRule(t *testing.T) { + t.Parallel() + + resultFalse := evaluator.NewRuleResult(false) // no mention of client certificates + resultClientCertificateRequired := evaluator.NewRuleResult(true, + criteria.ReasonClientCertificateRequired) + + cases := []struct { + label string + opts *config.Options + expected evaluator.RuleResult + }{ + {"zero", &config.Options{}, resultFalse}, + {"default", config.NewDefaultOptions(), resultFalse}, + {"client CA, default enforcement", &config.Options{ + DownstreamMTLS: config.DownstreamMTLSSettings{CA: "ZmFrZSBDQQ=="}, + }, resultClientCertificateRequired}, + {"client CA, reject connection", &config.Options{ + DownstreamMTLS: config.DownstreamMTLSSettings{ + CA: "ZmFrZSBDQQ==", + Enforcement: config.MTLSEnforcementRejectConnection, + }, + }, resultClientCertificateRequired}, + {"client CA, policy", &config.Options{ + DownstreamMTLS: config.DownstreamMTLSSettings{ + CA: "ZmFrZSBDQQ==", + Enforcement: config.MTLSEnforcementPolicy, + }, + }, resultFalse}, + } + for i := range cases { + c := &cases[i] + t.Run(c.label, func(t *testing.T) { + store := store.New() + c.opts.Policies = []config.Policy{{ + To: mustParseWeightedURLs(t, "http://example.com"), + }} + e, err := newPolicyEvaluator(c.opts, store) + require.NoError(t, err) + + r, err := e.Evaluate(context.Background(), &evaluator.Request{ + Policy: &c.opts.Policies[0], + HTTP: evaluator.RequestHTTP{}, + }) + require.NoError(t, err) + assert.Equal(t, c.expected, r.Deny) + }) + } +} diff --git a/config/constants.go b/config/constants.go index 93fdd62db..362ab380e 100644 --- a/config/constants.go +++ b/config/constants.go @@ -38,4 +38,5 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc( decodeJWTClaimHeadersHookFunc(), decodeCodecTypeHookFunc(), decodePPLPolicyHookFunc(), + decodeSANMatcherHookFunc(), )) diff --git a/config/custom.go b/config/custom.go index 812a8f423..1cd8a6902 100644 --- a/config/custom.go +++ b/config/custom.go @@ -508,6 +508,26 @@ func parseJSONPB(src map[string]interface{}, dst proto.Message, opts protojson.U return opts.Unmarshal(data, dst) } +// decodeSANMatcherHookFunc returns a decode hook for the SANMatcher type. +func decodeSANMatcherHookFunc() mapstructure.DecodeHookFunc { + return func(f, t reflect.Type, data interface{}) (interface{}, error) { + if t != reflect.TypeOf(SANMatcher{}) { + return data, nil + } + + b, err := json.Marshal(data) + if err != nil { + return nil, err + } + + var m SANMatcher + if err := json.Unmarshal(b, &m); err != nil { + return nil, err + } + return m, nil + } +} + // serializable converts mapstructure nested map into map[string]interface{} that is serializable to JSON func serializable(in interface{}) (interface{}, error) { switch typed := in.(type) { diff --git a/config/envoyconfig/luascripts/set-client-certificate-metadata.lua b/config/envoyconfig/luascripts/set-client-certificate-metadata.lua index d270c5d17..eda5310f0 100644 --- a/config/envoyconfig/luascripts/set-client-certificate-metadata.lua +++ b/config/envoyconfig/luascripts/set-client-certificate-metadata.lua @@ -1,6 +1,9 @@ function envoy_on_request(request_handle) local metadata = request_handle:streamInfo():dynamicMetadata() local ssl = request_handle:streamInfo():downstreamSslConnection() + if ssl == nil then + return + end metadata:set("com.pomerium.client-certificate-info", "presented", ssl:peerCertificatePresented()) metadata:set("com.pomerium.client-certificate-info", "chain", diff --git a/config/envoyconfig/testdata/main_http_connection_manager_filter.json b/config/envoyconfig/testdata/main_http_connection_manager_filter.json index 40b981e61..cc9cc0bb2 100644 --- a/config/envoyconfig/testdata/main_http_connection_manager_filter.json +++ b/config/envoyconfig/testdata/main_http_connection_manager_filter.json @@ -38,7 +38,7 @@ "typedConfig": { "@type": "type.googleapis.com/envoy.extensions.filters.http.lua.v3.Lua", "defaultSourceCode": { - "inlineString": "function envoy_on_request(request_handle)\n local metadata = request_handle:streamInfo():dynamicMetadata()\n local ssl = request_handle:streamInfo():downstreamSslConnection()\n metadata:set(\"com.pomerium.client-certificate-info\", \"presented\",\n ssl:peerCertificatePresented())\n metadata:set(\"com.pomerium.client-certificate-info\", \"chain\",\n ssl:urlEncodedPemEncodedPeerCertificateChain())\nend\n\nfunction envoy_on_response(response_handle) end\n" + "inlineString": "function envoy_on_request(request_handle)\n local metadata = request_handle:streamInfo():dynamicMetadata()\n local ssl = request_handle:streamInfo():downstreamSslConnection()\n if ssl == nil then\n return\n end\n metadata:set(\"com.pomerium.client-certificate-info\", \"presented\",\n ssl:peerCertificatePresented())\n metadata:set(\"com.pomerium.client-certificate-info\", \"chain\",\n ssl:urlEncodedPemEncodedPeerCertificateChain())\nend\n\nfunction envoy_on_response(response_handle) end\n" } } }, diff --git a/config/options.go b/config/options.go index 17249dacc..1f2838890 100644 --- a/config/options.go +++ b/config/options.go @@ -976,6 +976,26 @@ func (o *Options) GetMetricsBasicAuth() (username, password string, ok bool) { return string(bs[:idx]), string(bs[idx+1:]), true } +// HasAnyDownstreamMTLSClientCA returns true if there is a global downstream +// client CA or there are any per-route downstream client CAs. +func (o *Options) HasAnyDownstreamMTLSClientCA() bool { + // All the CA settings should already have been validated. + ca, _ := o.DownstreamMTLS.GetCA() + if len(ca) > 0 { + return true + } + allPolicies := o.GetAllPolicies() + for i := range allPolicies { + // We don't need to check TLSDownstreamClientCAFile here because + // Policy.Validate() will populate TLSDownstreamClientCA when + // TLSDownstreamClientCAFile is set. + if allPolicies[i].TLSDownstreamClientCA != "" { + return true + } + } + return false +} + // GetDataBrokerCertificate gets the optional databroker certificate. This method will return nil if no certificate is // specified. func (o *Options) GetDataBrokerCertificate() (*tls.Certificate, error) { diff --git a/config/options_test.go b/config/options_test.go index 3e88e3e85..b817ee22c 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -341,6 +341,27 @@ func Test_parsePolicyFile(t *testing.T) { } } +func Test_decodeSANMatcher(t *testing.T) { + // Verify that config file parsing will decode the SANMatcher type. + const yaml = ` +downstream_mtls: + match_subject_alt_names: + - dns: 'example-1\..*' + - dns: '.*\.example-2' +` + cfg := filepath.Join(t.TempDir(), "config.yaml") + err := os.WriteFile(cfg, []byte(yaml), 0644) + require.NoError(t, err) + + o, err := optionsFromViper(cfg) + require.NoError(t, err) + + assert.Equal(t, []SANMatcher{ + {Type: SANTypeDNS, Pattern: `example-1\..*`}, + {Type: SANTypeDNS, Pattern: `.*\.example-2`}, + }, o.DownstreamMTLS.MatchSubjectAltNames) +} + func Test_Checksum(t *testing.T) { o := NewDefaultOptions() @@ -738,6 +759,62 @@ func TestDeprecatedClientCAOptions(t *testing.T) { }) } +func TestHasAnyDownstreamMTLSClientCA(t *testing.T) { + t.Parallel() + + cases := []struct { + label string + opts *Options + expected bool + }{ + {"zero", &Options{}, false}, + {"default", NewDefaultOptions(), false}, + {"no client CAs", &Options{ + Policies: []Policy{ + {From: "https://example.com/one"}, + {From: "https://example.com/two"}, + {From: "https://example.com/three"}, + }, + }, false}, + {"global client CA only", &Options{ + DownstreamMTLS: DownstreamMTLSSettings{CA: "ZmFrZSBDQQ=="}, + Policies: []Policy{ + {From: "https://example.com/one"}, + {From: "https://example.com/two"}, + {From: "https://example.com/three"}, + }, + }, true}, + {"per-route CA only", &Options{ + Policies: []Policy{ + {From: "https://example.com/one"}, + { + From: "https://example.com/two", + TLSDownstreamClientCA: "ZmFrZSBDQQ==", + }, + {From: "https://example.com/three"}, + }, + }, true}, + {"both global and per-route client CAs", &Options{ + DownstreamMTLS: DownstreamMTLSSettings{CA: "ZmFrZSBDQQ=="}, + Policies: []Policy{ + {From: "https://example.com/one"}, + { + From: "https://example.com/two", + TLSDownstreamClientCA: "ZmFrZSBDQQ==", + }, + {From: "https://example.com/three"}, + }, + }, true}, + } + for i := range cases { + c := &cases[i] + t.Run(c.label, func(t *testing.T) { + actual := c.opts.HasAnyDownstreamMTLSClientCA() + assert.Equal(t, c.expected, actual) + }) + } +} + func TestOptions_DefaultURL(t *testing.T) { t.Parallel()