diff --git a/integration/policy_test.go b/integration/policy_test.go index 05475b75f..eb4ea65da 100644 --- a/integration/policy_test.go +++ b/integration/policy_test.go @@ -1,12 +1,16 @@ package main import ( + "bytes" "context" "crypto/tls" + "encoding/base64" "encoding/json" "io" "net/http" "net/url" + "regexp" + "strings" "testing" "time" @@ -496,3 +500,70 @@ func TestMultipleDownstreamClientCAs(t *testing.T) { assert.Equal(t, httputil.StatusInvalidClientCertificate, res.StatusCode) }) } + +func TestPomeriumJWT(t *testing.T) { + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*30) + defer clearTimeout() + + client := getClient(t) + + // Obtain a Pomerium attestation JWT from the httpdetails service. + res, err := flows.Authenticate(ctx, client, + mustParseURL("https://restricted-httpdetails.localhost.pomerium.io/"), + flows.WithEmail("user1@dogs.test")) + require.NoError(t, err) + defer res.Body.Close() + + var m map[string]interface{} + err = json.NewDecoder(res.Body).Decode(&m) + require.NoError(t, err) + + headers, ok := m["headers"].(map[string]interface{}) + require.True(t, ok) + headerJWT, ok := headers["x-pomerium-jwt-assertion"].(string) + require.True(t, ok) + + // Manually decode the payload section of the JWT in order to verify the + // format of the iat and exp timestamps. + // (https://github.com/pomerium/pomerium/issues/4149) + p := rawJWTPayload(t, headerJWT) + var digitsOnly = regexp.MustCompile(`^\d+$`) + assert.Regexp(t, digitsOnly, p["iat"]) + assert.Regexp(t, digitsOnly, p["exp"]) + + // Also verify the issuer and audience claims. + assert.Equal(t, "restricted-httpdetails.localhost.pomerium.io", p["iss"]) + assert.Equal(t, "restricted-httpdetails.localhost.pomerium.io", p["aud"]) + + // Obtain a Pomerium attestation JWT from the /.pomerium/jwt endpoint. The + // contents should be identical to the JWT header (except possibly the + // timestamps). (https://github.com/pomerium/pomerium/issues/4210) + res, err = client.Get("https://restricted-httpdetails.localhost.pomerium.io/.pomerium/jwt") + require.NoError(t, err) + defer res.Body.Close() + spaJWT, err := io.ReadAll(res.Body) + require.NoError(t, err) + + p2 := rawJWTPayload(t, string(spaJWT)) + + // Remove timestamps before comparing. + delete(p, "iat") + delete(p, "exp") + delete(p2, "iat") + delete(p2, "exp") + assert.Equal(t, p, p2) +} + +func rawJWTPayload(t *testing.T, jwt string) map[string]interface{} { + t.Helper() + s := strings.Split(jwt, ".") + require.Equal(t, 3, len(s), "unexpected JWT format") + payload, err := base64.RawURLEncoding.DecodeString(s[1]) + require.NoError(t, err, "JWT payload could not be decoded") + d := json.NewDecoder(bytes.NewReader(payload)) + d.UseNumber() + var decoded map[string]interface{} + err = d.Decode(&decoded) + require.NoError(t, err, "JWT payload could not be deserialized") + return decoded +}