urlutil: add version to query string

This commit is contained in:
Caleb Doxsey 2023-02-28 11:07:29 -07:00
parent 78a221cdbf
commit 076a5d2123
7 changed files with 167 additions and 96 deletions

View file

@ -212,7 +212,7 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error {
return httputil.NewError(http.StatusBadRequest, err) return httputil.NewError(http.StatusBadRequest, err)
} }
redirectTo, err := handlers.BuildCallbackURL(state.hpkePrivateKey, proxyPublicKey, requestParams, profile) redirectTo, err := urlutil.CallbackURL(state.hpkePrivateKey, proxyPublicKey, requestParams, profile)
if err != nil { if err != nil {
return httputil.NewError(http.StatusInternalServerError, err) return httputil.NewError(http.StatusInternalServerError, err)
} }

View file

@ -18,7 +18,6 @@ import (
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"github.com/pomerium/pomerium/authorize/evaluator" "github.com/pomerium/pomerium/authorize/evaluator"
"github.com/pomerium/pomerium/internal/handlers"
"github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/log" "github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/requestid" "github.com/pomerium/pomerium/internal/telemetry/requestid"
@ -201,7 +200,7 @@ func (a *Authorize) requireLoginResponse(
checkRequestURL := getCheckRequestURL(in) checkRequestURL := getCheckRequestURL(in)
checkRequestURL.Scheme = "https" checkRequestURL.Scheme = "https"
redirectTo, err := handlers.BuildSignInURL( redirectTo, err := urlutil.SignInURL(
state.hpkePrivateKey, state.hpkePrivateKey,
authenticateHPKEPublicKey, authenticateHPKEPublicKey,
authenticateURL, authenticateURL,

View file

@ -1,89 +0,0 @@
package handlers
import (
"fmt"
"net/url"
"time"
"google.golang.org/protobuf/encoding/protojson"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/hpke"
)
const signInExpiry = time.Minute * 5
// BuildCallbackURL builds the callback URL using an HPKE encrypted query string.
func BuildCallbackURL(
authenticatePrivateKey *hpke.PrivateKey,
proxyPublicKey *hpke.PublicKey,
requestParams url.Values,
profile *identity.Profile,
) (string, error) {
redirectURL, err := urlutil.ParseAndValidateURL(requestParams.Get(urlutil.QueryRedirectURI))
if err != nil {
return "", fmt.Errorf("invalid %s: %w", urlutil.QueryRedirectURI, err)
}
var callbackURL *url.URL
if requestParams.Has(urlutil.QueryCallbackURI) {
callbackURL, err = urlutil.ParseAndValidateURL(requestParams.Get(urlutil.QueryCallbackURI))
if err != nil {
return "", fmt.Errorf("invalid %s: %w", urlutil.QueryCallbackURI, err)
}
} else {
callbackURL, err = urlutil.DeepCopy(redirectURL)
if err != nil {
return "", fmt.Errorf("error copying %s: %w", urlutil.QueryRedirectURI, err)
}
callbackURL.Path = "/.pomerium/callback/"
callbackURL.RawQuery = ""
}
callbackParams := callbackURL.Query()
if requestParams.Has(urlutil.QueryIsProgrammatic) {
callbackParams.Set(urlutil.QueryIsProgrammatic, "true")
}
callbackParams.Set(urlutil.QueryRedirectURI, redirectURL.String())
rawProfile, err := protojson.Marshal(profile)
if err != nil {
return "", fmt.Errorf("error marshaling identity profile: %w", err)
}
callbackParams.Set(urlutil.QueryIdentityProfile, string(rawProfile))
urlutil.BuildTimeParameters(callbackParams, signInExpiry)
callbackParams, err = hpke.EncryptURLValues(authenticatePrivateKey, proxyPublicKey, callbackParams)
if err != nil {
return "", fmt.Errorf("error encrypting callback params: %w", err)
}
callbackURL.RawQuery = callbackParams.Encode()
return callbackURL.String(), nil
}
// BuildSignInURL buidls the sign in URL using an HPKE encrypted query string.
func BuildSignInURL(
senderPrivateKey *hpke.PrivateKey,
authenticatePublicKey *hpke.PublicKey,
authenticateURL *url.URL,
redirectURL *url.URL,
idpID string,
) (string, error) {
signInURL := *authenticateURL
signInURL.Path = "/.pomerium/sign_in"
q := signInURL.Query()
q.Set(urlutil.QueryRedirectURI, redirectURL.String())
q.Set(urlutil.QueryIdentityProviderID, idpID)
urlutil.BuildTimeParameters(q, signInExpiry)
q, err := hpke.EncryptURLValues(senderPrivateKey, authenticatePublicKey, q)
if err != nil {
return "", err
}
signInURL.RawQuery = q.Encode()
return signInURL.String(), nil
}

View file

@ -1,13 +1,74 @@
package urlutil package urlutil
import ( import (
"fmt"
"net/http" "net/http"
"net/url" "net/url"
"time"
"google.golang.org/protobuf/encoding/protojson"
"github.com/pomerium/pomerium/internal/version"
"github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/hpke"
) )
// DefaultDeviceType is the default device type when none is specified. // DefaultDeviceType is the default device type when none is specified.
const DefaultDeviceType = "any" const DefaultDeviceType = "any"
const signInExpiry = time.Minute * 5
// CallbackURL builds the callback URL using an HPKE encrypted query string.
func CallbackURL(
authenticatePrivateKey *hpke.PrivateKey,
proxyPublicKey *hpke.PublicKey,
requestParams url.Values,
profile *identity.Profile,
) (string, error) {
redirectURL, err := ParseAndValidateURL(requestParams.Get(QueryRedirectURI))
if err != nil {
return "", fmt.Errorf("invalid %s: %w", QueryRedirectURI, err)
}
var callbackURL *url.URL
if requestParams.Has(QueryCallbackURI) {
callbackURL, err = ParseAndValidateURL(requestParams.Get(QueryCallbackURI))
if err != nil {
return "", fmt.Errorf("invalid %s: %w", QueryCallbackURI, err)
}
} else {
callbackURL, err = DeepCopy(redirectURL)
if err != nil {
return "", fmt.Errorf("error copying %s: %w", QueryRedirectURI, err)
}
callbackURL.Path = "/.pomerium/callback/"
callbackURL.RawQuery = ""
}
callbackParams := callbackURL.Query()
if requestParams.Has(QueryIsProgrammatic) {
callbackParams.Set(QueryIsProgrammatic, "true")
}
callbackParams.Set(QueryRedirectURI, redirectURL.String())
rawProfile, err := protojson.Marshal(profile)
if err != nil {
return "", fmt.Errorf("error marshaling identity profile: %w", err)
}
callbackParams.Set(QueryIdentityProfile, string(rawProfile))
callbackParams.Set(QueryVersion, version.FullVersion())
BuildTimeParameters(callbackParams, signInExpiry)
callbackParams, err = hpke.EncryptURLValues(authenticatePrivateKey, proxyPublicKey, callbackParams)
if err != nil {
return "", fmt.Errorf("error encrypting callback params: %w", err)
}
callbackURL.RawQuery = callbackParams.Encode()
return callbackURL.String(), nil
}
// RedirectURL returns the redirect URL from the query string or a cookie. // RedirectURL returns the redirect URL from the query string or a cookie.
func RedirectURL(r *http.Request) (string, bool) { func RedirectURL(r *http.Request) (string, bool) {
if v := r.FormValue(QueryRedirectURI); v != "" { if v := r.FormValue(QueryRedirectURI); v != "" {
@ -21,16 +82,42 @@ func RedirectURL(r *http.Request) (string, bool) {
return "", false return "", false
} }
// SignInURL builds the sign in URL using an HPKE encrypted query string.
func SignInURL(
senderPrivateKey *hpke.PrivateKey,
authenticatePublicKey *hpke.PublicKey,
authenticateURL *url.URL,
redirectURL *url.URL,
idpID string,
) (string, error) {
signInURL := *authenticateURL
signInURL.Path = "/.pomerium/sign_in"
q := signInURL.Query()
q.Set(QueryRedirectURI, redirectURL.String())
q.Set(QueryIdentityProviderID, idpID)
q.Set(QueryVersion, version.FullVersion())
BuildTimeParameters(q, signInExpiry)
q, err := hpke.EncryptURLValues(senderPrivateKey, authenticatePublicKey, q)
if err != nil {
return "", err
}
signInURL.RawQuery = q.Encode()
return signInURL.String(), nil
}
// SignOutURL returns the /.pomerium/sign_out URL. // SignOutURL returns the /.pomerium/sign_out URL.
func SignOutURL(r *http.Request, authenticateURL *url.URL, key []byte) string { func SignOutURL(r *http.Request, authenticateURL *url.URL, key []byte) string {
u := authenticateURL.ResolveReference(&url.URL{ u := authenticateURL.ResolveReference(&url.URL{
Path: "/.pomerium/sign_out", Path: "/.pomerium/sign_out",
}) })
q := u.Query()
if redirectURI, ok := RedirectURL(r); ok { if redirectURI, ok := RedirectURL(r); ok {
u.RawQuery = (&url.Values{ q.Set(QueryRedirectURI, redirectURI)
QueryRedirectURI: {redirectURI},
}).Encode()
} }
q.Set(QueryVersion, version.FullVersion())
u.RawQuery = q.Encode()
return NewSignedURL(key, u).Sign().String() return NewSignedURL(key, u).Sign().String()
} }

View file

@ -2,14 +2,43 @@ package urlutil
import ( import (
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/pkg/grpc/identity"
"github.com/pomerium/pomerium/pkg/hpke"
) )
func TestCallbackURL(t *testing.T) {
t.Parallel()
k1 := hpke.DerivePrivateKey([]byte("sender"))
k2 := hpke.DerivePrivateKey([]byte("receiver"))
rawSignInURL, err := CallbackURL(k1, k2.PublicKey(), url.Values{
QueryRedirectURI: {"https://redirect.example.com"},
}, &identity.Profile{
ProviderId: "IDP-1",
})
require.NoError(t, err)
signInURL, err := ParseAndValidateURL(rawSignInURL)
require.NoError(t, err)
k3, q, err := hpke.DecryptURLValues(k2, signInURL.Query())
require.NoError(t, err)
assert.Equal(t, k1.PublicKey(), k3)
assert.NotEmpty(t, q.Get(QueryExpiry))
assert.NotEmpty(t, q.Get(QueryIssued))
assert.NotEmpty(t, q.Get(QueryVersion))
assert.Equal(t, "https://redirect.example.com", q.Get(QueryRedirectURI))
assert.JSONEq(t, `{ "providerId": "IDP-1" }`, q.Get(QueryIdentityProfile))
}
func TestRedirectURI(t *testing.T) { func TestRedirectURI(t *testing.T) {
t.Run("query", func(t *testing.T) { t.Run("query", func(t *testing.T) {
r, err := http.NewRequest("GET", "https://www.example.com?"+(url.Values{ r, err := http.NewRequest("GET", "https://www.example.com?"+(url.Values{
@ -45,3 +74,47 @@ func TestRedirectURI(t *testing.T) {
assert.Equal(t, "https://www.example.com/redirect", redirectURI) assert.Equal(t, "https://www.example.com/redirect", redirectURI)
}) })
} }
func TestSignInURL(t *testing.T) {
t.Parallel()
k1 := hpke.DerivePrivateKey([]byte("sender"))
k2 := hpke.DerivePrivateKey([]byte("receiver"))
authenticateURL := MustParseAndValidateURL("https://authenticate.example.com")
redirectURL := MustParseAndValidateURL("https://redirect.example.com")
rawSignInURL, err := SignInURL(k1, k2.PublicKey(), &authenticateURL, &redirectURL, "IDP-1")
require.NoError(t, err)
signInURL, err := ParseAndValidateURL(rawSignInURL)
require.NoError(t, err)
k3, q, err := hpke.DecryptURLValues(k2, signInURL.Query())
require.NoError(t, err)
assert.Equal(t, k1.PublicKey(), k3)
assert.NotEmpty(t, q.Get(QueryExpiry))
assert.NotEmpty(t, q.Get(QueryIssued))
assert.NotEmpty(t, q.Get(QueryVersion))
assert.Equal(t, "https://redirect.example.com", q.Get(QueryRedirectURI))
assert.Equal(t, "IDP-1", q.Get(QueryIdentityProviderID))
}
func TestSignOutURL(t *testing.T) {
t.Parallel()
r := httptest.NewRequest("GET", "https://route.example.com?"+(url.Values{
QueryRedirectURI: {"https://www.example.com/redirect"},
}).Encode(), nil)
authenticateURL := MustParseAndValidateURL("https://authenticate.example.com")
rawSignOutURL := SignOutURL(r, &authenticateURL, []byte("TEST"))
signOutURL, err := ParseAndValidateURL(rawSignOutURL)
require.NoError(t, err)
q := signOutURL.Query()
assert.NotEmpty(t, q.Get(QueryExpiry))
assert.NotEmpty(t, q.Get(QueryIssued))
assert.NotEmpty(t, q.Get(QueryVersion))
assert.NotEmpty(t, q.Get(QueryHmacSignature))
assert.Equal(t, "https://www.example.com/redirect", q.Get(QueryRedirectURI))
}

View file

@ -18,6 +18,7 @@ const (
QuerySession = "pomerium_session" QuerySession = "pomerium_session"
QuerySessionEncrypted = "pomerium_session_encrypted" QuerySessionEncrypted = "pomerium_session_encrypted"
QuerySessionState = "pomerium_session_state" QuerySessionState = "pomerium_session_state"
QueryVersion = "pomerium_version"
) )
// URL signature based query params used for verifying the authenticity of a URL. // URL signature based query params used for verifying the authenticity of a URL.

View file

@ -228,7 +228,7 @@ func (p *Proxy) ProgrammaticLogin(w http.ResponseWriter, r *http.Request) error
q.Set(urlutil.QueryIsProgrammatic, "true") q.Set(urlutil.QueryIsProgrammatic, "true")
signinURL.RawQuery = q.Encode() signinURL.RawQuery = q.Encode()
rawURL, err := handlers.BuildSignInURL(state.hpkePrivateKey, hpkeAuthenticateKey, &signinURL, redirectURI, idp.GetId()) rawURL, err := urlutil.SignInURL(state.hpkePrivateKey, hpkeAuthenticateKey, &signinURL, redirectURI, idp.GetId())
if err != nil { if err != nil {
return httputil.NewError(http.StatusInternalServerError, err) return httputil.NewError(http.StatusInternalServerError, err)
} }