From c1b3b45d12e2fd4f599e07027cfa7b5804271e10 Mon Sep 17 00:00:00 2001 From: bobby <1544881+desimone@users.noreply.github.com> Date: Sat, 22 Aug 2020 10:02:12 -0700 Subject: [PATCH] proxy: remove unused handlers (#1317) proxy: remove unused handlers authenticate: remove unused references to refresh_token Signed-off-by: Bobby DeSimone --- authenticate/handlers.go | 17 +---- authenticate/handlers_test.go | 2 +- integration/control_plane_test.go | 18 ++++- internal/urlutil/query_params.go | 1 - proxy/handlers.go | 71 +++++--------------- proxy/handlers_test.go | 32 ++++++--- proxy/middleware.go | 45 ------------- proxy/middleware_test.go | 108 ------------------------------ scripts/programmatic_access.py | 4 +- 9 files changed, 63 insertions(+), 235 deletions(-) diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 1cf4694d7..f35b10c13 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -105,13 +105,11 @@ func (a *Authenticate) wellKnown(w http.ResponseWriter, r *http.Request) error { wellKnownURLS := struct { // URL string referencing the client's JSON Web Key (JWK) Set // RFC7517 document, which contains the client's public keys. - JSONWebKeySetURL string `json:"jwks_uri"` - OAuth2Callback string `json:"authentication_callback_endpoint"` - ProgrammaticRefreshAPI string `json:"api_refresh_endpoint"` + JSONWebKeySetURL string `json:"jwks_uri"` + OAuth2Callback string `json:"authentication_callback_endpoint"` }{ state.redirectURL.ResolveReference(&url.URL{Path: "/.well-known/pomerium/jwks.json"}).String(), state.redirectURL.ResolveReference(&url.URL{Path: "/oauth2/callback"}).String(), - state.redirectURL.ResolveReference(&url.URL{Path: "/api/v1/refresh"}).String(), } w.Header().Set("Content-Type", "application/json") w.Header().Set("X-Content-Type-Options", "nosniff") @@ -234,17 +232,6 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { if r.FormValue(urlutil.QueryIsProgrammatic) == "true" { newSession.Programmatic = true - - pbSession, err := session.Get(ctx, state.dataBrokerClient, s.ID) - if err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - - encSession, err := state.encryptedEncoder.Marshal(pbSession.GetOauthToken()) - if err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - callbackParams.Set(urlutil.QueryRefreshToken, string(encSession)) callbackParams.Set(urlutil.QueryIsProgrammatic, "true") } diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index d08a057aa..019d6ba03 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -548,7 +548,7 @@ func TestWellKnownEndpoint(t *testing.T) { rr := httptest.NewRecorder() h.ServeHTTP(rr, req) body := rr.Body.String() - expected := `{"jwks_uri":"https://auth.example.com/.well-known/pomerium/jwks.json","authentication_callback_endpoint":"https://auth.example.com/oauth2/callback","api_refresh_endpoint":"https://auth.example.com/api/v1/refresh"}` + expected := `{"jwks_uri":"https://auth.example.com/.well-known/pomerium/jwks.json","authentication_callback_endpoint":"https://auth.example.com/oauth2/callback"}` assert.Equal(t, body, expected) } diff --git a/integration/control_plane_test.go b/integration/control_plane_test.go index 659a84608..875d357e1 100644 --- a/integration/control_plane_test.go +++ b/integration/control_plane_test.go @@ -18,7 +18,7 @@ func TestDashboard(t *testing.T) { t.Run("user dashboard", func(t *testing.T) { client := testcluster.NewHTTPClient() - req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/.pomerium", nil) + req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/.pomerium/", nil) if err != nil { t.Fatal(err) } @@ -31,6 +31,22 @@ func TestDashboard(t *testing.T) { assert.Equal(t, http.StatusFound, res.StatusCode, "unexpected status code") }) + t.Run("dashboard strict slash redirect", func(t *testing.T) { + client := testcluster.NewHTTPClient() + + req, err := http.NewRequestWithContext(ctx, "GET", "https://httpdetails.localhost.pomerium.io/.pomerium", nil) + if err != nil { + t.Fatal(err) + } + + res, err := client.Do(req) + if !assert.NoError(t, err, "unexpected http error") { + return + } + defer res.Body.Close() + + assert.Equal(t, http.StatusMovedPermanently, res.StatusCode, "unexpected status code") + }) t.Run("image asset", func(t *testing.T) { client := testcluster.NewHTTPClient() diff --git a/internal/urlutil/query_params.go b/internal/urlutil/query_params.go index 524c0628a..f951ff835 100644 --- a/internal/urlutil/query_params.go +++ b/internal/urlutil/query_params.go @@ -14,7 +14,6 @@ const ( QuerySession = "pomerium_session" QuerySessionEncrypted = "pomerium_session_encrypted" QueryRedirectURI = "pomerium_redirect_uri" - QueryRefreshToken = "pomerium_refresh_token" QueryAccessTokenID = "pomerium_session_access_token_id" QueryAudience = "pomerium_session_audience" QueryProgrammaticToken = "pomerium_programmatic_token" diff --git a/proxy/handlers.go b/proxy/handlers.go index fee87f9a3..dd4db31f2 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -7,11 +7,9 @@ import ( "net/url" "github.com/gorilla/mux" - "github.com/pomerium/csrf" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/middleware" - "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" ) @@ -20,23 +18,7 @@ import ( func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router { h := r.PathPrefix(dashboardPath).Subrouter() h.Use(middleware.SetHeaders(httputil.HeadersContentSecurityPolicy)) - // 1. Retrieve the user session and add it to the request context - h.Use(func(h http.Handler) http.Handler { - return sessions.RetrieveSession(p.state.Load().sessionStore)(h) - }) - // 2. AuthN - Verify the user is authenticated. Set email, group, & id headers - h.Use(p.AuthenticateSession) - // 3. Enforce CSRF protections for any non-idempotent http method - h.Use(func(h http.Handler) http.Handler { - opts := p.currentOptions.Load() - state := p.state.Load() - return csrf.Protect( - state.cookieSecret, - csrf.Secure(opts.CookieSecure), - csrf.CookieName(fmt.Sprintf("%s_csrf", opts.CookieName)), - csrf.ErrorHandler(httputil.HandlerFunc(httputil.CSRFFailureHandler)), - )(h) - }) + // dashboard endpoints can be used by user's to view, or modify their session h.Path("/").HandlerFunc(p.UserDashboard).Methods(http.MethodGet) h.Path("/sign_out").HandlerFunc(p.SignOut).Methods(http.MethodGet, http.MethodPost) @@ -48,13 +30,8 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router { c.Use(func(h http.Handler) http.Handler { return middleware.ValidateSignature(p.state.Load().sharedKey)(h) }) - - c.Path("/"). - Handler(httputil.HandlerFunc(p.ProgrammaticCallback)). - Methods(http.MethodGet). - Queries(urlutil.QueryIsProgrammatic, "true") - c.Path("/").Handler(httputil.HandlerFunc(p.Callback)).Methods(http.MethodGet) + // Programmatic API handlers and middleware a := r.PathPrefix(dashboardPath + "/api").Subrouter() // login api handler generates a user-navigable login url to authenticate @@ -92,7 +69,7 @@ func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) { httputil.Redirect(w, r, urlutil.NewSignedURL(state.sharedKey, &signoutURL).String(), http.StatusFound) } -// UserDashboard redirects to the authenticate dasbhoard. +// UserDashboard redirects to the authenticate dashboard. func (p *Proxy) UserDashboard(w http.ResponseWriter, r *http.Request) { state := p.state.Load() @@ -115,10 +92,23 @@ func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) error { redirectURLString := r.FormValue(urlutil.QueryRedirectURI) encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted) - if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil { + redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString) + if err != nil { return httputil.NewError(http.StatusBadRequest, err) } - httputil.Redirect(w, r, redirectURLString, http.StatusFound) + + rawJWT, err := p.saveCallbackSession(w, r, encryptedSession) + if err != nil { + return httputil.NewError(http.StatusBadRequest, err) + } + + // if programmatic, encode the session jwt as a query param + if isProgrammatic := r.FormValue(urlutil.QueryIsProgrammatic); isProgrammatic == "true" { + q := redirectURL.Query() + q.Set(urlutil.QueryPomeriumJWT, string(rawJWT)) + redirectURL.RawQuery = q.Encode() + } + httputil.Redirect(w, r, redirectURL.String(), http.StatusFound) return nil } @@ -168,28 +158,3 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error w.Write([]byte(response)) return nil } - -// ProgrammaticCallback handles a successful call to the authenticate service. -// In addition to returning the individual route session (JWT) it also returns -// the refresh token. -func (p *Proxy) ProgrammaticCallback(w http.ResponseWriter, r *http.Request) error { - redirectURLString := r.FormValue(urlutil.QueryRedirectURI) - encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted) - - redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString) - if err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - - rawJWT, err := p.saveCallbackSession(w, r, encryptedSession) - if err != nil { - return httputil.NewError(http.StatusBadRequest, err) - } - - q := redirectURL.Query() - q.Set(urlutil.QueryPomeriumJWT, string(rawJWT)) - q.Set(urlutil.QueryRefreshToken, r.FormValue(urlutil.QueryRefreshToken)) - redirectURL.RawQuery = q.Encode() - httputil.Redirect(w, r, redirectURL.String(), http.StatusFound) - return nil -} diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 1871c3a3c..5cf9ae264 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -64,6 +64,29 @@ func TestProxy_Signout(t *testing.T) { } } +func TestProxy_UserDashboard(t *testing.T) { + opts := testOptions(t) + err := ValidateOptions(opts) + if err != nil { + t.Fatal(err) + } + proxy, err := New(&config.Config{Options: opts}) + if err != nil { + t.Fatal(err) + } + req := httptest.NewRequest(http.MethodGet, "/.pomerium/sign_out", nil) + rr := httptest.NewRecorder() + proxy.UserDashboard(rr, req) + if status := rr.Code; status != http.StatusFound { + t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusFound) + } + body := rr.Body.String() + want := proxy.state.Load().authenticateURL.String() + if !strings.Contains(body, want) { + t.Errorf("handler returned unexpected body: got %v want %s ", body, want) + } +} + func TestProxy_SignOut(t *testing.T) { t.Parallel() tests := []struct { @@ -105,13 +128,6 @@ func TestProxy_SignOut(t *testing.T) { }) } } -func uriParseHelper(s string) *url.URL { - uri, err := url.Parse(s) - if err != nil { - panic(err) - } - return uri -} func TestProxy_Callback(t *testing.T) { t.Parallel() @@ -464,7 +480,7 @@ func TestProxy_ProgrammaticCallback(t *testing.T) { } w := httptest.NewRecorder() - httputil.HandlerFunc(p.ProgrammaticCallback).ServeHTTP(w, r) + httputil.HandlerFunc(p.Callback).ServeHTTP(w, r) if status := w.Code; status != tt.wantStatus { t.Errorf("status code: got %v want %v", status, tt.wantStatus) t.Errorf("\n%+v", w.Body.String()) diff --git a/proxy/middleware.go b/proxy/middleware.go index 7c3bc8545..fed252a5e 100644 --- a/proxy/middleware.go +++ b/proxy/middleware.go @@ -14,8 +14,6 @@ import ( "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/sessions" - "github.com/pomerium/pomerium/internal/telemetry/trace" - "github.com/pomerium/pomerium/internal/urlutil" ) type authorizeResponse struct { @@ -23,35 +21,6 @@ type authorizeResponse struct { statusCode int32 } -// AuthenticateSession is middleware to enforce a valid authentication -// session state is retrieved from the users's request context. -func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler { - return httputil.HandlerFunc(func(w http.ResponseWriter, r *http.Request) error { - ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession") - defer span.End() - - if _, err := sessions.FromContext(ctx); err != nil { - log.FromRequest(r).Debug().Err(err).Msg("proxy: session state") - return p.redirectToSignin(w, r) - } - next.ServeHTTP(w, r.WithContext(ctx)) - return nil - }) -} - -func (p *Proxy) redirectToSignin(w http.ResponseWriter, r *http.Request) error { - state := p.state.Load() - - signinURL := *state.authenticateSigninURL - q := signinURL.Query() - q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String()) - signinURL.RawQuery = q.Encode() - log.FromRequest(r).Debug().Str("url", signinURL.String()).Msg("proxy: redirectToSignin") - httputil.Redirect(w, r, urlutil.NewSignedURL(state.sharedKey, &signinURL).String(), http.StatusFound) - state.sessionStore.ClearSession(w, r) - return nil -} - func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorizeResponse, error) { state := p.state.Load() @@ -104,20 +73,6 @@ func (p *Proxy) isAuthorized(w http.ResponseWriter, r *http.Request) (*authorize return ar, nil } -// SetResponseHeaders sets a map of response headers. -func SetResponseHeaders(headers map[string]string) func(next http.Handler) http.Handler { - return func(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - ctx, span := trace.StartSpan(r.Context(), "proxy.SetResponseHeaders") - defer span.End() - for key, val := range headers { - r.Header.Set(key, val) - } - next.ServeHTTP(w, r.WithContext(ctx)) - }) - } -} - // jwtClaimMiddleware logs and propagates JWT claim information via request headers // // if returnJWTInfo is set to true, it will also return JWT claim information in the response diff --git a/proxy/middleware_test.go b/proxy/middleware_test.go index 623236e5b..b7f79eb22 100644 --- a/proxy/middleware_test.go +++ b/proxy/middleware_test.go @@ -1,89 +1,17 @@ package proxy import ( - "errors" - "fmt" "net/http" "net/http/httptest" - "strings" "testing" "time" - "github.com/google/go-cmp/cmp" "gopkg.in/square/go-jose.v2/jwt" - "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/jws" - "github.com/pomerium/pomerium/internal/encoding/mock" - "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/sessions" - mstore "github.com/pomerium/pomerium/internal/sessions/mock" ) -func TestProxy_AuthenticateSession(t *testing.T) { - t.Parallel() - fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - fmt.Fprint(w, http.StatusText(http.StatusOK)) - w.WriteHeader(http.StatusOK) - }) - - tests := []struct { - name string - refreshRespStatus int - errOnFailure bool - session sessions.SessionStore - ctxError error - provider identity.Authenticator - encoder encoding.MarshalUnmarshaler - refreshURL string - - wantStatus int - }{ - {"good", 200, false, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, nil, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusOK}, - {"invalid session", 200, false, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Second))}}, errors.New("hi"), identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound}, - {"expired", 200, false, &mstore.Store{Session: &sessions.State{Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{}, &mock.Encoder{}, "", http.StatusFound}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(tt.refreshRespStatus) - fmt.Fprintln(w, "REFRESH GOOD") - })) - defer ts.Close() - rURL := ts.URL - if tt.refreshURL != "" { - rURL = tt.refreshURL - } - - a := Proxy{ - state: newAtomicProxyState(&proxyState{ - sharedKey: "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=", - cookieSecret: []byte("80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ="), - authenticateURL: uriParseHelper("https://authenticate.corp.example"), - authenticateSigninURL: uriParseHelper("https://authenticate.corp.example/sign_in"), - authenticateRefreshURL: uriParseHelper(rURL), - sessionStore: tt.session, - encoder: tt.encoder, - }), - } - r := httptest.NewRequest(http.MethodGet, "/", nil) - state, _ := tt.session.LoadSession(r) - ctx := r.Context() - ctx = sessions.NewContext(ctx, state, tt.ctxError) - r = r.WithContext(ctx) - r.Header.Set("Accept", "application/json") - w := httptest.NewRecorder() - got := a.jwtClaimMiddleware(false)(a.AuthenticateSession(fn)) - got.ServeHTTP(w, r) - if status := w.Code; status != tt.wantStatus { - t.Errorf("AuthenticateSession() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) - } - - }) - } -} - func Test_jwtClaimMiddleware(t *testing.T) { claimHeaders := []string{"email", "groups", "missing"} sharedKey := "80ldlrU2d7w+wVpKNfevk6fmb8otEx6CqOfshj2LwhQ=" @@ -125,39 +53,3 @@ func Test_jwtClaimMiddleware(t *testing.T) { }) } - -func TestProxy_SetResponseHeaders(t *testing.T) { - t.Parallel() - fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - var sb strings.Builder - for k, v := range r.Header { - k = strings.ToLower(k) - for _, h := range v { - sb.WriteString(fmt.Sprintf("%v: %v\n", k, h)) - } - } - fmt.Fprint(w, sb.String()) - w.WriteHeader(http.StatusOK) - }) - tests := []struct { - name string - setHeaders map[string]string - wantHeaders string - }{ - {"good", map[string]string{"x-gonna": "give-it-to-ya"}, "x-gonna: give-it-to-ya\n"}, - {"nil", nil, ""}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - - r := httptest.NewRequest(http.MethodGet, "/", nil) - w := httptest.NewRecorder() - got := SetResponseHeaders(tt.setHeaders)(fn) - got.ServeHTTP(w, r) - if diff := cmp.Diff(w.Body.String(), tt.wantHeaders); diff != "" { - t.Errorf("SetResponseHeaders() :\n %s", diff) - } - }) - } -} diff --git a/scripts/programmatic_access.py b/scripts/programmatic_access.py index 358aa9aec..5b09d8586 100755 --- a/scripts/programmatic_access.py +++ b/scripts/programmatic_access.py @@ -25,9 +25,8 @@ args = parser.parse_args() class PomeriumSession: - def __init__(self, jwt, refresh_token): + def __init__(self, jwt): self.jwt = jwt - self.refresh_token = refresh_token def to_json(self): return json.dumps(self.__dict__, indent=2) @@ -55,7 +54,6 @@ class Callback(http.server.BaseHTTPRequestHandler): path_qp = urllib.parse.parse_qs(path) session = PomeriumSession( path_qp.get("pomerium_jwt")[0], - path_qp.get("pomerium_refresh_token")[0], ) done = True response = session.to_json().encode()