diff --git a/proxy/forward_auth.go b/proxy/forward_auth.go index 5b64dbfa9..c36a8a80b 100644 --- a/proxy/forward_auth.go +++ b/proxy/forward_auth.go @@ -101,38 +101,31 @@ func (p *Proxy) Verify(verifyOnly bool) http.Handler { return httputil.NewError(http.StatusForbidden, errors.New(http.StatusText(http.StatusForbidden))) } - // the route to validate will be pulled from the uri queryparam - // or inferred from forwarding headers - uriString := r.FormValue("uri") - if uriString == "" { - if r.Header.Get(httputil.HeaderForwardedProto) == "" || r.Header.Get(httputil.HeaderForwardedHost) == "" { - return httputil.NewError(http.StatusBadRequest, errors.New("no uri to validate")) - } - uriString = r.Header.Get(httputil.HeaderForwardedProto) + "://" + - r.Header.Get(httputil.HeaderForwardedHost) + - r.Header.Get(httputil.HeaderForwardedURI) - } - - uri, err := urlutil.ParseAndValidateURL(uriString) + uri, err := getURIStringFromRequest(r) if err != nil { return httputil.NewError(http.StatusBadRequest, err) } - authorized, err := p.isAuthorized(w, r) + ar, err := p.isAuthorized(w, r) if err != nil { return httputil.NewError(http.StatusBadRequest, err) } - if authorized { + if ar.authorized { w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.WriteHeader(http.StatusOK) fmt.Fprintf(w, "Access to %s is allowed.", uri.Host) return nil } + unAuthenticated := ar.statusCode == http.StatusUnauthorized + if unAuthenticated { + p.sessionStore.ClearSession(w, r) + } + _, err = sessions.FromContext(r.Context()) hasSession := err == nil - if hasSession { + if hasSession && !unAuthenticated { return httputil.NewError(http.StatusForbidden, errors.New("access denied")) } @@ -140,19 +133,46 @@ func (p *Proxy) Verify(verifyOnly bool) http.Handler { return httputil.NewError(http.StatusUnauthorized, err) } - // Traefik set the uri in the header, we must add it to redirect uri if present. Otherwise, request like - // https://example.com/foo will be redirected to https://example.com after authentication. - if xfu := r.Header.Get(httputil.HeaderForwardedURI); xfu != "" { - uri.Path += xfu - } - // redirect to authenticate - authN := *p.authenticateSigninURL - q := authN.Query() - q.Set(urlutil.QueryCallbackURI, uri.String()) - q.Set(urlutil.QueryRedirectURI, uri.String()) // final destination - q.Set(urlutil.QueryForwardAuth, urlutil.StripPort(r.Host)) // add fwd auth to trusted audience - authN.RawQuery = q.Encode() - httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &authN).String(), http.StatusFound) + p.forwardAuthRedirectToSignInWithURI(w, r, uri) return nil }) } + +// forwardAuthRedirectToSignInWithURI redirects request to authenticate signin url, +// with all necessary information extracted from given input uri. +func (p *Proxy) forwardAuthRedirectToSignInWithURI(w http.ResponseWriter, r *http.Request, uri *url.URL) { + // Traefik set the uri in the header, we must add it to redirect uri if present. Otherwise, request like + // https://example.com/foo will be redirected to https://example.com after authentication. + if xfu := r.Header.Get(httputil.HeaderForwardedURI); xfu != "" { + uri.Path += xfu + } + + // redirect to authenticate + authN := *p.authenticateSigninURL + q := authN.Query() + q.Set(urlutil.QueryCallbackURI, uri.String()) + q.Set(urlutil.QueryRedirectURI, uri.String()) // final destination + q.Set(urlutil.QueryForwardAuth, urlutil.StripPort(r.Host)) // add fwd auth to trusted audience + authN.RawQuery = q.Encode() + httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &authN).String(), http.StatusFound) +} + +func getURIStringFromRequest(r *http.Request) (*url.URL, error) { + // the route to validate will be pulled from the uri queryparam + // or inferred from forwarding headers + uriString := r.FormValue("uri") + if uriString == "" { + if r.Header.Get(httputil.HeaderForwardedProto) == "" || r.Header.Get(httputil.HeaderForwardedHost) == "" { + return nil, errors.New("no uri to validate") + } + uriString = r.Header.Get(httputil.HeaderForwardedProto) + "://" + + r.Header.Get(httputil.HeaderForwardedHost) + + r.Header.Get(httputil.HeaderForwardedURI) + } + + uri, err := urlutil.ParseAndValidateURL(uriString) + if err != nil { + return nil, err + } + return uri, nil +} diff --git a/proxy/forward_auth_test.go b/proxy/forward_auth_test.go index 067307193..305b869a6 100644 --- a/proxy/forward_auth_test.go +++ b/proxy/forward_auth_test.go @@ -10,7 +10,9 @@ import ( envoy_service_auth_v2 "github.com/envoyproxy/go-control-plane/envoy/service/auth/v2" "github.com/google/go-cmp/cmp" + "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/grpc" + "google.golang.org/grpc/codes" "gopkg.in/square/go-jose.v2/jwt" "github.com/pomerium/pomerium/config" @@ -37,6 +39,7 @@ func TestProxy_ForwardAuth(t *testing.T) { allowClient := &mockCheckClient{ response: &envoy_service_auth_v2.CheckResponse{ + Status: &status.Status{Code: int32(codes.OK), Message: "OK"}, HttpResponse: &envoy_service_auth_v2.CheckResponse_OkResponse{}, }, } diff --git a/proxy/middleware.go b/proxy/middleware.go index 95649e149..892a4477c 100644 --- a/proxy/middleware.go +++ b/proxy/middleware.go @@ -18,6 +18,11 @@ import ( "github.com/pomerium/pomerium/internal/urlutil" ) +type authorizeResponse struct { + authorized bool + statusCode int32 +} + // AuthenticateSession is middleware to enforce a valid authentication // session state is retrieved from the users's request context. func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler { @@ -45,10 +50,10 @@ func (p *Proxy) redirectToSignin(w http.ResponseWriter, r *http.Request) error { return nil } -func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (bool, error) { +func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorizeResponse, error) { tm, err := ptypes.TimestampProto(time.Now()) if err != nil { - return false, httputil.NewError(http.StatusInternalServerError, fmt.Errorf("error creating protobuf timestamp from current time: %w", err)) + return nil, httputil.NewError(http.StatusInternalServerError, fmt.Errorf("error creating protobuf timestamp from current time: %w", err)) } httpAttrs := &envoy_service_auth_v2.AttributeContext_HttpRequest{ @@ -76,18 +81,23 @@ func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (bool, erro }, }) if err != nil { - return false, httputil.NewError(http.StatusInternalServerError, err) + return nil, httputil.NewError(http.StatusInternalServerError, err) } + ar := &authorizeResponse{} switch res.HttpResponse.(type) { case *envoy_service_auth_v2.CheckResponse_OkResponse: for _, hdr := range res.GetOkResponse().GetHeaders() { w.Header().Set(hdr.GetHeader().GetKey(), hdr.GetHeader().GetValue()) } - return true, nil + ar.authorized = true + ar.statusCode = res.GetStatus().Code + case *envoy_service_auth_v2.CheckResponse_DeniedResponse: + ar.statusCode = int32(res.GetDeniedResponse().GetStatus().Code) default: - return false, nil + ar.statusCode = http.StatusInternalServerError } + return ar, nil } // SetResponseHeaders sets a map of response headers.