diff --git a/proxy/handlers.go b/proxy/handlers.go index bf55a8f90..c05c94618 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -3,6 +3,7 @@ package proxy import ( "encoding/base64" "fmt" + "io" "net/http" "net/url" @@ -22,6 +23,7 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router { // dashboard endpoints can be used by user's to view, or modify their session h.Path("/").HandlerFunc(p.UserDashboard).Methods(http.MethodGet) h.Path("/sign_out").HandlerFunc(p.SignOut).Methods(http.MethodGet, http.MethodPost) + h.Path("/jwt").Handler(httputil.HandlerFunc(p.jwtAssertion)).Methods(http.MethodGet) // Authenticate service callback handlers and middleware // callback used to set route-scoped session and redirect back to destination @@ -161,3 +163,25 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error w.Write([]byte(response)) return nil } + +func (p *Proxy) jwtAssertion(w http.ResponseWriter, r *http.Request) error { + res, err := p.authorizeCheck(r) + if err != nil { + return httputil.NewError(http.StatusInternalServerError, err) + } + + headers := append(res.GetOkResponse().GetHeaders(), res.GetDeniedResponse().GetHeaders()...) + for _, h := range headers { + if h.GetHeader().GetKey() == httputil.HeaderPomeriumJWTAssertion { + w.Header().Set("Content-Type", "application/jwt") + w.WriteHeader(http.StatusOK) + _, _ = io.WriteString(w, h.GetHeader().GetValue()) + return nil + } + } + + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusNotFound) + _, _ = io.WriteString(w, "jwt not found") + return nil +} diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 5cf9ae264..5231d7156 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -11,6 +11,10 @@ import ( "testing" "time" + envoy_api_v2_core "github.com/envoyproxy/go-control-plane/envoy/api/v2/core" + envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2" + "github.com/stretchr/testify/assert" + mstore "github.com/pomerium/pomerium/internal/sessions/mock" "github.com/pomerium/pomerium/pkg/cryptutil" @@ -64,6 +68,39 @@ func TestProxy_Signout(t *testing.T) { } } +func TestProxy_jwt(t *testing.T) { + authzClient := &mockCheckClient{ + response: &envoy_service_auth_v2.CheckResponse{ + HttpResponse: &envoy_service_auth_v2.CheckResponse_OkResponse{ + OkResponse: &envoy_service_auth_v2.OkHttpResponse{ + Headers: []*envoy_api_v2_core.HeaderValueOption{ + {Header: &envoy_api_v2_core.HeaderValue{ + Key: httputil.HeaderPomeriumJWTAssertion, + Value: "MOCK_JWT", + }}, + }, + }, + }, + }, + } + + req, _ := http.NewRequest("GET", "https://www.example.com/.pomerium/jwt", nil) + w := httptest.NewRecorder() + + proxy := &Proxy{ + state: newAtomicProxyState(&proxyState{ + authzClient: authzClient, + }), + } + err := proxy.jwtAssertion(w, req) + if !assert.NoError(t, err) { + return + } + + assert.Equal(t, "application/jwt", w.Header().Get("Content-Type")) + assert.Equal(t, w.Body.String(), "MOCK_JWT") +} + func TestProxy_UserDashboard(t *testing.T) { opts := testOptions(t) err := ValidateOptions(opts) diff --git a/proxy/middleware.go b/proxy/middleware.go index fed252a5e..28328422c 100644 --- a/proxy/middleware.go +++ b/proxy/middleware.go @@ -22,6 +22,28 @@ type authorizeResponse struct { } func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorizeResponse, error) { + res, err := p.authorizeCheck(r) + if err != nil { + return nil, httputil.NewError(http.StatusInternalServerError, err) + } + + ar := &authorizeResponse{} + switch res.HttpResponse.(type) { + case *envoy_service_auth_v2.CheckResponse_OkResponse: + for _, hdr := range res.GetOkResponse().GetHeaders() { + w.Header().Set(hdr.GetHeader().GetKey(), hdr.GetHeader().GetValue()) + } + ar.authorized = true + ar.statusCode = res.GetStatus().Code + case *envoy_service_auth_v2.CheckResponse_DeniedResponse: + ar.statusCode = int32(res.GetDeniedResponse().GetStatus().Code) + default: + ar.statusCode = http.StatusInternalServerError + } + return ar, nil +} + +func (p *Proxy) authorizeCheck(r *http.Request) (*envoy_service_auth_v2.CheckResponse, error) { state := p.state.Load() tm, err := ptypes.TimestampProto(time.Now()) @@ -45,7 +67,7 @@ func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorize httpAttrs.Path += "?" + r.URL.RawQuery } - res, err := state.authzClient.Check(r.Context(), &envoy_service_auth_v2.CheckRequest{ + return state.authzClient.Check(r.Context(), &envoy_service_auth_v2.CheckRequest{ Attributes: &envoy_service_auth_v2.AttributeContext{ Request: &envoy_service_auth_v2.AttributeContext_Request{ Time: tm, @@ -53,24 +75,6 @@ func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorize }, }, }) - if err != nil { - return nil, httputil.NewError(http.StatusInternalServerError, err) - } - - ar := &authorizeResponse{} - switch res.HttpResponse.(type) { - case *envoy_service_auth_v2.CheckResponse_OkResponse: - for _, hdr := range res.GetOkResponse().GetHeaders() { - w.Header().Set(hdr.GetHeader().GetKey(), hdr.GetHeader().GetValue()) - } - ar.authorized = true - ar.statusCode = res.GetStatus().Code - case *envoy_service_auth_v2.CheckResponse_DeniedResponse: - ar.statusCode = int32(res.GetDeniedResponse().GetStatus().Code) - default: - ar.statusCode = http.StatusInternalServerError - } - return ar, nil } // jwtClaimMiddleware logs and propagates JWT claim information via request headers