mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-12 16:47:41 +02:00
proxy: fix forward auth, request signing
Signed-off-by: Bobby DeSimone <bobbydesimone@gmail.com>
This commit is contained in:
parent
ec9607d1d5
commit
0f6a9d7f1d
32 changed files with 928 additions and 522 deletions
|
@ -24,10 +24,10 @@ import (
|
||||||
// CSPHeaders are the content security headers added to the service's handlers
|
// CSPHeaders are the content security headers added to the service's handlers
|
||||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src
|
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src
|
||||||
var CSPHeaders = map[string]string{
|
var CSPHeaders = map[string]string{
|
||||||
"Content-Security-Policy": "default-src 'none'; style-src 'self'" +
|
"Content-Security-Policy": "default-src 'none'; style-src " +
|
||||||
" 'sha256-z9MsgkMbQjRSLxzAfN55jB3a9pP0PQ4OHFH8b4iDP6s=' " +
|
"'sha256-spMkVDoBBY86p0RC1fBYwdnGyMypJM8eG57+p3VASyk=' " +
|
||||||
" 'sha256-qnVkQSG7pWu17hBhIw0kCpfEB3XGvt0mNRa6+uM6OUU=' " +
|
"'sha256-qnVkQSG7pWu17hBhIw0kCpfEB3XGvt0mNRa6+uM6OUU=' " +
|
||||||
" 'sha256-qOdRsNZhtR+htazbcy7guQl3Cn1cqOw1FcE4d3llae0='; " +
|
"'sha256-qOdRsNZhtR+htazbcy7guQl3Cn1cqOw1FcE4d3llae0=';" +
|
||||||
"img-src 'self';",
|
"img-src 'self';",
|
||||||
"Referrer-Policy": "Same-origin",
|
"Referrer-Policy": "Same-origin",
|
||||||
}
|
}
|
||||||
|
@ -54,7 +54,8 @@ func (a *Authenticate) Handler() http.Handler {
|
||||||
v := r.PathPrefix("/.pomerium").Subrouter()
|
v := r.PathPrefix("/.pomerium").Subrouter()
|
||||||
c := cors.New(cors.Options{
|
c := cors.New(cors.Options{
|
||||||
AllowOriginRequestFunc: func(r *http.Request, _ string) bool {
|
AllowOriginRequestFunc: func(r *http.Request, _ string) bool {
|
||||||
return middleware.ValidateRedirectURI(r, a.sharedKey)
|
err := middleware.ValidateRequestURL(r, a.sharedKey)
|
||||||
|
return err == nil
|
||||||
},
|
},
|
||||||
AllowCredentials: true,
|
AllowCredentials: true,
|
||||||
AllowedHeaders: []string{"*"},
|
AllowedHeaders: []string{"*"},
|
||||||
|
@ -111,71 +112,84 @@ func (a *Authenticate) refresh(w http.ResponseWriter, r *http.Request, s *sessio
|
||||||
// RobotsTxt handles the /robots.txt route.
|
// RobotsTxt handles the /robots.txt route.
|
||||||
func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
|
func (a *Authenticate) RobotsTxt(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SignIn handles to authenticating a user.
|
// SignIn handles to authenticating a user.
|
||||||
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) {
|
||||||
// grab and parse our redirect_uri
|
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
||||||
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// create a clone of the redirect URI, unless this is a programmatic request
|
|
||||||
// in which case we will redirect back to proxy's callback endpoint
|
|
||||||
callbackURL, _ := urlutil.DeepCopy(redirectURL)
|
|
||||||
|
|
||||||
q := redirectURL.Query()
|
jwtAudience := []string{a.RedirectURL.Hostname(), redirectURL.Hostname()}
|
||||||
|
|
||||||
if q.Get("pomerium_programmatic_destination_url") != "" {
|
var callbackURL *url.URL
|
||||||
callbackURL, err = urlutil.ParseAndValidateURL(q.Get("pomerium_programmatic_destination_url"))
|
// if the callback is explicitly set, set it and add an additional audience
|
||||||
|
if callbackStr := r.FormValue(urlutil.QueryCallbackURI); callbackStr != "" {
|
||||||
|
callbackURL, err = urlutil.ParseAndValidateURL(callbackStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
jwtAudience = append(jwtAudience, callbackURL.Hostname())
|
||||||
|
} else {
|
||||||
|
// otherwise, assume callback is the same host as redirect
|
||||||
|
callbackURL, _ = urlutil.DeepCopy(redirectURL)
|
||||||
|
callbackURL.Path = "/.pomerium/callback/"
|
||||||
|
callbackURL.RawQuery = ""
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// add an additional claim for the forward-auth host, if set
|
||||||
|
if fwdAuth := r.FormValue(urlutil.QueryForwardAuth); fwdAuth != "" {
|
||||||
|
jwtAudience = append(jwtAudience, fwdAuth)
|
||||||
|
}
|
||||||
|
|
||||||
s, err := sessions.FromContext(r.Context())
|
s, err := sessions.FromContext(r.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.SetImpersonation(q.Get("impersonate_email"), q.Get("impersonate_group"))
|
|
||||||
|
|
||||||
newSession := s.NewSession(a.RedirectURL.Host, []string{a.RedirectURL.Host, callbackURL.Host})
|
s.SetImpersonation(r.FormValue(urlutil.QueryImpersonateEmail), r.FormValue(urlutil.QueryImpersonateGroups))
|
||||||
if q.Get("pomerium_programmatic_destination_url") != "" {
|
|
||||||
|
newSession := s.NewSession(a.RedirectURL.Host, jwtAudience)
|
||||||
|
|
||||||
|
callbackParams := callbackURL.Query()
|
||||||
|
|
||||||
|
if r.FormValue(urlutil.QueryIsProgrammatic) == "true" {
|
||||||
newSession.Programmatic = true
|
newSession.Programmatic = true
|
||||||
encSession, err := a.encryptedEncoder.Marshal(newSession)
|
encSession, err := a.encryptedEncoder.Marshal(newSession)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
q.Set("pomerium_refresh_token", string(encSession))
|
callbackParams.Set(urlutil.QueryRefreshToken, string(encSession))
|
||||||
|
callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
|
||||||
}
|
}
|
||||||
|
|
||||||
// sign the route session, as a JWT
|
// sign the route session, as a JWT
|
||||||
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession(DefaultSessionDuration))
|
signedJWT, err := a.sharedEncoder.Marshal(newSession.RouteSession(DefaultSessionDuration))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// encrypt our route-based token JWT avoiding any accidental logging
|
// encrypt our route-based token JWT avoiding any accidental logging
|
||||||
encryptedJWT := cryptutil.Encrypt(a.sharedCipher, signedJWT, nil)
|
encryptedJWT := cryptutil.Encrypt(a.sharedCipher, signedJWT, nil)
|
||||||
// base64 our encrypted payload for URL-friendlyness
|
// base64 our encrypted payload for URL-friendlyness
|
||||||
encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT)
|
encodedJWT := base64.URLEncoding.EncodeToString(encryptedJWT)
|
||||||
|
|
||||||
// add our encoded and encrypted route-session JWT to a query param
|
// add our encoded and encrypted route-session JWT to a query param
|
||||||
q.Set("pomerium_jwt", encodedJWT)
|
callbackParams.Set(urlutil.QuerySessionEncrypted, encodedJWT)
|
||||||
|
callbackParams.Set(urlutil.QueryRedirectURI, redirectURL.String())
|
||||||
redirectURL.RawQuery = q.Encode()
|
callbackURL.RawQuery = callbackParams.Encode()
|
||||||
|
|
||||||
callbackURL.Path = "/.pomerium/callback"
|
|
||||||
|
|
||||||
// build our hmac-d redirect URL with our session, pointing back to the
|
// build our hmac-d redirect URL with our session, pointing back to the
|
||||||
// proxy's callback URL which is responsible for setting our new route-session
|
// proxy's callback URL which is responsible for setting our new route-session
|
||||||
uri := urlutil.SignedRedirectURL(a.sharedKey, callbackURL, redirectURL)
|
uri := urlutil.NewSignedURL(a.sharedKey, callbackURL)
|
||||||
httputil.Redirect(w, r, uri.String(), http.StatusFound)
|
httputil.Redirect(w, r, uri.String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -193,7 +207,7 @@ func (a *Authenticate) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err))
|
httputil.ErrorResponse(w, r, httputil.Error("could not revoke user session", http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
|
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
|
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"github.com/pomerium/pomerium/internal/identity"
|
"github.com/pomerium/pomerium/internal/identity"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
"github.com/pomerium/pomerium/internal/templates"
|
"github.com/pomerium/pomerium/internal/templates"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"golang.org/x/crypto/chacha20poly1305"
|
"golang.org/x/crypto/chacha20poly1305"
|
||||||
|
@ -108,14 +109,17 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
encoder encoding.MarshalUnmarshaler
|
encoder encoding.MarshalUnmarshaler
|
||||||
wantCode int
|
wantCode int
|
||||||
}{
|
}{
|
||||||
{"good", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"session not valid", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"session not valid", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(-10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"bad redirect uri query", "", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
{"bad redirect uri query", "", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "^^^"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||||
{"bad marshal", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
{"bad marshal", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||||
{"session error", "https", "corp.example.example", map[string]string{"state": "example"}, &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
{"session error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{LoadError: errors.New("error")}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||||
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
{"encrypted encoder error", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
{"encrypted encoder error", "https", "corp.example.example", map[string]string{urlutil.QueryRedirectURI: "https://dst.some.example/", urlutil.QueryCallbackURI: "https://some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{MarshalError: errors.New("error")}, http.StatusBadRequest},
|
||||||
{"good with different programmatic redirect", "https", "corp.example.example", map[string]string{"state": "example", "pomerium_programmatic_destination_url": "some.example"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
{"good with callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "https://some.example/", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
|
{"bad callback uri set", "https", "corp.example.example", map[string]string{urlutil.QueryCallbackURI: "^", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusBadRequest},
|
||||||
|
{"good programmatic request", "https", "corp.example.example", map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
|
{"good additional audience", "https", "corp.example.example", map[string]string{urlutil.QueryForwardAuth: "x.y.z", urlutil.QueryRedirectURI: "https://dst.some.example/"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@pomerium.io", AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Second)}}}, identity.MockProvider{}, &mock.Encoder{}, http.StatusFound},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -139,8 +143,7 @@ func TestAuthenticate_SignIn(t *testing.T) {
|
||||||
queryString.Set(k, v)
|
queryString.Set(k, v)
|
||||||
}
|
}
|
||||||
uri.RawQuery = queryString.Encode()
|
uri.RawQuery = queryString.Encode()
|
||||||
|
r := httptest.NewRequest(http.MethodGet, uri.String(), nil)
|
||||||
r := httptest.NewRequest(http.MethodGet, "/?redirect_uri="+uri.String(), nil)
|
|
||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
state, err := tt.session.LoadSession(r)
|
state, err := tt.session.LoadSession(r)
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
@ -195,7 +198,7 @@ func TestAuthenticate_SignOut(t *testing.T) {
|
||||||
params, _ := url.ParseQuery(u.RawQuery)
|
params, _ := url.ParseQuery(u.RawQuery)
|
||||||
params.Add("sig", tt.sig)
|
params.Add("sig", tt.sig)
|
||||||
params.Add("ts", tt.ts)
|
params.Add("ts", tt.ts)
|
||||||
params.Add("redirect_uri", tt.redirectURL)
|
params.Add(urlutil.QueryRedirectURI, tt.redirectURL)
|
||||||
u.RawQuery = params.Encode()
|
u.RawQuery = params.Encode()
|
||||||
r := httptest.NewRequest(tt.method, u.String(), nil)
|
r := httptest.NewRequest(tt.method, u.String(), nil)
|
||||||
state, _ := tt.sessionStore.LoadSession(r)
|
state, _ := tt.sessionStore.LoadSession(r)
|
||||||
|
@ -307,24 +310,26 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
||||||
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
|
fmt.Fprintln(w, "RVSI FILIVS CAISAR")
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
headers map[string]string
|
||||||
|
|
||||||
session sessions.SessionStore
|
session sessions.SessionStore
|
||||||
ctxError error
|
ctxError error
|
||||||
provider identity.Authenticator
|
provider identity.Authenticator
|
||||||
|
|
||||||
wantStatus int
|
wantStatus int
|
||||||
}{
|
}{
|
||||||
{"good", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
|
{"good", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, nil, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusOK},
|
||||||
{"invalid session", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
{"invalid session", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, errors.New("hi"), identity.MockProvider{}, http.StatusFound},
|
||||||
{"good refresh expired", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
{"good refresh expired", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
||||||
{"expired,refresh error", &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
|
{"expired,refresh error", nil, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusFound},
|
||||||
{"expired,save error", &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
{"expired,save error", nil, &sessions.MockSessionStore{SaveError: errors.New("error"), Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshResponse: sessions.State{AccessToken: &oauth2.Token{Expiry: time.Now().Add(10 * time.Minute)}}}, http.StatusFound},
|
||||||
|
{"expired XHR,refresh error", map[string]string{"X-Requested-With": "XmlHttpRequest"}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, sessions.ErrExpired, identity.MockProvider{RefreshError: errors.New("error")}, http.StatusUnauthorized},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -347,7 +352,11 @@ func TestAuthenticate_SessionValidatorMiddleware(t *testing.T) {
|
||||||
r = r.WithContext(ctx)
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
|
if len(tt.headers) != 0 {
|
||||||
|
for k, v := range tt.headers {
|
||||||
|
r.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
got := a.VerifySession(fn)
|
got := a.VerifySession(fn)
|
||||||
|
|
|
@ -161,7 +161,7 @@ func newGlobalRouter(o *config.Options) *mux.Router {
|
||||||
if len(o.Headers) != 0 {
|
if len(o.Headers) != 0 {
|
||||||
mux.Use(middleware.SetHeaders(o.Headers))
|
mux.Use(middleware.SetHeaders(o.Headers))
|
||||||
}
|
}
|
||||||
mux.Use(log.ForwardedAddrHandler("fwd_ip"))
|
mux.Use(log.HeadersHandler(httputil.HeadersXForwarded))
|
||||||
mux.Use(log.RemoteAddrHandler("ip"))
|
mux.Use(log.RemoteAddrHandler("ip"))
|
||||||
mux.Use(log.UserAgentHandler("user_agent"))
|
mux.Use(log.UserAgentHandler("user_agent"))
|
||||||
mux.Use(log.RefererHandler("referer"))
|
mux.Use(log.RefererHandler("referer"))
|
||||||
|
|
|
@ -189,7 +189,6 @@ var defaultOptions = Options{
|
||||||
CookieName: "_pomerium",
|
CookieName: "_pomerium",
|
||||||
DefaultUpstreamTimeout: 30 * time.Second,
|
DefaultUpstreamTimeout: 30 * time.Second,
|
||||||
Headers: map[string]string{
|
Headers: map[string]string{
|
||||||
"X-Content-Type-Options": "nosniff",
|
|
||||||
"X-Frame-Options": "SAMEORIGIN",
|
"X-Frame-Options": "SAMEORIGIN",
|
||||||
"X-XSS-Protection": "1; mode=block",
|
"X-XSS-Protection": "1; mode=block",
|
||||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
||||||
|
|
|
@ -226,7 +226,6 @@ func TestOptionsFromViper(t *testing.T) {
|
||||||
CookieHTTPOnly: true,
|
CookieHTTPOnly: true,
|
||||||
Headers: map[string]string{
|
Headers: map[string]string{
|
||||||
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
"Strict-Transport-Security": "max-age=31536000; includeSubDomains; preload",
|
||||||
"X-Content-Type-Options": "nosniff",
|
|
||||||
"X-Frame-Options": "SAMEORIGIN",
|
"X-Frame-Options": "SAMEORIGIN",
|
||||||
"X-XSS-Protection": "1; mode=block",
|
"X-XSS-Protection": "1; mode=block",
|
||||||
}},
|
}},
|
||||||
|
|
|
@ -20,7 +20,7 @@ For example:
|
||||||
```bash
|
```bash
|
||||||
$ curl "https://httpbin.example.com/.pomerium/api/v1/login?redirect_uri=http://localhost:8000"
|
$ curl "https://httpbin.example.com/.pomerium/api/v1/login?redirect_uri=http://localhost:8000"
|
||||||
|
|
||||||
https://authenticate.example.com/.pomerium/sign_in?redirect_uri=http%3A%2F%2Flocalhost%3Fpomerium_programmatic_destination_url%3Dhttps%253A%252F%252Fhttpbin.corp.example%252F.pomerium%252Fapi%252Fv1%252Flogin%253Fredirect_uri%253Dhttp%253A%252F%252Flocalhost&sig=hsLuzJctmgsN4kbMeQL16fe_FahjDBEcX0_kPYfg8bs%3D&ts=1573262981
|
https://authenticate.example.com/.pomerium/sign_in?redirect_uri=http%3A%2F%2Flocalhost%3Fpomerium_callback_uri%3Dhttps%253A%252F%252Fhttpbin.corp.example%252F.pomerium%252Fapi%252Fv1%252Flogin%253Fredirect_uri%253Dhttp%253A%252F%252Flocalhost&sig=hsLuzJctmgsN4kbMeQL16fe_FahjDBEcX0_kPYfg8bs%3D&ts=1573262981
|
||||||
```
|
```
|
||||||
|
|
||||||
### Callback handler
|
### Callback handler
|
||||||
|
|
|
@ -21,5 +21,5 @@ spec:
|
||||||
paths:
|
paths:
|
||||||
- path: /
|
- path: /
|
||||||
backend:
|
backend:
|
||||||
serviceName: dashboard-kubernetes-dashboard
|
serviceName: helm-dashboard-kubernetes-dashboard
|
||||||
servicePort: https
|
servicePort: https
|
||||||
|
|
5
go.sum
5
go.sum
|
@ -176,8 +176,6 @@ github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFR
|
||||||
github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik=
|
github.com/rs/cors v1.7.0 h1:+88SsELBHx5r+hZ8TCkggzSstaWNbDvThkVK8H6f9ik=
|
||||||
github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU=
|
github.com/rs/cors v1.7.0/go.mod h1:gFx+x8UowdsKA9AchylcLynDq+nNFfI8FkUZdN/jGCU=
|
||||||
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
||||||
github.com/rs/zerolog v1.14.3 h1:4EGfSkR2hJDB0s3oFfrlPqjU1e4WLncergLil3nEKW0=
|
|
||||||
github.com/rs/zerolog v1.14.3/go.mod h1:3WXPzbXEEliJ+a6UFE4vhIxV8qR1EML6ngzP9ug4eYg=
|
|
||||||
github.com/rs/zerolog v1.16.0 h1:AaELmZdcJHT8m6oZ5py4213cdFK8XGXkB3dFdAQ+P7Q=
|
github.com/rs/zerolog v1.16.0 h1:AaELmZdcJHT8m6oZ5py4213cdFK8XGXkB3dFdAQ+P7Q=
|
||||||
github.com/rs/zerolog v1.16.0/go.mod h1:9nvC1axdVrAHcu/s9taAVfBuIdTZLVQmKQyvrUjF5+I=
|
github.com/rs/zerolog v1.16.0/go.mod h1:9nvC1axdVrAHcu/s9taAVfBuIdTZLVQmKQyvrUjF5+I=
|
||||||
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo=
|
||||||
|
@ -221,8 +219,6 @@ golang.org/x/crypto v0.0.0-20180904163835-0709b304e793/go.mod h1:6SG95UA2DQfeDnf
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf h1:fnPsqIDRbCSgumaMCRpoIoF2s4qxv0xSSS0BVZUE/ss=
|
|
||||||
golang.org/x/crypto v0.0.0-20191029031824-8986dd9e96cf/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
|
||||||
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4 h1:PDpCLFAH/YIX0QpHPf2eO7L4rC2OOirBrKtXTLLiNTY=
|
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4 h1:PDpCLFAH/YIX0QpHPf2eO7L4rC2OOirBrKtXTLLiNTY=
|
||||||
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
golang.org/x/crypto v0.0.0-20191106202628-ed6320f186d4/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
|
@ -306,7 +302,6 @@ golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3
|
||||||
golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
golang.org/x/tools v0.0.0-20190312151545-0bb0c0a6e846/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||||
golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
golang.org/x/tools v0.0.0-20190312170243-e65039ee4138/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||||
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||||
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
|
||||||
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
golang.org/x/tools v0.0.0-20190506145303-2d16b83fe98c/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||||
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||||
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
golang.org/x/tools v0.0.0-20190606124116-d0a3d012864b/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||||
|
|
|
@ -1,9 +0,0 @@
|
||||||
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
|
||||||
|
|
||||||
const (
|
|
||||||
// HeaderPomeriumResponse is set when pomerium itself creates a response,
|
|
||||||
// as opposed to the downstream application and can be used to distinguish
|
|
||||||
// between an application error, and a pomerium related error when debugging.
|
|
||||||
// Especially useful when working with single page apps (SPA).
|
|
||||||
HeaderPomeriumResponse = "x-pomerium-intercepted-response"
|
|
||||||
)
|
|
|
@ -79,6 +79,8 @@ func ErrorResponse(w http.ResponseWriter, r *http.Request, e error) {
|
||||||
writeJSONResponse(w, statusCode, response)
|
writeJSONResponse(w, statusCode, response)
|
||||||
} else {
|
} else {
|
||||||
w.WriteHeader(statusCode)
|
w.WriteHeader(statusCode)
|
||||||
|
w.Header().Set("Content-Type", "text/html")
|
||||||
|
|
||||||
t := struct {
|
t := struct {
|
||||||
Code int
|
Code int
|
||||||
Title string
|
Title string
|
||||||
|
|
57
internal/httputil/headers.go
Normal file
57
internal/httputil/headers.go
Normal file
|
@ -0,0 +1,57 @@
|
||||||
|
package httputil // import "github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// HeaderPomeriumResponse is set when pomerium itself creates a response,
|
||||||
|
// as opposed to the downstream application and can be used to distinguish
|
||||||
|
// between an application error, and a pomerium related error when debugging.
|
||||||
|
// Especially useful when working with single page apps (SPA).
|
||||||
|
HeaderPomeriumResponse = "x-pomerium-intercepted-response"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HeadersContentSecurityPolicy are the content security headers added to the service's handlers
|
||||||
|
// by default includes profile photo exceptions for supported identity providers.
|
||||||
|
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Security-Policy/script-src
|
||||||
|
var HeadersContentSecurityPolicy = map[string]string{
|
||||||
|
"Content-Security-Policy": "default-src 'none'; style-src 'self'; img-src *;",
|
||||||
|
"Referrer-Policy": "Same-origin",
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward headers contains information from the client-facing side of proxy
|
||||||
|
// servers that is altered or lost when a proxy is involved in the path of the
|
||||||
|
// request.
|
||||||
|
//
|
||||||
|
// https://tools.ietf.org/html/rfc7239
|
||||||
|
// https://en.wikipedia.org/wiki/X-Forwarded-For
|
||||||
|
const (
|
||||||
|
HeaderForwardedFor = "X-Forwarded-For"
|
||||||
|
HeaderForwardedHost = "X-Forwarded-Host"
|
||||||
|
HeaderForwardedMethod = "X-Forwarded-Method" // traefik
|
||||||
|
HeaderForwardedPort = "X-Forwarded-Port"
|
||||||
|
HeaderForwardedProto = "X-Forwarded-Proto"
|
||||||
|
HeaderForwardedServer = "X-Forwarded-Server"
|
||||||
|
HeaderForwardedURI = "X-Forwarded-Uri" // traefik
|
||||||
|
HeaderOriginalMethod = "X-Original-Method" // nginx
|
||||||
|
HeaderOriginalURL = "X-Original-Url" // nginx
|
||||||
|
HeaderRealIP = "X-Real-Ip"
|
||||||
|
HeaderSentFrom = "X-Sent-From"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HeadersXForwarded is the slice of the header keys used to contain information
|
||||||
|
// from the client-facing side of proxy servers that is altered or lost when a
|
||||||
|
// proxy is involved in the path of the request.
|
||||||
|
//
|
||||||
|
// https://tools.ietf.org/html/rfc7239
|
||||||
|
// https://en.wikipedia.org/wiki/X-Forwarded-For
|
||||||
|
var HeadersXForwarded = []string{
|
||||||
|
HeaderForwardedFor,
|
||||||
|
HeaderForwardedHost,
|
||||||
|
HeaderForwardedMethod,
|
||||||
|
HeaderForwardedPort,
|
||||||
|
HeaderForwardedProto,
|
||||||
|
HeaderForwardedServer,
|
||||||
|
HeaderForwardedURI,
|
||||||
|
HeaderOriginalMethod,
|
||||||
|
HeaderOriginalURL,
|
||||||
|
HeaderRealIP,
|
||||||
|
HeaderSentFrom,
|
||||||
|
}
|
|
@ -113,7 +113,7 @@ func (p *Provider) Authenticate(ctx context.Context, code string) (*sessions.Sta
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s, err := sessions.NewStateFromTokens(idToken, oauth2Token, p.RedirectURL.Host)
|
s, err := sessions.NewStateFromTokens(idToken, oauth2Token, p.RedirectURL.Hostname())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,6 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
|
"github.com/pomerium/pomerium/internal/middleware/responsewriter"
|
||||||
|
@ -172,20 +171,21 @@ func AccessHandler(f func(r *http.Request, status, size int, duration time.Durat
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ForwardedAddrHandler returns the client IP address from a request. If a
|
// HeadersHandler adds the provided set of header keys to the log context.
|
||||||
// request goes through multiple proxies, the IP addresses of each successive
|
//
|
||||||
// proxy is listed. This means, the right-most IP address is the IP address of
|
// https://tools.ietf.org/html/rfc7239
|
||||||
// the most recent proxy and the left-most IP address is the IP address of the
|
// https://en.wikipedia.org/wiki/X-Forwarded-For
|
||||||
// originating client.
|
|
||||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For
|
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/X-Forwarded-For
|
||||||
func ForwardedAddrHandler(fieldKey string) func(next http.Handler) http.Handler {
|
func HeadersHandler(headers []string) func(next http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
if ra := r.Header.Get("X-Forwarded-For"); ra != "" {
|
for _, key := range headers {
|
||||||
log := zerolog.Ctx(r.Context())
|
if values := r.Header[key]; len(values) != 0 {
|
||||||
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
log := zerolog.Ctx(r.Context())
|
||||||
return c.Strs(fieldKey, strings.Split(ra, ","))
|
log.UpdateContext(func(c zerolog.Context) zerolog.Context {
|
||||||
})
|
return c.Strs(key, values)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
|
|
|
@ -253,20 +253,20 @@ func BenchmarkDataRace(b *testing.B) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestForwardedAddrHandler(t *testing.T) {
|
func TestLogHeadersHandler(t *testing.T) {
|
||||||
out := &bytes.Buffer{}
|
out := &bytes.Buffer{}
|
||||||
|
|
||||||
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
r := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||||
|
|
||||||
r.Header.Set("X-Forwarded-For", "proxy1,proxy2,proxy3")
|
r.Header.Set("X-Forwarded-For", "proxy1,proxy2,proxy3")
|
||||||
|
|
||||||
h := ForwardedAddrHandler("fwd_ip")(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
h := HeadersHandler([]string{"X-Forwarded-For"})(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
l := FromRequest(r)
|
l := FromRequest(r)
|
||||||
l.Log().Msg("")
|
l.Log().Msg("")
|
||||||
}))
|
}))
|
||||||
h = NewHandler(zerolog.New(out))(h)
|
h = NewHandler(zerolog.New(out))(h)
|
||||||
h.ServeHTTP(nil, r)
|
h.ServeHTTP(nil, r)
|
||||||
if want, got := `{"fwd_ip":["proxy1","proxy2","proxy3"]}`+"\n", decodeIfBinary(out); want != got {
|
if want, got := `{"X-Forwarded-For":["proxy1,proxy2,proxy3"]}`+"\n", decodeIfBinary(out); want != got {
|
||||||
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
t.Errorf("Invalid log output, got: %s, want: %s", got, want)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,6 @@ func TestCorsBypass(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
||||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
|
|
@ -1,13 +1,10 @@
|
||||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
"github.com/pomerium/pomerium/internal/telemetry/trace"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
@ -34,8 +31,8 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature")
|
ctx, span := trace.StartSpan(r.Context(), "middleware.ValidateSignature")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
if !ValidateRedirectURI(r, sharedSecret) {
|
if err := ValidateRequestURL(r, sharedSecret); err != nil {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, nil))
|
httputil.ErrorResponse(w, r, httputil.Error("invalid signature", http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
|
@ -43,27 +40,24 @@ func ValidateSignature(sharedSecret string) func(next http.Handler) http.Handler
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidateRedirectURI takes a request and parses `redirect_uri`, `sig`, `ts`
|
// ValidateRequestURL validates the current absolute request URL was signed
|
||||||
// and validates the supplied signature (`sig`)'s HMAC for validity.
|
// by a given shared key.
|
||||||
func ValidateRedirectURI(r *http.Request, key string) bool {
|
func ValidateRequestURL(r *http.Request, key string) error {
|
||||||
return ValidSignature(
|
return urlutil.NewSignedURL(key, urlutil.GetAbsoluteURL(r)).Validate()
|
||||||
r.FormValue("redirect_uri"),
|
|
||||||
r.FormValue("sig"),
|
|
||||||
r.FormValue("ts"),
|
|
||||||
key)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Healthcheck endpoint middleware useful to setting up a path like
|
// Healthcheck endpoint middleware useful to setting up a path like
|
||||||
// `/ping` that load balancers or uptime testing external services
|
// `/ping` that load balancers or uptime testing external services
|
||||||
// can make a request before hitting any routes. It's also convenient
|
// can make a request before hitting any routes. It's also convenient
|
||||||
// to place this above ACL middlewares as well.
|
// to place this above ACL middlewares as well.
|
||||||
|
//
|
||||||
|
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html
|
||||||
func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler {
|
func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler {
|
||||||
f := func(next http.Handler) http.Handler {
|
f := func(next http.Handler) http.Handler {
|
||||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, span := trace.StartSpan(r.Context(), "middleware.Healthcheck")
|
ctx, span := trace.StartSpan(r.Context(), "middleware.Healthcheck")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
if strings.EqualFold(r.URL.Path, endpoint) {
|
if strings.EqualFold(r.URL.Path, endpoint) {
|
||||||
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html
|
|
||||||
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
if r.Method != http.MethodGet && r.Method != http.MethodHead {
|
||||||
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
||||||
return
|
return
|
||||||
|
@ -82,26 +76,6 @@ func Healthcheck(endpoint, msg string) func(http.Handler) http.Handler {
|
||||||
return f
|
return f
|
||||||
}
|
}
|
||||||
|
|
||||||
// ValidSignature checks to see if a signature is valid. Compares hmac of
|
|
||||||
// redirect uri, timestamp, and secret and signature.
|
|
||||||
func ValidSignature(redirectURI, sigVal, timestamp, secret string) bool {
|
|
||||||
if redirectURI == "" || sigVal == "" || timestamp == "" || secret == "" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
_, err := urlutil.ParseAndValidateURL(redirectURI)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
requestSig, err := base64.URLEncoding.DecodeString(sigVal)
|
|
||||||
if err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if err := cryptutil.ValidTimestamp(timestamp); err != nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return cryptutil.CheckHMAC([]byte(fmt.Sprint(redirectURI, timestamp)), requestSig, secret)
|
|
||||||
}
|
|
||||||
|
|
||||||
// StripCookie strips the cookie from the downstram request.
|
// StripCookie strips the cookie from the downstram request.
|
||||||
func StripCookie(cookieName string) func(next http.Handler) http.Handler {
|
func StripCookie(cookieName string) func(next http.Handler) http.Handler {
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package middleware // import "github.com/pomerium/pomerium/internal/middleware"
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httptest"
|
"net/http/httptest"
|
||||||
|
@ -9,47 +8,10 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
func hmacHelperFunc(rawRedirect string, timestamp time.Time, secret string) []byte {
|
|
||||||
data := []byte(fmt.Sprint(rawRedirect, timestamp.Unix()))
|
|
||||||
return cryptutil.GenerateHMAC(data, secret)
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_ValidSignature(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
goodURL := "https://example.com/redirect"
|
|
||||||
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
|
||||||
now := fmt.Sprint(time.Now().Unix())
|
|
||||||
rawSig := hmacHelperFunc(goodURL, time.Now(), secretA)
|
|
||||||
sig := base64.URLEncoding.EncodeToString(rawSig)
|
|
||||||
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
redirectURI string
|
|
||||||
sigVal string
|
|
||||||
timestamp string
|
|
||||||
secret string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{"good signature", goodURL, sig, now, secretA, true},
|
|
||||||
{"empty redirect url", "", sig, now, secretA, false},
|
|
||||||
{"bad redirect url", "https://google.com^", sig, now, secretA, false},
|
|
||||||
{"malformed signature", goodURL, sig + "^", now, "&*&@**($&#(", false},
|
|
||||||
{"malformed timestamp", goodURL, sig, now + "^", secretA, false},
|
|
||||||
{"stale timestamp", goodURL, sig, staleTime, secretA, false},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := ValidSignature(tt.redirectURI, tt.sigVal, tt.timestamp, tt.secret); got != tt.want {
|
|
||||||
t.Errorf("ValidSignature() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetHeaders(t *testing.T) {
|
func TestSetHeaders(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
@ -79,56 +41,6 @@ func TestSetHeaders(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestValidateSignature(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
secretA := "41aOD7VNtQ1/KZDCGrkYpaHwB50JC1y6BDs2KPRVd2A="
|
|
||||||
now := fmt.Sprint(time.Now().Unix())
|
|
||||||
goodURL := "https://example.com/redirect"
|
|
||||||
rawSig := hmacHelperFunc(goodURL, time.Now(), secretA)
|
|
||||||
sig := base64.URLEncoding.EncodeToString(rawSig)
|
|
||||||
staleTime := fmt.Sprint(time.Now().Add(-6 * time.Minute).Unix())
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
sharedSecret string
|
|
||||||
redirectURI string
|
|
||||||
sig string
|
|
||||||
ts string
|
|
||||||
status int
|
|
||||||
}{
|
|
||||||
{"valid signature", secretA, goodURL, sig, now, http.StatusOK},
|
|
||||||
{"stale signature", secretA, goodURL, sig, staleTime, http.StatusBadRequest},
|
|
||||||
{"malformed", secretA, goodURL, "%zzzzz", now, http.StatusBadRequest},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
v := url.Values{}
|
|
||||||
v.Set("redirect_uri", tt.redirectURI)
|
|
||||||
v.Set("ts", tt.ts)
|
|
||||||
v.Set("sig", tt.sig)
|
|
||||||
|
|
||||||
req := &http.Request{
|
|
||||||
Method: http.MethodGet,
|
|
||||||
URL: &url.URL{RawQuery: v.Encode()}}
|
|
||||||
if tt.name == "malformed" {
|
|
||||||
req.URL.RawQuery = "sig=%zzzzz"
|
|
||||||
}
|
|
||||||
|
|
||||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.Write([]byte("Hi"))
|
|
||||||
})
|
|
||||||
rr := httptest.NewRecorder()
|
|
||||||
handler := ValidateSignature(tt.sharedSecret)(testHandler)
|
|
||||||
handler.ServeHTTP(rr, req)
|
|
||||||
if rr.Code != tt.status {
|
|
||||||
t.Errorf("Status code differs. got %d want %d", rr.Code, tt.status)
|
|
||||||
t.Errorf("%s", rr.Body)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHealthCheck(t *testing.T) {
|
func TestHealthCheck(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
testHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -212,7 +124,6 @@ func TestTimeoutHandlerFunc(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
||||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
@ -242,3 +153,42 @@ func TestTimeoutHandlerFunc(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestValidateSignature(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
})
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
secretA string
|
||||||
|
secretB string
|
||||||
|
wantStatus int
|
||||||
|
wantBody string
|
||||||
|
}{
|
||||||
|
{"good", "secret", "secret", http.StatusOK, http.StatusText(http.StatusOK)},
|
||||||
|
{"secret mistmatch", "secret", "hunter42", http.StatusBadRequest, "{\"error\":\"invalid signature\"}\n"},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
signedURL := urlutil.NewSignedURL(tt.secretB, &url.URL{Scheme: "https", Host: "pomerium.io"})
|
||||||
|
|
||||||
|
r := httptest.NewRequest(http.MethodGet, signedURL.String(), nil)
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
got := ValidateSignature(tt.secretA)(fn)
|
||||||
|
got.ServeHTTP(w, r)
|
||||||
|
if status := w.Code; status != tt.wantStatus {
|
||||||
|
t.Errorf("SignRequest() error = %v, wantErr %v\n%v", w.Result().StatusCode, tt.wantStatus, w.Body.String())
|
||||||
|
}
|
||||||
|
body := w.Body.String()
|
||||||
|
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
|
||||||
|
t.Errorf("SignRequest() %s", diff)
|
||||||
|
t.Errorf("%s", signedURL)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -4,6 +4,8 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Context keys
|
// Context keys
|
||||||
|
@ -41,8 +43,8 @@ func retrieveFromRequest(r *http.Request, sessions ...SessionLoader) (*State, er
|
||||||
return state, err
|
return state, err
|
||||||
}
|
}
|
||||||
if state != nil {
|
if state != nil {
|
||||||
err := state.Verify(r.Host)
|
err := state.Verify(urlutil.StripPort(r.Host))
|
||||||
return state, err // N.B.: state is _not nil_
|
return state, err // N.B.: state is _not_ nil_
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
14
internal/urlutil/errors.go
Normal file
14
internal/urlutil/errors.go
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrExpired indicates that token is used after expiry time indicated in exp claim.
|
||||||
|
ErrExpired = errors.New("internal/urlutil: validation failed, url hmac is expired")
|
||||||
|
|
||||||
|
// ErrIssuedInTheFuture indicates that the issued field is in the future.
|
||||||
|
ErrIssuedInTheFuture = errors.New("internal/urlutil: validation field, url hmac issued in the future")
|
||||||
|
|
||||||
|
// ErrNumericDateMalformed indicates a malformed unix timestamp was found while parsing.
|
||||||
|
ErrNumericDateMalformed = errors.New("internal/urlutil: malformed unix timestamp field")
|
||||||
|
)
|
24
internal/urlutil/query_params.go
Normal file
24
internal/urlutil/query_params.go
Normal file
|
@ -0,0 +1,24 @@
|
||||||
|
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
|
||||||
|
// Common query parameters used to set and send data between Pomerium
|
||||||
|
// services over HTTP calls and redirects. They are typically used in
|
||||||
|
// conjunction with a HMAC to ensure authenticity.
|
||||||
|
const (
|
||||||
|
QueryCallbackURI = "pomerium_callback_uri"
|
||||||
|
QueryImpersonateEmail = "pomerium_impersonate_email"
|
||||||
|
QueryImpersonateGroups = "pomerium_impersonate_groups"
|
||||||
|
QueryIsProgrammatic = "pomerium_programmatic"
|
||||||
|
QueryForwardAuth = "pomerium_forward_auth"
|
||||||
|
QueryPomeriumJWT = "pomerium_jwt"
|
||||||
|
QuerySessionEncrypted = "pomerium_session_encrypted"
|
||||||
|
QueryRedirectURI = "pomerium_redirect_uri"
|
||||||
|
QueryRefreshToken = "pomerium_refresh_token"
|
||||||
|
)
|
||||||
|
|
||||||
|
// URL signature based query params used for verifying the authenticity of a URL.
|
||||||
|
const (
|
||||||
|
QueryHmacExpiry = "pomerium_expiry"
|
||||||
|
QueryHmacIssued = "pomerium_issued"
|
||||||
|
QueryHmacSignature = "pomerium_signature"
|
||||||
|
QueryHmacURI = "pomerium_uri"
|
||||||
|
)
|
130
internal/urlutil/signed.go
Normal file
130
internal/urlutil/signed.go
Normal file
|
@ -0,0 +1,130 @@
|
||||||
|
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/base64"
|
||||||
|
"fmt"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SignedURL is a shared-key HMAC wrapped URL.
|
||||||
|
type SignedURL struct {
|
||||||
|
uri url.URL
|
||||||
|
key string
|
||||||
|
signed bool
|
||||||
|
|
||||||
|
// mockable time for testing
|
||||||
|
timeNow func() time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewSignedURL creates a new copy of a URL that can be signed with a shared key.
|
||||||
|
//
|
||||||
|
// N.B. It is the user's responsibility to make sure the key is 256 bits and
|
||||||
|
// the url is not nil.
|
||||||
|
func NewSignedURL(key string, uri *url.URL) *SignedURL {
|
||||||
|
return &SignedURL{uri: *uri, key: key, timeNow: time.Now} // uri is copied
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sign creates a shared-key HMAC signed URL.
|
||||||
|
func (su *SignedURL) Sign() *url.URL {
|
||||||
|
now := su.timeNow()
|
||||||
|
issued := newNumericDate(now)
|
||||||
|
expiry := newNumericDate(now.Add(5 * time.Minute))
|
||||||
|
params := su.uri.Query()
|
||||||
|
params.Set(QueryHmacIssued, fmt.Sprint(issued))
|
||||||
|
params.Set(QueryHmacExpiry, fmt.Sprint(expiry))
|
||||||
|
su.uri.RawQuery = params.Encode()
|
||||||
|
params.Set(QueryHmacSignature, hmacURL(su.key, su.uri.String(), issued, expiry))
|
||||||
|
su.uri.RawQuery = params.Encode()
|
||||||
|
su.signed = true
|
||||||
|
return &su.uri
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements the stringer interface and returns a signed URL string.
|
||||||
|
func (su *SignedURL) String() string {
|
||||||
|
if !su.signed {
|
||||||
|
su.Sign()
|
||||||
|
su.signed = true
|
||||||
|
}
|
||||||
|
return su.uri.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate checks to see if a signed URL is valid.
|
||||||
|
func (su *SignedURL) Validate() error {
|
||||||
|
now := su.timeNow()
|
||||||
|
params := su.uri.Query()
|
||||||
|
sig, err := base64.URLEncoding.DecodeString(params.Get(QueryHmacSignature))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("internal/urlutil: malformed signature %w", err)
|
||||||
|
}
|
||||||
|
params.Del(QueryHmacSignature)
|
||||||
|
su.uri.RawQuery = params.Encode()
|
||||||
|
|
||||||
|
issued, err := newNumericDateFromString(params.Get(QueryHmacIssued))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("internal/urlutil: issued %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
expiry, err := newNumericDateFromString(params.Get(QueryHmacExpiry))
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("internal/urlutil: expiry %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if expiry != nil && now.Add(-DefaultLeeway).After(expiry.Time()) {
|
||||||
|
return ErrExpired
|
||||||
|
}
|
||||||
|
|
||||||
|
if issued != nil && now.Add(DefaultLeeway).Before(issued.Time()) {
|
||||||
|
return ErrIssuedInTheFuture
|
||||||
|
}
|
||||||
|
|
||||||
|
validHMAC := cryptutil.CheckHMAC(
|
||||||
|
[]byte(fmt.Sprint(su.uri.String(), issued, expiry)),
|
||||||
|
sig,
|
||||||
|
su.key)
|
||||||
|
if !validHMAC {
|
||||||
|
return fmt.Errorf("internal/urlutil: hmac failed %s", su.uri.String())
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// hmacURL takes a redirect url string and timestamp and returns the base64
|
||||||
|
// encoded HMAC result.
|
||||||
|
func hmacURL(key string, data ...interface{}) string {
|
||||||
|
h := cryptutil.GenerateHMAC([]byte(fmt.Sprint(data...)), key)
|
||||||
|
return base64.URLEncoding.EncodeToString(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// numericDate used because we don't need the precision of a typical time.Time.
|
||||||
|
type numericDate int64
|
||||||
|
|
||||||
|
func newNumericDate(t time.Time) *numericDate {
|
||||||
|
if t.IsZero() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
out := numericDate(t.Unix())
|
||||||
|
return &out
|
||||||
|
}
|
||||||
|
|
||||||
|
func newNumericDateFromString(s string) (*numericDate, error) {
|
||||||
|
i, err := strconv.ParseInt(s, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return nil, ErrNumericDateMalformed
|
||||||
|
}
|
||||||
|
out := numericDate(i)
|
||||||
|
return &out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *numericDate) Time() time.Time {
|
||||||
|
if n == nil {
|
||||||
|
return time.Time{}
|
||||||
|
}
|
||||||
|
return time.Unix(int64(*n), 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (n *numericDate) String() string {
|
||||||
|
return strconv.FormatInt(int64(*n), 10)
|
||||||
|
}
|
97
internal/urlutil/signed_test.go
Normal file
97
internal/urlutil/signed_test.go
Normal file
|
@ -0,0 +1,97 @@
|
||||||
|
package urlutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestSignedURL(t *testing.T) {
|
||||||
|
original := time.Unix(1574117851, 0) // ;-)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
key string
|
||||||
|
uri url.URL
|
||||||
|
origTime func() time.Time
|
||||||
|
newTime func() time.Time
|
||||||
|
wantStr string
|
||||||
|
want url.URL
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"good", "test-key", url.URL{Scheme: "https", Host: "pomerium.io"},
|
||||||
|
func() time.Time { return original }, func() time.Time { return original },
|
||||||
|
"https://pomerium.io?pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D",
|
||||||
|
url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"},
|
||||||
|
false},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
signedURL := NewSignedURL(tt.key, &tt.uri)
|
||||||
|
signedURL.timeNow = tt.origTime
|
||||||
|
|
||||||
|
if diff := cmp.Diff(signedURL.String(), tt.wantStr); diff != "" {
|
||||||
|
t.Errorf("signedURL() = %v", diff)
|
||||||
|
}
|
||||||
|
|
||||||
|
signedURL = NewSignedURL(tt.key, &tt.uri)
|
||||||
|
signedURL.timeNow = tt.origTime
|
||||||
|
got := signedURL.Sign()
|
||||||
|
|
||||||
|
if diff := cmp.Diff(*got, tt.want); diff != "" {
|
||||||
|
t.Errorf("NewSignedURL() = %s", diff)
|
||||||
|
}
|
||||||
|
err := signedURL.Validate()
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(signedURL.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(u, got); diff != "" {
|
||||||
|
t.Errorf("signedURL() = %v", diff)
|
||||||
|
}
|
||||||
|
// subsequent string calls shouldn't result in a change
|
||||||
|
u, err = url.Parse(signedURL.String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
if diff := cmp.Diff(u, got); diff != "" {
|
||||||
|
t.Errorf("signedURL() = %v", diff)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSignedURL_Validate(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
uri url.URL
|
||||||
|
key string
|
||||||
|
timeNow func() time.Time
|
||||||
|
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{"good", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0) }, false},
|
||||||
|
{"bad key", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "bad-key", func() time.Time { return time.Unix(1574117851, 0) }, true},
|
||||||
|
{"bad no expiry", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0) }, true},
|
||||||
|
{"bad issued", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0) }, true},
|
||||||
|
{"bad signature body", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=^"}, "test-key", func() time.Time { return time.Unix(1574117851, 0) }, true},
|
||||||
|
{"bad expired", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0).Add(time.Hour) }, true},
|
||||||
|
{"bad not yet valid", url.URL{Scheme: "https", Host: "pomerium.io", RawQuery: "pomerium_expiry=1574118151&pomerium_issued=1574117851&pomerium_signature=XtvM-Y-oPvoGGV2Q5G0vrQ_CgNeYhVyTG5dHIqLsBOU%3D"}, "test-key", func() time.Time { return time.Unix(1574117851, 0).Add(-time.Hour) }, true},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
out := NewSignedURL(tt.key, &tt.uri)
|
||||||
|
out.timeNow = tt.timeNow
|
||||||
|
|
||||||
|
if err := out.Validate(); (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("SignedURL.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,15 +1,16 @@
|
||||||
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
|
package urlutil // import "github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/base64"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/cryptutil"
|
const (
|
||||||
|
// DefaultLeeway defines the default leeway for matching NotBefore/Expiry claims.
|
||||||
|
DefaultLeeway = 1.0 * time.Minute
|
||||||
)
|
)
|
||||||
|
|
||||||
// StripPort returns a host, without any port number.
|
// StripPort returns a host, without any port number.
|
||||||
|
@ -66,57 +67,6 @@ func DeepCopy(u *url.URL) (*url.URL, error) {
|
||||||
return ParseAndValidateURL(u.String())
|
return ParseAndValidateURL(u.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
var mockNow testTime
|
|
||||||
|
|
||||||
// testTime is safe to use concurrently.
|
|
||||||
type testTime struct {
|
|
||||||
sync.Mutex
|
|
||||||
mockNow int64
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tt *testTime) setNow(n int64) {
|
|
||||||
tt.Lock()
|
|
||||||
tt.mockNow = n
|
|
||||||
tt.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tt *testTime) now() int64 {
|
|
||||||
tt.Lock()
|
|
||||||
defer tt.Unlock()
|
|
||||||
return tt.mockNow
|
|
||||||
}
|
|
||||||
|
|
||||||
// timestamp returns the current timestamp, in seconds.
|
|
||||||
//
|
|
||||||
// For testing purposes, the function that generates the timestamp can be
|
|
||||||
// overridden. If not set, it will return time.Now().UTC().Unix().
|
|
||||||
func timestamp() int64 {
|
|
||||||
if mockNow.now() == 0 {
|
|
||||||
return time.Now().UTC().Unix()
|
|
||||||
}
|
|
||||||
return mockNow.now()
|
|
||||||
}
|
|
||||||
|
|
||||||
// SignedRedirectURL takes a destination URL and adds redirect_uri to it's
|
|
||||||
// query params, along with a timestamp and an keyed signature.
|
|
||||||
func SignedRedirectURL(key string, destination, u *url.URL) *url.URL {
|
|
||||||
now := timestamp()
|
|
||||||
rawURL := u.String()
|
|
||||||
params, _ := url.ParseQuery(destination.RawQuery) // handled by incoming mux
|
|
||||||
params.Set("redirect_uri", rawURL)
|
|
||||||
params.Set("ts", fmt.Sprint(now))
|
|
||||||
params.Set("sig", hmacURL(key, rawURL, now))
|
|
||||||
destination.RawQuery = params.Encode()
|
|
||||||
return destination
|
|
||||||
}
|
|
||||||
|
|
||||||
// hmacURL takes a redirect url string and timestamp and returns the base64
|
|
||||||
// encoded HMAC result.
|
|
||||||
func hmacURL(key, data string, timestamp int64) string {
|
|
||||||
h := cryptutil.GenerateHMAC([]byte(fmt.Sprint(data, timestamp)), key)
|
|
||||||
return base64.URLEncoding.EncodeToString(h)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetAbsoluteURL returns the current handler's absolute url.
|
// GetAbsoluteURL returns the current handler's absolute url.
|
||||||
// https://stackoverflow.com/a/23152483
|
// https://stackoverflow.com/a/23152483
|
||||||
func GetAbsoluteURL(r *http.Request) *url.URL {
|
func GetAbsoluteURL(r *http.Request) *url.URL {
|
||||||
|
|
|
@ -108,48 +108,6 @@ func TestValidateURL(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSignedRedirectURL(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
mockedTime int64
|
|
||||||
key string
|
|
||||||
destination *url.URL
|
|
||||||
urlToSign *url.URL
|
|
||||||
want *url.URL
|
|
||||||
}{
|
|
||||||
{"good", 2, "hunter42", &url.URL{Host: "pomerium.io", Scheme: "https://"}, &url.URL{Host: "pomerium.io", Scheme: "https://", Path: "/ok"}, &url.URL{Host: "pomerium.io", Scheme: "https://", RawQuery: "redirect_uri=https%3A%2F%2F%3A%2F%2Fpomerium.io%2Fok&sig=7jdo1XFcmuhjBHnpfVhll5cXflYByeMnbp5kRz87CVQ%3D&ts=2"}},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
mockNow.setNow(tt.mockedTime)
|
|
||||||
got := SignedRedirectURL(tt.key, tt.destination, tt.urlToSign)
|
|
||||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
|
||||||
t.Errorf("SignedRedirectURL() = diff %v", diff)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_timestamp(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
dontWant int64
|
|
||||||
}{
|
|
||||||
{"if unset should never return", 0},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
mockNow.setNow(tt.dontWant)
|
|
||||||
if got := timestamp(); got == tt.dontWant {
|
|
||||||
t.Errorf("timestamp() = %v, dontWant %v", got, tt.dontWant)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseURLHelper(s string) *url.URL {
|
func parseURLHelper(s string) *url.URL {
|
||||||
u, _ := url.Parse(s)
|
u, _ := url.Parse(s)
|
||||||
return u
|
return u
|
||||||
|
|
114
proxy/forward_auth.go
Normal file
114
proxy/forward_auth.go
Normal file
|
@ -0,0 +1,114 @@
|
||||||
|
package proxy // import "github.com/pomerium/pomerium/proxy"
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (p *Proxy) registerFwdAuthHandlers() http.Handler {
|
||||||
|
r := httputil.NewRouter()
|
||||||
|
r.StrictSlash(true)
|
||||||
|
r.Use(sessions.RetrieveSession(p.sessionStore))
|
||||||
|
|
||||||
|
r.Handle("/verify", http.HandlerFunc(p.nginxCallback)).
|
||||||
|
Queries("uri", "{uri}", urlutil.QuerySessionEncrypted, "", urlutil.QueryRedirectURI, "")
|
||||||
|
r.Handle("/", http.HandlerFunc(p.postSessionSetNOP)).
|
||||||
|
Queries("uri", "{uri}",
|
||||||
|
urlutil.QuerySessionEncrypted, "",
|
||||||
|
urlutil.QueryRedirectURI, "")
|
||||||
|
r.Handle("/", http.HandlerFunc(p.traefikCallback)).
|
||||||
|
HeadersRegexp(httputil.HeaderForwardedURI, urlutil.QuerySessionEncrypted)
|
||||||
|
r.Handle("/", p.Verify(false)).Queries("uri", "{uri}")
|
||||||
|
r.Handle("/verify", p.Verify(true)).Queries("uri", "{uri}")
|
||||||
|
|
||||||
|
return r
|
||||||
|
}
|
||||||
|
|
||||||
|
// postSessionSetNOP after successfully setting the
|
||||||
|
func (p *Proxy) postSessionSetNOP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
httputil.Redirect(w, r, r.FormValue(urlutil.QueryRedirectURI), http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) nginxCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
|
||||||
|
if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
w.WriteHeader(http.StatusUnauthorized)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Proxy) traefikCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
forwardedURL, err := url.Parse(r.Header.Get(httputil.HeaderForwardedURI))
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
q := forwardedURL.Query()
|
||||||
|
redirectURLString := q.Get(urlutil.QueryRedirectURI)
|
||||||
|
encryptedSession := q.Get(urlutil.QuerySessionEncrypted)
|
||||||
|
|
||||||
|
if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
httputil.Redirect(w, r, redirectURLString, http.StatusFound)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify checks a user's credentials for an arbitrary host. If the user
|
||||||
|
// is properly authenticated and is authorized to access the supplied host,
|
||||||
|
// a `200` http status code is returned. If the user is not authenticated, they
|
||||||
|
// will be redirected to the authenticate service to sign in with their identity
|
||||||
|
// provider. If the user is unauthorized, a `401` error is returned.
|
||||||
|
func (p *Proxy) Verify(verifyOnly bool) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
uri, err := urlutil.ParseAndValidateURL(r.FormValue("uri"))
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error("bad verification uri", http.StatusBadRequest, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := sessions.FromContext(r.Context())
|
||||||
|
if errors.Is(err, sessions.ErrNoSessionFound) || errors.Is(err, sessions.ErrExpired) {
|
||||||
|
if verifyOnly {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authN := *p.authenticateSigninURL
|
||||||
|
q := authN.Query()
|
||||||
|
q.Set(urlutil.QueryCallbackURI, uri.String())
|
||||||
|
q.Set(urlutil.QueryRedirectURI, uri.String()) // final destination
|
||||||
|
q.Set(urlutil.QueryForwardAuth, urlutil.StripPort(r.Host)) // add fwd auth to trusted audience
|
||||||
|
authN.RawQuery = q.Encode()
|
||||||
|
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &authN).String(), http.StatusFound)
|
||||||
|
return
|
||||||
|
} else if err != nil {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// depending on the configuration of the fronting proxy, the request Host
|
||||||
|
// and/or `X-Forwarded-Host` may be untrustd or change so we reverify
|
||||||
|
// the session's validity against the supplied uri
|
||||||
|
if err := s.Verify(uri.Hostname()); err != nil {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
p.addPomeriumHeaders(w, r)
|
||||||
|
if err := p.authorize(uri.Host, w, r); err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
|
w.WriteHeader(http.StatusOK)
|
||||||
|
fmt.Fprintf(w, "Access to %s is allowed.", uri.Host)
|
||||||
|
})
|
||||||
|
}
|
119
proxy/forward_auth_test.go
Normal file
119
proxy/forward_auth_test.go
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/go-cmp/cmp"
|
||||||
|
"github.com/pomerium/pomerium/config"
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
|
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||||
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
|
"github.com/pomerium/pomerium/proxy/clients"
|
||||||
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProxy_ForwardAuth(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
opts := testOptions(t)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
options config.Options
|
||||||
|
ctxError error
|
||||||
|
method string
|
||||||
|
|
||||||
|
headers map[string]string
|
||||||
|
qp map[string]string
|
||||||
|
|
||||||
|
requestURI string
|
||||||
|
verifyURI string
|
||||||
|
|
||||||
|
cipher encoding.MarshalUnmarshaler
|
||||||
|
sessionStore sessions.SessionStore
|
||||||
|
authorizer clients.Authorizer
|
||||||
|
wantStatus int
|
||||||
|
wantBody string
|
||||||
|
}{
|
||||||
|
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, "Access to some.domain.example is allowed."},
|
||||||
|
{"good verify only, no redirect", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
|
||||||
|
{"good redirect not required", opts, nil, http.MethodGet, nil, nil, "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{LoadError: sessions.ErrInvalidAudience}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
||||||
|
{"bad naked domain uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
|
||||||
|
{"bad naked domain uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
|
||||||
|
{"bad empty verification uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
|
||||||
|
{"bad empty verification uri verify only", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
|
||||||
|
{"not authorized", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"},
|
||||||
|
{"not authorized verify endpoint", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"},
|
||||||
|
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
|
||||||
|
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
||||||
|
{"not authorized because of error", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"error\":\"authz error\"}\n"},
|
||||||
|
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
||||||
|
{"not authorized, bad audience request uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"not.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
||||||
|
{"not authorized, bad audience verify uri", opts, nil, http.MethodGet, nil, nil, "https://some.domain.example/", "https://fwdauth.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Audience: []string{"some.domain.example"}, Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, invalid audience claim (aud)\"}\n"},
|
||||||
|
// traefik
|
||||||
|
{"good traefik callback", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||||
|
{"bad traefik callback bad session", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: "https://some.domain.example?" + urlutil.QuerySessionEncrypted + "=" + goodEncryptionString + "garbage"}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
|
{"bad traefik callback bad url", opts, nil, http.MethodGet, map[string]string{httputil.HeaderForwardedURI: urlutil.QuerySessionEncrypted + ""}, nil, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
|
// nginx
|
||||||
|
{"good nginx callback redirect", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||||
|
{"good nginx callback set session okay but return unauthorized", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusUnauthorized, ""},
|
||||||
|
{"bad nginx callback failed to set sesion", opts, nil, http.MethodGet, nil, map[string]string{urlutil.QueryRedirectURI: "https://some.domain.example/", urlutil.QuerySessionEncrypted: goodEncryptionString + "nope"}, "https://some.domain.example/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p, err := New(tt.options)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
p.encoder = tt.cipher
|
||||||
|
p.sessionStore = tt.sessionStore
|
||||||
|
p.AuthorizeClient = tt.authorizer
|
||||||
|
p.UpdateOptions(tt.options)
|
||||||
|
uri, err := url.Parse(tt.requestURI)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
queryString := uri.Query()
|
||||||
|
for k, v := range tt.qp {
|
||||||
|
queryString.Set(k, v)
|
||||||
|
}
|
||||||
|
if tt.verifyURI != "" {
|
||||||
|
queryString.Set("uri", tt.verifyURI)
|
||||||
|
}
|
||||||
|
|
||||||
|
uri.RawQuery = queryString.Encode()
|
||||||
|
|
||||||
|
r := httptest.NewRequest(tt.method, uri.String(), nil)
|
||||||
|
state, _ := tt.sessionStore.LoadSession(r)
|
||||||
|
|
||||||
|
ctx := r.Context()
|
||||||
|
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
if len(tt.headers) != 0 {
|
||||||
|
for k, v := range tt.headers {
|
||||||
|
r.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
router := p.registerFwdAuthHandlers()
|
||||||
|
router.ServeHTTP(w, r)
|
||||||
|
if status := w.Code; status != tt.wantStatus {
|
||||||
|
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
||||||
|
t.Errorf("\n%+v", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantBody != "" {
|
||||||
|
body := w.Body.String()
|
||||||
|
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
|
||||||
|
t.Errorf("wrong body\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,7 +19,6 @@ import (
|
||||||
|
|
||||||
// registerDashboardHandlers returns the proxy service's ServeMux
|
// registerDashboardHandlers returns the proxy service's ServeMux
|
||||||
func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
|
func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
|
||||||
// dashboard subrouter
|
|
||||||
h := r.PathPrefix(dashboardURL).Subrouter()
|
h := r.PathPrefix(dashboardURL).Subrouter()
|
||||||
// 1. Retrieve the user session and add it to the request context
|
// 1. Retrieve the user session and add it to the request context
|
||||||
h.Use(sessions.RetrieveSession(p.sessionStore))
|
h.Use(sessions.RetrieveSession(p.sessionStore))
|
||||||
|
@ -32,19 +31,27 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
|
||||||
csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieOptions.Name)),
|
csrf.CookieName(fmt.Sprintf("%s_csrf", p.cookieOptions.Name)),
|
||||||
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
|
csrf.ErrorHandler(http.HandlerFunc(httputil.CSRFFailureHandler)),
|
||||||
))
|
))
|
||||||
|
// dashboard endpoints can be used by user's to view, or modify their session
|
||||||
h.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet)
|
h.HandleFunc("/", p.UserDashboard).Methods(http.MethodGet)
|
||||||
h.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost)
|
h.HandleFunc("/impersonate", p.Impersonate).Methods(http.MethodPost)
|
||||||
h.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost)
|
h.HandleFunc("/sign_out", p.SignOut).Methods(http.MethodGet, http.MethodPost)
|
||||||
|
|
||||||
// Authenticate service callback handlers and middleware
|
// Authenticate service callback handlers and middleware
|
||||||
|
// callback used to set route-scoped session and redirect back to destination
|
||||||
|
// only accept signed requests (hmac) from other trusted pomerium services
|
||||||
c := r.PathPrefix(dashboardURL + "/callback").Subrouter()
|
c := r.PathPrefix(dashboardURL + "/callback").Subrouter()
|
||||||
// only accept payloads that have come from a trusted service (hmac)
|
|
||||||
c.Use(middleware.ValidateSignature(p.SharedKey))
|
c.Use(middleware.ValidateSignature(p.SharedKey))
|
||||||
c.HandleFunc("/", p.Callback).Queries("redirect_uri", "{redirect_uri}").Methods(http.MethodGet)
|
|
||||||
|
|
||||||
|
c.Path("/").HandlerFunc(p.ProgrammaticCallback).Methods(http.MethodGet).
|
||||||
|
Queries(urlutil.QueryIsProgrammatic, "true")
|
||||||
|
|
||||||
|
c.Path("/").HandlerFunc(p.Callback).Methods(http.MethodGet)
|
||||||
// Programmatic API handlers and middleware
|
// Programmatic API handlers and middleware
|
||||||
a := r.PathPrefix(dashboardURL + "/api").Subrouter()
|
a := r.PathPrefix(dashboardURL + "/api").Subrouter()
|
||||||
a.HandleFunc("/v1/login", p.ProgrammaticLogin).Queries("redirect_uri", "{redirect_uri}").Methods(http.MethodGet)
|
// login api handler generates a user-navigable login url to authenticate
|
||||||
|
a.HandleFunc("/v1/login", p.ProgrammaticLogin).
|
||||||
|
Queries(urlutil.QueryRedirectURI, "").
|
||||||
|
Methods(http.MethodGet)
|
||||||
|
|
||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
@ -52,7 +59,6 @@ func (p *Proxy) registerDashboardHandlers(r *mux.Router) *mux.Router {
|
||||||
// RobotsTxt sets the User-Agent header in the response to be "Disallow"
|
// RobotsTxt sets the User-Agent header in the response to be "Disallow"
|
||||||
func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
|
func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
fmt.Fprintf(w, "User-agent: *\nDisallow: /")
|
||||||
}
|
}
|
||||||
|
@ -62,12 +68,17 @@ func (p *Proxy) RobotsTxt(w http.ResponseWriter, _ *http.Request) {
|
||||||
// the local session state.
|
// the local session state.
|
||||||
func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) SignOut(w http.ResponseWriter, r *http.Request) {
|
||||||
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
|
redirectURL := &url.URL{Scheme: "https", Host: r.Host, Path: "/"}
|
||||||
if uri, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri")); err == nil && uri.String() != "" {
|
if uri, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI)); err == nil && uri.String() != "" {
|
||||||
redirectURL = uri
|
redirectURL = uri
|
||||||
}
|
}
|
||||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSignoutURL, redirectURL)
|
|
||||||
|
signoutURL := *p.authenticateSignoutURL
|
||||||
|
q := signoutURL.Query()
|
||||||
|
q.Set(urlutil.QueryRedirectURI, redirectURL.String())
|
||||||
|
signoutURL.RawQuery = q.Encode()
|
||||||
|
|
||||||
p.sessionStore.ClearSession(w, r)
|
p.sessionStore.ClearSession(w, r)
|
||||||
httputil.Redirect(w, r, uri.String(), http.StatusFound)
|
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signoutURL).String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UserDashboard lets users investigate, and refresh their current session.
|
// UserDashboard lets users investigate, and refresh their current session.
|
||||||
|
@ -112,110 +123,95 @@ func (p *Proxy) Impersonate(w http.ResponseWriter, r *http.Request) {
|
||||||
// OK to impersonation
|
// OK to impersonation
|
||||||
redirectURL := urlutil.GetAbsoluteURL(r)
|
redirectURL := urlutil.GetAbsoluteURL(r)
|
||||||
redirectURL.Path = dashboardURL // redirect back to the dashboard
|
redirectURL.Path = dashboardURL // redirect back to the dashboard
|
||||||
q := redirectURL.Query()
|
signinURL := *p.authenticateSigninURL
|
||||||
q.Add("impersonate_email", r.FormValue("email"))
|
q := signinURL.Query()
|
||||||
q.Add("impersonate_group", r.FormValue("group"))
|
q.Set(urlutil.QueryRedirectURI, redirectURL.String())
|
||||||
redirectURL.RawQuery = q.Encode()
|
q.Set(urlutil.QueryImpersonateEmail, r.FormValue("email"))
|
||||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String()
|
q.Set(urlutil.QueryImpersonateGroups, r.FormValue("group"))
|
||||||
httputil.Redirect(w, r, uri, http.StatusFound)
|
signinURL.RawQuery = q.Encode()
|
||||||
|
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (p *Proxy) registerFwdAuthHandlers() http.Handler {
|
// Callback handles the result of a successful call to the authenticate service
|
||||||
r := httputil.NewRouter()
|
// and is responsible setting returned per-route session.
|
||||||
r.StrictSlash(true)
|
|
||||||
r.Use(sessions.RetrieveSession(p.sessionStore))
|
|
||||||
r.Handle("/", p.Verify(false)).Queries("uri", "{uri}").Methods(http.MethodGet)
|
|
||||||
r.Handle("/verify", p.Verify(true)).Queries("uri", "{uri}").Methods(http.MethodGet)
|
|
||||||
return r
|
|
||||||
}
|
|
||||||
|
|
||||||
// Verify checks a user's credentials for an arbitrary host. If the user
|
|
||||||
// is properly authenticated and is authorized to access the supplied host,
|
|
||||||
// a `200` http status code is returned. If the user is not authenticated, they
|
|
||||||
// will be redirected to the authenticate service to sign in with their identity
|
|
||||||
// provider. If the user is unauthorized, a `401` error is returned.
|
|
||||||
func (p *Proxy) Verify(verifyOnly bool) http.Handler {
|
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
uri, err := urlutil.ParseAndValidateURL(r.FormValue("uri"))
|
|
||||||
if err != nil || uri.String() == "" {
|
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("bad verification uri", http.StatusBadRequest, nil))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := p.authenticate(verifyOnly, w, r); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := p.authorize(uri.Host, w, r); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
|
||||||
fmt.Fprintf(w, fmt.Sprintf("Access to %s is allowed.", uri.Host))
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Callback takes a `redirect_uri` query param that has been hmac'd by the
|
|
||||||
// authenticate service. Embedded in the `redirect_uri` are query-params
|
|
||||||
// that tell this handler how to set the per-route user session.
|
|
||||||
// Callback is responsible for redirecting the user back to the intended
|
|
||||||
// destination URL and path, as well as to clean up any additional query params
|
|
||||||
// added by the authenticate service.
|
|
||||||
func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) Callback(w http.ResponseWriter, r *http.Request) {
|
||||||
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
|
redirectURLString := r.FormValue(urlutil.QueryRedirectURI)
|
||||||
if err != nil {
|
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
|
||||||
|
if _, err := p.saveCallbackSession(w, r, encryptedSession); err != nil {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
q := redirectURL.Query()
|
httputil.Redirect(w, r, redirectURLString, http.StatusFound)
|
||||||
// 1. extract the base64 encoded and encrypted JWT from redirect_uri's query params
|
}
|
||||||
encryptedJWT, err := base64.URLEncoding.DecodeString(q.Get("pomerium_jwt"))
|
|
||||||
if err != nil {
|
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
q.Del("pomerium_jwt")
|
|
||||||
q.Del("impersonate_email")
|
|
||||||
q.Del("impersonate_group")
|
|
||||||
|
|
||||||
|
// saveCallbackSession takes an encrypted per-route session token, and decrypts
|
||||||
|
// it using the shared service key, then stores it the local session store.
|
||||||
|
func (p *Proxy) saveCallbackSession(w http.ResponseWriter, r *http.Request, enctoken string) ([]byte, error) {
|
||||||
|
// 1. extract the base64 encoded and encrypted JWT from query params
|
||||||
|
encryptedJWT, err := base64.URLEncoding.DecodeString(enctoken)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("proxy: malfromed callback token: %w", err)
|
||||||
|
}
|
||||||
// 2. decrypt the JWT using the cipher using the _shared_ secret key
|
// 2. decrypt the JWT using the cipher using the _shared_ secret key
|
||||||
rawJWT, err := cryptutil.Decrypt(p.sharedCipher, encryptedJWT, nil)
|
rawJWT, err := cryptutil.Decrypt(p.sharedCipher, encryptedJWT, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("", http.StatusBadRequest, err))
|
return nil, fmt.Errorf("proxy: callback token decrypt error: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
// 3. Save the decrypted JWT to the session store directly as a string, without resigning
|
// 3. Save the decrypted JWT to the session store directly as a string, without resigning
|
||||||
if err = p.sessionStore.SaveSession(w, r, rawJWT); err != nil {
|
if err = p.sessionStore.SaveSession(w, r, rawJWT); err != nil {
|
||||||
httputil.ErrorResponse(w, r, err)
|
return nil, fmt.Errorf("proxy: callback session save failure: %w", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
return rawJWT, nil
|
||||||
// if this is a programmatic request, don't strip the tokens before redirect
|
|
||||||
if redirectURL.Query().Get("pomerium_programmatic_destination_url") != "" {
|
|
||||||
q.Set("pomerium_jwt", string(rawJWT))
|
|
||||||
}
|
|
||||||
redirectURL.RawQuery = q.Encode()
|
|
||||||
|
|
||||||
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProgrammaticLogin returns a signed url that can be used to login
|
// ProgrammaticLogin returns a signed url that can be used to login
|
||||||
// using the authenticate service.
|
// using the authenticate service.
|
||||||
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) {
|
func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
redirectURL, err := urlutil.ParseAndValidateURL(r.FormValue("redirect_uri"))
|
redirectURI, err := urlutil.ParseAndValidateURL(r.FormValue(urlutil.QueryRedirectURI))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect_uri", http.StatusBadRequest, err))
|
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect uri", http.StatusBadRequest, err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
q := redirectURL.Query()
|
signinURL := *p.authenticateSigninURL
|
||||||
q.Add("pomerium_programmatic_destination_url", urlutil.GetAbsoluteURL(r).String())
|
callbackURI := urlutil.GetAbsoluteURL(r)
|
||||||
redirectURL.RawQuery = q.Encode()
|
callbackURI.Path = dashboardURL + "/callback/"
|
||||||
response := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, redirectURL).String()
|
q := signinURL.Query()
|
||||||
|
q.Set(urlutil.QueryCallbackURI, callbackURI.String())
|
||||||
|
q.Set(urlutil.QueryRedirectURI, redirectURI.String())
|
||||||
|
q.Set(urlutil.QueryIsProgrammatic, "true")
|
||||||
|
signinURL.RawQuery = q.Encode()
|
||||||
|
response := urlutil.NewSignedURL(p.SharedKey, &signinURL).String()
|
||||||
|
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
w.Write([]byte(response))
|
w.Write([]byte(response))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ProgrammaticCallback handles a successful call to the authenticate service.
|
||||||
|
// In addition to returning the individual route session (JWT) it also returns
|
||||||
|
// the refresh token.
|
||||||
|
func (p *Proxy) ProgrammaticCallback(w http.ResponseWriter, r *http.Request) {
|
||||||
|
redirectURLString := r.FormValue(urlutil.QueryRedirectURI)
|
||||||
|
encryptedSession := r.FormValue(urlutil.QuerySessionEncrypted)
|
||||||
|
|
||||||
|
redirectURL, err := urlutil.ParseAndValidateURL(redirectURLString)
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error("malformed redirect uri", http.StatusBadRequest, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawJWT, err := p.saveCallbackSession(w, r, encryptedSession)
|
||||||
|
if err != nil {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusBadRequest, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
q := redirectURL.Query()
|
||||||
|
q.Set(urlutil.QueryPomeriumJWT, string(rawJWT))
|
||||||
|
q.Set(urlutil.QueryRefreshToken, r.FormValue(urlutil.QueryRefreshToken))
|
||||||
|
redirectURL.RawQuery = q.Encode()
|
||||||
|
|
||||||
|
httputil.Redirect(w, r, redirectURL.String(), http.StatusFound)
|
||||||
|
}
|
||||||
|
|
|
@ -11,17 +11,22 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/internal/cryptutil"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/config"
|
"github.com/pomerium/pomerium/config"
|
||||||
"github.com/pomerium/pomerium/internal/encoding"
|
"github.com/pomerium/pomerium/internal/encoding"
|
||||||
"github.com/pomerium/pomerium/internal/encoding/mock"
|
"github.com/pomerium/pomerium/internal/encoding/mock"
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/sessions"
|
"github.com/pomerium/pomerium/internal/sessions"
|
||||||
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
"github.com/pomerium/pomerium/proxy/clients"
|
"github.com/pomerium/pomerium/proxy/clients"
|
||||||
|
|
||||||
"github.com/google/go-cmp/cmp"
|
"github.com/google/go-cmp/cmp"
|
||||||
"gopkg.in/square/go-jose.v2/jwt"
|
"gopkg.in/square/go-jose.v2/jwt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const goodEncryptionString = "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="
|
||||||
|
|
||||||
func TestProxy_RobotsTxt(t *testing.T) {
|
func TestProxy_RobotsTxt(t *testing.T) {
|
||||||
proxy := Proxy{}
|
proxy := Proxy{}
|
||||||
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
|
req := httptest.NewRequest(http.MethodGet, "/robots.txt", nil)
|
||||||
|
@ -189,12 +194,12 @@ func TestProxy_SignOut(t *testing.T) {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
postForm := url.Values{}
|
postForm := url.Values{}
|
||||||
postForm.Add("redirect_uri", tt.redirectURL)
|
postForm.Add(urlutil.QueryRedirectURI, tt.redirectURL)
|
||||||
uri := &url.URL{Path: "/"}
|
uri := &url.URL{Path: "/"}
|
||||||
|
|
||||||
query, _ := url.ParseQuery(uri.RawQuery)
|
query, _ := url.ParseQuery(uri.RawQuery)
|
||||||
if tt.verb == http.MethodGet {
|
if tt.verb == http.MethodGet {
|
||||||
query.Add("redirect_uri", tt.redirectURL)
|
query.Add(urlutil.QueryRedirectURI, tt.redirectURL)
|
||||||
uri.RawQuery = query.Encode()
|
uri.RawQuery = query.Encode()
|
||||||
}
|
}
|
||||||
r := httptest.NewRequest(tt.verb, uri.String(), bytes.NewBufferString(postForm.Encode()))
|
r := httptest.NewRequest(tt.verb, uri.String(), bytes.NewBufferString(postForm.Encode()))
|
||||||
|
@ -217,87 +222,7 @@ func uriParseHelper(s string) *url.URL {
|
||||||
}
|
}
|
||||||
return uri
|
return uri
|
||||||
}
|
}
|
||||||
func TestProxy_VerifyWithMiddleware(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
opts := testOptions(t)
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
options config.Options
|
|
||||||
ctxError error
|
|
||||||
method string
|
|
||||||
qp string
|
|
||||||
path string
|
|
||||||
verifyURI string
|
|
||||||
|
|
||||||
cipher encoding.MarshalUnmarshaler
|
|
||||||
sessionStore sessions.SessionStore
|
|
||||||
authorizer clients.Authorizer
|
|
||||||
wantStatus int
|
|
||||||
wantBody string
|
|
||||||
}{
|
|
||||||
{"good", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
|
|
||||||
{"good verify only", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusOK, ""},
|
|
||||||
{"bad naked domain uri", opts, nil, http.MethodGet, "", "/", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
|
|
||||||
{"bad naked domain uri verify only", opts, nil, http.MethodGet, "", "/verify", "a.naked.domain", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
|
|
||||||
{"bad empty verification uri", opts, nil, http.MethodGet, "", "/", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
|
|
||||||
{"bad empty verification uri verify only", opts, nil, http.MethodGet, "", "/verify", " ", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, "{\"error\":\"bad verification uri\"}\n"},
|
|
||||||
{"not authorized", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"},
|
|
||||||
{"not authorized verify endpoint", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"user@test.example is not authorized for some.domain.example\"}\n"},
|
|
||||||
{"not authorized expired, redirect to auth", opts, sessions.ErrExpired, http.MethodGet, "", "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusFound, ""},
|
|
||||||
{"not authorized expired, don't redirect!", opts, sessions.ErrExpired, http.MethodGet, "", "/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
|
||||||
{"not authorized because of error", opts, nil, http.MethodGet, "", "/", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeError: errors.New("authz error")}, http.StatusInternalServerError, "{\"error\":\"authz error\"}\n"},
|
|
||||||
{"not authorized expired, do not redirect to auth", opts, nil, http.MethodGet, "", "/verify", "https://some.domain.example", &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(-10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: false}, http.StatusUnauthorized, "{\"error\":\"internal/sessions: validation failed, token is expired (exp)\"}\n"},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
p, err := New(tt.options)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
p.encoder = tt.cipher
|
|
||||||
p.sessionStore = tt.sessionStore
|
|
||||||
p.AuthorizeClient = tt.authorizer
|
|
||||||
p.UpdateOptions(tt.options)
|
|
||||||
uri := &url.URL{Path: tt.path}
|
|
||||||
queryString := uri.Query()
|
|
||||||
queryString.Set("donstrip", "ok")
|
|
||||||
if tt.qp != "" {
|
|
||||||
queryString.Set(tt.qp, "true")
|
|
||||||
}
|
|
||||||
if tt.verifyURI != "" {
|
|
||||||
queryString.Set("uri", tt.verifyURI)
|
|
||||||
}
|
|
||||||
|
|
||||||
uri.RawQuery = queryString.Encode()
|
|
||||||
|
|
||||||
r := httptest.NewRequest(tt.method, uri.String(), nil)
|
|
||||||
state, err := tt.sessionStore.LoadSession(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
ctx := r.Context()
|
|
||||||
ctx = sessions.NewContext(ctx, state, tt.ctxError)
|
|
||||||
r = r.WithContext(ctx)
|
|
||||||
r.Header.Set("Authorization", "Bearer blah")
|
|
||||||
r.Header.Set("Accept", "application/json")
|
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
|
||||||
router := p.registerFwdAuthHandlers()
|
|
||||||
router.ServeHTTP(w, r)
|
|
||||||
if status := w.Code; status != tt.wantStatus {
|
|
||||||
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
|
||||||
t.Errorf("\n%+v", w.Body.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
if tt.wantBody != "" {
|
|
||||||
body := w.Body.String()
|
|
||||||
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
|
|
||||||
t.Errorf("wrong body\n%s", diff)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
func TestProxy_Callback(t *testing.T) {
|
func TestProxy_Callback(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
opts := testOptions(t)
|
opts := testOptions(t)
|
||||||
|
@ -311,7 +236,8 @@ func TestProxy_Callback(t *testing.T) {
|
||||||
host string
|
host string
|
||||||
path string
|
path string
|
||||||
|
|
||||||
qp map[string]string
|
headers map[string]string
|
||||||
|
qp map[string]string
|
||||||
|
|
||||||
cipher encoding.MarshalUnmarshaler
|
cipher encoding.MarshalUnmarshaler
|
||||||
sessionStore sessions.SessionStore
|
sessionStore sessions.SessionStore
|
||||||
|
@ -319,11 +245,12 @@ func TestProxy_Callback(t *testing.T) {
|
||||||
wantStatus int
|
wantStatus int
|
||||||
wantBody string
|
wantBody string
|
||||||
}{
|
}{
|
||||||
{"good", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_programmatic_destination_url": "ok", "pomerium_jwt": "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
{"good", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||||
{"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"good programmatic", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||||
{"bad save session", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "KBEjQ9rnCxaAX-GOqetGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusInternalServerError, ""},
|
{"bad decrypt", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "KBEjQ9rnCxaAX-GOqexGw9ivEQURqts3zZ2mNGy0wnVa3SbtM399KlBq2nZ-9wM21FfsZX52er4jlmC7kPEKM3P7uZ41zR0zeys1-_74a5tQp-vsf1WXZfRsgVOuBcWPkMiWEoc379JFHxGDudp5VhU8B-dcQt4f3_PtLTHARkuH54io1Va2gNMq4Hiy8sQ1MPGCQeltH_JMzzdDpXdmdusWrXUvCGkba24muvAV06D8XRVJj6Iu9eK94qFnqcHc7wzziEbb8ADBues9dwbtb6jl8vMWz5rN6XvXqA5YpZv_MQZlsrO4oXFFQDevdgB84cX1tVbVu6qZvK_yQBZqzpOjWA9uIaoSENMytoXuWAlFO_sXjswfX8JTNdGwzB7qQRNPqxVG_sM_tzY3QhPm8zqwEzsXG5DokxZfVt2I5WJRUEovFDb4BnK9KFnnkEzLEdMudixVnXeGmTtycgJvoTeTCQRPfDYkcgJ7oKf4tGea-W7z5UAVa2RduJM9ZoM6YtJX7jgDm__PvvqcE0knJUF87XHBzdcOjoDF-CUze9xDJgNBlvPbJqVshKrwoqSYpePSDH9GUCNKxGequW3Ma8GvlFfhwd0rK6IZG-XWkyk0XSWQIGkDSjAvhB1wsOusCCguDjbpVZpaW5MMyTkmx68pl6qlIKT5UCcrVPl4ix5ZEj91mUDF0O1t04haD7VZuLVFXVGmqtFrBKI76sdYN-zkokaa1_chPRTyqMQFlqu_8LD6-RiK3UccGM-dEmnX72i91NP9F9OK0WJr9Cheup1C_P0mjqAO4Cb8oIHm0Oxz_mRqv5QbTGJtb3xwPLPuVjVCiE4gGBcuU2ixpSVf5HUF7y1KicVMCKiX9ATCBtg8sTdQZQnPEtHcHHAvdsnDVwev1LGfqA-Gdvg="}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
{"bad base64", opts, http.MethodGet, "http", "example.com", "/", map[string]string{"pomerium_jwt": "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"bad save session", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
{"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
{"bad base64", opts, http.MethodGet, "http", "example.com", "/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
|
{"malformed redirect", opts, http.MethodGet, "http", "example.com", "/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -345,13 +272,21 @@ func TestProxy_Callback(t *testing.T) {
|
||||||
uri := &url.URL{Path: "/"}
|
uri := &url.URL{Path: "/"}
|
||||||
if tt.qp != nil {
|
if tt.qp != nil {
|
||||||
qu := uri.Query()
|
qu := uri.Query()
|
||||||
qu.Set("redirect_uri", redirectURI.String())
|
for k, v := range tt.qp {
|
||||||
|
qu.Set(k, v)
|
||||||
|
}
|
||||||
|
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
|
||||||
uri.RawQuery = qu.Encode()
|
uri.RawQuery = qu.Encode()
|
||||||
}
|
}
|
||||||
|
|
||||||
r := httptest.NewRequest(tt.method, uri.String(), nil)
|
r := httptest.NewRequest(tt.method, uri.String(), nil)
|
||||||
|
|
||||||
r.Header.Set("Accept", "application/json")
|
r.Header.Set("Accept", "application/json")
|
||||||
|
if len(tt.headers) != 0 {
|
||||||
|
for k, v := range tt.headers {
|
||||||
|
r.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
w := httptest.NewRecorder()
|
w := httptest.NewRecorder()
|
||||||
p.Callback(w, r)
|
p.Callback(w, r)
|
||||||
|
@ -379,18 +314,20 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
|
||||||
|
|
||||||
method string
|
method string
|
||||||
|
|
||||||
scheme string
|
scheme string
|
||||||
host string
|
host string
|
||||||
path string
|
path string
|
||||||
qp map[string]string
|
headers map[string]string
|
||||||
|
qp map[string]string
|
||||||
|
|
||||||
wantStatus int
|
wantStatus int
|
||||||
wantBody string
|
wantBody string
|
||||||
}{
|
}{
|
||||||
{"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "http://localhost"}, http.StatusOK, ""},
|
{"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusOK, ""},
|
||||||
{"router miss, bad redirect_uri query", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"bad_redirect_uri": "http://localhost"}, http.StatusNotFound, ""},
|
{"good body not checked", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusOK, ""},
|
||||||
{"bad redirect_uri missing scheme", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "localhost"}, http.StatusBadRequest, "{\"error\":\"malformed redirect_uri\"}\n"},
|
{"router miss, bad redirect_uri query", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{"bad_redirect_uri": "http://localhost"}, http.StatusNotFound, ""},
|
||||||
{"bad http method", opts, http.MethodPost, "https", "corp.example.example", "/.pomerium/api/v1/login", map[string]string{"redirect_uri": "http://localhost"}, http.StatusMethodNotAllowed, ""},
|
{"bad redirect_uri missing scheme", opts, http.MethodGet, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "localhost"}, http.StatusBadRequest, "{\"error\":\"malformed redirect uri\"}\n"},
|
||||||
|
{"bad http method", opts, http.MethodPost, "https", "corp.example.example", "/.pomerium/api/v1/login", nil, map[string]string{urlutil.QueryRedirectURI: "http://localhost"}, http.StatusMethodNotAllowed, ""},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
@ -428,3 +365,83 @@ func TestProxy_ProgrammaticLogin(t *testing.T) {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProxy_ProgrammaticCallback(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
opts := testOptions(t)
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
options config.Options
|
||||||
|
|
||||||
|
method string
|
||||||
|
|
||||||
|
redirectURI string
|
||||||
|
|
||||||
|
headers map[string]string
|
||||||
|
qp map[string]string
|
||||||
|
|
||||||
|
cipher encoding.MarshalUnmarshaler
|
||||||
|
sessionStore sessions.SessionStore
|
||||||
|
authorizer clients.Authorizer
|
||||||
|
wantStatus int
|
||||||
|
wantBody string
|
||||||
|
}{
|
||||||
|
{"good", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||||
|
{"good programmatic", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QueryIsProgrammatic: "true", urlutil.QueryCallbackURI: "ok", urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusFound, ""},
|
||||||
|
{"bad decrypt", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString + cryptutil.NewBase64Key()}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
|
{"bad save session", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: goodEncryptionString}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{SaveError: errors.New("hi")}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
|
{"bad base64", opts, http.MethodGet, "http://pomerium.io/", nil, map[string]string{urlutil.QuerySessionEncrypted: "^"}, &mock.Encoder{MarshalResponse: []byte("x")}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
|
{"malformed redirect", opts, http.MethodGet, "http://pomerium.io/", nil, nil, &mock.Encoder{}, &sessions.MockSessionStore{Session: &sessions.State{Email: "user@test.example", Expiry: jwt.NewNumericDate(time.Now().Add(10 * time.Minute))}}, clients.MockAuthorize{AuthorizeResponse: true}, http.StatusBadRequest, ""},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
p, err := New(tt.options)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
p.encoder = tt.cipher
|
||||||
|
p.sessionStore = tt.sessionStore
|
||||||
|
p.AuthorizeClient = tt.authorizer
|
||||||
|
p.UpdateOptions(tt.options)
|
||||||
|
redirectURI, _ := url.Parse(tt.redirectURI)
|
||||||
|
queryString := redirectURI.Query()
|
||||||
|
for k, v := range tt.qp {
|
||||||
|
queryString.Set(k, v)
|
||||||
|
}
|
||||||
|
redirectURI.RawQuery = queryString.Encode()
|
||||||
|
|
||||||
|
uri := &url.URL{Path: "/"}
|
||||||
|
if tt.qp != nil {
|
||||||
|
qu := uri.Query()
|
||||||
|
for k, v := range tt.qp {
|
||||||
|
qu.Set(k, v)
|
||||||
|
}
|
||||||
|
qu.Set(urlutil.QueryRedirectURI, redirectURI.String())
|
||||||
|
uri.RawQuery = qu.Encode()
|
||||||
|
}
|
||||||
|
|
||||||
|
r := httptest.NewRequest(tt.method, uri.String(), nil)
|
||||||
|
|
||||||
|
r.Header.Set("Accept", "application/json")
|
||||||
|
if len(tt.headers) != 0 {
|
||||||
|
for k, v := range tt.headers {
|
||||||
|
r.Header.Set(k, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
p.ProgrammaticCallback(w, r)
|
||||||
|
if status := w.Code; status != tt.wantStatus {
|
||||||
|
t.Errorf("status code: got %v want %v", status, tt.wantStatus)
|
||||||
|
t.Errorf("\n%+v", w.Body.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantBody != "" {
|
||||||
|
body := w.Body.String()
|
||||||
|
if diff := cmp.Diff(body, tt.wantBody); diff != "" {
|
||||||
|
t.Errorf("wrong body\n%s", diff)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -30,39 +30,37 @@ func (p *Proxy) AuthenticateSession(next http.Handler) http.Handler {
|
||||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession")
|
ctx, span := trace.StartSpan(r.Context(), "proxy.AuthenticateSession")
|
||||||
defer span.End()
|
defer span.End()
|
||||||
if err := p.authenticate(false, w, r.WithContext(ctx)); err != nil {
|
|
||||||
|
if s, err := sessions.FromContext(ctx); err != nil {
|
||||||
log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session")
|
log.FromRequest(r).Debug().Err(err).Msg("proxy: authenticate session")
|
||||||
|
p.sessionStore.ClearSession(w, r)
|
||||||
|
if s != nil && s.Programmatic {
|
||||||
|
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
signinURL := *p.authenticateSigninURL
|
||||||
|
q := signinURL.Query()
|
||||||
|
q.Set(urlutil.QueryRedirectURI, urlutil.GetAbsoluteURL(r).String())
|
||||||
|
signinURL.RawQuery = q.Encode()
|
||||||
|
httputil.Redirect(w, r, urlutil.NewSignedURL(p.SharedKey, &signinURL).String(), http.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
p.addPomeriumHeaders(w, r)
|
||||||
next.ServeHTTP(w, r.WithContext(ctx))
|
next.ServeHTTP(w, r.WithContext(ctx))
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// authenticate authenticates a user and sets an appropriate response type,
|
func (p *Proxy) addPomeriumHeaders(w http.ResponseWriter, r *http.Request) {
|
||||||
// redirect to authenticate or error handler depending on if err on failure is set.
|
|
||||||
func (p *Proxy) authenticate(errOnFailure bool, w http.ResponseWriter, r *http.Request) error {
|
|
||||||
s, err := sessions.FromContext(r.Context())
|
s, err := sessions.FromContext(r.Context())
|
||||||
if err != nil {
|
if err == nil && s != nil {
|
||||||
p.sessionStore.ClearSession(w, r)
|
r.Header.Set(HeaderUserID, s.Subject)
|
||||||
|
r.Header.Set(HeaderEmail, s.RequestEmail())
|
||||||
if errOnFailure || (s != nil && s.Programmatic) {
|
r.Header.Set(HeaderGroups, s.RequestGroups())
|
||||||
httputil.ErrorResponse(w, r, httputil.Error(err.Error(), http.StatusUnauthorized, err))
|
w.Header().Set(HeaderUserID, s.Subject)
|
||||||
return err
|
w.Header().Set(HeaderEmail, s.RequestEmail())
|
||||||
}
|
w.Header().Set(HeaderGroups, s.RequestGroups())
|
||||||
uri := urlutil.SignedRedirectURL(p.SharedKey, p.authenticateSigninURL, urlutil.GetAbsoluteURL(r))
|
|
||||||
httputil.Redirect(w, r, uri.String(), http.StatusFound)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
// add pomerium's headers to the downstream request
|
|
||||||
r.Header.Set(HeaderUserID, s.Subject)
|
|
||||||
r.Header.Set(HeaderEmail, s.RequestEmail())
|
|
||||||
r.Header.Set(HeaderGroups, s.RequestGroups())
|
|
||||||
// and upstream
|
|
||||||
w.Header().Set(HeaderUserID, s.Subject)
|
|
||||||
w.Header().Set(HeaderEmail, s.RequestEmail())
|
|
||||||
w.Header().Set(HeaderGroups, s.RequestGroups())
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// AuthorizeSession is middleware to enforce a user is authorized for a request
|
// AuthorizeSession is middleware to enforce a user is authorized for a request
|
||||||
|
|
|
@ -20,7 +20,6 @@ func TestProxy_AuthenticateSession(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
||||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
@ -70,7 +69,6 @@ func TestProxy_AuthorizeSession(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
||||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
@ -131,7 +129,6 @@ func TestProxy_SignRequest(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
||||||
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
fmt.Fprint(w, http.StatusText(http.StatusOK))
|
||||||
w.WriteHeader(http.StatusOK)
|
w.WriteHeader(http.StatusOK)
|
||||||
})
|
})
|
||||||
|
@ -184,7 +181,6 @@ func TestProxy_SetResponseHeaders(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
fn := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
|
||||||
var sb strings.Builder
|
var sb strings.Builder
|
||||||
for k, v := range r.Header {
|
for k, v := range r.Header {
|
||||||
k = strings.ToLower(k)
|
k = strings.ToLower(k)
|
||||||
|
|
|
@ -187,7 +187,7 @@ func (p *Proxy) UpdatePolicies(opts *config.Options) error {
|
||||||
|
|
||||||
if opts.ForwardAuthURL != nil {
|
if opts.ForwardAuthURL != nil {
|
||||||
// if a forward auth endpoint is set, register its handlers
|
// if a forward auth endpoint is set, register its handlers
|
||||||
h := r.Host(opts.ForwardAuthURL.Host).Subrouter()
|
h := r.Host(opts.ForwardAuthURL.Hostname()).Subrouter()
|
||||||
h.PathPrefix("/").Handler(p.registerFwdAuthHandlers())
|
h.PathPrefix("/").Handler(p.registerFwdAuthHandlers())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -82,7 +82,9 @@ def main():
|
||||||
# initial login to make sure we have our credential
|
# initial login to make sure we have our credential
|
||||||
if args.login:
|
if args.login:
|
||||||
dst = urllib.parse.urlparse(args.dst)
|
dst = urllib.parse.urlparse(args.dst)
|
||||||
query_params = {"redirect_uri": "http://{}:{}".format(args.server, args.port)}
|
query_params = {
|
||||||
|
"pomerium_redirect_uri": "http://{}:{}".format(args.server, args.port)
|
||||||
|
}
|
||||||
enc_query_params = urllib.parse.urlencode(query_params)
|
enc_query_params = urllib.parse.urlencode(query_params)
|
||||||
dst_login = "{}://{}{}?{}".format(
|
dst_login = "{}://{}{}?{}".format(
|
||||||
dst.scheme, dst.hostname, "/.pomerium/api/v1/login", enc_query_params,
|
dst.scheme, dst.hostname, "/.pomerium/api/v1/login", enc_query_params,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue