diff --git a/authenticate/handlers.go b/authenticate/handlers.go index acfe51190..6081627c8 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -24,10 +24,10 @@ import ( // CSPHeaders are the content security headers added to the service's handlers // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src var CSPHeaders = map[string]string{ - "Content-Security-Policy": "default-src 'none'; style-src 'self'" + - " 'sha256-z9MsgkMbQjRSLxzAfN55jB3a9pP0PQ4OHFH8b4iDP6s=' " + - " 'sha256-qnVkQSG7pWu17hBhIw0kCpfEB3XGvt0mNRa6+uM6OUU=' " + - " 'sha256-qOdRsNZhtR+htazbcy7guQl3Cn1cqOw1FcE4d3llae0='; " + + "Content-Security-Policy": "default-src 'none'; style-src " + + "'sha256-spMkVDoBBY86p0RC1fBYwdnGyMypJM8eG57+p3VASyk=' " + + "'sha256-qnVkQSG7pWu17hBhIw0kCpfEB3XGvt0mNRa6+uM6OUU=' " + + "'sha256-qOdRsNZhtR+htazbcy7guQl3Cn1cqOw1FcE4d3llae0=';" + "img-src 'self';", "Referrer-Policy": "Same-origin", } @@ -54,7 +54,8 @@ func (a *Authenticate) Handler() http.Handler { v := r.PathPrefix("/.pomerium").Subrouter() c := cors.New(cors.Options{ AllowOriginRequestFunc: func(r *http.Request, _ string) bool { - return middleware.ValidateRedirectURI(r, a.sharedKey) + err := middleware.ValidateRequestURL(r, a.sharedKey) + return err == nil }, AllowCredentials: true, AllowedHeaders: []string{"*"}, @@ -111,71 +112,84 @@ func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessio // RobotsTxt handles the /robots.txt route. func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") w.WriteHeader(http.StatusOK) fmt.Fprintf(w, "User-agent: *\nDisallow: /") } // SignIn handles to authenticating a user. func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { - // grab and parse our redirect_uri - redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")) + redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) if err != nil { httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) return } - // create a clone of the redirect URI, unless this is a programmatic request - // in which case we will redirect back to proxy's callback endpoint - callbackURL, _ := urlutil.DeepCopy(redirectURL) - q := redirectURL.Query() + jwtAudience := []string{a.RedirectURL.Hostname(), redirectURL.Hostname()} - if q.Get("pomerium_programmatic_destination_url") != "" { - callbackURL, err = urlutil.ParseAndValidateURL(q.Get("pomerium_programmatic_destination_url")) + var callbackURL *url.URL + // if the callback is explicitly set, set it and add an additional audience + if callbackStr := r.FormValue(urlutil.QueryCallbackURI); callbackStr != "" { + callbackURL, err = urlutil.ParseAndValidateURL(callbackStr) if err != nil { - httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return } + jwtAudience = append(jwtAudience, callbackURL.Hostname()) + } else { + // otherwise, assume callback is the same host as redirect + callbackURL, _ = urlutil.DeepCopy(redirectURL) + callbackURL.Path = "/.pomerium/callback/" + callbackURL.RawQuery = "" } + + // add an additional claim for the forward-auth host, if set + if fwdAuth := r.FormValue(urlutil.QueryForwardAuth); fwdAuth != "" { + jwtAudience = append(jwtAudience, fwdAuth) + } + s, err := sessions.FromContext(r.Context()) if err != nil { - httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return } - s.SetImpersonation(q.Get("impersonate_email"), q.Get("impersonate_group")) - newSession := s.NewSession(a.RedirectURL.Host, []string{a.RedirectURL.Host, callbackURL.Host}) - if q.Get("pomerium_programmatic_destination_url") != "" { + s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups)) + + newSession := s.NewSession(a.RedirectURL.Host, jwtAudience) + + callbackParams := callbackURL.Query() + + if r.FormValue(urlutil.QueryIsProgrammatic) == "true" { newSession.Programmatic = true encSession, err := a.encryptedEncoder.Marshal(newSession) if err != nil { - httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return } - q.Set("pomerium_refresh_token", string(encSession)) + callbackParams.Set(urlutil.QueryRefreshToken, string(encSession)) + callbackParams.Set(urlutil.QueryIsProgrammatic, "true") } // sign the route session, as a JWT signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession(DefaultSessionDuration)) if err != nil { - httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return } + // encrypt our route-based token JWT avoiding any accidental logging encryptedJWT := cryptutil.Encrypt(a.sharedCipher, signedJWT, nil) // base64 our encrypted payload for URL-friendlyness encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT) // add our encoded and encrypted route-session JWT to a query param - q.Set("pomerium_jwt", encodedJWT) - - redirectURL.RawQuery = q.Encode() - - callbackURL.Path = "/.pomerium/callback" + callbackParams.Set(urlutil.QuerySessionEncrypted, encodedJWT) + callbackParams.Set(urlutil.QueryRedirectURI, redirectURL.String()) + callbackURL.RawQuery = callbackParams.Encode() // build our hmac-d redirect URL with our session, pointing back to the // proxy's callback URL which is responsible for setting our new route-session - uri := urlutil.SignedRedirectURL(a.sharedKey, callbackURL, redirectURL) + uri := urlutil.NewSignedURL(a.sharedKey, callbackURL) httputil.Redirect(w, r, uri.String(), http.StatusFound) } @@ -193,7 +207,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) { httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err)) return } - redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")) + redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) if err != nil { httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) return diff --git a/authenticate/handlers_test.go b/authenticate/handlers_test.go index c856b4f99..df51752d2 100644 --- a/authenticate/handlers_test.go +++ b/authenticate/handlers_test.go @@ -16,6 +16,7 @@ import ( "github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/templates" + "github.com/pomerium/pomerium/internal/urlutil" "github.com/google/go-cmp/cmp" "golang.org/x/crypto/chacha20poly1305" @@ -108,14 +109,17 @@ func TestAuthenticate_SignIn(t *testing.T) { encoder encoding.MarshalUnmarshaler wantCode int }{ - {"good", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"session not valid", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"bad redirect uri query", "", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, - {"bad marshal", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, - {"session error", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, - {"good with different programmatic redirect", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, - {"encrypted encoder error", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, - {"good with different programmatic redirect", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, + {"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, + {"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, + {"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, + {"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, + {"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, + {"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, + {"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -139,8 +143,7 @@ func TestAuthenticate_SignIn(t *testing.T) { queryString.Set(k, v) } uri.RawQuery = queryString.Encode() - - r := httptest.NewRequest(http.MethodGet, "/?redirect_uri="+uri.String(), nil) + r := httptest.NewRequest(http.MethodGet, uri.String(), nil) r.Header.Set("Accept", "application/json") state, err := tt.session.LoadSession(r) ctx := r.Context() @@ -195,7 +198,7 @@ func TestAuthenticate_SignOut(t *testing.T) { params, _ := url.ParseQuery(u.RawQuery) params.Add("sig", tt.sig) params.Add("ts", tt.ts) - params.Add("redirect_uri", tt.redirectURL) + params.Add(urlutil.QueryRedirectURI, tt.redirectURL) u.RawQuery = params.Encode() r := httptest.NewRequest(tt.method, u.String(), nil) state, _ := tt.sessionStore.LoadSession(r) @@ -307,24 +310,26 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { t.Parallel() fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") fmt.Fprintln(w, "RVSI FILIVS CAISAR") w.WriteHeader(http.StatusOK) }) tests := []struct { - name string + name string + headers map[string]string + session sessions.SessionStore ctxError error provider identity.Authenticator wantStatus int }{ - {"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK}, - {"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound}, - {"good refresh expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound}, - {"expired,refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound}, - {"expired,save error", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound}, + {"good", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK}, + {"invalid session", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound}, + {"good refresh expired", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound}, + {"expired,refresh error", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound}, + {"expired,save error", nil, &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound}, + {"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -347,7 +352,11 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) { r = r.WithContext(ctx) r.Header.Set("Accept", "application/json") - + if len(tt.headers) != 0 { + for k, v := range tt.headers { + r.Header.Set(k, v) + } + } w := httptest.NewRecorder() got := a.VerifySession(fn) diff --git a/cmd/pomerium/main.go b/cmd/pomerium/main.go index 7b99a4c50..d6733a219 100644 --- a/cmd/pomerium/main.go +++ b/cmd/pomerium/main.go @@ -161,7 +161,7 @@ func newGlobalRouter(o *config.Options) *mux.Router { if len(o.Headers) != 0 { mux.Use(middleware.SetHeaders(o.Headers)) } - mux.Use(log.ForwardedAddrHandler("fwd_ip")) + mux.Use(log.HeadersHandler(httputil.HeadersXForwarded)) mux.Use(log.RemoteAddrHandler("ip")) mux.Use(log.UserAgentHandler("user_agent")) mux.Use(log.RefererHandler("referer")) diff --git a/config/options.go b/config/options.go index 5e3273b40..739e08f88 100644 --- a/config/options.go +++ b/config/options.go @@ -189,7 +189,6 @@ var defaultOptions = Options{ CookieName: "_pomerium", DefaultUpstreamTimeout: 30 * time.Second, Headers: map[string]string{ - "X-Content-Type-Options": "nosniff", "X-Frame-Options": "SAMEORIGIN", "X-XSS-Protection": "1; mode=block", "Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload", diff --git a/config/options_test.go b/config/options_test.go index e4311a510..1eb3a87ac 100644 --- a/config/options_test.go +++ b/config/options_test.go @@ -226,7 +226,6 @@ func TestOptionsFromViper(t *testing.T) { CookieHTTPOnly: true, Headers: map[string]string{ "Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload", - "X-Content-Type-Options": "nosniff", "X-Frame-Options": "SAMEORIGIN", "X-XSS-Protection": "1; mode=block", }}, diff --git a/docs/docs/reference/programmatic-access.md b/docs/docs/reference/programmatic-access.md index d17278232..ff5579284 100644 --- a/docs/docs/reference/programmatic-access.md +++ b/docs/docs/reference/programmatic-access.md @@ -20,7 +20,7 @@ For example: ```bash $ curl "https://httpbin.example.com/.pomerium/api/v1/login?redirect_uri=http://localhost:8000" -https://authenticate.example.com/.pomerium/sign_in?redirect_uri=http%3A%2F%2Flocalhost%3Fpomerium_programmatic_destination_url%3Dhttps%253A%252F%252Fhttpbin.corp.example%252F.pomerium%252Fapi%252Fv1%252Flogin%253Fredirect_uri%253Dhttp%253A%252F%252Flocalhost&sig=hsLuzJctmgsN4kbMeQL16fe_FahjDBEcX0_kPYfg8bs%3D&ts=1573262981 +https://authenticate.example.com/.pomerium/sign_in?redirect_uri=http%3A%2F%2Flocalhost%3Fpomerium_callback_uri%3Dhttps%253A%252F%252Fhttpbin.corp.example%252F.pomerium%252Fapi%252Fv1%252Flogin%253Fredirect_uri%253Dhttp%253A%252F%252Flocalhost&sig=hsLuzJctmgsN4kbMeQL16fe_FahjDBEcX0_kPYfg8bs%3D&ts=1573262981 ``` ### Callback handler diff --git a/docs/recipes/yml/dashboard-fwdauth.ingress.yaml b/docs/recipes/yml/dashboard-fwdauth.ingress.yaml index 922966f6b..012f2ed6e 100644 --- a/docs/recipes/yml/dashboard-fwdauth.ingress.yaml +++ b/docs/recipes/yml/dashboard-fwdauth.ingress.yaml @@ -21,5 +21,5 @@ spec: paths: - path: / backend: - serviceName: dashboard-kubernetes-dashboard + serviceName: helm-dashboard-kubernetes-dashboard servicePort: https diff --git a/go.sum b/go.sum index 956d38bd8..3038bcbd8 100644 --- a/go.sum +++ b/go.sum @@ -176,8 +176,6 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= -github.com/rs/zerolog v1.14.3 h1:4EGfSkR2hJDB0s3oFfrlPqjU1e4WLncergLil3nEKW0= -github.com/rs/zerolog v1.14.3/go.mod h1:3WXPzbXEEliJ+a6UFE4vhIxV8qR1EML6ngzP9ug4eYg= github.com/rs/zerolog v1.16.0 h1:AaELmZdcJHT8m6oZ5py4213cdFK8XGXkB3dFdAQ+P7Q= github.com/rs/zerolog v1.16.0/go.mod h1:9nvC1axdVrAHcu/s9taAVfBuIdTZLVQmKQyvrUjF5+I= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= @@ -221,8 +219,6 @@ golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnf golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf h1:fnPsqIDRbCSgumaMCRpoIoF2s4qxv0xSSS0BVZUE/ss= -golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4 h1:PDpCLFAH/YIX0QpHPf2eO7L4rC2OOirBrKtXTLLiNTY= golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= @@ -306,7 +302,6 @@ golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3 golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= diff --git a/internal/httputil/constants.go b/internal/httputil/constants.go deleted file mode 100644 index fa92b4adc..000000000 --- a/internal/httputil/constants.go +++ /dev/null @@ -1,9 +0,0 @@ -package httputil // import "github.com/pomerium/pomerium/internal/httputil" - -const ( - // HeaderPomeriumResponse is set when pomerium itself creates a response, - // as opposed to the downstream application and can be used to distinguish - // between an application error, and a pomerium related error when debugging. - // Especially useful when working with single page apps (SPA). - HeaderPomeriumResponse = "x-pomerium-intercepted-response" -) diff --git a/internal/httputil/errors.go b/internal/httputil/errors.go index cc4a9547b..4bf1397ab 100644 --- a/internal/httputil/errors.go +++ b/internal/httputil/errors.go @@ -79,6 +79,8 @@ func ErrorResponse(w http.ResponseWriter, r *http.Request, e error) { writeJSONResponse(w, statusCode, response) } else { w.WriteHeader(statusCode) + w.Header().Set("Content-Type", "text/html") + t := struct { Code int Title string diff --git a/internal/httputil/headers.go b/internal/httputil/headers.go new file mode 100644 index 000000000..5b4de870b --- /dev/null +++ b/internal/httputil/headers.go @@ -0,0 +1,57 @@ +package httputil // import "github.com/pomerium/pomerium/internal/httputil" + +const ( + // HeaderPomeriumResponse is set when pomerium itself creates a response, + // as opposed to the downstream application and can be used to distinguish + // between an application error, and a pomerium related error when debugging. + // Especially useful when working with single page apps (SPA). + HeaderPomeriumResponse = "x-pomerium-intercepted-response" +) + +// HeadersContentSecurityPolicy are the content security headers added to the service's handlers +// by default includes profile photo exceptions for supported identity providers. +// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src +var HeadersContentSecurityPolicy = map[string]string{ + "Content-Security-Policy": "default-src 'none'; style-src 'self'; img-src *;", + "Referrer-Policy": "Same-origin", +} + +// Forward headers contains information from the client-facing side of proxy +// servers that is altered or lost when a proxy is involved in the path of the +// request. +// +// https://tools.ietf.org/html/rfc7239 +// https://en.wikipedia.org/wiki/X-Forwarded-For +const ( + HeaderForwardedFor = "X-Forwarded-For" + HeaderForwardedHost = "X-Forwarded-Host" + HeaderForwardedMethod = "X-Forwarded-Method" // traefik + HeaderForwardedPort = "X-Forwarded-Port" + HeaderForwardedProto = "X-Forwarded-Proto" + HeaderForwardedServer = "X-Forwarded-Server" + HeaderForwardedURI = "X-Forwarded-Uri" // traefik + HeaderOriginalMethod = "X-Original-Method" // nginx + HeaderOriginalURL = "X-Original-Url" // nginx + HeaderRealIP = "X-Real-Ip" + HeaderSentFrom = "X-Sent-From" +) + +// HeadersXForwarded is the slice of the header keys used to contain information +// from the client-facing side of proxy servers that is altered or lost when a +// proxy is involved in the path of the request. +// +// https://tools.ietf.org/html/rfc7239 +// https://en.wikipedia.org/wiki/X-Forwarded-For +var HeadersXForwarded = []string{ + HeaderForwardedFor, + HeaderForwardedHost, + HeaderForwardedMethod, + HeaderForwardedPort, + HeaderForwardedProto, + HeaderForwardedServer, + HeaderForwardedURI, + HeaderOriginalMethod, + HeaderOriginalURL, + HeaderRealIP, + HeaderSentFrom, +} diff --git a/internal/identity/providers.go b/internal/identity/providers.go index 31b1a6d40..aaca23ba5 100644 --- a/internal/identity/providers.go +++ b/internal/identity/providers.go @@ -113,7 +113,7 @@ func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.Sta return nil, err } - s, err := sessions.NewStateFromTokens(idToken, oauth2Token, p.RedirectURL.Host) + s, err := sessions.NewStateFromTokens(idToken, oauth2Token, p.RedirectURL.Hostname()) if err != nil { return nil, err } diff --git a/internal/log/middleware.go b/internal/log/middleware.go index 2cc6867bb..e635a228d 100644 --- a/internal/log/middleware.go +++ b/internal/log/middleware.go @@ -6,7 +6,6 @@ import ( "fmt" "net" "net/http" - "strings" "time" "github.com/pomerium/pomerium/internal/middleware/responsewriter" @@ -172,20 +171,21 @@ func AccessHandler(f func(r *http.Request, status, size int, duration time.Durat } } -// ForwardedAddrHandler returns the client IP address from a request. If a -// request goes through multiple proxies, the IP addresses of each successive -// proxy is listed. This means, the right-most IP address is the IP address of -// the most recent proxy and the left-most IP address is the IP address of the -// originating client. +// HeadersHandler adds the provided set of header keys to the log context. +// +// https://tools.ietf.org/html/rfc7239 +// https://en.wikipedia.org/wiki/X-Forwarded-For // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For -func ForwardedAddrHandler(fieldKey string) func(next http.Handler) http.Handler { +func HeadersHandler(headers []string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if ra := r.Header.Get("X-Forwarded-For"); ra != "" { - log := zerolog.Ctx(r.Context()) - log.UpdateContext(func(c zerolog.Context) zerolog.Context { - return c.Strs(fieldKey, strings.Split(ra, ",")) - }) + for _, key := range headers { + if values := r.Header[key]; len(values) != 0 { + log := zerolog.Ctx(r.Context()) + log.UpdateContext(func(c zerolog.Context) zerolog.Context { + return c.Strs(key, values) + }) + } } next.ServeHTTP(w, r) }) diff --git a/internal/log/middleware_test.go b/internal/log/middleware_test.go index 8241f742b..539670695 100644 --- a/internal/log/middleware_test.go +++ b/internal/log/middleware_test.go @@ -253,20 +253,20 @@ func BenchmarkDataRace(b *testing.B) { }) } -func TestForwardedAddrHandler(t *testing.T) { +func TestLogHeadersHandler(t *testing.T) { out := &bytes.Buffer{} r := httptest.NewRequest(http.MethodGet, "/", nil) r.Header.Set("X-Forwarded-For", "proxy1,proxy2,proxy3") - h := ForwardedAddrHandler("fwd_ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + h := HeadersHandler([]string{"X-Forwarded-For"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { l := FromRequest(r) l.Log().Msg("") })) h = NewHandler(zerolog.New(out))(h) h.ServeHTTP(nil, r) - if want, got := `{"fwd_ip":["proxy1","proxy2","proxy3"]}`+"\n", decodeIfBinary(out); want != got { + if want, got := `{"X-Forwarded-For":["proxy1,proxy2,proxy3"]}`+"\n", decodeIfBinary(out); want != got { t.Errorf("Invalid log output, got: %s, want: %s", got, want) } } diff --git a/internal/middleware/cors_test.go b/internal/middleware/cors_test.go index a09a01866..4b1128c30 100644 --- a/internal/middleware/cors_test.go +++ b/internal/middleware/cors_test.go @@ -19,7 +19,6 @@ func TestCorsBypass(t *testing.T) { t.Parallel() fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") fmt.Fprint(w, http.StatusText(http.StatusOK)) w.WriteHeader(http.StatusOK) }) diff --git a/internal/middleware/middleware.go b/internal/middleware/middleware.go index d6dcf8374..2628a4e0d 100644 --- a/internal/middleware/middleware.go +++ b/internal/middleware/middleware.go @@ -1,13 +1,10 @@ package middleware // import "github.com/pomerium/pomerium/internal/middleware" import ( - "encoding/base64" - "fmt" "net/http" "strings" "time" - "github.com/pomerium/pomerium/internal/cryptutil" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/urlutil" @@ -34,8 +31,8 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature") defer span.End() - if !ValidateRedirectURI(r, sharedSecret) { - httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil)) + if err := ValidateRequestURL(r, sharedSecret); err != nil { + httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, err)) return } next.ServeHTTP(w, r.WithContext(ctx)) @@ -43,27 +40,24 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler } } -// ValidateRedirectURI takes a request and parses `redirect_uri`, `sig`, `ts` -// and validates the supplied signature (`sig`)'s HMAC for validity. -func ValidateRedirectURI(r *http.Request, key string) bool { - return ValidSignature( - r.FormValue("redirect_uri"), - r.FormValue("sig"), - r.FormValue("ts"), - key) +// ValidateRequestURL validates the current absolute request URL was signed +// by a given shared key. +func ValidateRequestURL(r *http.Request, key string) error { + return urlutil.NewSignedURL(key, urlutil.GetAbsoluteURL(r)).Validate() } // Healthcheck endpoint middleware useful to setting up a path like // `/ping` that load balancers or uptime testing external services // can make a request before hitting any routes. It's also convenient // to place this above ACL middlewares as well. +// +// https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler { f := func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ctx, span := trace.StartSpan(r.Context(), "middleware.Healthcheck") defer span.End() if strings.EqualFold(r.URL.Path, endpoint) { - // https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html if r.Method != http.MethodGet && r.Method != http.MethodHead { http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) return @@ -82,26 +76,6 @@ func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler { return f } -// ValidSignature checks to see if a signature is valid. Compares hmac of -// redirect uri, timestamp, and secret and signature. -func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool { - if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" { - return false - } - _, err := urlutil.ParseAndValidateURL(redirectURI) - if err != nil { - return false - } - requestSig, err := base64.URLEncoding.DecodeString(sigVal) - if err != nil { - return false - } - if err := cryptutil.ValidTimestamp(timestamp); err != nil { - return false - } - return cryptutil.CheckHMAC([]byte(fmt.Sprint(redirectURI, timestamp)), requestSig, secret) -} - // StripCookie strips the cookie from the downstram request. func StripCookie(cookieName string) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { diff --git a/internal/middleware/middleware_test.go b/internal/middleware/middleware_test.go index 4a91ceb9b..b42bbc6e4 100644 --- a/internal/middleware/middleware_test.go +++ b/internal/middleware/middleware_test.go @@ -1,7 +1,6 @@ -package middleware // import "github.com/pomerium/pomerium/internal/middleware" +package middleware import ( - "encoding/base64" "fmt" "net/http" "net/http/httptest" @@ -9,47 +8,10 @@ import ( "testing" "time" - "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/google/go-cmp/cmp" + "github.com/pomerium/pomerium/internal/urlutil" ) -func hmacHelperFunc(rawRedirect string, timestamp time.Time, secret string) []byte { - data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix())) - return cryptutil.GenerateHMAC(data, secret) -} - -func Test_ValidSignature(t *testing.T) { - t.Parallel() - goodURL := "https://example.com/redirect" - secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=" - now := fmt.Sprint(time.Now().Unix()) - rawSig := hmacHelperFunc(goodURL, time.Now(), secretA) - sig := base64.URLEncoding.EncodeToString(rawSig) - staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix()) - - tests := []struct { - name string - redirectURI string - sigVal string - timestamp string - secret string - want bool - }{ - {"good signature", goodURL, sig, now, secretA, true}, - {"empty redirect url", "", sig, now, secretA, false}, - {"bad redirect url", "https://google.com^", sig, now, secretA, false}, - {"malformed signature", goodURL, sig + "^", now, "&*&@**($&#(", false}, - {"malformed timestamp", goodURL, sig, now + "^", secretA, false}, - {"stale timestamp", goodURL, sig, staleTime, secretA, false}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - if got := ValidSignature(tt.redirectURI, tt.sigVal, tt.timestamp, tt.secret); got != tt.want { - t.Errorf("ValidSignature() = %v, want %v", got, tt.want) - } - }) - } -} - func TestSetHeaders(t *testing.T) { tests := []struct { name string @@ -79,56 +41,6 @@ func TestSetHeaders(t *testing.T) { } } -func TestValidateSignature(t *testing.T) { - t.Parallel() - secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A=" - now := fmt.Sprint(time.Now().Unix()) - goodURL := "https://example.com/redirect" - rawSig := hmacHelperFunc(goodURL, time.Now(), secretA) - sig := base64.URLEncoding.EncodeToString(rawSig) - staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix()) - - tests := []struct { - name string - sharedSecret string - redirectURI string - sig string - ts string - status int - }{ - {"valid signature", secretA, goodURL, sig, now, http.StatusOK}, - {"stale signature", secretA, goodURL, sig, staleTime, http.StatusBadRequest}, - {"malformed", secretA, goodURL, "%zzzzz", now, http.StatusBadRequest}, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - v := url.Values{} - v.Set("redirect_uri", tt.redirectURI) - v.Set("ts", tt.ts) - v.Set("sig", tt.sig) - - req := &http.Request{ - Method: http.MethodGet, - URL: &url.URL{RawQuery: v.Encode()}} - if tt.name == "malformed" { - req.URL.RawQuery = "sig=%zzzzz" - } - - testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("Hi")) - }) - rr := httptest.NewRecorder() - handler := ValidateSignature(tt.sharedSecret)(testHandler) - handler.ServeHTTP(rr, req) - if rr.Code != tt.status { - t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status) - t.Errorf("%s", rr.Body) - } - }) - } -} - func TestHealthCheck(t *testing.T) { t.Parallel() testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -212,7 +124,6 @@ func TestTimeoutHandlerFunc(t *testing.T) { t.Parallel() fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") fmt.Fprint(w, http.StatusText(http.StatusOK)) w.WriteHeader(http.StatusOK) }) @@ -242,3 +153,42 @@ func TestTimeoutHandlerFunc(t *testing.T) { }) } } + +func TestValidateSignature(t *testing.T) { + t.Parallel() + fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + fmt.Fprint(w, http.StatusText(http.StatusOK)) + w.WriteHeader(http.StatusOK) + }) + + tests := []struct { + name string + secretA string + secretB string + wantStatus int + wantBody string + }{ + {"good", "secret", "secret", http.StatusOK, http.StatusText(http.StatusOK)}, + {"secret mistmatch", "secret", "hunter42", http.StatusBadRequest, "{\"error\":\"invalid signature\"}\n"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + signedURL := urlutil.NewSignedURL(tt.secretB, &url.URL{Scheme: "https", Host: "pomerium.io"}) + + r := httptest.NewRequest(http.MethodGet, signedURL.String(), nil) + r.Header.Set("Accept", "application/json") + w := httptest.NewRecorder() + got := ValidateSignature(tt.secretA)(fn) + got.ServeHTTP(w, r) + if status := w.Code; status != tt.wantStatus { + t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String()) + } + body := w.Body.String() + if diff := cmp.Diff(body, tt.wantBody); diff != "" { + t.Errorf("SignRequest() %s", diff) + t.Errorf("%s", signedURL) + } + }) + } +} diff --git a/internal/sessions/middleware.go b/internal/sessions/middleware.go index 34cfae1f9..f7a619b90 100644 --- a/internal/sessions/middleware.go +++ b/internal/sessions/middleware.go @@ -4,6 +4,8 @@ import ( "context" "errors" "net/http" + + "github.com/pomerium/pomerium/internal/urlutil" ) // Context keys @@ -41,8 +43,8 @@ func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, er return state, err } if state != nil { - err := state.Verify(r.Host) - return state, err // N.B.: state is _not nil_ + err := state.Verify(urlutil.StripPort(r.Host)) + return state, err // N.B.: state is _not_ nil_ } } diff --git a/internal/urlutil/errors.go b/internal/urlutil/errors.go new file mode 100644 index 000000000..1a569cf58 --- /dev/null +++ b/internal/urlutil/errors.go @@ -0,0 +1,14 @@ +package urlutil // import "github.com/pomerium/pomerium/internal/urlutil" + +import "errors" + +var ( + // ErrExpired indicates that token is used after expiry time indicated in exp claim. + ErrExpired = errors.New("internal/urlutil: validation failed, url hmac is expired") + + // ErrIssuedInTheFuture indicates that the issued field is in the future. + ErrIssuedInTheFuture = errors.New("internal/urlutil: validation field, url hmac issued in the future") + + // ErrNumericDateMalformed indicates a malformed unix timestamp was found while parsing. + ErrNumericDateMalformed = errors.New("internal/urlutil: malformed unix timestamp field") +) diff --git a/internal/urlutil/query_params.go b/internal/urlutil/query_params.go new file mode 100644 index 000000000..fb57de774 --- /dev/null +++ b/internal/urlutil/query_params.go @@ -0,0 +1,24 @@ +package urlutil // import "github.com/pomerium/pomerium/internal/urlutil" + +// Common query parameters used to set and send data between Pomerium +// services over HTTP calls and redirects. They are typically used in +// conjunction with a HMAC to ensure authenticity. +const ( + QueryCallbackURI = "pomerium_callback_uri" + QueryImpersonateEmail = "pomerium_impersonate_email" + QueryImpersonateGroups = "pomerium_impersonate_groups" + QueryIsProgrammatic = "pomerium_programmatic" + QueryForwardAuth = "pomerium_forward_auth" + QueryPomeriumJWT = "pomerium_jwt" + QuerySessionEncrypted = "pomerium_session_encrypted" + QueryRedirectURI = "pomerium_redirect_uri" + QueryRefreshToken = "pomerium_refresh_token" +) + +// URL signature based query params used for verifying the authenticity of a URL. +const ( + QueryHmacExpiry = "pomerium_expiry" + QueryHmacIssued = "pomerium_issued" + QueryHmacSignature = "pomerium_signature" + QueryHmacURI = "pomerium_uri" +) diff --git a/internal/urlutil/signed.go b/internal/urlutil/signed.go new file mode 100644 index 000000000..039f5bbb4 --- /dev/null +++ b/internal/urlutil/signed.go @@ -0,0 +1,130 @@ +package urlutil // import "github.com/pomerium/pomerium/internal/urlutil" + +import ( + "encoding/base64" + "fmt" + "net/url" + "strconv" + "time" + + "github.com/pomerium/pomerium/internal/cryptutil" +) + +// SignedURL is a shared-key HMAC wrapped URL. +type SignedURL struct { + uri url.URL + key string + signed bool + + // mockable time for testing + timeNow func() time.Time +} + +// NewSignedURL creates a new copy of a URL that can be signed with a shared key. +// +// N.B. It is the user's responsibility to make sure the key is 256 bits and +// the url is not nil. +func NewSignedURL(key string, uri *url.URL) *SignedURL { + return &SignedURL{uri: *uri, key: key, timeNow: time.Now} // uri is copied +} + +// Sign creates a shared-key HMAC signed URL. +func (su *SignedURL) Sign() *url.URL { + now := su.timeNow() + issued := newNumericDate(now) + expiry := newNumericDate(now.Add(5 * time.Minute)) + params := su.uri.Query() + params.Set(QueryHmacIssued, fmt.Sprint(issued)) + params.Set(QueryHmacExpiry, fmt.Sprint(expiry)) + su.uri.RawQuery = params.Encode() + params.Set(QueryHmacSignature, hmacURL(su.key, su.uri.String(), issued, expiry)) + su.uri.RawQuery = params.Encode() + su.signed = true + return &su.uri +} + +// String implements the stringer interface and returns a signed URL string. +func (su *SignedURL) String() string { + if !su.signed { + su.Sign() + su.signed = true + } + return su.uri.String() +} + +// Validate checks to see if a signed URL is valid. +func (su *SignedURL) Validate() error { + now := su.timeNow() + params := su.uri.Query() + sig, err := base64.URLEncoding.DecodeString(params.Get(QueryHmacSignature)) + if err != nil { + return fmt.Errorf("internal/urlutil: malformed signature %w", err) + } + params.Del(QueryHmacSignature) + su.uri.RawQuery = params.Encode() + + issued, err := newNumericDateFromString(params.Get(QueryHmacIssued)) + if err != nil { + return fmt.Errorf("internal/urlutil: issued %w", err) + } + + expiry, err := newNumericDateFromString(params.Get(QueryHmacExpiry)) + if err != nil { + return fmt.Errorf("internal/urlutil: expiry %w", err) + } + + if expiry != nil && now.Add(-DefaultLeeway).After(expiry.Time()) { + return ErrExpired + } + + if issued != nil && now.Add(DefaultLeeway).Before(issued.Time()) { + return ErrIssuedInTheFuture + } + + validHMAC := cryptutil.CheckHMAC( + []byte(fmt.Sprint(su.uri.String(), issued, expiry)), + sig, + su.key) + if !validHMAC { + return fmt.Errorf("internal/urlutil: hmac failed %s", su.uri.String()) + } + return nil +} + +// hmacURL takes a redirect url string and timestamp and returns the base64 +// encoded HMAC result. +func hmacURL(key string, data ...interface{}) string { + h := cryptutil.GenerateHMAC([]byte(fmt.Sprint(data...)), key) + return base64.URLEncoding.EncodeToString(h) +} + +// numericDate used because we don't need the precision of a typical time.Time. +type numericDate int64 + +func newNumericDate(t time.Time) *numericDate { + if t.IsZero() { + return nil + } + out := numericDate(t.Unix()) + return &out +} + +func newNumericDateFromString(s string) (*numericDate, error) { + i, err := strconv.ParseInt(s, 10, 64) + if err != nil { + return nil, ErrNumericDateMalformed + } + out := numericDate(i) + return &out, nil +} + +func (n *numericDate) Time() time.Time { + if n == nil { + return time.Time{} + } + return time.Unix(int64(*n), 0) +} + +func (n *numericDate) String() string { + return strconv.FormatInt(int64(*n), 10) +} diff --git a/internal/urlutil/signed_test.go b/internal/urlutil/signed_test.go new file mode 100644 index 000000000..761cadf03 --- /dev/null +++ b/internal/urlutil/signed_test.go @@ -0,0 +1,97 @@ +package urlutil + +import ( + "net/url" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestSignedURL(t *testing.T) { + original := time.Unix(1574117851, 0) // ;-) + tests := []struct { + name string + key string + uri url.URL + origTime func() time.Time + newTime func() time.Time + wantStr string + want url.URL + wantErr bool + }{ + {"good", "test-key", url.URL{Scheme: "https", Host: "pomerium.io"}, + func() time.Time { return original }, func() time.Time { return original }, + "https://pomerium.io?pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D", + url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, + false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + signedURL := NewSignedURL(tt.key, &tt.uri) + signedURL.timeNow = tt.origTime + + if diff := cmp.Diff(signedURL.String(), tt.wantStr); diff != "" { + t.Errorf("signedURL() = %v", diff) + } + + signedURL = NewSignedURL(tt.key, &tt.uri) + signedURL.timeNow = tt.origTime + got := signedURL.Sign() + + if diff := cmp.Diff(*got, tt.want); diff != "" { + t.Errorf("NewSignedURL() = %s", diff) + } + err := signedURL.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr) + } + + u, err := url.Parse(signedURL.String()) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(u, got); diff != "" { + t.Errorf("signedURL() = %v", diff) + } + // subsequent string calls shouldn't result in a change + u, err = url.Parse(signedURL.String()) + if err != nil { + t.Fatal(err) + } + if diff := cmp.Diff(u, got); diff != "" { + t.Errorf("signedURL() = %v", diff) + } + }) + } +} + +func TestSignedURL_Validate(t *testing.T) { + t.Parallel() + tests := []struct { + name string + uri url.URL + key string + timeNow func() time.Time + + wantErr bool + }{ + {"good", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0) }, false}, + {"bad key", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "bad-key", func() time.Time { return time.Unix(1574117851, 0) }, true}, + {"bad no expiry", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0) }, true}, + {"bad issued", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0) }, true}, + {"bad signature body", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=^"}, "test-key", func() time.Time { return time.Unix(1574117851, 0) }, true}, + {"bad expired", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0).Add(time.Hour) }, true}, + {"bad not yet valid", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0).Add(-time.Hour) }, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out := NewSignedURL(tt.key, &tt.uri) + out.timeNow = tt.timeNow + + if err := out.Validate(); (err != nil) != tt.wantErr { + t.Errorf("SignedURL.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/urlutil/url.go b/internal/urlutil/url.go index 6c4d8a25f..32109b065 100644 --- a/internal/urlutil/url.go +++ b/internal/urlutil/url.go @@ -1,15 +1,16 @@ package urlutil // import "github.com/pomerium/pomerium/internal/urlutil" import ( - "encoding/base64" "fmt" "net/http" "net/url" "strings" - "sync" "time" +) - "github.com/pomerium/pomerium/internal/cryptutil" +const ( + // DefaultLeeway defines the default leeway for matching NotBefore/Expiry claims. + DefaultLeeway = 1.0 * time.Minute ) // StripPort returns a host, without any port number. @@ -66,57 +67,6 @@ func DeepCopy(u *url.URL) (*url.URL, error) { return ParseAndValidateURL(u.String()) } -var mockNow testTime - -// testTime is safe to use concurrently. -type testTime struct { - sync.Mutex - mockNow int64 -} - -func (tt *testTime) setNow(n int64) { - tt.Lock() - tt.mockNow = n - tt.Unlock() -} - -func (tt *testTime) now() int64 { - tt.Lock() - defer tt.Unlock() - return tt.mockNow -} - -// timestamp returns the current timestamp, in seconds. -// -// For testing purposes, the function that generates the timestamp can be -// overridden. If not set, it will return time.Now().UTC().Unix(). -func timestamp() int64 { - if mockNow.now() == 0 { - return time.Now().UTC().Unix() - } - return mockNow.now() -} - -// SignedRedirectURL takes a destination URL and adds redirect_uri to it's -// query params, along with a timestamp and an keyed signature. -func SignedRedirectURL(key string, destination, u *url.URL) *url.URL { - now := timestamp() - rawURL := u.String() - params, _ := url.ParseQuery(destination.RawQuery) // handled by incoming mux - params.Set("redirect_uri", rawURL) - params.Set("ts", fmt.Sprint(now)) - params.Set("sig", hmacURL(key, rawURL, now)) - destination.RawQuery = params.Encode() - return destination -} - -// hmacURL takes a redirect url string and timestamp and returns the base64 -// encoded HMAC result. -func hmacURL(key, data string, timestamp int64) string { - h := cryptutil.GenerateHMAC([]byte(fmt.Sprint(data, timestamp)), key) - return base64.URLEncoding.EncodeToString(h) -} - // GetAbsoluteURL returns the current handler's absolute url. // https://stackoverflow.com/a/23152483 func GetAbsoluteURL(r *http.Request) *url.URL { diff --git a/internal/urlutil/url_test.go b/internal/urlutil/url_test.go index 5935fc86c..3f7c8db73 100644 --- a/internal/urlutil/url_test.go +++ b/internal/urlutil/url_test.go @@ -108,48 +108,6 @@ func TestValidateURL(t *testing.T) { } } -func TestSignedRedirectURL(t *testing.T) { - t.Parallel() - tests := []struct { - name string - mockedTime int64 - key string - destination *url.URL - urlToSign *url.URL - want *url.URL - }{ - {"good", 2, "hunter42", &url.URL{Host: "pomerium.io", Scheme: "https://"}, &url.URL{Host: "pomerium.io", Scheme: "https://", Path: "/ok"}, &url.URL{Host: "pomerium.io", Scheme: "https://", RawQuery: "redirect_uri=https%3A%2F%2F%3A%2F%2Fpomerium.io%2Fok&sig=7jdo1XFcmuhjBHnpfVhll5cXflYByeMnbp5kRz87CVQ%3D&ts=2"}}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockNow.setNow(tt.mockedTime) - got := SignedRedirectURL(tt.key, tt.destination, tt.urlToSign) - if diff := cmp.Diff(got, tt.want); diff != "" { - t.Errorf("SignedRedirectURL() = diff %v", diff) - } - }) - } -} - -func Test_timestamp(t *testing.T) { - t.Parallel() - - tests := []struct { - name string - dontWant int64 - }{ - {"if unset should never return", 0}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mockNow.setNow(tt.dontWant) - if got := timestamp(); got == tt.dontWant { - t.Errorf("timestamp() = %v, dontWant %v", got, tt.dontWant) - } - }) - } -} - func parseURLHelper(s string) *url.URL { u, _ := url.Parse(s) return u diff --git a/proxy/forward_auth.go b/proxy/forward_auth.go new file mode 100644 index 000000000..cbf2cef15 --- /dev/null +++ b/proxy/forward_auth.go @@ -0,0 +1,114 @@ +package proxy // import "github.com/pomerium/pomerium/proxy" + +import ( + "errors" + "fmt" + "net/http" + "net/url" + + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/urlutil" +) + +func (p *Proxy) registerFwdAuthHandlers() http.Handler { + r := httputil.NewRouter() + r.StrictSlash(true) + r.Use(sessions.RetrieveSession(p.sessionStore)) + + r.Handle("/verify", http.HandlerFunc(p.nginxCallback)). + Queries("uri", "{uri}", urlutil.QuerySessionEncrypted, "", urlutil.QueryRedirectURI, "") + r.Handle("/", http.HandlerFunc(p.postSessionSetNOP)). + Queries("uri", "{uri}", + urlutil.QuerySessionEncrypted, "", + urlutil.QueryRedirectURI, "") + r.Handle("/", http.HandlerFunc(p.traefikCallback)). + HeadersRegexp(httputil.HeaderForwardedURI, urlutil.QuerySessionEncrypted) + r.Handle("/", p.Verify(false)).Queries("uri", "{uri}") + r.Handle("/verify", p.Verify(true)).Queries("uri", "{uri}") + + return r +} + +// postSessionSetNOP after successfully setting the +func (p *Proxy) postSessionSetNOP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + httputil.Redirect(w, r, r.FormValue(urlutil.QueryRedirectURI), http.StatusFound) +} + +func (p *Proxy) nginxCallback(w http.ResponseWriter, r *http.Request) { + encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted) + if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil { + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) + return + } + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusUnauthorized) +} + +func (p *Proxy) traefikCallback(w http.ResponseWriter, r *http.Request) { + forwardedURL, err := url.Parse(r.Header.Get(httputil.HeaderForwardedURI)) + if err != nil { + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) + return + } + q := forwardedURL.Query() + redirectURLString := q.Get(urlutil.QueryRedirectURI) + encryptedSession := q.Get(urlutil.QuerySessionEncrypted) + + if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil { + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) + return + } + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + httputil.Redirect(w, r, redirectURLString, http.StatusFound) +} + +// Verify checks a user's credentials for an arbitrary host. If the user +// is properly authenticated and is authorized to access the supplied host, +// a `200` http status code is returned. If the user is not authenticated, they +// will be redirected to the authenticate service to sign in with their identity +// provider. If the user is unauthorized, a `401` error is returned. +func (p *Proxy) Verify(verifyOnly bool) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + uri, err := urlutil.ParseAndValidateURL(r.FormValue("uri")) + if err != nil { + httputil.ErrorResponse(w, r, httputil.Error("bad verification uri", http.StatusBadRequest, err)) + return + } + + s, err := sessions.FromContext(r.Context()) + if errors.Is(err, sessions.ErrNoSessionFound) || errors.Is(err, sessions.ErrExpired) { + if verifyOnly { + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err)) + return + } + authN := *p.authenticateSigninURL + q := authN.Query() + q.Set(urlutil.QueryCallbackURI, uri.String()) + q.Set(urlutil.QueryRedirectURI, uri.String()) // final destination + q.Set(urlutil.QueryForwardAuth, urlutil.StripPort(r.Host)) // add fwd auth to trusted audience + authN.RawQuery = q.Encode() + httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &authN).String(), http.StatusFound) + return + } else if err != nil { + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err)) + return + } + // depending on the configuration of the fronting proxy, the request Host + // and/or `X-Forwarded-Host` may be untrustd or change so we reverify + // the session's validity against the supplied uri + if err := s.Verify(uri.Hostname()); err != nil { + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err)) + return + } + p.addPomeriumHeaders(w, r) + if err := p.authorize(uri.Host, w, r); err != nil { + return + } + + w.Header().Set("Content-Type", "text/plain; charset=utf-8") + w.WriteHeader(http.StatusOK) + fmt.Fprintf(w, "Access to %s is allowed.", uri.Host) + }) +} diff --git a/proxy/forward_auth_test.go b/proxy/forward_auth_test.go new file mode 100644 index 000000000..ef1b39b88 --- /dev/null +++ b/proxy/forward_auth_test.go @@ -0,0 +1,119 @@ +package proxy + +import ( + "errors" + "net/http" + "net/http/httptest" + "net/url" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/encoding" + "github.com/pomerium/pomerium/internal/encoding/mock" + "github.com/pomerium/pomerium/internal/httputil" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/proxy/clients" + "gopkg.in/square/go-jose.v2/jwt" +) + +func TestProxy_ForwardAuth(t *testing.T) { + t.Parallel() + opts := testOptions(t) + tests := []struct { + name string + options config.Options + ctxError error + method string + + headers map[string]string + qp map[string]string + + requestURI string + verifyURI string + + cipher encoding.MarshalUnmarshaler + sessionStore sessions.SessionStore + authorizer clients.Authorizer + wantStatus int + wantBody string + }{ + {"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."}, + {"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""}, + {"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, + {"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"}, + {"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"}, + {"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"}, + {"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"}, + {"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"}, + {"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"}, + {"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""}, + {"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"}, + {"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"error\":\"authz error\"}\n"}, + {"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"}, + {"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, + {"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, invalid audience claim (aud)\"}\n"}, + // traefik + {"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, + {"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"bad traefik callback bad url", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: urlutil.QuerySessionEncrypted + ""}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + // nginx + {"good nginx callback redirect", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, + {"good nginx callback set session okay but return unauthorized", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, ""}, + {"bad nginx callback failed to set sesion", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString + "nope"}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p, err := New(tt.options) + if err != nil { + t.Fatal(err) + } + p.encoder = tt.cipher + p.sessionStore = tt.sessionStore + p.AuthorizeClient = tt.authorizer + p.UpdateOptions(tt.options) + uri, err := url.Parse(tt.requestURI) + if err != nil { + t.Fatal(err) + } + queryString := uri.Query() + for k, v := range tt.qp { + queryString.Set(k, v) + } + if tt.verifyURI != "" { + queryString.Set("uri", tt.verifyURI) + } + + uri.RawQuery = queryString.Encode() + + r := httptest.NewRequest(tt.method, uri.String(), nil) + state, _ := tt.sessionStore.LoadSession(r) + + ctx := r.Context() + ctx = sessions.NewContext(ctx, state, tt.ctxError) + r = r.WithContext(ctx) + r.Header.Set("Accept", "application/json") + if len(tt.headers) != 0 { + for k, v := range tt.headers { + r.Header.Set(k, v) + } + } + w := httptest.NewRecorder() + router := p.registerFwdAuthHandlers() + router.ServeHTTP(w, r) + if status := w.Code; status != tt.wantStatus { + t.Errorf("status code: got %v want %v", status, tt.wantStatus) + t.Errorf("\n%+v", w.Body.String()) + } + + if tt.wantBody != "" { + body := w.Body.String() + if diff := cmp.Diff(body, tt.wantBody); diff != "" { + t.Errorf("wrong body\n%s", diff) + } + } + }) + } +} diff --git a/proxy/handlers.go b/proxy/handlers.go index 1fa296a5d..396a8ea50 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -19,7 +19,6 @@ import ( // registerDashboardHandlers returns the proxy service's ServeMux func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router { - // dashboard subrouter h := r.PathPrefix(dashboardURL).Subrouter() // 1. Retrieve the user session and add it to the request context h.Use(sessions.RetrieveSession(p.sessionStore)) @@ -32,19 +31,27 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router { csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieOptions.Name)), csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)), )) + // dashboard endpoints can be used by user's to view, or modify their session h.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet) h.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost) h.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost) // Authenticate service callback handlers and middleware + // callback used to set route-scoped session and redirect back to destination + // only accept signed requests (hmac) from other trusted pomerium services c := r.PathPrefix(dashboardURL + "/callback").Subrouter() - // only accept payloads that have come from a trusted service (hmac) c.Use(middleware.ValidateSignature(p.SharedKey)) - c.HandleFunc("/", p.Callback).Queries("redirect_uri", "{redirect_uri}").Methods(http.MethodGet) + c.Path("/").HandlerFunc(p.ProgrammaticCallback).Methods(http.MethodGet). + Queries(urlutil.QueryIsProgrammatic, "true") + + c.Path("/").HandlerFunc(p.Callback).Methods(http.MethodGet) // Programmatic API handlers and middleware a := r.PathPrefix(dashboardURL + "/api").Subrouter() - a.HandleFunc("/v1/login", p.ProgrammaticLogin).Queries("redirect_uri", "{redirect_uri}").Methods(http.MethodGet) + // login api handler generates a user-navigable login url to authenticate + a.HandleFunc("/v1/login", p.ProgrammaticLogin). + Queries(urlutil.QueryRedirectURI, ""). + Methods(http.MethodGet) return r } @@ -52,7 +59,6 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router { // RobotsTxt sets the User-Agent header in the response to be "Disallow" func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") w.WriteHeader(http.StatusOK) fmt.Fprintf(w, "User-agent: *\nDisallow: /") } @@ -62,12 +68,17 @@ func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) { // the local session state. func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) { redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"} - if uri, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")); err == nil && uri.String() != "" { + if uri, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)); err == nil && uri.String() != "" { redirectURL = uri } - uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL) + + signoutURL := *p.authenticateSignoutURL + q := signoutURL.Query() + q.Set(urlutil.QueryRedirectURI, redirectURL.String()) + signoutURL.RawQuery = q.Encode() + p.sessionStore.ClearSession(w, r) - httputil.Redirect(w, r, uri.String(), http.StatusFound) + httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signoutURL).String(), http.StatusFound) } // UserDashboard lets users investigate, and refresh their current session. @@ -112,110 +123,95 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) { // OK to impersonation redirectURL := urlutil.GetAbsoluteURL(r) redirectURL.Path = dashboardURL // redirect back to the dashboard - q := redirectURL.Query() - q.Add("impersonate_email", r.FormValue("email")) - q.Add("impersonate_group", r.FormValue("group")) - redirectURL.RawQuery = q.Encode() - uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String() - httputil.Redirect(w, r, uri, http.StatusFound) + signinURL := *p.authenticateSigninURL + q := signinURL.Query() + q.Set(urlutil.QueryRedirectURI, redirectURL.String()) + q.Set(urlutil.QueryImpersonateEmail, r.FormValue("email")) + q.Set(urlutil.QueryImpersonateGroups, r.FormValue("group")) + signinURL.RawQuery = q.Encode() + httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound) } -func (p *Proxy) registerFwdAuthHandlers() http.Handler { - r := httputil.NewRouter() - r.StrictSlash(true) - r.Use(sessions.RetrieveSession(p.sessionStore)) - r.Handle("/", p.Verify(false)).Queries("uri", "{uri}").Methods(http.MethodGet) - r.Handle("/verify", p.Verify(true)).Queries("uri", "{uri}").Methods(http.MethodGet) - return r -} - -// Verify checks a user's credentials for an arbitrary host. If the user -// is properly authenticated and is authorized to access the supplied host, -// a `200` http status code is returned. If the user is not authenticated, they -// will be redirected to the authenticate service to sign in with their identity -// provider. If the user is unauthorized, a `401` error is returned. -func (p *Proxy) Verify(verifyOnly bool) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - uri, err := urlutil.ParseAndValidateURL(r.FormValue("uri")) - if err != nil || uri.String() == "" { - httputil.ErrorResponse(w, r, httputil.Error("bad verification uri", http.StatusBadRequest, nil)) - return - } - if err := p.authenticate(verifyOnly, w, r); err != nil { - return - } - if err := p.authorize(uri.Host, w, r); err != nil { - return - } - - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") - w.WriteHeader(http.StatusOK) - fmt.Fprintf(w, fmt.Sprintf("Access to %s is allowed.", uri.Host)) - }) - -} - -// Callback takes a `redirect_uri` query param that has been hmac'd by the -// authenticate service. Embedded in the `redirect_uri` are query-params -// that tell this handler how to set the per-route user session. -// Callback is responsible for redirecting the user back to the intended -// destination URL and path, as well as to clean up any additional query params -// added by the authenticate service. +// Callback handles the result of a successful call to the authenticate service +// and is responsible setting returned per-route session. func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) { - redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")) - if err != nil { - httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) + redirectURLString := r.FormValue(urlutil.QueryRedirectURI) + encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted) + + if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil { + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) return } - q := redirectURL.Query() - // 1. extract the base64 encoded and encrypted JWT from redirect_uri's query params - encryptedJWT, err := base64.URLEncoding.DecodeString(q.Get("pomerium_jwt")) - if err != nil { - httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) - return - } - q.Del("pomerium_jwt") - q.Del("impersonate_email") - q.Del("impersonate_group") + httputil.Redirect(w, r, redirectURLString, http.StatusFound) +} +// saveCallbackSession takes an encrypted per-route session token, and decrypts +// it using the shared service key, then stores it the local session store. +func (p *Proxy) saveCallbackSession(w http.ResponseWriter, r *http.Request, enctoken string) ([]byte, error) { + // 1. extract the base64 encoded and encrypted JWT from query params + encryptedJWT, err := base64.URLEncoding.DecodeString(enctoken) + if err != nil { + return nil, fmt.Errorf("proxy: malfromed callback token: %w", err) + } // 2. decrypt the JWT using the cipher using the _shared_ secret key rawJWT, err := cryptutil.Decrypt(p.sharedCipher, encryptedJWT, nil) if err != nil { - httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) - return + return nil, fmt.Errorf("proxy: callback token decrypt error: %w", err) } // 3. Save the decrypted JWT to the session store directly as a string, without resigning if err = p.sessionStore.SaveSession(w, r, rawJWT); err != nil { - httputil.ErrorResponse(w, r, err) - return + return nil, fmt.Errorf("proxy: callback session save failure: %w", err) } - - // if this is a programmatic request, don't strip the tokens before redirect - if redirectURL.Query().Get("pomerium_programmatic_destination_url") != "" { - q.Set("pomerium_jwt", string(rawJWT)) - } - redirectURL.RawQuery = q.Encode() - - httputil.Redirect(w, r, redirectURL.String(), http.StatusFound) + return rawJWT, nil } // ProgrammaticLogin returns a signed url that can be used to login // using the authenticate service. func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) { - redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")) + redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)) if err != nil { - httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) + httputil.ErrorResponse(w, r, httputil.Error("malformed redirect uri", http.StatusBadRequest, err)) return } - q := redirectURL.Query() - q.Add("pomerium_programmatic_destination_url", urlutil.GetAbsoluteURL(r).String()) - redirectURL.RawQuery = q.Encode() - response := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String() + signinURL := *p.authenticateSigninURL + callbackURI := urlutil.GetAbsoluteURL(r) + callbackURI.Path = dashboardURL + "/callback/" + q := signinURL.Query() + q.Set(urlutil.QueryCallbackURI, callbackURI.String()) + q.Set(urlutil.QueryRedirectURI, redirectURI.String()) + q.Set(urlutil.QueryIsProgrammatic, "true") + signinURL.RawQuery = q.Encode() + response := urlutil.NewSignedURL(p.SharedKey, &signinURL).String() w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") w.WriteHeader(http.StatusOK) w.Write([]byte(response)) } + +// ProgrammaticCallback handles a successful call to the authenticate service. +// In addition to returning the individual route session (JWT) it also returns +// the refresh token. +func (p *Proxy) ProgrammaticCallback(w http.ResponseWriter, r *http.Request) { + redirectURLString := r.FormValue(urlutil.QueryRedirectURI) + encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted) + + redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString) + if err != nil { + httputil.ErrorResponse(w, r, httputil.Error("malformed redirect uri", http.StatusBadRequest, err)) + return + } + + rawJWT, err := p.saveCallbackSession(w, r, encryptedSession) + if err != nil { + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err)) + return + } + + q := redirectURL.Query() + q.Set(urlutil.QueryPomeriumJWT, string(rawJWT)) + q.Set(urlutil.QueryRefreshToken, r.FormValue(urlutil.QueryRefreshToken)) + redirectURL.RawQuery = q.Encode() + + httputil.Redirect(w, r, redirectURL.String(), http.StatusFound) +} diff --git a/proxy/handlers_test.go b/proxy/handlers_test.go index 0004fc5d9..c4a5201ad 100644 --- a/proxy/handlers_test.go +++ b/proxy/handlers_test.go @@ -11,17 +11,22 @@ import ( "testing" "time" + "github.com/pomerium/pomerium/internal/cryptutil" + "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/proxy/clients" "github.com/google/go-cmp/cmp" "gopkg.in/square/go-jose.v2/jwt" ) +const goodEncryptionString = "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg=" + func TestProxy_RobotsTxt(t *testing.T) { proxy := Proxy{} req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil) @@ -189,12 +194,12 @@ func TestProxy_SignOut(t *testing.T) { t.Fatal(err) } postForm := url.Values{} - postForm.Add("redirect_uri", tt.redirectURL) + postForm.Add(urlutil.QueryRedirectURI, tt.redirectURL) uri := &url.URL{Path: "/"} query, _ := url.ParseQuery(uri.RawQuery) if tt.verb == http.MethodGet { - query.Add("redirect_uri", tt.redirectURL) + query.Add(urlutil.QueryRedirectURI, tt.redirectURL) uri.RawQuery = query.Encode() } r := httptest.NewRequest(tt.verb, uri.String(), bytes.NewBufferString(postForm.Encode())) @@ -217,87 +222,7 @@ func uriParseHelper(s string) *url.URL { } return uri } -func TestProxy_VerifyWithMiddleware(t *testing.T) { - t.Parallel() - opts := testOptions(t) - tests := []struct { - name string - options config.Options - ctxError error - method string - qp string - path string - verifyURI string - cipher encoding.MarshalUnmarshaler - sessionStore sessions.SessionStore - authorizer clients.Authorizer - wantStatus int - wantBody string - }{ - {"good", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""}, - {"good verify only", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""}, - {"bad naked domain uri", opts, nil, http.MethodGet, "", "/", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"}, - {"bad naked domain uri verify only", opts, nil, http.MethodGet, "", "/verify", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"}, - {"bad empty verification uri", opts, nil, http.MethodGet, "", "/", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"}, - {"bad empty verification uri verify only", opts, nil, http.MethodGet, "", "/verify", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"}, - {"not authorized", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"}, - {"not authorized verify endpoint", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"}, - {"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, "", "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""}, - {"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, "", "/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"}, - {"not authorized because of error", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"error\":\"authz error\"}\n"}, - {"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p, err := New(tt.options) - if err != nil { - t.Fatal(err) - } - p.encoder = tt.cipher - p.sessionStore = tt.sessionStore - p.AuthorizeClient = tt.authorizer - p.UpdateOptions(tt.options) - uri := &url.URL{Path: tt.path} - queryString := uri.Query() - queryString.Set("donstrip", "ok") - if tt.qp != "" { - queryString.Set(tt.qp, "true") - } - if tt.verifyURI != "" { - queryString.Set("uri", tt.verifyURI) - } - - uri.RawQuery = queryString.Encode() - - r := httptest.NewRequest(tt.method, uri.String(), nil) - state, err := tt.sessionStore.LoadSession(r) - if err != nil { - t.Fatal(err) - } - ctx := r.Context() - ctx = sessions.NewContext(ctx, state, tt.ctxError) - r = r.WithContext(ctx) - r.Header.Set("Authorization", "Bearer blah") - r.Header.Set("Accept", "application/json") - - w := httptest.NewRecorder() - router := p.registerFwdAuthHandlers() - router.ServeHTTP(w, r) - if status := w.Code; status != tt.wantStatus { - t.Errorf("status code: got %v want %v", status, tt.wantStatus) - t.Errorf("\n%+v", w.Body.String()) - } - - if tt.wantBody != "" { - body := w.Body.String() - if diff := cmp.Diff(body, tt.wantBody); diff != "" { - t.Errorf("wrong body\n%s", diff) - } - } - }) - } -} func TestProxy_Callback(t *testing.T) { t.Parallel() opts := testOptions(t) @@ -311,7 +236,8 @@ func TestProxy_Callback(t *testing.T) { host string path string - qp map[string]string + headers map[string]string + qp map[string]string cipher encoding.MarshalUnmarshaler sessionStore sessions.SessionStore @@ -319,11 +245,12 @@ func TestProxy_Callback(t *testing.T) { wantStatus int wantBody string }{ - {"good", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_programmatic_destination_url": "ok", "pomerium_jwt": "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, - {"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, - {"bad save session", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusInternalServerError, ""}, - {"bad base64", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, - {"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"good", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, + {"good programmatic", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, + {"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"bad save session", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"bad base64", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -345,13 +272,21 @@ func TestProxy_Callback(t *testing.T) { uri := &url.URL{Path: "/"} if tt.qp != nil { qu := uri.Query() - qu.Set("redirect_uri", redirectURI.String()) + for k, v := range tt.qp { + qu.Set(k, v) + } + qu.Set(urlutil.QueryRedirectURI, redirectURI.String()) uri.RawQuery = qu.Encode() } r := httptest.NewRequest(tt.method, uri.String(), nil) r.Header.Set("Accept", "application/json") + if len(tt.headers) != 0 { + for k, v := range tt.headers { + r.Header.Set(k, v) + } + } w := httptest.NewRecorder() p.Callback(w, r) @@ -379,18 +314,20 @@ func TestProxy_ProgrammaticLogin(t *testing.T) { method string - scheme string - host string - path string - qp map[string]string + scheme string + host string + path string + headers map[string]string + qp map[string]string wantStatus int wantBody string }{ - {"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "http://localhost"}, http.StatusOK, ""}, - {"router miss, bad redirect_uri query", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"bad_redirect_uri": "http://localhost"}, http.StatusNotFound, ""}, - {"bad redirect_uri missing scheme", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "localhost"}, http.StatusBadRequest, "{\"error\":\"malformed redirect_uri\"}\n"}, - {"bad http method", opts, http.MethodPost, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "http://localhost"}, http.StatusMethodNotAllowed, ""}, + {"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusOK, ""}, + {"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusOK, ""}, + {"router miss, bad redirect_uri query", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{"bad_redirect_uri": "http://localhost"}, http.StatusNotFound, ""}, + {"bad redirect_uri missing scheme", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "localhost"}, http.StatusBadRequest, "{\"error\":\"malformed redirect uri\"}\n"}, + {"bad http method", opts, http.MethodPost, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusMethodNotAllowed, ""}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -428,3 +365,83 @@ func TestProxy_ProgrammaticLogin(t *testing.T) { }) } } + +func TestProxy_ProgrammaticCallback(t *testing.T) { + t.Parallel() + opts := testOptions(t) + tests := []struct { + name string + options config.Options + + method string + + redirectURI string + + headers map[string]string + qp map[string]string + + cipher encoding.MarshalUnmarshaler + sessionStore sessions.SessionStore + authorizer clients.Authorizer + wantStatus int + wantBody string + }{ + {"good", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, + {"good programmatic", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, + {"bad decrypt", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"bad save session", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"bad base64", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + {"malformed redirect", opts, http.MethodGet, "http://pomerium.io/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + p, err := New(tt.options) + if err != nil { + t.Fatal(err) + } + p.encoder = tt.cipher + p.sessionStore = tt.sessionStore + p.AuthorizeClient = tt.authorizer + p.UpdateOptions(tt.options) + redirectURI, _ := url.Parse(tt.redirectURI) + queryString := redirectURI.Query() + for k, v := range tt.qp { + queryString.Set(k, v) + } + redirectURI.RawQuery = queryString.Encode() + + uri := &url.URL{Path: "/"} + if tt.qp != nil { + qu := uri.Query() + for k, v := range tt.qp { + qu.Set(k, v) + } + qu.Set(urlutil.QueryRedirectURI, redirectURI.String()) + uri.RawQuery = qu.Encode() + } + + r := httptest.NewRequest(tt.method, uri.String(), nil) + + r.Header.Set("Accept", "application/json") + if len(tt.headers) != 0 { + for k, v := range tt.headers { + r.Header.Set(k, v) + } + } + + w := httptest.NewRecorder() + p.ProgrammaticCallback(w, r) + if status := w.Code; status != tt.wantStatus { + t.Errorf("status code: got %v want %v", status, tt.wantStatus) + t.Errorf("\n%+v", w.Body.String()) + } + + if tt.wantBody != "" { + body := w.Body.String() + if diff := cmp.Diff(body, tt.wantBody); diff != "" { + t.Errorf("wrong body\n%s", diff) + } + } + }) + } +} diff --git a/proxy/middleware.go b/proxy/middleware.go index 6b569cb9a..0aa7aa1e2 100644 --- a/proxy/middleware.go +++ b/proxy/middleware.go @@ -30,39 +30,37 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession") defer span.End() - if err := p.authenticate(false, w, r.WithContext(ctx)); err != nil { + + if s, err := sessions.FromContext(ctx); err != nil { log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session") + p.sessionStore.ClearSession(w, r) + if s != nil && s.Programmatic { + httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err)) + return + } + signinURL := *p.authenticateSigninURL + q := signinURL.Query() + q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String()) + signinURL.RawQuery = q.Encode() + httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound) return } + p.addPomeriumHeaders(w, r) next.ServeHTTP(w, r.WithContext(ctx)) }) } -// authenticate authenticates a user and sets an appropriate response type, -// redirect to authenticate or error handler depending on if err on failure is set. -func (p *Proxy) authenticate(errOnFailure bool, w http.ResponseWriter, r *http.Request) error { +func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) { s, err := sessions.FromContext(r.Context()) - if err != nil { - p.sessionStore.ClearSession(w, r) - - if errOnFailure || (s != nil && s.Programmatic) { - httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err)) - return err - } - uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r)) - httputil.Redirect(w, r, uri.String(), http.StatusFound) - return err + if err == nil && s != nil { + r.Header.Set(HeaderUserID, s.Subject) + r.Header.Set(HeaderEmail, s.RequestEmail()) + r.Header.Set(HeaderGroups, s.RequestGroups()) + w.Header().Set(HeaderUserID, s.Subject) + w.Header().Set(HeaderEmail, s.RequestEmail()) + w.Header().Set(HeaderGroups, s.RequestGroups()) } - // add pomerium's headers to the downstream request - r.Header.Set(HeaderUserID, s.Subject) - r.Header.Set(HeaderEmail, s.RequestEmail()) - r.Header.Set(HeaderGroups, s.RequestGroups()) - // and upstream - w.Header().Set(HeaderUserID, s.Subject) - w.Header().Set(HeaderEmail, s.RequestEmail()) - w.Header().Set(HeaderGroups, s.RequestGroups()) - return nil } // AuthorizeSession is middleware to enforce a user is authorized for a request diff --git a/proxy/middleware_test.go b/proxy/middleware_test.go index a8a8268e8..5d72335ef 100644 --- a/proxy/middleware_test.go +++ b/proxy/middleware_test.go @@ -20,7 +20,6 @@ func TestProxy_AuthenticateSession(t *testing.T) { t.Parallel() fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") fmt.Fprint(w, http.StatusText(http.StatusOK)) w.WriteHeader(http.StatusOK) }) @@ -70,7 +69,6 @@ func TestProxy_AuthorizeSession(t *testing.T) { t.Parallel() fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") fmt.Fprint(w, http.StatusText(http.StatusOK)) w.WriteHeader(http.StatusOK) }) @@ -131,7 +129,6 @@ func TestProxy_SignRequest(t *testing.T) { t.Parallel() fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") fmt.Fprint(w, http.StatusText(http.StatusOK)) w.WriteHeader(http.StatusOK) }) @@ -184,7 +181,6 @@ func TestProxy_SetResponseHeaders(t *testing.T) { t.Parallel() fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.Header().Set("X-Content-Type-Options", "nosniff") var sb strings.Builder for k, v := range r.Header { k = strings.ToLower(k) diff --git a/proxy/proxy.go b/proxy/proxy.go index cebc9d6ad..dc9c65fb2 100755 --- a/proxy/proxy.go +++ b/proxy/proxy.go @@ -187,7 +187,7 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error { if opts.ForwardAuthURL != nil { // if a forward auth endpoint is set, register its handlers - h := r.Host(opts.ForwardAuthURL.Host).Subrouter() + h := r.Host(opts.ForwardAuthURL.Hostname()).Subrouter() h.PathPrefix("/").Handler(p.registerFwdAuthHandlers()) } diff --git a/scripts/programmatic_access.py b/scripts/programmatic_access.py index 7d2d12d98..e091dea82 100755 --- a/scripts/programmatic_access.py +++ b/scripts/programmatic_access.py @@ -82,7 +82,9 @@ def main(): # initial login to make sure we have our credential if args.login: dst = urllib.parse.urlparse(args.dst) - query_params = {"redirect_uri": "http://{}:{}".format(args.server, args.port)} + query_params = { + "pomerium_redirect_uri": "http://{}:{}".format(args.server, args.port) + } enc_query_params = urllib.parse.urlencode(query_params) dst_login = "{}://{}{}?{}".format( dst.scheme, dst.hostname, "/.pomerium/api/v1/login", enc_query_params,