proxy: fix forward auth, request signing

Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
Bobby DeSimone 2019-11-25 14:29:52 -08:00
parent ec9607d1d5
commit 0f6a9d7f1d
No known key found for this signature in database
GPG key ID: AEE4CF12FE86D07E
32 changed files with 928 additions and 522 deletions

View file

@ -24,10 +24,10 @@ import (
// CSPHeaders are the content security headers added to the service's handlers // CSPHeaders are the content security headers added to the service's handlers
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src
var CSPHeaders = map[string]string{ var CSPHeaders = map[string]string{
"Content-Security-Policy": "default-src 'none'; style-src 'self'" + "Content-Security-Policy": "default-src 'none'; style-src " +
" 'sha256-z9MsgkMbQjRSLxzAfN55jB3a9pP0PQ4OHFH8b4iDP6s=' " + "'sha256-spMkVDoBBY86p0RC1fBYwdnGyMypJM8eG57+p3VASyk=' " +
" 'sha256-qnVkQSG7pWu17hBhIw0kCpfEB3XGvt0mNRa6+uM6OUU=' " + "'sha256-qnVkQSG7pWu17hBhIw0kCpfEB3XGvt0mNRa6+uM6OUU=' " +
" 'sha256-qOdRsNZhtR+htazbcy7guQl3Cn1cqOw1FcE4d3llae0='; " + "'sha256-qOdRsNZhtR+htazbcy7guQl3Cn1cqOw1FcE4d3llae0=';" +
"img-src 'self';", "img-src 'self';",
"Referrer-Policy": "Same-origin", "Referrer-Policy": "Same-origin",
} }
@ -54,7 +54,8 @@ func (a *Authenticate) Handler() http.Handler {
v := r.PathPrefix("/.pomerium").Subrouter() v := r.PathPrefix("/.pomerium").Subrouter()
c := cors.New(cors.Options{ c := cors.New(cors.Options{
AllowOriginRequestFunc: func(r *http.Request, _ string) bool { AllowOriginRequestFunc: func(r *http.Request, _ string) bool {
return middleware.ValidateRedirectURI(r, a.sharedKey) err := middleware.ValidateRequestURL(r, a.sharedKey)
return err == nil
}, },
AllowCredentials: true, AllowCredentials: true,
AllowedHeaders: []string{"*"}, AllowedHeaders: []string{"*"},
@ -111,71 +112,84 @@ func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessio
// RobotsTxt handles the /robots.txt route. // RobotsTxt handles the /robots.txt route.
func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "User-agent: *\nDisallow: /") fmt.Fprintf(w, "User-agent: *\nDisallow: /")
} }
// SignIn handles to authenticating a user. // SignIn handles to authenticating a user.
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) { func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
// grab and parse our redirect_uri redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
return return
} }
// create a clone of the redirect URI, unless this is a programmatic request
// in which case we will redirect back to proxy's callback endpoint
callbackURL, _ := urlutil.DeepCopy(redirectURL)
q := redirectURL.Query() jwtAudience := []string{a.RedirectURL.Hostname(), redirectURL.Hostname()}
if q.Get("pomerium_programmatic_destination_url") != "" { var callbackURL *url.URL
callbackURL, err = urlutil.ParseAndValidateURL(q.Get("pomerium_programmatic_destination_url")) // if the callback is explicitly set, set it and add an additional audience
if callbackStr := r.FormValue(urlutil.QueryCallbackURI); callbackStr != "" {
callbackURL, err = urlutil.ParseAndValidateURL(callbackStr)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return return
} }
jwtAudience = append(jwtAudience, callbackURL.Hostname())
} else {
// otherwise, assume callback is the same host as redirect
callbackURL, _ = urlutil.DeepCopy(redirectURL)
callbackURL.Path = "/.pomerium/callback/"
callbackURL.RawQuery = ""
} }
// add an additional claim for the forward-auth host, if set
if fwdAuth := r.FormValue(urlutil.QueryForwardAuth); fwdAuth != "" {
jwtAudience = append(jwtAudience, fwdAuth)
}
s, err := sessions.FromContext(r.Context()) s, err := sessions.FromContext(r.Context())
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return return
} }
s.SetImpersonation(q.Get("impersonate_email"), q.Get("impersonate_group"))
newSession := s.NewSession(a.RedirectURL.Host, []string{a.RedirectURL.Host, callbackURL.Host}) s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups))
if q.Get("pomerium_programmatic_destination_url") != "" {
newSession := s.NewSession(a.RedirectURL.Host, jwtAudience)
callbackParams := callbackURL.Query()
if r.FormValue(urlutil.QueryIsProgrammatic) == "true" {
newSession.Programmatic = true newSession.Programmatic = true
encSession, err := a.encryptedEncoder.Marshal(newSession) encSession, err := a.encryptedEncoder.Marshal(newSession)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return return
} }
q.Set("pomerium_refresh_token", string(encSession)) callbackParams.Set(urlutil.QueryRefreshToken, string(encSession))
callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
} }
// sign the route session, as a JWT // sign the route session, as a JWT
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession(DefaultSessionDuration)) signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession(DefaultSessionDuration))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return return
} }
// encrypt our route-based token JWT avoiding any accidental logging // encrypt our route-based token JWT avoiding any accidental logging
encryptedJWT := cryptutil.Encrypt(a.sharedCipher, signedJWT, nil) encryptedJWT := cryptutil.Encrypt(a.sharedCipher, signedJWT, nil)
// base64 our encrypted payload for URL-friendlyness // base64 our encrypted payload for URL-friendlyness
encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT) encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT)
// add our encoded and encrypted route-session JWT to a query param // add our encoded and encrypted route-session JWT to a query param
q.Set("pomerium_jwt", encodedJWT) callbackParams.Set(urlutil.QuerySessionEncrypted, encodedJWT)
callbackParams.Set(urlutil.QueryRedirectURI, redirectURL.String())
redirectURL.RawQuery = q.Encode() callbackURL.RawQuery = callbackParams.Encode()
callbackURL.Path = "/.pomerium/callback"
// build our hmac-d redirect URL with our session, pointing back to the // build our hmac-d redirect URL with our session, pointing back to the
// proxy's callback URL which is responsible for setting our new route-session // proxy's callback URL which is responsible for setting our new route-session
uri := urlutil.SignedRedirectURL(a.sharedKey, callbackURL, redirectURL) uri := urlutil.NewSignedURL(a.sharedKey, callbackURL)
httputil.Redirect(w, r, uri.String(), http.StatusFound) httputil.Redirect(w, r, uri.String(), http.StatusFound)
} }
@ -193,7 +207,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err))
return return
} }
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")) redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
return return

View file

@ -16,6 +16,7 @@ import (
"github.com/pomerium/pomerium/internal/identity" "github.com/pomerium/pomerium/internal/identity"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/templates" "github.com/pomerium/pomerium/internal/templates"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
@ -108,14 +109,17 @@ func TestAuthenticate_SignIn(t *testing.T) {
encoder encoding.MarshalUnmarshaler encoder encoding.MarshalUnmarshaler
wantCode int wantCode int
}{ }{
{"good", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"session not valid", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad redirect uri query", "", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, {"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"bad marshal", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, {"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"session error", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, {"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound}, {"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"encrypted encoder error", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest}, {"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest}, {"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -139,8 +143,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
queryString.Set(k, v) queryString.Set(k, v)
} }
uri.RawQuery = queryString.Encode() uri.RawQuery = queryString.Encode()
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
r := httptest.NewRequest(http.MethodGet, "/?redirect_uri="+uri.String(), nil)
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
state, err := tt.session.LoadSession(r) state, err := tt.session.LoadSession(r)
ctx := r.Context() ctx := r.Context()
@ -195,7 +198,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
params, _ := url.ParseQuery(u.RawQuery) params, _ := url.ParseQuery(u.RawQuery)
params.Add("sig", tt.sig) params.Add("sig", tt.sig)
params.Add("ts", tt.ts) params.Add("ts", tt.ts)
params.Add("redirect_uri", tt.redirectURL) params.Add(urlutil.QueryRedirectURI, tt.redirectURL)
u.RawQuery = params.Encode() u.RawQuery = params.Encode()
r := httptest.NewRequest(tt.method, u.String(), nil) r := httptest.NewRequest(tt.method, u.String(), nil)
state, _ := tt.sessionStore.LoadSession(r) state, _ := tt.sessionStore.LoadSession(r)
@ -307,24 +310,26 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
t.Parallel() t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprintln(w, "RVSI FILIVS CAISAR") fmt.Fprintln(w, "RVSI FILIVS CAISAR")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
tests := []struct { tests := []struct {
name string name string
headers map[string]string
session sessions.SessionStore session sessions.SessionStore
ctxError error ctxError error
provider identity.Authenticator provider identity.Authenticator
wantStatus int wantStatus int
}{ }{
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK}, {"good", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound}, {"invalid session", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
{"good refresh expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound}, {"good refresh expired", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
{"expired,refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound}, {"expired,refresh error", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
{"expired,save error", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound}, {"expired,save error", nil, &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -347,7 +352,11 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
r = r.WithContext(ctx) r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
if len(tt.headers) != 0 {
for k, v := range tt.headers {
r.Header.Set(k, v)
}
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
got := a.VerifySession(fn) got := a.VerifySession(fn)

View file

@ -161,7 +161,7 @@ func newGlobalRouter(o *config.Options) *mux.Router {
if len(o.Headers) != 0 { if len(o.Headers) != 0 {
mux.Use(middleware.SetHeaders(o.Headers)) mux.Use(middleware.SetHeaders(o.Headers))
} }
mux.Use(log.ForwardedAddrHandler("fwd_ip")) mux.Use(log.HeadersHandler(httputil.HeadersXForwarded))
mux.Use(log.RemoteAddrHandler("ip")) mux.Use(log.RemoteAddrHandler("ip"))
mux.Use(log.UserAgentHandler("user_agent")) mux.Use(log.UserAgentHandler("user_agent"))
mux.Use(log.RefererHandler("referer")) mux.Use(log.RefererHandler("referer"))

View file

@ -189,7 +189,6 @@ var defaultOptions = Options{
CookieName: "_pomerium", CookieName: "_pomerium",
DefaultUpstreamTimeout: 30 * time.Second, DefaultUpstreamTimeout: 30 * time.Second,
Headers: map[string]string{ Headers: map[string]string{
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "SAMEORIGIN", "X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block", "X-XSS-Protection": "1; mode=block",
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload", "Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",

View file

@ -226,7 +226,6 @@ func TestOptionsFromViper(t *testing.T) {
CookieHTTPOnly: true, CookieHTTPOnly: true,
Headers: map[string]string{ Headers: map[string]string{
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload", "Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
"X-Content-Type-Options": "nosniff",
"X-Frame-Options": "SAMEORIGIN", "X-Frame-Options": "SAMEORIGIN",
"X-XSS-Protection": "1; mode=block", "X-XSS-Protection": "1; mode=block",
}}, }},

View file

@ -20,7 +20,7 @@ For example:
```bash ```bash
$ curl "https://httpbin.example.com/.pomerium/api/v1/login?redirect_uri=http://localhost:8000" $ curl "https://httpbin.example.com/.pomerium/api/v1/login?redirect_uri=http://localhost:8000"
https://authenticate.example.com/.pomerium/sign_in?redirect_uri=http%3A%2F%2Flocalhost%3Fpomerium_programmatic_destination_url%3Dhttps%253A%252F%252Fhttpbin.corp.example%252F.pomerium%252Fapi%252Fv1%252Flogin%253Fredirect_uri%253Dhttp%253A%252F%252Flocalhost&sig=hsLuzJctmgsN4kbMeQL16fe_FahjDBEcX0_kPYfg8bs%3D&ts=1573262981 https://authenticate.example.com/.pomerium/sign_in?redirect_uri=http%3A%2F%2Flocalhost%3Fpomerium_callback_uri%3Dhttps%253A%252F%252Fhttpbin.corp.example%252F.pomerium%252Fapi%252Fv1%252Flogin%253Fredirect_uri%253Dhttp%253A%252F%252Flocalhost&sig=hsLuzJctmgsN4kbMeQL16fe_FahjDBEcX0_kPYfg8bs%3D&ts=1573262981
``` ```
### Callback handler ### Callback handler

View file

@ -21,5 +21,5 @@ spec:
paths: paths:
- path: / - path: /
backend: backend:
serviceName: dashboard-kubernetes-dashboard serviceName: helm-dashboard-kubernetes-dashboard
servicePort: https servicePort: https

5
go.sum
View file

@ -176,8 +176,6 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR
github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik= github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik=
github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU= github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/zerolog v1.14.3 h1:4EGfSkR2hJDB0s3oFfrlPqjU1e4WLncergLil3nEKW0=
github.com/rs/zerolog v1.14.3/go.mod h1:3WXPzbXEEliJ+a6UFE4vhIxV8qR1EML6ngzP9ug4eYg=
github.com/rs/zerolog v1.16.0 h1:AaELmZdcJHT8m6oZ5py4213cdFK8XGXkB3dFdAQ+P7Q= github.com/rs/zerolog v1.16.0 h1:AaELmZdcJHT8m6oZ5py4213cdFK8XGXkB3dFdAQ+P7Q=
github.com/rs/zerolog v1.16.0/go.mod h1:9nvC1axdVrAHcu/s9taAVfBuIdTZLVQmKQyvrUjF5+I= github.com/rs/zerolog v1.16.0/go.mod h1:9nvC1axdVrAHcu/s9taAVfBuIdTZLVQmKQyvrUjF5+I=
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
@ -221,8 +219,6 @@ golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnf
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf h1:fnPsqIDRbCSgumaMCRpoIoF2s4qxv0xSSS0BVZUE/ss=
golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4 h1:PDpCLFAH/YIX0QpHPf2eO7L4rC2OOirBrKtXTLLiNTY= golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4 h1:PDpCLFAH/YIX0QpHPf2eO7L4rC2OOirBrKtXTLLiNTY=
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
@ -306,7 +302,6 @@ golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3
golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=

View file

@ -1,9 +0,0 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
const (
// HeaderPomeriumResponse is set when pomerium itself creates a response,
// as opposed to the downstream application and can be used to distinguish
// between an application error, and a pomerium related error when debugging.
// Especially useful when working with single page apps (SPA).
HeaderPomeriumResponse = "x-pomerium-intercepted-response"
)

View file

@ -79,6 +79,8 @@ func ErrorResponse(w http.ResponseWriter, r *http.Request, e error) {
writeJSONResponse(w, statusCode, response) writeJSONResponse(w, statusCode, response)
} else { } else {
w.WriteHeader(statusCode) w.WriteHeader(statusCode)
w.Header().Set("Content-Type", "text/html")
t := struct { t := struct {
Code int Code int
Title string Title string

View file

@ -0,0 +1,57 @@
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
const (
// HeaderPomeriumResponse is set when pomerium itself creates a response,
// as opposed to the downstream application and can be used to distinguish
// between an application error, and a pomerium related error when debugging.
// Especially useful when working with single page apps (SPA).
HeaderPomeriumResponse = "x-pomerium-intercepted-response"
)
// HeadersContentSecurityPolicy are the content security headers added to the service's handlers
// by default includes profile photo exceptions for supported identity providers.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src
var HeadersContentSecurityPolicy = map[string]string{
"Content-Security-Policy": "default-src 'none'; style-src 'self'; img-src *;",
"Referrer-Policy": "Same-origin",
}
// Forward headers contains information from the client-facing side of proxy
// servers that is altered or lost when a proxy is involved in the path of the
// request.
//
// https://tools.ietf.org/html/rfc7239
// https://en.wikipedia.org/wiki/X-Forwarded-For
const (
HeaderForwardedFor = "X-Forwarded-For"
HeaderForwardedHost = "X-Forwarded-Host"
HeaderForwardedMethod = "X-Forwarded-Method" // traefik
HeaderForwardedPort = "X-Forwarded-Port"
HeaderForwardedProto = "X-Forwarded-Proto"
HeaderForwardedServer = "X-Forwarded-Server"
HeaderForwardedURI = "X-Forwarded-Uri" // traefik
HeaderOriginalMethod = "X-Original-Method" // nginx
HeaderOriginalURL = "X-Original-Url" // nginx
HeaderRealIP = "X-Real-Ip"
HeaderSentFrom = "X-Sent-From"
)
// HeadersXForwarded is the slice of the header keys used to contain information
// from the client-facing side of proxy servers that is altered or lost when a
// proxy is involved in the path of the request.
//
// https://tools.ietf.org/html/rfc7239
// https://en.wikipedia.org/wiki/X-Forwarded-For
var HeadersXForwarded = []string{
HeaderForwardedFor,
HeaderForwardedHost,
HeaderForwardedMethod,
HeaderForwardedPort,
HeaderForwardedProto,
HeaderForwardedServer,
HeaderForwardedURI,
HeaderOriginalMethod,
HeaderOriginalURL,
HeaderRealIP,
HeaderSentFrom,
}

View file

@ -113,7 +113,7 @@ func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.Sta
return nil, err return nil, err
} }
s, err := sessions.NewStateFromTokens(idToken, oauth2Token, p.RedirectURL.Host) s, err := sessions.NewStateFromTokens(idToken, oauth2Token, p.RedirectURL.Hostname())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View file

@ -6,7 +6,6 @@ import (
"fmt" "fmt"
"net" "net"
"net/http" "net/http"
"strings"
"time" "time"
"github.com/pomerium/pomerium/internal/middleware/responsewriter" "github.com/pomerium/pomerium/internal/middleware/responsewriter"
@ -172,20 +171,21 @@ func AccessHandler(f func(r *http.Request, status, size int, duration time.Durat
} }
} }
// ForwardedAddrHandler returns the client IP address from a request. If a // HeadersHandler adds the provided set of header keys to the log context.
// request goes through multiple proxies, the IP addresses of each successive //
// proxy is listed. This means, the right-most IP address is the IP address of // https://tools.ietf.org/html/rfc7239
// the most recent proxy and the left-most IP address is the IP address of the // https://en.wikipedia.org/wiki/X-Forwarded-For
// originating client.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For
func ForwardedAddrHandler(fieldKey string) func(next http.Handler) http.Handler { func HeadersHandler(headers []string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if ra := r.Header.Get("X-Forwarded-For"); ra != "" { for _, key := range headers {
log := zerolog.Ctx(r.Context()) if values := r.Header[key]; len(values) != 0 {
log.UpdateContext(func(c zerolog.Context) zerolog.Context { log := zerolog.Ctx(r.Context())
return c.Strs(fieldKey, strings.Split(ra, ",")) log.UpdateContext(func(c zerolog.Context) zerolog.Context {
}) return c.Strs(key, values)
})
}
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
}) })

View file

@ -253,20 +253,20 @@ func BenchmarkDataRace(b *testing.B) {
}) })
} }
func TestForwardedAddrHandler(t *testing.T) { func TestLogHeadersHandler(t *testing.T) {
out := &bytes.Buffer{} out := &bytes.Buffer{}
r := httptest.NewRequest(http.MethodGet, "/", nil) r := httptest.NewRequest(http.MethodGet, "/", nil)
r.Header.Set("X-Forwarded-For", "proxy1,proxy2,proxy3") r.Header.Set("X-Forwarded-For", "proxy1,proxy2,proxy3")
h := ForwardedAddrHandler("fwd_ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := HeadersHandler([]string{"X-Forwarded-For"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
l := FromRequest(r) l := FromRequest(r)
l.Log().Msg("") l.Log().Msg("")
})) }))
h = NewHandler(zerolog.New(out))(h) h = NewHandler(zerolog.New(out))(h)
h.ServeHTTP(nil, r) h.ServeHTTP(nil, r)
if want, got := `{"fwd_ip":["proxy1","proxy2","proxy3"]}`+"\n", decodeIfBinary(out); want != got { if want, got := `{"X-Forwarded-For":["proxy1,proxy2,proxy3"]}`+"\n", decodeIfBinary(out); want != got {
t.Errorf("Invalid log output, got: %s, want: %s", got, want) t.Errorf("Invalid log output, got: %s, want: %s", got, want)
} }
} }

View file

@ -19,7 +19,6 @@ func TestCorsBypass(t *testing.T) {
t.Parallel() t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprint(w, http.StatusText(http.StatusOK)) fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })

View file

@ -1,13 +1,10 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware" package middleware // import "github.com/pomerium/pomerium/internal/middleware"
import ( import (
"encoding/base64"
"fmt"
"net/http" "net/http"
"strings" "strings"
"time" "time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/telemetry/trace" "github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/internal/urlutil"
@ -34,8 +31,8 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature") ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature")
defer span.End() defer span.End()
if !ValidateRedirectURI(r, sharedSecret) { if err := ValidateRequestURL(r, sharedSecret); err != nil {
httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil)) httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, err))
return return
} }
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
@ -43,27 +40,24 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
} }
} }
// ValidateRedirectURI takes a request and parses `redirect_uri`, `sig`, `ts` // ValidateRequestURL validates the current absolute request URL was signed
// and validates the supplied signature (`sig`)'s HMAC for validity. // by a given shared key.
func ValidateRedirectURI(r *http.Request, key string) bool { func ValidateRequestURL(r *http.Request, key string) error {
return ValidSignature( return urlutil.NewSignedURL(key, urlutil.GetAbsoluteURL(r)).Validate()
r.FormValue("redirect_uri"),
r.FormValue("sig"),
r.FormValue("ts"),
key)
} }
// Healthcheck endpoint middleware useful to setting up a path like // Healthcheck endpoint middleware useful to setting up a path like
// `/ping` that load balancers or uptime testing external services // `/ping` that load balancers or uptime testing external services
// can make a request before hitting any routes. It's also convenient // can make a request before hitting any routes. It's also convenient
// to place this above ACL middlewares as well. // to place this above ACL middlewares as well.
//
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html
func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler { func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler {
f := func(next http.Handler) http.Handler { f := func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) { fn := func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "middleware.Healthcheck") ctx, span := trace.StartSpan(r.Context(), "middleware.Healthcheck")
defer span.End() defer span.End()
if strings.EqualFold(r.URL.Path, endpoint) { if strings.EqualFold(r.URL.Path, endpoint) {
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html
if r.Method != http.MethodGet && r.Method != http.MethodHead { if r.Method != http.MethodGet && r.Method != http.MethodHead {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return return
@ -82,26 +76,6 @@ func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler {
return f return f
} }
// ValidSignature checks to see if a signature is valid. Compares hmac of
// redirect uri, timestamp, and secret and signature.
func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
return false
}
_, err := urlutil.ParseAndValidateURL(redirectURI)
if err != nil {
return false
}
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
if err != nil {
return false
}
if err := cryptutil.ValidTimestamp(timestamp); err != nil {
return false
}
return cryptutil.CheckHMAC([]byte(fmt.Sprint(redirectURI, timestamp)), requestSig, secret)
}
// StripCookie strips the cookie from the downstram request. // StripCookie strips the cookie from the downstram request.
func StripCookie(cookieName string) func(next http.Handler) http.Handler { func StripCookie(cookieName string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {

View file

@ -1,7 +1,6 @@
package middleware // import "github.com/pomerium/pomerium/internal/middleware" package middleware
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -9,47 +8,10 @@ import (
"testing" "testing"
"time" "time"
"github.com/pomerium/pomerium/internal/cryptutil" "github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/internal/urlutil"
) )
func hmacHelperFunc(rawRedirect string, timestamp time.Time, secret string) []byte {
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
return cryptutil.GenerateHMAC(data, secret)
}
func Test_ValidSignature(t *testing.T) {
t.Parallel()
goodURL := "https://example.com/redirect"
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
now := fmt.Sprint(time.Now().Unix())
rawSig := hmacHelperFunc(goodURL, time.Now(), secretA)
sig := base64.URLEncoding.EncodeToString(rawSig)
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
tests := []struct {
name string
redirectURI string
sigVal string
timestamp string
secret string
want bool
}{
{"good signature", goodURL, sig, now, secretA, true},
{"empty redirect url", "", sig, now, secretA, false},
{"bad redirect url", "https://google.com^", sig, now, secretA, false},
{"malformed signature", goodURL, sig + "^", now, "&*&@**($&#(", false},
{"malformed timestamp", goodURL, sig, now + "^", secretA, false},
{"stale timestamp", goodURL, sig, staleTime, secretA, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := ValidSignature(tt.redirectURI, tt.sigVal, tt.timestamp, tt.secret); got != tt.want {
t.Errorf("ValidSignature() = %v, want %v", got, tt.want)
}
})
}
}
func TestSetHeaders(t *testing.T) { func TestSetHeaders(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
@ -79,56 +41,6 @@ func TestSetHeaders(t *testing.T) {
} }
} }
func TestValidateSignature(t *testing.T) {
t.Parallel()
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
now := fmt.Sprint(time.Now().Unix())
goodURL := "https://example.com/redirect"
rawSig := hmacHelperFunc(goodURL, time.Now(), secretA)
sig := base64.URLEncoding.EncodeToString(rawSig)
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
tests := []struct {
name string
sharedSecret string
redirectURI string
sig string
ts string
status int
}{
{"valid signature", secretA, goodURL, sig, now, http.StatusOK},
{"stale signature", secretA, goodURL, sig, staleTime, http.StatusBadRequest},
{"malformed", secretA, goodURL, "%zzzzz", now, http.StatusBadRequest},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := url.Values{}
v.Set("redirect_uri", tt.redirectURI)
v.Set("ts", tt.ts)
v.Set("sig", tt.sig)
req := &http.Request{
Method: http.MethodGet,
URL: &url.URL{RawQuery: v.Encode()}}
if tt.name == "malformed" {
req.URL.RawQuery = "sig=%zzzzz"
}
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("Hi"))
})
rr := httptest.NewRecorder()
handler := ValidateSignature(tt.sharedSecret)(testHandler)
handler.ServeHTTP(rr, req)
if rr.Code != tt.status {
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
t.Errorf("%s", rr.Body)
}
})
}
}
func TestHealthCheck(t *testing.T) { func TestHealthCheck(t *testing.T) {
t.Parallel() t.Parallel()
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -212,7 +124,6 @@ func TestTimeoutHandlerFunc(t *testing.T) {
t.Parallel() t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprint(w, http.StatusText(http.StatusOK)) fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
@ -242,3 +153,42 @@ func TestTimeoutHandlerFunc(t *testing.T) {
}) })
} }
} }
func TestValidateSignature(t *testing.T) {
t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK)
})
tests := []struct {
name string
secretA string
secretB string
wantStatus int
wantBody string
}{
{"good", "secret", "secret", http.StatusOK, http.StatusText(http.StatusOK)},
{"secret mistmatch", "secret", "hunter42", http.StatusBadRequest, "{\"error\":\"invalid signature\"}\n"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
signedURL := urlutil.NewSignedURL(tt.secretB, &url.URL{Scheme: "https", Host: "pomerium.io"})
r := httptest.NewRequest(http.MethodGet, signedURL.String(), nil)
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
got := ValidateSignature(tt.secretA)(fn)
got.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
}
body := w.Body.String()
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
t.Errorf("SignRequest() %s", diff)
t.Errorf("%s", signedURL)
}
})
}
}

