diff --git a/authorize/check_response.go b/authorize/check_response.go index 2e931a04b..2990f0a11 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -2,12 +2,14 @@ package authorize import ( "context" + "encoding/json" "errors" "io" "net/http" "net/http/httptest" "net/url" "sort" + "strconv" "strings" envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" @@ -132,34 +134,51 @@ func (a *Authorize) deniedResponse( ) (*envoy_service_auth_v3.CheckResponse, error) { respHeader := []*envoy_config_core_v3.HeaderValueOption{} - // create a http response writer recorder - w := httptest.NewRecorder() - r := getHTTPRequestFromCheckRequest(in) + var respBody []byte + switch { + case isJSONWebRequest(in): + respBody, _ = json.Marshal(map[string]any{ + "error": reason, + "request_id": requestid.FromContext(ctx), + }) + respHeader = append(respHeader, + mkHeader("Content-Type", "application/json")) + case isGRPCWebRequest(in): + respHeader = append(respHeader, + mkHeader("Content-Type", "application/grpc-web+json"), + mkHeader("grpc-status", strconv.Itoa(int(codes.Unauthenticated))), + mkHeader("grpc-message", codes.Unauthenticated.String())) + default: + // create a http response writer recorder + w := httptest.NewRecorder() + r := getHTTPRequestFromCheckRequest(in) - // build the user info / debug endpoint - debugEndpoint, _ := a.userInfoEndpointURL(in) // if there's an error, we just wont display it + // build the user info / debug endpoint + debugEndpoint, _ := a.userInfoEndpointURL(in) // if there's an error, we just wont display it - // run the request through our go error handler - httpErr := httputil.HTTPError{ - Status: int(code), - Err: errors.New(reason), - DebugURL: debugEndpoint, - RequestID: requestid.FromContext(ctx), - BrandingOptions: a.currentOptions.Load().BrandingOptions, + // run the request through our go error handler + httpErr := httputil.HTTPError{ + Status: int(code), + Err: errors.New(reason), + DebugURL: debugEndpoint, + RequestID: requestid.FromContext(ctx), + BrandingOptions: a.currentOptions.Load().BrandingOptions, + } + httpErr.ErrorResponse(ctx, w, r) + + // transpose the go http response writer into a envoy response + resp := w.Result() + defer resp.Body.Close() + + var err error + respBody, err = io.ReadAll(resp.Body) + if err != nil { + log.Error(ctx).Err(err).Msg("error executing error template") + return nil, err + } + // convert go headers to envoy headers + respHeader = append(respHeader, toEnvoyHeaders(resp.Header)...) } - httpErr.ErrorResponse(ctx, w, r) - - // transpose the go http response writer into a envoy response - resp := w.Result() - defer resp.Body.Close() - - respBody, err := io.ReadAll(resp.Body) - if err != nil { - log.Error(ctx).Err(err).Msg("error executing error template") - return nil, err - } - // convert go headers to envoy headers - respHeader = append(respHeader, toEnvoyHeaders(resp.Header)...) // add any additional headers for k, v := range headers { @@ -333,3 +352,50 @@ func (a *Authorize) shouldRedirect(in *envoy_service_auth_v3.CheckRequest) bool return mediaType == "text/html" } + +func isGRPCWebRequest(in *envoy_service_auth_v3.CheckRequest) bool { + hdrs := in.GetAttributes().GetRequest().GetHttp().GetHeaders() + if hdrs == nil { + return false + } + + v := getHeader(hdrs, "Accept") + if v == "" { + return false + } + + accept, err := rfc7231.ParseAccept(v) + if err != nil { + return false + } + + return accept.Acceptable("application/grpc-web-text") +} + +func isJSONWebRequest(in *envoy_service_auth_v3.CheckRequest) bool { + hdrs := in.GetAttributes().GetRequest().GetHttp().GetHeaders() + if hdrs == nil { + return false + } + + v := getHeader(hdrs, "Accept") + if v == "" { + return false + } + + accept, err := rfc7231.ParseAccept(v) + if err != nil { + return false + } + + return accept.Acceptable("application/json") +} + +func getHeader(hdrs map[string]string, key string) string { + for k, v := range hdrs { + if strings.EqualFold(k, key) { + return v + } + } + return "" +} diff --git a/authorize/check_response_test.go b/authorize/check_response_test.go index 0e67833c1..69200a37e 100644 --- a/authorize/check_response_test.go +++ b/authorize/check_response_test.go @@ -6,19 +6,17 @@ import ( "net/http/httptest" "testing" - envoy_config_core_v3 "github.com/envoyproxy/go-control-plane/envoy/config/core/v3" envoy_service_auth_v3 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v3" - envoy_type_v3 "github.com/envoyproxy/go-control-plane/envoy/type/v3" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/genproto/googleapis/rpc/status" - "google.golang.org/grpc/codes" "google.golang.org/protobuf/encoding/protojson" "github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/atomicutil" + "github.com/pomerium/pomerium/internal/telemetry/requestid" "github.com/pomerium/pomerium/internal/testutil" hpke_handlers "github.com/pomerium/pomerium/pkg/hpke/handlers" "github.com/pomerium/pomerium/pkg/policy/criteria" @@ -182,6 +180,8 @@ func TestAuthorize_okResponse(t *testing.T) { } func TestAuthorize_deniedResponse(t *testing.T) { + t.Parallel() + a := &Authorize{currentOptions: config.NewAtomicOptions(), state: atomicutil.NewValue(new(authorizeState))} a.currentOptions.Store(&config.Options{ Policies: []config.Policy{{ @@ -192,48 +192,118 @@ func TestAuthorize_deniedResponse(t *testing.T) { }}, }) - tests := []struct { - name string - in *envoy_service_auth_v3.CheckRequest - code int32 - reason string - headers map[string]string - want *envoy_service_auth_v3.CheckResponse - }{ - { - "html denied", - nil, - http.StatusBadRequest, - "Access Denied", - nil, - &envoy_service_auth_v3.CheckResponse{ - Status: &status.Status{Code: int32(codes.PermissionDenied), Message: "Access Denied"}, - HttpResponse: &envoy_service_auth_v3.CheckResponse_DeniedResponse{ - DeniedResponse: &envoy_service_auth_v3.DeniedHttpResponse{ - Status: &envoy_type_v3.HttpStatus{ - Code: envoy_type_v3.StatusCode(codes.InvalidArgument), + t.Run("json", func(t *testing.T) { + t.Parallel() + ctx := context.Background() + ctx = requestid.WithValue(ctx, "REQUESTID") + + res, err := a.deniedResponse(ctx, &envoy_service_auth_v3.CheckRequest{ + Attributes: &envoy_service_auth_v3.AttributeContext{ + Request: &envoy_service_auth_v3.AttributeContext_Request{ + Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{ + Headers: map[string]string{ + "Accept": "application/json", }, - Headers: []*envoy_config_core_v3.HeaderValueOption{ - mkHeader("Content-Type", "text/html; charset=UTF-8"), - mkHeader("X-Pomerium-Intercepted-Response", "true"), - }, - Body: "Access Denied", }, }, }, - }, - } - for _, tc := range tests { - tc := tc - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - got, err := a.deniedResponse(context.TODO(), tc.in, tc.code, tc.reason, tc.headers) - require.NoError(t, err) - assert.Equal(t, tc.want.Status.Code, got.Status.Code) - assert.Equal(t, tc.want.Status.Message, got.Status.Message) - testutil.AssertProtoEqual(t, tc.want.GetDeniedResponse().GetHeaders(), got.GetDeniedResponse().GetHeaders()) - }) - } + }, http.StatusBadRequest, "ERROR", nil) + assert.NoError(t, err) + testutil.AssertProtoJSONEqual(t, `{ + "deniedResponse": { + "body": "{\"error\":\"ERROR\",\"request_id\":\"REQUESTID\"}", + "headers": [ + { + "appendAction": "OVERWRITE_IF_EXISTS_OR_ADD", + "header": { "key": "Content-Type", "value": "application/json" } + } + ], + "status": { + "code": "BadRequest" + } + }, + "status": { + "code": 7, + "message": "Access Denied" + } + }`, res) + }) + + t.Run("grpc-web", func(t *testing.T) { + t.Parallel() + ctx := context.Background() + ctx = requestid.WithValue(ctx, "REQUESTID") + + res, err := a.deniedResponse(ctx, &envoy_service_auth_v3.CheckRequest{ + Attributes: &envoy_service_auth_v3.AttributeContext{ + Request: &envoy_service_auth_v3.AttributeContext_Request{ + Http: &envoy_service_auth_v3.AttributeContext_HttpRequest{ + Headers: map[string]string{ + "Accept": "application/grpc-web-text", + }, + }, + }, + }, + }, http.StatusBadRequest, "ERROR", nil) + assert.NoError(t, err) + testutil.AssertProtoJSONEqual(t, `{ + "deniedResponse": { + "headers": [ + { + "appendAction": "OVERWRITE_IF_EXISTS_OR_ADD", + "header": { "key": "Content-Type", "value": "application/grpc-web+json" } + }, + { + "appendAction": "OVERWRITE_IF_EXISTS_OR_ADD", + "header": { "key": "grpc-status", "value": "16" } + }, + { + "appendAction": "OVERWRITE_IF_EXISTS_OR_ADD", + "header": { "key": "grpc-message", "value": "Unauthenticated" } + } + ], + "status": { + "code": "BadRequest" + } + }, + "status": { + "code": 7, + "message": "Access Denied" + } + }`, res) + }) + + t.Run("html", func(t *testing.T) { + t.Parallel() + ctx := context.Background() + ctx = requestid.WithValue(ctx, "REQUESTID") + + res, err := a.deniedResponse(ctx, &envoy_service_auth_v3.CheckRequest{}, http.StatusBadRequest, "ERROR", nil) + assert.NoError(t, err) + assert.Contains(t, res.GetDeniedResponse().GetBody(), "") + res.HttpResponse.(*envoy_service_auth_v3.CheckResponse_DeniedResponse).DeniedResponse.Body = "" + testutil.AssertProtoJSONEqual(t, `{ + "deniedResponse": { + "headers": [ + { + "appendAction": "OVERWRITE_IF_EXISTS_OR_ADD", + "header": { "key": "Content-Type", "value": "text/html; charset=UTF-8" } + }, + { + "appendAction": "OVERWRITE_IF_EXISTS_OR_ADD", + "header": { "key": "X-Pomerium-Intercepted-Response", "value": "true" } + } + ], + "status": { + "code": "BadRequest" + } + }, + "status": { + "code": 7, + "message": "Access Denied" + } + }`, res) + }) } func mustParseWeightedURLs(t *testing.T, urls ...string) []config.WeightedURL {