diff --git a/authenticate/handlers.go b/authenticate/handlers.go index 9b0aba49c..3929f5b87 100644 --- a/authenticate/handlers.go +++ b/authenticate/handlers.go @@ -218,7 +218,12 @@ func (a *Authenticate) SignIn(w http.ResponseWriter, r *http.Request) error { a.logAuthenticateEvent(r, profile) - redirectTo, err := urlutil.CallbackURL(state.hpkePrivateKey, proxyPublicKey, requestParams, profile) + encryptURLValues := hpke.EncryptURLValuesV1 + if hpke.IsEncryptedURLV2(r.Form) { + encryptURLValues = hpke.EncryptURLValuesV2 + } + + redirectTo, err := urlutil.CallbackURL(state.hpkePrivateKey, proxyPublicKey, requestParams, profile, encryptURLValues) if err != nil { return httputil.NewError(http.StatusInternalServerError, err) } diff --git a/go.mod b/go.mod index 37fc16143..15f6601de 100644 --- a/go.mod +++ b/go.mod @@ -156,7 +156,7 @@ require ( github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect github.com/jackc/pgx/v5 v5.3.1 - github.com/klauspost/compress v1.16.0 // indirect + github.com/klauspost/compress v1.16.0 github.com/klauspost/cpuid/v2 v2.2.4 // indirect github.com/libdns/libdns v0.2.1 // indirect github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect diff --git a/internal/urlutil/known.go b/internal/urlutil/known.go index 20d1cd03d..dd39c11ba 100644 --- a/internal/urlutil/known.go +++ b/internal/urlutil/known.go @@ -40,6 +40,7 @@ func CallbackURL( proxyPublicKey *hpke.PublicKey, requestParams url.Values, profile *identity.Profile, + encryptURLValues hpke.EncryptURLValuesFunc, ) (string, error) { redirectURL, err := ParseAndValidateURL(requestParams.Get(QueryRedirectURI)) if err != nil { @@ -76,7 +77,7 @@ func CallbackURL( BuildTimeParameters(callbackParams, signInExpiry) - callbackParams, err = hpke.EncryptURLValues(authenticatePrivateKey, proxyPublicKey, callbackParams) + callbackParams, err = encryptURLValues(authenticatePrivateKey, proxyPublicKey, callbackParams) if err != nil { return "", fmt.Errorf("error encrypting callback params: %w", err) } @@ -115,7 +116,7 @@ func SignInURL( q.Set(QueryVersion, versionStr()) q.Set(QueryRequestUUID, uuid.NewString()) BuildTimeParameters(q, signInExpiry) - q, err := hpke.EncryptURLValues(senderPrivateKey, authenticatePublicKey, q) + q, err := hpke.EncryptURLValuesV2(senderPrivateKey, authenticatePublicKey, q) if err != nil { return "", err } diff --git a/internal/urlutil/known_test.go b/internal/urlutil/known_test.go index 5683e4738..d128daffd 100644 --- a/internal/urlutil/known_test.go +++ b/internal/urlutil/known_test.go @@ -23,7 +23,7 @@ func TestCallbackURL(t *testing.T) { QueryRedirectURI: {"https://redirect.example.com"}, }, &identity.Profile{ ProviderId: "IDP-1", - }) + }, hpke.EncryptURLValuesV1) require.NoError(t, err) signInURL, err := ParseAndValidateURL(rawSignInURL) diff --git a/pkg/hpke/url.go b/pkg/hpke/url.go index 49cbbbc19..12bbc2721 100644 --- a/pkg/hpke/url.go +++ b/pkg/hpke/url.go @@ -3,35 +3,76 @@ package hpke import ( "fmt" "net/url" + + "github.com/klauspost/compress/zstd" ) // URL Parameters const ( - ParamSenderPublicKey = "pomerium_hpke_sender_pub" - ParamQuery = "pomerium_hpke_query" + paramSenderPublicKey = "pomerium_hpke_sender_pub" + paramQuery = "pomerium_hpke_query" + + paramSenderPublicKeyV2 = "k" + paramQueryV2 = "q" ) // IsEncryptedURL returns true if the url.Values contain an HPKE encrypted query. func IsEncryptedURL(values url.Values) bool { - return values.Has(ParamSenderPublicKey) && values.Has(ParamQuery) + return IsEncryptedURLV1(values) || IsEncryptedURLV2(values) } -// EncryptURLValues encrypts URL values using the Seal method. -func EncryptURLValues( +// IsEncryptedURLV1 returns true if the url.Values contain a V1 HPKE encrypted query. +func IsEncryptedURLV1(values url.Values) bool { + return values.Has(paramSenderPublicKey) && values.Has(paramQuery) +} + +// IsEncryptedURLV2 returns true if the url.Values contains a V2 HPKE encrypted query. +func IsEncryptedURLV2(values url.Values) bool { + return values.Has(paramSenderPublicKeyV2) && values.Has(paramQueryV2) +} + +// An EncryptURLValuesFunc is a function that encrypts url values. +type EncryptURLValuesFunc func(senderPrivateKey *PrivateKey, receiverPublicKey *PublicKey, values url.Values) (encrypted url.Values, err error) + +// EncryptURLValuesV1 encrypts URL values using the Seal method. +func EncryptURLValuesV1( senderPrivateKey *PrivateKey, receiverPublicKey *PublicKey, values url.Values, ) (encrypted url.Values, err error) { values = withoutHPKEParams(values) - sealed, err := Seal(senderPrivateKey, receiverPublicKey, []byte(values.Encode())) + encoded := encodeQueryStringV1(values) + + sealed, err := Seal(senderPrivateKey, receiverPublicKey, encoded) if err != nil { return nil, fmt.Errorf("hpke: failed to seal URL values %w", err) } return url.Values{ - ParamSenderPublicKey: {senderPrivateKey.PublicKey().String()}, - ParamQuery: {encode(sealed)}, + paramSenderPublicKey: {senderPrivateKey.PublicKey().String()}, + paramQuery: {encode(sealed)}, + }, nil +} + +// EncryptURLValuesV2 encrypts URL values using the Seal method and compresses the query string. +func EncryptURLValuesV2( + senderPrivateKey *PrivateKey, + receiverPublicKey *PublicKey, + values url.Values, +) (encrypted url.Values, err error) { + values = withoutHPKEParams(values) + + encoded := encodeQueryStringV2(values) + + sealed, err := Seal(senderPrivateKey, receiverPublicKey, encoded) + if err != nil { + return nil, fmt.Errorf("hpke: failed to seal URL values %w", err) + } + + return url.Values{ + paramSenderPublicKeyV2: {senderPrivateKey.PublicKey().String()}, + paramQueryV2: {encode(sealed)}, }, nil } @@ -40,31 +81,50 @@ func DecryptURLValues( receiverPrivateKey *PrivateKey, encrypted url.Values, ) (senderPublicKey *PublicKey, values url.Values, err error) { - if !encrypted.Has(ParamSenderPublicKey) { - return nil, nil, fmt.Errorf("hpke: missing sender public key in query parameters") - } - if !encrypted.Has(ParamQuery) { - return nil, nil, fmt.Errorf("hpke: missing encrypted query in query parameters") - } + var decrypted url.Values + switch { + case IsEncryptedURLV1(encrypted): + senderPublicKey, err = PublicKeyFromString(encrypted.Get(paramSenderPublicKey)) + if err != nil { + return nil, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err) + } - senderPublicKey, err = PublicKeyFromString(encrypted.Get(ParamSenderPublicKey)) - if err != nil { - return nil, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err) - } + sealed, err := decode(encrypted.Get(paramQuery)) + if err != nil { + return nil, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err) + } - sealed, err := decode(encrypted.Get(ParamQuery)) - if err != nil { - return nil, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err) - } + message, err := Open(receiverPrivateKey, senderPublicKey, sealed) + if err != nil { + return nil, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err) + } - message, err := Open(receiverPrivateKey, senderPublicKey, sealed) - if err != nil { - return nil, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err) - } + decrypted, err = decodeQueryStringV1(message) + if err != nil { + return nil, nil, fmt.Errorf("hpke: invalid query parameter: %w", err) + } + case IsEncryptedURLV2(encrypted): + senderPublicKey, err = PublicKeyFromString(encrypted.Get(paramSenderPublicKeyV2)) + if err != nil { + return nil, nil, fmt.Errorf("hpke: invalid sender public key parameter: %w", err) + } - decrypted, err := url.ParseQuery(string(message)) - if err != nil { - return nil, nil, fmt.Errorf("hpke: invalid query parameter: %w", err) + sealed, err := decode(encrypted.Get(paramQueryV2)) + if err != nil { + return nil, nil, fmt.Errorf("hpke: failed decoding query parameter: %w", err) + } + + message, err := Open(receiverPrivateKey, senderPublicKey, sealed) + if err != nil { + return nil, nil, fmt.Errorf("hpke: failed to open sealed message: %w", err) + } + + decrypted, err = decodeQueryStringV2(message) + if err != nil { + return nil, nil, fmt.Errorf("hpke: invalid query parameter: %w", err) + } + default: + return nil, nil, fmt.Errorf("hpke: missing query parameters") } values = withoutHPKEParams(encrypted) @@ -78,9 +138,33 @@ func DecryptURLValues( func withoutHPKEParams(values url.Values) url.Values { filtered := make(url.Values) for k, vs := range values { - if k != ParamSenderPublicKey && k != ParamQuery { + if k != paramSenderPublicKey && k != paramQuery && k != paramSenderPublicKeyV2 && k != paramQueryV2 { filtered[k] = vs } } return filtered } + +var zstdEncoder, _ = zstd.NewWriter(nil, zstd.WithEncoderLevel(zstd.SpeedBestCompression)) + +func encodeQueryStringV1(values url.Values) []byte { + return []byte(values.Encode()) +} + +func encodeQueryStringV2(values url.Values) []byte { + return zstdEncoder.EncodeAll([]byte(values.Encode()), nil) +} + +var zstdDecoder, _ = zstd.NewReader(nil) + +func decodeQueryStringV1(raw []byte) (url.Values, error) { + return url.ParseQuery(string(raw)) +} + +func decodeQueryStringV2(raw []byte) (url.Values, error) { + bs, err := zstdDecoder.DecodeAll(raw, nil) + if err != nil { + return nil, err + } + return url.ParseQuery(string(bs)) +} diff --git a/pkg/hpke/url_test.go b/pkg/hpke/url_test.go index 461762ddb..32f673b26 100644 --- a/pkg/hpke/url_test.go +++ b/pkg/hpke/url_test.go @@ -2,6 +2,7 @@ package hpke import ( "net/url" + "strings" "testing" "github.com/stretchr/testify/assert" @@ -9,29 +10,70 @@ import ( ) func TestEncryptURLValues(t *testing.T) { + t.Parallel() + k1, err := GeneratePrivateKey() require.NoError(t, err) k2, err := GeneratePrivateKey() require.NoError(t, err) - encrypted, err := EncryptURLValues(k1, k2.PublicKey(), url.Values{ - "a": {"b", "c"}, - "x": {"y", "z"}, + t.Run("v1", func(t *testing.T) { + t.Parallel() + + encrypted, err := EncryptURLValuesV1(k1, k2.PublicKey(), url.Values{ + "a": {"b", "c"}, + "x": {"y", "z"}, + }) + assert.NoError(t, err) + assert.True(t, encrypted.Has(paramSenderPublicKey)) + assert.True(t, encrypted.Has(paramQuery)) + + assert.True(t, IsEncryptedURL(encrypted)) + + encrypted.Set("extra", "value") + encrypted.Set("a", "notb") + senderPublicKey, decrypted, err := DecryptURLValues(k2, encrypted) + assert.NoError(t, err) + assert.Equal(t, url.Values{ + "a": {"b", "c"}, + "x": {"y", "z"}, + "extra": {"value"}, + }, decrypted) + assert.Equal(t, k1.PublicKey().String(), senderPublicKey.String()) }) - assert.NoError(t, err) - assert.True(t, encrypted.Has(ParamSenderPublicKey)) - assert.True(t, encrypted.Has(ParamQuery)) + t.Run("v2", func(t *testing.T) { + t.Parallel() - assert.True(t, IsEncryptedURL(encrypted)) + encrypted, err := EncryptURLValuesV2(k1, k2.PublicKey(), url.Values{ + "a": {"b", "c"}, + "x": {"y", "z"}, + }) + assert.NoError(t, err) + assert.True(t, encrypted.Has(paramSenderPublicKeyV2)) + assert.True(t, encrypted.Has(paramQueryV2)) - encrypted.Set("extra", "value") - encrypted.Set("a", "notb") - senderPublicKey, decrypted, err := DecryptURLValues(k2, encrypted) - assert.NoError(t, err) - assert.Equal(t, url.Values{ - "a": {"b", "c"}, - "x": {"y", "z"}, - "extra": {"value"}, - }, decrypted) - assert.Equal(t, k1.PublicKey().String(), senderPublicKey.String()) + assert.True(t, IsEncryptedURL(encrypted)) + + encrypted.Set("extra", "value") + encrypted.Set("a", "notb") + senderPublicKey, decrypted, err := DecryptURLValues(k2, encrypted) + assert.NoError(t, err) + assert.Equal(t, url.Values{ + "a": {"b", "c"}, + "x": {"y", "z"}, + "extra": {"value"}, + }, decrypted) + assert.Equal(t, k1.PublicKey().String(), senderPublicKey.String()) + }) + + t.Run("compresses", func(t *testing.T) { + t.Parallel() + + encrypted, err := EncryptURLValuesV2(k1, k2.PublicKey(), url.Values{ + "a": {strings.Repeat("b", 1024*128)}, + }) + assert.NoError(t, err) + + assert.Less(t, len(encrypted.Encode()), 1024) + }) }