View file

@ -4,6 +4,8 @@ import (
"context" "context"
"errors" "errors"
"net/http" "net/http"
"github.com/pomerium/pomerium/internal/urlutil"
) )
// Context keys // Context keys
@ -41,8 +43,8 @@ func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, er
return state, err return state, err
} }
if state != nil { if state != nil {
err := state.Verify(r.Host) err := state.Verify(urlutil.StripPort(r.Host))
return state, err // N.B.: state is _not nil_ return state, err // N.B.: state is _not_ nil_
} }
} }

View file

@ -0,0 +1,14 @@
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
import "errors"
var (
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
ErrExpired = errors.New("internal/urlutil: validation failed, url hmac is expired")
// ErrIssuedInTheFuture indicates that the issued field is in the future.
ErrIssuedInTheFuture = errors.New("internal/urlutil: validation field, url hmac issued in the future")
// ErrNumericDateMalformed indicates a malformed unix timestamp was found while parsing.
ErrNumericDateMalformed = errors.New("internal/urlutil: malformed unix timestamp field")
)

View file

@ -0,0 +1,24 @@
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
// Common query parameters used to set and send data between Pomerium
// services over HTTP calls and redirects. They are typically used in
// conjunction with a HMAC to ensure authenticity.
const (
QueryCallbackURI = "pomerium_callback_uri"
QueryImpersonateEmail = "pomerium_impersonate_email"
QueryImpersonateGroups = "pomerium_impersonate_groups"
QueryIsProgrammatic = "pomerium_programmatic"
QueryForwardAuth = "pomerium_forward_auth"
QueryPomeriumJWT = "pomerium_jwt"
QuerySessionEncrypted = "pomerium_session_encrypted"
QueryRedirectURI = "pomerium_redirect_uri"
QueryRefreshToken = "pomerium_refresh_token"
)
// URL signature based query params used for verifying the authenticity of a URL.
const (
QueryHmacExpiry = "pomerium_expiry"
QueryHmacIssued = "pomerium_issued"
QueryHmacSignature = "pomerium_signature"
QueryHmacURI = "pomerium_uri"
)

130
internal/urlutil/signed.go Normal file
View file

@ -0,0 +1,130 @@
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
import (
"encoding/base64"
"fmt"
"net/url"
"strconv"
"time"
"github.com/pomerium/pomerium/internal/cryptutil"
)
// SignedURL is a shared-key HMAC wrapped URL.
type SignedURL struct {
uri url.URL
key string
signed bool
// mockable time for testing
timeNow func() time.Time
}
// NewSignedURL creates a new copy of a URL that can be signed with a shared key.
//
// N.B. It is the user's responsibility to make sure the key is 256 bits and
// the url is not nil.
func NewSignedURL(key string, uri *url.URL) *SignedURL {
return &SignedURL{uri: *uri, key: key, timeNow: time.Now} // uri is copied
}
// Sign creates a shared-key HMAC signed URL.
func (su *SignedURL) Sign() *url.URL {
now := su.timeNow()
issued := newNumericDate(now)
expiry := newNumericDate(now.Add(5 * time.Minute))
params := su.uri.Query()
params.Set(QueryHmacIssued, fmt.Sprint(issued))
params.Set(QueryHmacExpiry, fmt.Sprint(expiry))
su.uri.RawQuery = params.Encode()
params.Set(QueryHmacSignature, hmacURL(su.key, su.uri.String(), issued, expiry))
su.uri.RawQuery = params.Encode()
su.signed = true
return &su.uri
}
// String implements the stringer interface and returns a signed URL string.
func (su *SignedURL) String() string {
if !su.signed {
su.Sign()
su.signed = true
}
return su.uri.String()
}
// Validate checks to see if a signed URL is valid.
func (su *SignedURL) Validate() error {
now := su.timeNow()
params := su.uri.Query()
sig, err := base64.URLEncoding.DecodeString(params.Get(QueryHmacSignature))
if err != nil {
return fmt.Errorf("internal/urlutil: malformed signature %w", err)
}
params.Del(QueryHmacSignature)
su.uri.RawQuery = params.Encode()
issued, err := newNumericDateFromString(params.Get(QueryHmacIssued))
if err != nil {
return fmt.Errorf("internal/urlutil: issued %w", err)
}
expiry, err := newNumericDateFromString(params.Get(QueryHmacExpiry))
if err != nil {
return fmt.Errorf("internal/urlutil: expiry %w", err)
}
if expiry != nil && now.Add(-DefaultLeeway).After(expiry.Time()) {
return ErrExpired
}
if issued != nil && now.Add(DefaultLeeway).Before(issued.Time()) {
return ErrIssuedInTheFuture
}
validHMAC := cryptutil.CheckHMAC(
[]byte(fmt.Sprint(su.uri.String(), issued, expiry)),
sig,
su.key)
if !validHMAC {
return fmt.Errorf("internal/urlutil: hmac failed %s", su.uri.String())
}
return nil
}
// hmacURL takes a redirect url string and timestamp and returns the base64
// encoded HMAC result.
func hmacURL(key string, data ...interface{}) string {
h := cryptutil.GenerateHMAC([]byte(fmt.Sprint(data...)), key)
return base64.URLEncoding.EncodeToString(h)
}
// numericDate used because we don't need the precision of a typical time.Time.
type numericDate int64
func newNumericDate(t time.Time) *numericDate {
if t.IsZero() {
return nil
}
out := numericDate(t.Unix())
return &out
}
func newNumericDateFromString(s string) (*numericDate, error) {
i, err := strconv.ParseInt(s, 10, 64)
if err != nil {
return nil, ErrNumericDateMalformed
}
out := numericDate(i)
return &out, nil
}
func (n *numericDate) Time() time.Time {
if n == nil {
return time.Time{}
}
return time.Unix(int64(*n), 0)
}
func (n *numericDate) String() string {
return strconv.FormatInt(int64(*n), 10)
}

View file

