From 2f328e7de022c88ae4220a7d553884e83629e442 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Thu, 27 Jan 2022 16:10:47 -0700 Subject: [PATCH] authenticate: fix expiring user info endpoint (#2976) * authenticate: fix expiring user info endpoint * add test --- authenticate/authenticate.go | 20 ------------- authenticate/handlers.go | 30 ++++++++----------- authenticate/url.go | 57 ++++++++++++++++++++++++++++++++++++ authenticate/url_test.go | 52 ++++++++++++++++++++++++++++++++ 4 files changed, 122 insertions(+), 37 deletions(-) create mode 100644 authenticate/url.go create mode 100644 authenticate/url_test.go diff --git a/authenticate/authenticate.go b/authenticate/authenticate.go index 5183afca4..640d3094f 100644 --- a/authenticate/authenticate.go +++ b/authenticate/authenticate.go @@ -16,7 +16,6 @@ import ( "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/urlutil" "github.com/pomerium/pomerium/pkg/cryptutil" - "github.com/pomerium/pomerium/pkg/webauthnutil" ) // ValidateOptions checks that configuration are complete and valid. @@ -125,25 +124,6 @@ func (a *Authenticate) updateProvider(cfg *config.Config) error { return nil } -func (a *Authenticate) getWebAuthnURL(values url.Values) (*url.URL, error) { - uri, err := a.options.Load().GetAuthenticateURL() - if err != nil { - return nil, err - } - - uri = uri.ResolveReference(&url.URL{ - Path: "/.pomerium/webauthn", - RawQuery: buildURLValues(values, url.Values{ - urlutil.QueryDeviceType: {webauthnutil.DefaultDeviceType}, - urlutil.QueryEnrollmentToken: nil, - urlutil.QueryRedirectURI: {uri.ResolveReference(&url.URL{ - Path: "/.pomerium/device-enrolled", - }).String()}, - }).Encode(), - }) - return urlutil.NewSignedURL(a.state.Load().sharedKey, uri).Sign(), nil -} - // buildURLValues creates a new url.Values map by traversing the keys in `defaults` and using the values // from `values` if they exist, otherwise the provided defaults func buildURLValues(values, defaults url.Values) url.Values { diff --git a/authenticate/handlers.go b/authenticate/handlers.go index ce116b470..25a51591c 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -448,6 +448,19 @@ func (a *Authenticate) userInfo(w http.ResponseWriter, r *http.Request) error { ctx, span := trace.StartSpan(r.Context(), "authenticate.userInfo") defer span.End() + // if we came in with a redirect URI, save it to a cookie so it doesn't expire with the HMAC + if redirectURI := r.FormValue(urlutil.QueryRedirectURI); redirectURI != "" { + u := urlutil.GetAbsoluteURL(r) + u.RawQuery = "" + + http.SetCookie(w, &http.Cookie{ + Name: urlutil.QueryRedirectURI, + Value: redirectURI, + }) + http.Redirect(w, r, u.String(), http.StatusFound) + return nil + } + state := a.state.Load() s, err := a.getSessionFromCtx(ctx) @@ -626,23 +639,6 @@ func (a *Authenticate) revokeSession(ctx context.Context, w http.ResponseWriter, return rawIDToken } -func (a *Authenticate) getSignOutURL(r *http.Request) (*url.URL, error) { - uri, err := a.options.Load().GetAuthenticateURL() - if err != nil { - return nil, err - } - - uri = uri.ResolveReference(&url.URL{ - Path: "/.pomerium/sign_out", - }) - if redirectURI := r.FormValue(urlutil.QueryRedirectURI); redirectURI != "" { - uri.RawQuery = (&url.Values{ - urlutil.QueryRedirectURI: {redirectURI}, - }).Encode() - } - return urlutil.NewSignedURL(a.state.Load().sharedKey, uri).Sign(), nil -} - func (a *Authenticate) getCurrentSession(ctx context.Context) (s *session.Session, isImpersonated bool, err error) { client := a.state.Load().dataBrokerClient diff --git a/authenticate/url.go b/authenticate/url.go new file mode 100644 index 000000000..6f6b7b9b9 --- /dev/null +++ b/authenticate/url.go @@ -0,0 +1,57 @@ +package authenticate + +import ( + "net/http" + "net/url" + + "github.com/pomerium/pomerium/internal/urlutil" + "github.com/pomerium/pomerium/pkg/webauthnutil" +) + +func (a *Authenticate) getRedirectURI(r *http.Request) (string, bool) { + if v := r.FormValue(urlutil.QueryRedirectURI); v != "" { + return v, true + } + + if c, err := r.Cookie(urlutil.QueryRedirectURI); err == nil { + return c.Value, true + } + + return "", false +} + +func (a *Authenticate) getSignOutURL(r *http.Request) (*url.URL, error) { + uri, err := a.options.Load().GetAuthenticateURL() + if err != nil { + return nil, err + } + + uri = uri.ResolveReference(&url.URL{ + Path: "/.pomerium/sign_out", + }) + if redirectURI, ok := a.getRedirectURI(r); ok { + uri.RawQuery = (&url.Values{ + urlutil.QueryRedirectURI: {redirectURI}, + }).Encode() + } + return urlutil.NewSignedURL(a.state.Load().sharedKey, uri).Sign(), nil +} + +func (a *Authenticate) getWebAuthnURL(values url.Values) (*url.URL, error) { + uri, err := a.options.Load().GetAuthenticateURL() + if err != nil { + return nil, err + } + + uri = uri.ResolveReference(&url.URL{ + Path: "/.pomerium/webauthn", + RawQuery: buildURLValues(values, url.Values{ + urlutil.QueryDeviceType: {webauthnutil.DefaultDeviceType}, + urlutil.QueryEnrollmentToken: nil, + urlutil.QueryRedirectURI: {uri.ResolveReference(&url.URL{ + Path: "/.pomerium/device-enrolled", + }).String()}, + }).Encode(), + }) + return urlutil.NewSignedURL(a.state.Load().sharedKey, uri).Sign(), nil +} diff --git a/authenticate/url_test.go b/authenticate/url_test.go new file mode 100644 index 000000000..ff2e5dd06 --- /dev/null +++ b/authenticate/url_test.go @@ -0,0 +1,52 @@ +package authenticate + +import ( + "net/http" + "net/url" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/internal/urlutil" +) + +func TestAuthenticate_getRedirectURI(t *testing.T) { + t.Run("query", func(t *testing.T) { + r, err := http.NewRequest("GET", "https://www.example.com?"+(url.Values{ + urlutil.QueryRedirectURI: {"https://www.example.com/redirect"}, + }).Encode(), nil) + require.NoError(t, err) + + a := new(Authenticate) + redirectURI, ok := a.getRedirectURI(r) + assert.True(t, ok) + assert.Equal(t, "https://www.example.com/redirect", redirectURI) + }) + t.Run("form", func(t *testing.T) { + r, err := http.NewRequest("POST", "https://www.example.com", strings.NewReader((url.Values{ + urlutil.QueryRedirectURI: {"https://www.example.com/redirect"}, + }).Encode())) + require.NoError(t, err) + r.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + a := new(Authenticate) + redirectURI, ok := a.getRedirectURI(r) + assert.True(t, ok) + assert.Equal(t, "https://www.example.com/redirect", redirectURI) + }) + t.Run("cookie", func(t *testing.T) { + r, err := http.NewRequest("GET", "https://www.example.com", nil) + require.NoError(t, err) + r.AddCookie(&http.Cookie{ + Name: urlutil.QueryRedirectURI, + Value: "https://www.example.com/redirect", + }) + + a := new(Authenticate) + redirectURI, ok := a.getRedirectURI(r) + assert.True(t, ok) + assert.Equal(t, "https://www.example.com/redirect", redirectURI) + }) +}