diff --git a/authorize/check_response.go b/authorize/check_response.go index 97004c36a..aee0524cf 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -75,7 +75,7 @@ func (a *Authorize) handleResultAllowed( _ *envoy_service_auth_v3.CheckRequest, result *evaluator.Result, ) (*envoy_service_auth_v3.CheckResponse, error) { - return a.okResponse(result.Headers), nil + return a.okResponse(result.Headers, result.HeadersToRemove), nil } func (a *Authorize) handleResultDenied( @@ -115,12 +115,13 @@ func invalidClientCertReason(reasons criteria.Reasons) bool { reasons.Has(criteria.ReasonInvalidClientCertificate) } -func (a *Authorize) okResponse(headers http.Header) *envoy_service_auth_v3.CheckResponse { +func (a *Authorize) okResponse(headersToSet http.Header, headersToRemove []string) *envoy_service_auth_v3.CheckResponse { return &envoy_service_auth_v3.CheckResponse{ Status: &status.Status{Code: int32(codes.OK), Message: "OK"}, HttpResponse: &envoy_service_auth_v3.CheckResponse_OkResponse{ OkResponse: &envoy_service_auth_v3.OkHttpResponse{ - Headers: toEnvoyHeaders(headers), + Headers: toEnvoyHeaders(headersToSet), + HeadersToRemove: headersToRemove, }, }, } @@ -298,7 +299,7 @@ func (a *Authorize) requireWebAuthnResponse( // If we're already on a webauthn route, return OK. // https://github.com/pomerium/pomerium-console/issues/3210 if checkRequestURL.Path == urlutil.WebAuthnURLPath || checkRequestURL.Path == urlutil.DeviceEnrolledPath { - return a.okResponse(result.Headers), nil + return a.okResponse(result.Headers, result.HeadersToRemove), nil } if !a.shouldRedirect(in, request) { diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index 6c100541a..caa510e33 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -314,11 +314,26 @@ func TestAuthorize_okResponse(t *testing.T) { Status: &status.Status{Code: 0, Message: "OK"}, }, }, + { + "ok reply with headers to remove", + &evaluator.Result{ + Allow: evaluator.NewRuleResult(true), + HeadersToRemove: []string{"x-header-to-remove"}, + }, + &envoy_service_auth_v3.CheckResponse{ + Status: &status.Status{Code: 0, Message: "OK"}, + HttpResponse: &envoy_service_auth_v3.CheckResponse_OkResponse{ + OkResponse: &envoy_service_auth_v3.OkHttpResponse{ + HeadersToRemove: []string{"x-header-to-remove"}, + }, + }, + }, + }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got := a.okResponse(tc.reply.Headers) + got := a.okResponse(tc.reply.Headers, tc.reply.HeadersToRemove) assert.Equal(t, tc.want.Status.Code, got.Status.Code) assert.Equal(t, tc.want.Status.Message, got.Status.Message) want, _ := protojson.Marshal(tc.want.GetOkResponse()) diff --git a/authorize/evaluator/evaluator.go b/authorize/evaluator/evaluator.go index 5bcf79361..11a35384d 100644 --- a/authorize/evaluator/evaluator.go +++ b/authorize/evaluator/evaluator.go @@ -147,6 +147,7 @@ type Result struct { Allow RuleResult Deny RuleResult Headers http.Header + HeadersToRemove []string Traces []contextutil.PolicyEvaluationTrace AdditionalLogFields map[log.AuthorizeLogField]any } @@ -322,6 +323,7 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error) Allow: policyOutput.Allow, Deny: policyOutput.Deny, Headers: headersOutput.Headers, + HeadersToRemove: headersOutput.HeadersToRemove, Traces: policyOutput.Traces, AdditionalLogFields: headersOutput.AdditionalLogFields, } diff --git a/authorize/evaluator/headers_evaluator.go b/authorize/evaluator/headers_evaluator.go index 32a23bca3..977cedc2e 100644 --- a/authorize/evaluator/headers_evaluator.go +++ b/authorize/evaluator/headers_evaluator.go @@ -17,6 +17,7 @@ import ( // HeadersResponse is the output from the headers.rego script. type HeadersResponse struct { Headers http.Header + HeadersToRemove []string AdditionalLogFields map[log.AuthorizeLogField]any } diff --git a/authorize/evaluator/headers_evaluator_evaluation.go b/authorize/evaluator/headers_evaluator_evaluation.go index a8f8f5684..ae3347eee 100644 --- a/authorize/evaluator/headers_evaluator_evaluation.go +++ b/authorize/evaluator/headers_evaluator_evaluation.go @@ -123,7 +123,7 @@ func (e *headersEvaluatorEvaluation) fillMCPHeaders(ctx context.Context) (err er return nil } - e.response.Headers.Del("Authorization") + e.response.HeadersToRemove = append(e.response.HeadersToRemove, "Authorization") return nil } diff --git a/authorize/evaluator/headers_evaluator_test.go b/authorize/evaluator/headers_evaluator_test.go index 4380a322a..122accf08 100644 --- a/authorize/evaluator/headers_evaluator_test.go +++ b/authorize/evaluator/headers_evaluator_test.go @@ -521,7 +521,7 @@ func TestHeadersEvaluator(t *testing.T) { }) require.NoError(t, err) // Should delete Authorization header when no upstream OAuth2 is configured - assert.Empty(t, output.Headers.Get("Authorization")) + assert.Contains(t, output.HeadersToRemove, "Authorization") }) t.Run("no mcp config", func(t *testing.T) { diff --git a/authorize/grpc.go b/authorize/grpc.go index 7d0316910..8190ddf44 100644 --- a/authorize/grpc.go +++ b/authorize/grpc.go @@ -84,7 +84,7 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe a.logAuthorizeCheck(ctx, req, &evaluator.Result{ Allow: evaluator.NewRuleResult(true, criteria.ReasonMCPHandshake), }, s, u) - return a.okResponse(make(http.Header)), nil + return a.okResponse(make(http.Header), nil), nil } }