@ -0,0 +1,97 @@
package urlutil
import (
"net/url"
"testing"
"time"
"github.com/google/go-cmp/cmp"
)
func TestSignedURL(t *testing.T) {
original := time.Unix(1574117851, 0) // ;-)
tests := []struct {
name string
key string
uri url.URL
origTime func() time.Time
newTime func() time.Time
wantStr string
want url.URL
wantErr bool
}{
{"good", "test-key", url.URL{Scheme: "https", Host: "pomerium.io"},
func() time.Time { return original }, func() time.Time { return original },
"https://pomerium.io?pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D",
url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"},
false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
signedURL := NewSignedURL(tt.key, &tt.uri)
signedURL.timeNow = tt.origTime
if diff := cmp.Diff(signedURL.String(), tt.wantStr); diff != "" {
t.Errorf("signedURL() = %v", diff)
}
signedURL = NewSignedURL(tt.key, &tt.uri)
signedURL.timeNow = tt.origTime
got := signedURL.Sign()
if diff := cmp.Diff(*got, tt.want); diff != "" {
t.Errorf("NewSignedURL() = %s", diff)
}
err := signedURL.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
}
u, err := url.Parse(signedURL.String())
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(u, got); diff != "" {
t.Errorf("signedURL() = %v", diff)
}
// subsequent string calls shouldn't result in a change
u, err = url.Parse(signedURL.String())
if err != nil {
t.Fatal(err)
}
if diff := cmp.Diff(u, got); diff != "" {
t.Errorf("signedURL() = %v", diff)
}
})
}
}
func TestSignedURL_Validate(t *testing.T) {
t.Parallel()
tests := []struct {
name string
uri url.URL
key string
timeNow func() time.Time
wantErr bool
}{
{"good", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0) }, false},
{"bad key", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "bad-key", func() time.Time { return time.Unix(1574117851, 0) }, true},
{"bad no expiry", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0) }, true},
{"bad issued", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0) }, true},
{"bad signature body", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=^"}, "test-key", func() time.Time { return time.Unix(1574117851, 0) }, true},
{"bad expired", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0).Add(time.Hour) }, true},
{"bad not yet valid", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0).Add(-time.Hour) }, true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
out := NewSignedURL(tt.key, &tt.uri)
out.timeNow = tt.timeNow
if err := out.Validate(); (err != nil) != tt.wantErr {
t.Errorf("SignedURL.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View file

@ -1,15 +1,16 @@
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil" package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
import ( import (
"encoding/base64"
"fmt" "fmt"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"sync"
"time" "time"
)
"github.com/pomerium/pomerium/internal/cryptutil" const (
// DefaultLeeway defines the default leeway for matching NotBefore/Expiry claims.
DefaultLeeway = 1.0 * time.Minute
) )
// StripPort returns a host, without any port number. // StripPort returns a host, without any port number.
@ -66,57 +67,6 @@ func DeepCopy(u *url.URL) (*url.URL, error) {
return ParseAndValidateURL(u.String()) return ParseAndValidateURL(u.String())
} }
var mockNow testTime
// testTime is safe to use concurrently.
type testTime struct {
sync.Mutex
mockNow int64
}
func (tt *testTime) setNow(n int64) {
tt.Lock()
tt.mockNow = n
tt.Unlock()
}
func (tt *testTime) now() int64 {
tt.Lock()
defer tt.Unlock()
return tt.mockNow
}
// timestamp returns the current timestamp, in seconds.
//
// For testing purposes, the function that generates the timestamp can be
// overridden. If not set, it will return time.Now().UTC().Unix().
func timestamp() int64 {
if mockNow.now() == 0 {
return time.Now().UTC().Unix()
}
return mockNow.now()
}
// SignedRedirectURL takes a destination URL and adds redirect_uri to it's
// query params, along with a timestamp and an keyed signature.
func SignedRedirectURL(key string, destination, u *url.URL) *url.URL {
now := timestamp()
rawURL := u.String()
params, _ := url.ParseQuery(destination.RawQuery) // handled by incoming mux
params.Set("redirect_uri", rawURL)
params.Set("ts", fmt.Sprint(now))
params.Set("sig", hmacURL(key, rawURL, now))
destination.RawQuery = params.Encode()
return destination
}
// hmacURL takes a redirect url string and timestamp and returns the base64
// encoded HMAC result.
func hmacURL(key, data string, timestamp int64) string {
h := cryptutil.GenerateHMAC([]byte(fmt.Sprint(data, timestamp)), key)
return base64.URLEncoding.EncodeToString(h)
}
// GetAbsoluteURL returns the current handler's absolute url. // GetAbsoluteURL returns the current handler's absolute url.
// https://stackoverflow.com/a/23152483 // https://stackoverflow.com/a/23152483
func GetAbsoluteURL(r *http.Request) *url.URL { func GetAbsoluteURL(r *http.Request) *url.URL {

View file

@ -108,48 +108,6 @@ func TestValidateURL(t *testing.T) {
} }
} }
func TestSignedRedirectURL(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mockedTime int64
key string
destination *url.URL
urlToSign *url.URL
want *url.URL
}{
{"good", 2, "hunter42", &url.URL{Host: "pomerium.io", Scheme: "https://"}, &url.URL{Host: "pomerium.io", Scheme: "https://", Path: "/ok"}, &url.URL{Host: "pomerium.io", Scheme: "https://", RawQuery: "redirect_uri=https%3A%2F%2F%3A%2F%2Fpomerium.io%2Fok&sig=7jdo1XFcmuhjBHnpfVhll5cXflYByeMnbp5kRz87CVQ%3D&ts=2"}},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockNow.setNow(tt.mockedTime)
got := SignedRedirectURL(tt.key, tt.destination, tt.urlToSign)
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("SignedRedirectURL() = diff %v", diff)
}
})
}
}
func Test_timestamp(t *testing.T) {
t.Parallel()
tests := []struct {
name string
dontWant int64
}{
{"if unset should never return", 0},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
mockNow.setNow(tt.dontWant)
if got := timestamp(); got == tt.dontWant {
t.Errorf("timestamp() = %v, dontWant %v", got, tt.dontWant)
}
})
}
}
func parseURLHelper(s string) *url.URL { func parseURLHelper(s string) *url.URL {
u, _ := url.Parse(s) u, _ := url.Parse(s)
return u return u

114
proxy/forward_auth.go Normal file
View file

@ -0,0 +1,114 @@
package proxy // import "github.com/pomerium/pomerium/proxy"
import (
"errors"
"fmt"
"net/http"
"net/url"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil"
)
func (p *Proxy) registerFwdAuthHandlers() http.Handler {
r := httputil.NewRouter()
r.StrictSlash(true)
r.Use(sessions.RetrieveSession(p.sessionStore))
r.Handle("/verify", http.HandlerFunc(p.nginxCallback)).
Queries("uri", "{uri}", urlutil.QuerySessionEncrypted, "", urlutil.QueryRedirectURI, "")
r.Handle("/", http.HandlerFunc(p.postSessionSetNOP)).
Queries("uri", "{uri}",
urlutil.QuerySessionEncrypted, "",
urlutil.QueryRedirectURI, "")
r.Handle("/", http.HandlerFunc(p.traefikCallback)).
HeadersRegexp(httputil.HeaderForwardedURI, urlutil.QuerySessionEncrypted)
r.Handle("/", p.Verify(false)).Queries("uri", "{uri}")
r.Handle("/verify", p.Verify(true)).Queries("uri", "{uri}")
return r
}
// postSessionSetNOP after successfully setting the
func (p *Proxy) postSessionSetNOP(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
httputil.Redirect(w, r, r.FormValue(urlutil.QueryRedirectURI), http.StatusFound)
}
func (p *Proxy) nginxCallback(w http.ResponseWriter, r *http.Request) {
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusUnauthorized)
}
func (p *Proxy) traefikCallback(w http.ResponseWriter, r *http.Request) {
forwardedURL, err := url.Parse(r.Header.Get(httputil.HeaderForwardedURI))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return
}
q := forwardedURL.Query()
redirectURLString := q.Get(urlutil.QueryRedirectURI)
encryptedSession := q.Get(urlutil.QuerySessionEncrypted)
if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
httputil.Redirect(w, r, redirectURLString, http.StatusFound)
}
// Verify checks a user's credentials for an arbitrary host. If the user
// is properly authenticated and is authorized to access the supplied host,
// a `200` http status code is returned. If the user is not authenticated, they
// will be redirected to the authenticate service to sign in with their identity
// provider. If the user is unauthorized, a `401` error is returned.
func (p *Proxy) Verify(verifyOnly bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
uri, err := urlutil.ParseAndValidateURL(r.FormValue("uri"))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("bad verification uri", http.StatusBadRequest, err))
return
}
s, err := sessions.FromContext(r.Context())
if errors.Is(err, sessions.ErrNoSessionFound) || errors.Is(err, sessions.ErrExpired) {
if verifyOnly {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
return
}
authN := *p.authenticateSigninURL
q := authN.Query()
q.Set(urlutil.QueryCallbackURI, uri.String())
q.Set(urlutil.QueryRedirectURI, uri.String()) // final destination
q.Set(urlutil.QueryForwardAuth, urlutil.StripPort(r.Host)) // add fwd auth to trusted audience
authN.RawQuery = q.Encode()
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &authN).String(), http.StatusFound)
return
} else if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
return
}
// depending on the configuration of the fronting proxy, the request Host
// and/or `X-Forwarded-Host` may be untrustd or change so we reverify
// the session's validity against the supplied uri
if err := s.Verify(uri.Hostname()); err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
return
}
p.addPomeriumHeaders(w, r)
if err := p.authorize(uri.Host, w, r); err != nil {
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "Access to %s is allowed.", uri.Host)
})
}

119
proxy/forward_auth_test.go Normal file
View file

@ -0,0 +1,119 @@
package proxy
import (
"errors"
"net/http"
"net/http/httptest"
"net/url"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/proxy/clients"
"gopkg.in/square/go-jose.v2/jwt"
)
func TestProxy_ForwardAuth(t *testing.T) {
t.Parallel()
opts := testOptions(t)
tests := []struct {
name string
options config.Options
ctxError error
method string
headers map[string]string
qp map[string]string
requestURI string
verifyURI string
cipher encoding.MarshalUnmarshaler
sessionStore sessions.SessionStore
authorizer clients.Authorizer
wantStatus int
wantBody string
}{
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."},
{"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
{"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
{"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
{"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
{"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
{"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"},
{"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"error\":\"authz error\"}\n"},
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"},
{"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
{"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
// traefik
{"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad traefik callback bad url", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: urlutil.QuerySessionEncrypted + ""}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
// nginx
{"good nginx callback redirect", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"good nginx callback set session okay but return unauthorized", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, ""},
{"bad nginx callback failed to set sesion", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString + "nope"}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(tt.options)
if err != nil {
t.Fatal(err)
}
p.encoder = tt.cipher
p.sessionStore = tt.sessionStore
p.AuthorizeClient = tt.authorizer
p.UpdateOptions(tt.options)
uri, err := url.Parse(tt.requestURI)
if err != nil {
t.Fatal(err)
}
queryString := uri.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
if tt.verifyURI != "" {
queryString.Set("uri", tt.verifyURI)
}
uri.RawQuery = queryString.Encode()
r := httptest.NewRequest(tt.method, uri.String(), nil)
state, _ := tt.sessionStore.LoadSession(r)
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Accept", "application/json")
if len(tt.headers) != 0 {
for k, v := range tt.headers {
r.Header.Set(k, v)
}
}
w := httptest.NewRecorder()
router := p.registerFwdAuthHandlers()
router.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", w.Body.String())
}
if tt.wantBody != "" {
body := w.Body.String()
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
t.Errorf("wrong body\n%s", diff)
}
}
})
}
}

View file

@ -19,7 +19,6 @@ import (
// registerDashboardHandlers returns the proxy service's ServeMux // registerDashboardHandlers returns the proxy service's ServeMux
func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router { func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
// dashboard subrouter
h := r.PathPrefix(dashboardURL).Subrouter() h := r.PathPrefix(dashboardURL).Subrouter()
// 1. Retrieve the user session and add it to the request context // 1. Retrieve the user session and add it to the request context
h.Use(sessions.RetrieveSession(p.sessionStore)) h.Use(sessions.RetrieveSession(p.sessionStore))
@ -32,19 +31,27 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieOptions.Name)), csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieOptions.Name)),
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)), csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
)) ))
// dashboard endpoints can be used by user's to view, or modify their session
h.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet) h.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet)
h.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost) h.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost)
h.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost) h.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost)
// Authenticate service callback handlers and middleware // Authenticate service callback handlers and middleware
// callback used to set route-scoped session and redirect back to destination
// only accept signed requests (hmac) from other trusted pomerium services
c := r.PathPrefix(dashboardURL + "/callback").Subrouter() c := r.PathPrefix(dashboardURL + "/callback").Subrouter()
// only accept payloads that have come from a trusted service (hmac)
c.Use(middleware.ValidateSignature(p.SharedKey)) c.Use(middleware.ValidateSignature(p.SharedKey))
c.HandleFunc("/", p.Callback).Queries("redirect_uri", "{redirect_uri}").Methods(http.MethodGet)
c.Path("/").HandlerFunc(p.ProgrammaticCallback).Methods(http.MethodGet).
Queries(urlutil.QueryIsProgrammatic, "true")
c.Path("/").HandlerFunc(p.Callback).Methods(http.MethodGet)
// Programmatic API handlers and middleware // Programmatic API handlers and middleware
a := r.PathPrefix(dashboardURL + "/api").Subrouter() a := r.PathPrefix(dashboardURL + "/api").Subrouter()
a.HandleFunc("/v1/login", p.ProgrammaticLogin).Queries("redirect_uri", "{redirect_uri}").Methods(http.MethodGet) // login api handler generates a user-navigable login url to authenticate
a.HandleFunc("/v1/login", p.ProgrammaticLogin).
Queries(urlutil.QueryRedirectURI, "").
Methods(http.MethodGet)
return r return r
} }
@ -52,7 +59,6 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
// RobotsTxt sets the User-Agent header in the response to be "Disallow" // RobotsTxt sets the User-Agent header in the response to be "Disallow"
func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) { func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, "User-agent: *\nDisallow: /") fmt.Fprintf(w, "User-agent: *\nDisallow: /")
} }
@ -62,12 +68,17 @@ func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
// the local session state. // the local session state.
func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) { func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"} redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
if uri, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")); err == nil && uri.String() != "" { if uri, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)); err == nil && uri.String() != "" {
redirectURL = uri redirectURL = uri
} }
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
signoutURL := *p.authenticateSignoutURL
q := signoutURL.Query()
q.Set(urlutil.QueryRedirectURI, redirectURL.String())
signoutURL.RawQuery = q.Encode()
p.sessionStore.ClearSession(w, r) p.sessionStore.ClearSession(w, r)
httputil.Redirect(w, r, uri.String(), http.StatusFound) httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signoutURL).String(), http.StatusFound)
} }
// UserDashboard lets users investigate, and refresh their current session. // UserDashboard lets users investigate, and refresh their current session.
@ -112,110 +123,95 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
// OK to impersonation // OK to impersonation
redirectURL := urlutil.GetAbsoluteURL(r) redirectURL := urlutil.GetAbsoluteURL(r)
redirectURL.Path = dashboardURL // redirect back to the dashboard redirectURL.Path = dashboardURL // redirect back to the dashboard
q := redirectURL.Query() signinURL := *p.authenticateSigninURL
q.Add("impersonate_email", r.FormValue("email")) q := signinURL.Query()
q.Add("impersonate_group", r.FormValue("group")) q.Set(urlutil.QueryRedirectURI, redirectURL.String())
redirectURL.RawQuery = q.Encode() q.Set(urlutil.QueryImpersonateEmail, r.FormValue("email"))
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String() q.Set(urlutil.QueryImpersonateGroups, r.FormValue("group"))
httputil.Redirect(w, r, uri, http.StatusFound) signinURL.RawQuery = q.Encode()
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
} }
func (p *Proxy) registerFwdAuthHandlers() http.Handler { // Callback handles the result of a successful call to the authenticate service
r := httputil.NewRouter() // and is responsible setting returned per-route session.
r.StrictSlash(true)
r.Use(sessions.RetrieveSession(p.sessionStore))
r.Handle("/", p.Verify(false)).Queries("uri", "{uri}").Methods(http.MethodGet)
r.Handle("/verify", p.Verify(true)).Queries("uri", "{uri}").Methods(http.MethodGet)
return r
}
// Verify checks a user's credentials for an arbitrary host. If the user
// is properly authenticated and is authorized to access the supplied host,
// a `200` http status code is returned. If the user is not authenticated, they
// will be redirected to the authenticate service to sign in with their identity
// provider. If the user is unauthorized, a `401` error is returned.
func (p *Proxy) Verify(verifyOnly bool) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
uri, err := urlutil.ParseAndValidateURL(r.FormValue("uri"))
if err != nil || uri.String() == "" {
httputil.ErrorResponse(w, r, httputil.Error("bad verification uri", http.StatusBadRequest, nil))
return
}
if err := p.authenticate(verifyOnly, w, r); err != nil {
return
}
if err := p.authorize(uri.Host, w, r); err != nil {
return
}
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(http.StatusOK)
fmt.Fprintf(w, fmt.Sprintf("Access to %s is allowed.", uri.Host))
})
}
// Callback takes a `redirect_uri` query param that has been hmac'd by the
// authenticate service. Embedded in the `redirect_uri` are query-params
// that tell this handler how to set the per-route user session.
// Callback is responsible for redirecting the user back to the intended
// destination URL and path, as well as to clean up any additional query params
// added by the authenticate service.
func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) { func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) {
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")) redirectURLString := r.FormValue(urlutil.QueryRedirectURI)
if err != nil { encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return return
} }
q := redirectURL.Query() httputil.Redirect(w, r, redirectURLString, http.StatusFound)
// 1. extract the base64 encoded and encrypted JWT from redirect_uri's query params }
encryptedJWT, err := base64.URLEncoding.DecodeString(q.Get("pomerium_jwt"))
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
return
}
q.Del("pomerium_jwt")
q.Del("impersonate_email")
q.Del("impersonate_group")
// saveCallbackSession takes an encrypted per-route session token, and decrypts
// it using the shared service key, then stores it the local session store.
func (p *Proxy) saveCallbackSession(w http.ResponseWriter, r *http.Request, enctoken string) ([]byte, error) {
// 1. extract the base64 encoded and encrypted JWT from query params
encryptedJWT, err := base64.URLEncoding.DecodeString(enctoken)
if err != nil {
return nil, fmt.Errorf("proxy: malfromed callback token: %w", err)
}
// 2. decrypt the JWT using the cipher using the _shared_ secret key // 2. decrypt the JWT using the cipher using the _shared_ secret key
rawJWT, err := cryptutil.Decrypt(p.sharedCipher, encryptedJWT, nil) rawJWT, err := cryptutil.Decrypt(p.sharedCipher, encryptedJWT, nil)
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err)) return nil, fmt.Errorf("proxy: callback token decrypt error: %w", err)
return
} }
// 3. Save the decrypted JWT to the session store directly as a string, without resigning // 3. Save the decrypted JWT to the session store directly as a string, without resigning
if err = p.sessionStore.SaveSession(w, r, rawJWT); err != nil { if err = p.sessionStore.SaveSession(w, r, rawJWT); err != nil {
httputil.ErrorResponse(w, r, err) return nil, fmt.Errorf("proxy: callback session save failure: %w", err)
return
} }
return rawJWT, nil
// if this is a programmatic request, don't strip the tokens before redirect
if redirectURL.Query().Get("pomerium_programmatic_destination_url") != "" {
q.Set("pomerium_jwt", string(rawJWT))
}
redirectURL.RawQuery = q.Encode()
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
} }
// ProgrammaticLogin returns a signed url that can be used to login // ProgrammaticLogin returns a signed url that can be used to login
// using the authenticate service. // using the authenticate service.
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) { func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) {
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")) redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
if err != nil { if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err)) httputil.ErrorResponse(w, r, httputil.Error("malformed redirect uri", http.StatusBadRequest, err))
return return
} }
q := redirectURL.Query() signinURL := *p.authenticateSigninURL
q.Add("pomerium_programmatic_destination_url", urlutil.GetAbsoluteURL(r).String()) callbackURI := urlutil.GetAbsoluteURL(r)
redirectURL.RawQuery = q.Encode() callbackURI.Path = dashboardURL + "/callback/"
response := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String() q := signinURL.Query()
q.Set(urlutil.QueryCallbackURI, callbackURI.String())
q.Set(urlutil.QueryRedirectURI, redirectURI.String())
q.Set(urlutil.QueryIsProgrammatic, "true")
signinURL.RawQuery = q.Encode()
response := urlutil.NewSignedURL(p.SharedKey, &signinURL).String()
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
w.Write([]byte(response)) w.Write([]byte(response))
} }
// ProgrammaticCallback handles a successful call to the authenticate service.
// In addition to returning the individual route session (JWT) it also returns
// the refresh token.
func (p *Proxy) ProgrammaticCallback(w http.ResponseWriter, r *http.Request) {
redirectURLString := r.FormValue(urlutil.QueryRedirectURI)
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect uri", http.StatusBadRequest, err))
return
}
rawJWT, err := p.saveCallbackSession(w, r, encryptedSession)
if err != nil {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
return
}
q := redirectURL.Query()
q.Set(urlutil.QueryPomeriumJWT, string(rawJWT))
q.Set(urlutil.QueryRefreshToken, r.FormValue(urlutil.QueryRefreshToken))
redirectURL.RawQuery = q.Encode()
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
}

View file

@ -11,17 +11,22 @@ import (
"testing" "testing"
"time" "time"
"github.com/pomerium/pomerium/internal/cryptutil"
"github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/encoding" "github.com/pomerium/pomerium/internal/encoding"
"github.com/pomerium/pomerium/internal/encoding/mock" "github.com/pomerium/pomerium/internal/encoding/mock"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/sessions" "github.com/pomerium/pomerium/internal/sessions"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/proxy/clients" "github.com/pomerium/pomerium/proxy/clients"
"github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp"
"gopkg.in/square/go-jose.v2/jwt" "gopkg.in/square/go-jose.v2/jwt"
) )
const goodEncryptionString = "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="
func TestProxy_RobotsTxt(t *testing.T) { func TestProxy_RobotsTxt(t *testing.T) {
proxy := Proxy{} proxy := Proxy{}
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil) req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
@ -189,12 +194,12 @@ func TestProxy_SignOut(t *testing.T) {
t.Fatal(err) t.Fatal(err)
} }
postForm := url.Values{} postForm := url.Values{}
postForm.Add("redirect_uri", tt.redirectURL) postForm.Add(urlutil.QueryRedirectURI, tt.redirectURL)
uri := &url.URL{Path: "/"} uri := &url.URL{Path: "/"}
query, _ := url.ParseQuery(uri.RawQuery) query, _ := url.ParseQuery(uri.RawQuery)
if tt.verb == http.MethodGet { if tt.verb == http.MethodGet {
query.Add("redirect_uri", tt.redirectURL) query.Add(urlutil.QueryRedirectURI, tt.redirectURL)
uri.RawQuery = query.Encode() uri.RawQuery = query.Encode()
} }
r := httptest.NewRequest(tt.verb, uri.String(), bytes.NewBufferString(postForm.Encode())) r := httptest.NewRequest(tt.verb, uri.String(), bytes.NewBufferString(postForm.Encode()))
@ -217,87 +222,7 @@ func uriParseHelper(s string) *url.URL {
} }
return uri return uri
} }
func TestProxy_VerifyWithMiddleware(t *testing.T) {
t.Parallel()
opts := testOptions(t)
tests := []struct {
name string
options config.Options
ctxError error
method string
qp string
path string
verifyURI string
cipher encoding.MarshalUnmarshaler
sessionStore sessions.SessionStore
authorizer clients.Authorizer
wantStatus int
wantBody string
}{
{"good", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
{"good verify only", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
{"bad naked domain uri", opts, nil, http.MethodGet, "", "/", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
{"bad naked domain uri verify only", opts, nil, http.MethodGet, "", "/verify", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
{"bad empty verification uri", opts, nil, http.MethodGet, "", "/", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
{"bad empty verification uri verify only", opts, nil, http.MethodGet, "", "/verify", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
{"not authorized", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized verify endpoint", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"},
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, "", "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, "", "/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"},
{"not authorized because of error", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"error\":\"authz error\"}\n"},
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(tt.options)
if err != nil {
t.Fatal(err)
}
p.encoder = tt.cipher
p.sessionStore = tt.sessionStore
p.AuthorizeClient = tt.authorizer
p.UpdateOptions(tt.options)
uri := &url.URL{Path: tt.path}
queryString := uri.Query()
queryString.Set("donstrip", "ok")
if tt.qp != "" {
queryString.Set(tt.qp, "true")
}
if tt.verifyURI != "" {
queryString.Set("uri", tt.verifyURI)
}
uri.RawQuery = queryString.Encode()
r := httptest.NewRequest(tt.method, uri.String(), nil)
state, err := tt.sessionStore.LoadSession(r)
if err != nil {
t.Fatal(err)
}
ctx := r.Context()
ctx = sessions.NewContext(ctx, state, tt.ctxError)
r = r.WithContext(ctx)
r.Header.Set("Authorization", "Bearer blah")
r.Header.Set("Accept", "application/json")
w := httptest.NewRecorder()
router := p.registerFwdAuthHandlers()
router.ServeHTTP(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", w.Body.String())
}
if tt.wantBody != "" {
body := w.Body.String()
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
t.Errorf("wrong body\n%s", diff)
}
}
})
}
}
func TestProxy_Callback(t *testing.T) { func TestProxy_Callback(t *testing.T) {
t.Parallel() t.Parallel()
opts := testOptions(t) opts := testOptions(t)
@ -311,7 +236,8 @@ func TestProxy_Callback(t *testing.T) {
host string host string
path string path string
qp map[string]string headers map[string]string
qp map[string]string
cipher encoding.MarshalUnmarshaler cipher encoding.MarshalUnmarshaler
sessionStore sessions.SessionStore sessionStore sessions.SessionStore
@ -319,11 +245,12 @@ func TestProxy_Callback(t *testing.T) {
wantStatus int wantStatus int
wantBody string wantBody string
}{ }{
{"good", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_programmatic_destination_url": "ok", "pomerium_jwt": "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""}, {"good", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"good programmatic", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad save session", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusInternalServerError, ""}, {"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad base64", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"bad save session", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""}, {"bad base64", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -345,13 +272,21 @@ func TestProxy_Callback(t *testing.T) {
uri := &url.URL{Path: "/"} uri := &url.URL{Path: "/"}
if tt.qp != nil { if tt.qp != nil {
qu := uri.Query() qu := uri.Query()
qu.Set("redirect_uri", redirectURI.String()) for k, v := range tt.qp {
qu.Set(k, v)
}
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
uri.RawQuery = qu.Encode() uri.RawQuery = qu.Encode()
} }
r := httptest.NewRequest(tt.method, uri.String(), nil) r := httptest.NewRequest(tt.method, uri.String(), nil)
r.Header.Set("Accept", "application/json") r.Header.Set("Accept", "application/json")
if len(tt.headers) != 0 {
for k, v := range tt.headers {
r.Header.Set(k, v)
}
}
w := httptest.NewRecorder() w := httptest.NewRecorder()
p.Callback(w, r) p.Callback(w, r)
@ -379,18 +314,20 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
method string method string
scheme string scheme string
host string host string
path string path string
qp map[string]string headers map[string]string
qp map[string]string
wantStatus int wantStatus int
wantBody string wantBody string
}{ }{
{"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "http://localhost"}, http.StatusOK, ""}, {"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusOK, ""},
{"router miss, bad redirect_uri query", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"bad_redirect_uri": "http://localhost"}, http.StatusNotFound, ""}, {"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusOK, ""},
{"bad redirect_uri missing scheme", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "localhost"}, http.StatusBadRequest, "{\"error\":\"malformed redirect_uri\"}\n"}, {"router miss, bad redirect_uri query", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{"bad_redirect_uri": "http://localhost"}, http.StatusNotFound, ""},
{"bad http method", opts, http.MethodPost, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "http://localhost"}, http.StatusMethodNotAllowed, ""}, {"bad redirect_uri missing scheme", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "localhost"}, http.StatusBadRequest, "{\"error\":\"malformed redirect uri\"}\n"},
{"bad http method", opts, http.MethodPost, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusMethodNotAllowed, ""},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -428,3 +365,83 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
}) })
} }
} }
func TestProxy_ProgrammaticCallback(t *testing.T) {
t.Parallel()
opts := testOptions(t)
tests := []struct {
name string
options config.Options
method string
redirectURI string
headers map[string]string
qp map[string]string
cipher encoding.MarshalUnmarshaler
sessionStore sessions.SessionStore
authorizer clients.Authorizer
wantStatus int
wantBody string
}{
{"good", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"good programmatic", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
{"bad decrypt", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad save session", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"bad base64", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
{"malformed redirect", opts, http.MethodGet, "http://pomerium.io/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
p, err := New(tt.options)
if err != nil {
t.Fatal(err)
}
p.encoder = tt.cipher
p.sessionStore = tt.sessionStore
p.AuthorizeClient = tt.authorizer
p.UpdateOptions(tt.options)
redirectURI, _ := url.Parse(tt.redirectURI)
queryString := redirectURI.Query()
for k, v := range tt.qp {
queryString.Set(k, v)
}
redirectURI.RawQuery = queryString.Encode()
uri := &url.URL{Path: "/"}
if tt.qp != nil {
qu := uri.Query()
for k, v := range tt.qp {
qu.Set(k, v)
}
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
uri.RawQuery = qu.Encode()
}
r := httptest.NewRequest(tt.method, uri.String(), nil)
r.Header.Set("Accept", "application/json")
if len(tt.headers) != 0 {
for k, v := range tt.headers {
r.Header.Set(k, v)
}
}
w := httptest.NewRecorder()
p.ProgrammaticCallback(w, r)
if status := w.Code; status != tt.wantStatus {
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
t.Errorf("\n%+v", w.Body.String())
}
if tt.wantBody != "" {
body := w.Body.String()
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
t.Errorf("wrong body\n%s", diff)
}
}
})
}
}

