diff --git a/authorize/evaluator/headers_evaluator.go b/authorize/evaluator/headers_evaluator.go index f0244029e..9ff615b4c 100644 --- a/authorize/evaluator/headers_evaluator.go +++ b/authorize/evaluator/headers_evaluator.go @@ -100,8 +100,8 @@ type HeadersEvaluator struct { } // NewHeadersEvaluator creates a new HeadersEvaluator. -func NewHeadersEvaluator(ctx context.Context, store *store.Store) (*HeadersEvaluator, error) { - r := rego.New( +func NewHeadersEvaluator(ctx context.Context, store *store.Store, options ...func(rego *rego.Rego)) (*HeadersEvaluator, error) { + r := rego.New(append([]func(*rego.Rego){ rego.Store(store), rego.Module("pomerium.headers", opa.HeadersRego), rego.Query("result := data.pomerium.headers"), @@ -110,7 +110,7 @@ func NewHeadersEvaluator(ctx context.Context, store *store.Store) (*HeadersEvalu variableSubstitutionFunctionRegoOption, store.GetDataBrokerRecordOption(), rego.SetRegoVersion(ast.RegoV1), - ) + }, options...)...) q, err := r.PrepareForEval(ctx) if err != nil { diff --git a/authorize/evaluator/headers_evaluator_test.go b/authorize/evaluator/headers_evaluator_test.go index d5036464d..fb04be425 100644 --- a/authorize/evaluator/headers_evaluator_test.go +++ b/authorize/evaluator/headers_evaluator_test.go @@ -5,13 +5,14 @@ import ( "context" "encoding/base64" "encoding/json" + "fmt" "math" - "strconv" "strings" "testing" "time" "github.com/go-jose/go-jose/v3/jwt" + "github.com/open-policy-agent/opa/rego" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "google.golang.org/protobuf/proto" @@ -74,13 +75,15 @@ func TestHeadersEvaluator(t *testing.T) { publicJWK, err := cryptutil.PublicJWKFromBytes(encodedSigningKey) require.NoError(t, err) + evalTime := time.Now().Round(time.Second) + eval := func(t *testing.T, data []proto.Message, input *HeadersRequest) (*HeadersResponse, error) { ctx := context.Background() ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...)) store := store.New() store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY")) store.UpdateSigningKey(privateJWK) - e, err := NewHeadersEvaluator(ctx, store) + e, err := NewHeadersEvaluator(ctx, store, rego.Time(evalTime)) require.NoError(t, err) return e.Evaluate(ctx, input) } @@ -118,15 +121,11 @@ func TestHeadersEvaluator(t *testing.T) { err = d.Decode(&jwtPayloadDecoded) require.NoError(t, err) - // The 'iat' claim is set from the session store. - assert.Equal(t, json.Number("1686870680"), jwtPayloadDecoded["iat"], + // The 'iat' and 'exp' claims are set based on the current time. + assert.Equal(t, json.Number(fmt.Sprint(evalTime.Unix())), jwtPayloadDecoded["iat"], "unexpected 'iat' timestamp format") - - // The 'exp' claim will vary with the current time, but we can still - // use Atoi() to verify that it can be parsed as an integer. - exp := string(jwtPayloadDecoded["exp"].(json.Number)) - _, err = strconv.Atoi(exp) - assert.NoError(t, err, "unexpected 'exp' timestamp format") + assert.Equal(t, json.Number(fmt.Sprint(evalTime.Add(5*time.Minute).Unix())), jwtPayloadDecoded["exp"], + "unexpected 'exp' timestamp format") rawJWT, err := jwt.ParseSigned(jwtHeader) require.NoError(t, err) @@ -135,6 +134,7 @@ func TestHeadersEvaluator(t *testing.T) { err = rawJWT.Claims(publicJWK, &claims) require.NoError(t, err) + assert.NotEmpty(t, claims["jti"]) assert.Equal(t, claims["iss"], "from.example.com") assert.Equal(t, claims["aud"], "from.example.com") assert.Equal(t, claims["exp"], math.Round(claims["exp"].(float64))) diff --git a/authorize/evaluator/opa/policy/headers.rego b/authorize/evaluator/opa/policy/headers.rego index 379f575f8..a22493fcd 100644 --- a/authorize/evaluator/opa/policy/headers.rego +++ b/authorize/evaluator/opa/policy/headers.rego @@ -28,8 +28,7 @@ import rego.v1 # output: # identity_headers: map[string][]string -# 5 minutes from now in seconds -five_minutes := round((time.now_ns() / 1e9) + (60 * 5)) +now_s := round(time.now_ns() / 1e9) # get the session session := v if { @@ -82,23 +81,11 @@ jwt_payload_iss := v if { v := input.issuer } else := "" -jwt_payload_jti := v if { - v = session.id -} else := "" +jwt_payload_jti := uuid.rfc4122("jti") -jwt_payload_exp := v if { - v = min([five_minutes, round(session.expires_at.seconds)]) -} else := v if { - v = five_minutes -} else := null +jwt_payload_iat := now_s -jwt_payload_iat := v if { - # sessions store the issued_at on the id_token - v = round(session.id_token.issued_at.seconds) -} else := v if { - # service accounts store the issued at directly - v = round(session.issued_at.seconds) -} else := null +jwt_payload_exp := now_s + (5*60) # 5 minutes from now jwt_payload_sub := v if { v = session.user_id @@ -135,8 +122,8 @@ base_jwt_claims := [ ["iss", jwt_payload_iss], ["aud", jwt_payload_aud], ["jti", jwt_payload_jti], - ["exp", jwt_payload_exp], ["iat", jwt_payload_iat], + ["exp", jwt_payload_exp], ["sub", jwt_payload_sub], ["user", jwt_payload_user], ["email", jwt_payload_email], diff --git a/integration/policy_test.go b/integration/policy_test.go index c29a156d3..b9899847d 100644 --- a/integration/policy_test.go +++ b/integration/policy_test.go @@ -537,7 +537,7 @@ func TestPomeriumJWT(t *testing.T) { // Obtain a Pomerium attestation JWT from the /.pomerium/jwt endpoint. The // contents should be identical to the JWT header (except possibly the - // timestamps). (https://github.com/pomerium/pomerium/issues/4210) + // timestamps and the jtis). (https://github.com/pomerium/pomerium/issues/4210) res, err = client.Get("https://restricted-httpdetails.localhost.pomerium.io/.pomerium/jwt") require.NoError(t, err) defer res.Body.Close() @@ -549,8 +549,10 @@ func TestPomeriumJWT(t *testing.T) { // Remove timestamps before comparing. delete(p, "iat") delete(p, "exp") + delete(p, "jti") delete(p2, "iat") delete(p2, "exp") + delete(p2, "jti") assert.Equal(t, p, p2) }