diff --git a/authenticate/handlers.go b/authenticate/handlers.go index a370279e4..e3c3a838a 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -212,7 +212,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { return httputil.NewError(http.StatusBadRequest, err) } - redirectTo, err := handlers.BuildCallbackURL(state.hpkePrivateKey, proxyPublicKey, requestParams, profile) + redirectTo, err := urlutil.CallbackURL(state.hpkePrivateKey, proxyPublicKey, requestParams, profile) if err != nil { return httputil.NewError(http.StatusInternalServerError, err) } diff --git a/authorize/check_response.go b/authorize/check_response.go index 89f898f23..4125921c1 100644 --- a/authorize/check_response.go +++ b/authorize/check_response.go @@ -18,7 +18,6 @@ import ( "google.golang.org/grpc/codes" "github.com/pomerium/pomerium/authorize/evaluator" - "github.com/pomerium/pomerium/internal/handlers" "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/telemetry/requestid" @@ -201,7 +200,7 @@ func (a *Authorize) requireLoginResponse( checkRequestURL := getCheckRequestURL(in) checkRequestURL.Scheme = "https" - redirectTo, err := handlers.BuildSignInURL( + redirectTo, err := urlutil.SignInURL( state.hpkePrivateKey, authenticateHPKEPublicKey, authenticateURL, diff --git a/internal/handlers/sign_in.go b/internal/handlers/sign_in.go deleted file mode 100644 index 077d63452..000000000 --- a/internal/handlers/sign_in.go +++ /dev/null @@ -1,89 +0,0 @@ -package handlers - -import ( - "fmt" - "net/url" - "time" - - "google.golang.org/protobuf/encoding/protojson" - - "github.com/pomerium/pomerium/internal/urlutil" - "github.com/pomerium/pomerium/pkg/grpc/identity" - "github.com/pomerium/pomerium/pkg/hpke" -) - -const signInExpiry = time.Minute * 5 - -// BuildCallbackURL builds the callback URL using an HPKE encrypted query string. -func BuildCallbackURL( - authenticatePrivateKey *hpke.PrivateKey, - proxyPublicKey *hpke.PublicKey, - requestParams url.Values, - profile *identity.Profile, -) (string, error) { - redirectURL, err := urlutil.ParseAndValidateURL(requestParams.Get(urlutil.QueryRedirectURI)) - if err != nil { - return "", fmt.Errorf("invalid %s: %w", urlutil.QueryRedirectURI, err) - } - - var callbackURL *url.URL - if requestParams.Has(urlutil.QueryCallbackURI) { - callbackURL, err = urlutil.ParseAndValidateURL(requestParams.Get(urlutil.QueryCallbackURI)) - if err != nil { - return "", fmt.Errorf("invalid %s: %w", urlutil.QueryCallbackURI, err) - } - } else { - callbackURL, err = urlutil.DeepCopy(redirectURL) - if err != nil { - return "", fmt.Errorf("error copying %s: %w", urlutil.QueryRedirectURI, err) - } - callbackURL.Path = "/.pomerium/callback/" - callbackURL.RawQuery = "" - } - - callbackParams := callbackURL.Query() - if requestParams.Has(urlutil.QueryIsProgrammatic) { - callbackParams.Set(urlutil.QueryIsProgrammatic, "true") - } - callbackParams.Set(urlutil.QueryRedirectURI, redirectURL.String()) - - rawProfile, err := protojson.Marshal(profile) - if err != nil { - return "", fmt.Errorf("error marshaling identity profile: %w", err) - } - callbackParams.Set(urlutil.QueryIdentityProfile, string(rawProfile)) - - urlutil.BuildTimeParameters(callbackParams, signInExpiry) - - callbackParams, err = hpke.EncryptURLValues(authenticatePrivateKey, proxyPublicKey, callbackParams) - if err != nil { - return "", fmt.Errorf("error encrypting callback params: %w", err) - } - callbackURL.RawQuery = callbackParams.Encode() - - return callbackURL.String(), nil -} - -// BuildSignInURL buidls the sign in URL using an HPKE encrypted query string. -func BuildSignInURL( - senderPrivateKey *hpke.PrivateKey, - authenticatePublicKey *hpke.PublicKey, - authenticateURL *url.URL, - redirectURL *url.URL, - idpID string, -) (string, error) { - signInURL := *authenticateURL - signInURL.Path = "/.pomerium/sign_in" - - q := signInURL.Query() - q.Set(urlutil.QueryRedirectURI, redirectURL.String()) - q.Set(urlutil.QueryIdentityProviderID, idpID) - urlutil.BuildTimeParameters(q, signInExpiry) - q, err := hpke.EncryptURLValues(senderPrivateKey, authenticatePublicKey, q) - if err != nil { - return "", err - } - signInURL.RawQuery = q.Encode() - - return signInURL.String(), nil -} diff --git a/internal/urlutil/known.go b/internal/urlutil/known.go index 6ef226cb9..79bfecaae 100644 --- a/internal/urlutil/known.go +++ b/internal/urlutil/known.go @@ -1,13 +1,74 @@ package urlutil import ( + "fmt" "net/http" "net/url" + "time" + + "google.golang.org/protobuf/encoding/protojson" + + "github.com/pomerium/pomerium/internal/version" + "github.com/pomerium/pomerium/pkg/grpc/identity" + "github.com/pomerium/pomerium/pkg/hpke" ) // DefaultDeviceType is the default device type when none is specified. const DefaultDeviceType = "any" +const signInExpiry = time.Minute * 5 + +// CallbackURL builds the callback URL using an HPKE encrypted query string. +func CallbackURL( + authenticatePrivateKey *hpke.PrivateKey, + proxyPublicKey *hpke.PublicKey, + requestParams url.Values, + profile *identity.Profile, +) (string, error) { + redirectURL, err := ParseAndValidateURL(requestParams.Get(QueryRedirectURI)) + if err != nil { + return "", fmt.Errorf("invalid %s: %w", QueryRedirectURI, err) + } + + var callbackURL *url.URL + if requestParams.Has(QueryCallbackURI) { + callbackURL, err = ParseAndValidateURL(requestParams.Get(QueryCallbackURI)) + if err != nil { + return "", fmt.Errorf("invalid %s: %w", QueryCallbackURI, err) + } + } else { + callbackURL, err = DeepCopy(redirectURL) + if err != nil { + return "", fmt.Errorf("error copying %s: %w", QueryRedirectURI, err) + } + callbackURL.Path = "/.pomerium/callback/" + callbackURL.RawQuery = "" + } + + callbackParams := callbackURL.Query() + if requestParams.Has(QueryIsProgrammatic) { + callbackParams.Set(QueryIsProgrammatic, "true") + } + callbackParams.Set(QueryRedirectURI, redirectURL.String()) + + rawProfile, err := protojson.Marshal(profile) + if err != nil { + return "", fmt.Errorf("error marshaling identity profile: %w", err) + } + callbackParams.Set(QueryIdentityProfile, string(rawProfile)) + callbackParams.Set(QueryVersion, version.FullVersion()) + + BuildTimeParameters(callbackParams, signInExpiry) + + callbackParams, err = hpke.EncryptURLValues(authenticatePrivateKey, proxyPublicKey, callbackParams) + if err != nil { + return "", fmt.Errorf("error encrypting callback params: %w", err) + } + callbackURL.RawQuery = callbackParams.Encode() + + return callbackURL.String(), nil +} + // RedirectURL returns the redirect URL from the query string or a cookie. func RedirectURL(r *http.Request) (string, bool) { if v := r.FormValue(QueryRedirectURI); v != "" { @@ -21,16 +82,42 @@ func RedirectURL(r *http.Request) (string, bool) { return "", false } +// SignInURL builds the sign in URL using an HPKE encrypted query string. +func SignInURL( + senderPrivateKey *hpke.PrivateKey, + authenticatePublicKey *hpke.PublicKey, + authenticateURL *url.URL, + redirectURL *url.URL, + idpID string, +) (string, error) { + signInURL := *authenticateURL + signInURL.Path = "/.pomerium/sign_in" + + q := signInURL.Query() + q.Set(QueryRedirectURI, redirectURL.String()) + q.Set(QueryIdentityProviderID, idpID) + q.Set(QueryVersion, version.FullVersion()) + BuildTimeParameters(q, signInExpiry) + q, err := hpke.EncryptURLValues(senderPrivateKey, authenticatePublicKey, q) + if err != nil { + return "", err + } + signInURL.RawQuery = q.Encode() + + return signInURL.String(), nil +} + // SignOutURL returns the /.pomerium/sign_out URL. func SignOutURL(r *http.Request, authenticateURL *url.URL, key []byte) string { u := authenticateURL.ResolveReference(&url.URL{ Path: "/.pomerium/sign_out", }) + q := u.Query() if redirectURI, ok := RedirectURL(r); ok { - u.RawQuery = (&url.Values{ - QueryRedirectURI: {redirectURI}, - }).Encode() + q.Set(QueryRedirectURI, redirectURI) } + q.Set(QueryVersion, version.FullVersion()) + u.RawQuery = q.Encode() return NewSignedURL(key, u).Sign().String() } diff --git a/internal/urlutil/known_test.go b/internal/urlutil/known_test.go index df25b9f40..0e5f9d574 100644 --- a/internal/urlutil/known_test.go +++ b/internal/urlutil/known_test.go @@ -2,14 +2,43 @@ package urlutil import ( "net/http" + "net/http/httptest" "net/url" "strings" "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + + "github.com/pomerium/pomerium/pkg/grpc/identity" + "github.com/pomerium/pomerium/pkg/hpke" ) +func TestCallbackURL(t *testing.T) { + t.Parallel() + k1 := hpke.DerivePrivateKey([]byte("sender")) + k2 := hpke.DerivePrivateKey([]byte("receiver")) + + rawSignInURL, err := CallbackURL(k1, k2.PublicKey(), url.Values{ + QueryRedirectURI: {"https://redirect.example.com"}, + }, &identity.Profile{ + ProviderId: "IDP-1", + }) + require.NoError(t, err) + + signInURL, err := ParseAndValidateURL(rawSignInURL) + require.NoError(t, err) + + k3, q, err := hpke.DecryptURLValues(k2, signInURL.Query()) + require.NoError(t, err) + assert.Equal(t, k1.PublicKey(), k3) + assert.NotEmpty(t, q.Get(QueryExpiry)) + assert.NotEmpty(t, q.Get(QueryIssued)) + assert.NotEmpty(t, q.Get(QueryVersion)) + assert.Equal(t, "https://redirect.example.com", q.Get(QueryRedirectURI)) + assert.JSONEq(t, `{ "providerId": "IDP-1" }`, q.Get(QueryIdentityProfile)) +} + func TestRedirectURI(t *testing.T) { t.Run("query", func(t *testing.T) { r, err := http.NewRequest("GET", "https://www.example.com?"+(url.Values{ @@ -45,3 +74,47 @@ func TestRedirectURI(t *testing.T) { assert.Equal(t, "https://www.example.com/redirect", redirectURI) }) } + +func TestSignInURL(t *testing.T) { + t.Parallel() + k1 := hpke.DerivePrivateKey([]byte("sender")) + k2 := hpke.DerivePrivateKey([]byte("receiver")) + + authenticateURL := MustParseAndValidateURL("https://authenticate.example.com") + redirectURL := MustParseAndValidateURL("https://redirect.example.com") + + rawSignInURL, err := SignInURL(k1, k2.PublicKey(), &authenticateURL, &redirectURL, "IDP-1") + require.NoError(t, err) + + signInURL, err := ParseAndValidateURL(rawSignInURL) + require.NoError(t, err) + + k3, q, err := hpke.DecryptURLValues(k2, signInURL.Query()) + require.NoError(t, err) + assert.Equal(t, k1.PublicKey(), k3) + assert.NotEmpty(t, q.Get(QueryExpiry)) + assert.NotEmpty(t, q.Get(QueryIssued)) + assert.NotEmpty(t, q.Get(QueryVersion)) + assert.Equal(t, "https://redirect.example.com", q.Get(QueryRedirectURI)) + assert.Equal(t, "IDP-1", q.Get(QueryIdentityProviderID)) +} + +func TestSignOutURL(t *testing.T) { + t.Parallel() + + r := httptest.NewRequest("GET", "https://route.example.com?"+(url.Values{ + QueryRedirectURI: {"https://www.example.com/redirect"}, + }).Encode(), nil) + authenticateURL := MustParseAndValidateURL("https://authenticate.example.com") + + rawSignOutURL := SignOutURL(r, &authenticateURL, []byte("TEST")) + signOutURL, err := ParseAndValidateURL(rawSignOutURL) + require.NoError(t, err) + + q := signOutURL.Query() + assert.NotEmpty(t, q.Get(QueryExpiry)) + assert.NotEmpty(t, q.Get(QueryIssued)) + assert.NotEmpty(t, q.Get(QueryVersion)) + assert.NotEmpty(t, q.Get(QueryHmacSignature)) + assert.Equal(t, "https://www.example.com/redirect", q.Get(QueryRedirectURI)) +} diff --git a/internal/urlutil/query_params.go b/internal/urlutil/query_params.go index 9992dee80..15a0c850a 100644 --- a/internal/urlutil/query_params.go +++ b/internal/urlutil/query_params.go @@ -18,6 +18,7 @@ const ( QuerySession = "pomerium_session" QuerySessionEncrypted = "pomerium_session_encrypted" QuerySessionState = "pomerium_session_state" + QueryVersion = "pomerium_version" ) // URL signature based query params used for verifying the authenticity of a URL. diff --git a/proxy/handlers.go b/proxy/handlers.go index c05f3db98..132114b06 100644 --- a/proxy/handlers.go +++ b/proxy/handlers.go @@ -228,7 +228,7 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error q.Set(urlutil.QueryIsProgrammatic, "true") signinURL.RawQuery = q.Encode() - rawURL, err := handlers.BuildSignInURL(state.hpkePrivateKey, hpkeAuthenticateKey, &signinURL, redirectURI, idp.GetId()) + rawURL, err := urlutil.SignInURL(state.hpkePrivateKey, hpkeAuthenticateKey, &signinURL, redirectURI, idp.GetId()) if err != nil { return httputil.NewError(http.StatusInternalServerError, err) }