View file

@ -30,39 +30,37 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession") ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession")
defer span.End() defer span.End()
if err := p.authenticate(false, w, r.WithContext(ctx)); err != nil {
if s, err := sessions.FromContext(ctx); err != nil {
log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session") log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session")
p.sessionStore.ClearSession(w, r)
if s != nil && s.Programmatic {
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
return
}
signinURL := *p.authenticateSigninURL
q := signinURL.Query()
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
signinURL.RawQuery = q.Encode()
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
return return
} }
p.addPomeriumHeaders(w, r)
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
}) })
} }
// authenticate authenticates a user and sets an appropriate response type, func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
// redirect to authenticate or error handler depending on if err on failure is set.
func (p *Proxy) authenticate(errOnFailure bool, w http.ResponseWriter, r *http.Request) error {
s, err := sessions.FromContext(r.Context()) s, err := sessions.FromContext(r.Context())
if err != nil { if err == nil && s != nil {
p.sessionStore.ClearSession(w, r) r.Header.Set(HeaderUserID, s.Subject)
r.Header.Set(HeaderEmail, s.RequestEmail())
if errOnFailure || (s != nil && s.Programmatic) { r.Header.Set(HeaderGroups, s.RequestGroups())
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err)) w.Header().Set(HeaderUserID, s.Subject)
return err w.Header().Set(HeaderEmail, s.RequestEmail())
} w.Header().Set(HeaderGroups, s.RequestGroups())
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
httputil.Redirect(w, r, uri.String(), http.StatusFound)
return err
} }
// add pomerium's headers to the downstream request
r.Header.Set(HeaderUserID, s.Subject)
r.Header.Set(HeaderEmail, s.RequestEmail())
r.Header.Set(HeaderGroups, s.RequestGroups())
// and upstream
w.Header().Set(HeaderUserID, s.Subject)
w.Header().Set(HeaderEmail, s.RequestEmail())
w.Header().Set(HeaderGroups, s.RequestGroups())
return nil
} }
// AuthorizeSession is middleware to enforce a user is authorized for a request // AuthorizeSession is middleware to enforce a user is authorized for a request

View file

@ -20,7 +20,6 @@ func TestProxy_AuthenticateSession(t *testing.T) {
t.Parallel() t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprint(w, http.StatusText(http.StatusOK)) fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
@ -70,7 +69,6 @@ func TestProxy_AuthorizeSession(t *testing.T) {
t.Parallel() t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprint(w, http.StatusText(http.StatusOK)) fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
@ -131,7 +129,6 @@ func TestProxy_SignRequest(t *testing.T) {
t.Parallel() t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
fmt.Fprint(w, http.StatusText(http.StatusOK)) fmt.Fprint(w, http.StatusText(http.StatusOK))
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
}) })
@ -184,7 +181,6 @@ func TestProxy_SetResponseHeaders(t *testing.T) {
t.Parallel() t.Parallel()
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/plain; charset=utf-8") w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("X-Content-Type-Options", "nosniff")
var sb strings.Builder var sb strings.Builder
for k, v := range r.Header { for k, v := range r.Header {
k = strings.ToLower(k) k = strings.ToLower(k)

View file

@ -187,7 +187,7 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error {
if opts.ForwardAuthURL != nil { if opts.ForwardAuthURL != nil {
// if a forward auth endpoint is set, register its handlers // if a forward auth endpoint is set, register its handlers
h := r.Host(opts.ForwardAuthURL.Host).Subrouter() h := r.Host(opts.ForwardAuthURL.Hostname()).Subrouter()
h.PathPrefix("/").Handler(p.registerFwdAuthHandlers()) h.PathPrefix("/").Handler(p.registerFwdAuthHandlers())
} }

View file

@ -82,7 +82,9 @@ def main():
# initial login to make sure we have our credential # initial login to make sure we have our credential
if args.login: if args.login:
dst = urllib.parse.urlparse(args.dst) dst = urllib.parse.urlparse(args.dst)
query_params = {"redirect_uri": "http://{}:{}".format(args.server, args.port)} query_params = {
"pomerium_redirect_uri": "http://{}:{}".format(args.server, args.port)
}
enc_query_params = urllib.parse.urlencode(query_params) enc_query_params = urllib.parse.urlencode(query_params)
dst_login = "{}://{}{}?{}".format( dst_login = "{}://{}{}?{}".format(
dst.scheme, dst.hostname, "/.pomerium/api/v1/login", enc_query_params, dst.scheme, dst.hostname, "/.pomerium/api/v1/login", enc_query_params,