diff --git a/authorize/evaluator/headers_evaluator.go b/authorize/evaluator/headers_evaluator.go index 88351c932..7574121ca 100644 --- a/authorize/evaluator/headers_evaluator.go +++ b/authorize/evaluator/headers_evaluator.go @@ -104,10 +104,11 @@ func NewHeadersEvaluator(ctx context.Context, store *store.Store) (*HeadersEvalu r := rego.New( rego.Store(store), rego.Module("pomerium.headers", opa.HeadersRego), - rego.Query("result = data.pomerium.headers"), + rego.Query("result := data.pomerium.headers"), getGoogleCloudServerlessHeadersRegoOption, variableSubstitutionFunctionRegoOption, store.GetDataBrokerRecordOption(), + rego.SetRegoVersion(ast.RegoV1), ) q, err := r.PrepareForEval(ctx) diff --git a/authorize/evaluator/opa/policy/headers.rego b/authorize/evaluator/opa/policy/headers.rego index 6f3f178c7..20d55d0d9 100644 --- a/authorize/evaluator/opa/policy/headers.rego +++ b/authorize/evaluator/opa/policy/headers.rego @@ -1,5 +1,7 @@ package pomerium.headers +import rego.v1 + # input: # enable_google_cloud_serverless_authentication: boolean # enable_routing_key: boolean @@ -30,11 +32,11 @@ package pomerium.headers five_minutes := round((time.now_ns() / 1e9) + (60 * 5)) # get the session -session = v { +session := v if { # try a service account v = get_databroker_record("type.googleapis.com/user.ServiceAccount", input.session.id) v != null -} else = iv { +} else := iv if { # try an impersonated session v = get_databroker_record("type.googleapis.com/session.Session", input.session.id) v != null @@ -42,89 +44,89 @@ session = v { iv = get_databroker_record("type.googleapis.com/session.Session", v.impersonate_session_id) iv != null -} else = v { +} else := v if { # try a normal session v = get_databroker_record("type.googleapis.com/session.Session", input.session.id) v != null object.get(v, "impersonate_session_id", "") == "" -} else = {} +} else := {} -user = u { +user := u if { u = get_databroker_record("type.googleapis.com/user.User", session.user_id) u != null -} else = {} +} else := {} -directory_user = du { +directory_user := du if { du = get_databroker_record("pomerium.io/DirectoryUser", session.user_id) du != null -} else = {} +} else := {} -group_ids = gs { +group_ids := gs if { gs = directory_user.group_ids gs != null -} else = [] +} else := [] groups := array.concat(group_ids, array.concat(get_databroker_group_names(group_ids), get_databroker_group_emails(group_ids))) -jwt_headers = { +jwt_headers := { "typ": "JWT", "alg": data.signing_key.alg, "kid": data.signing_key.kid, } -jwt_payload_aud = v { +jwt_payload_aud := v if { v := input.issuer -} else = "" +} else := "" -jwt_payload_iss = v { +jwt_payload_iss := v if { v := input.issuer -} else = "" +} else := "" -jwt_payload_jti = v { +jwt_payload_jti := v if { v = session.id -} else = "" +} else := "" -jwt_payload_exp = v { +jwt_payload_exp := v if { v = min([five_minutes, round(session.expires_at.seconds)]) -} else = v { +} else := v if { v = five_minutes -} else = null +} else := null -jwt_payload_iat = v { +jwt_payload_iat := v if { # sessions store the issued_at on the id_token v = round(session.id_token.issued_at.seconds) -} else = v { +} else := v if { # service accounts store the issued at directly v = round(session.issued_at.seconds) -} else = null +} else := null -jwt_payload_sub = v { +jwt_payload_sub := v if { v = session.user_id -} else = "" +} else := "" -jwt_payload_user = v { +jwt_payload_user := v if { v = session.user_id -} else = "" +} else := "" -jwt_payload_email = v { +jwt_payload_email := v if { v = directory_user.email -} else = v { +} else := v if { v = user.email -} else = "" +} else := "" -jwt_payload_groups = v { +jwt_payload_groups := v if { v = array.concat(group_ids, get_databroker_group_names(group_ids)) v != [] -} else = v { +} else := v if { v = session.claims.groups v != null -} else = [] +} else := [] -jwt_payload_name = v { +jwt_payload_name := v if { v = get_header_string_value(session.claims.name) -} else = v { +} else := v if { v = get_header_string_value(user.claims.name) -} else = "" +} else := "" # the session id is always set to the input session id, even if impersonating jwt_payload_sid := input.session.id @@ -162,62 +164,62 @@ additional_jwt_claims := [[k, v] | jwt_claims := array.concat(base_jwt_claims, additional_jwt_claims) -jwt_payload = {key: value | +jwt_payload := {key: value | # use a comprehension over an array to remove nil values [key, value] := jwt_claims[_] value != null } -signed_jwt = io.jwt.encode_sign(jwt_headers, jwt_payload, data.signing_key) +signed_jwt := io.jwt.encode_sign(jwt_headers, jwt_payload, data.signing_key) -kubernetes_headers = h { +kubernetes_headers := h if { input.kubernetes_service_account_token != "" h := [ ["Authorization", concat(" ", ["Bearer", input.kubernetes_service_account_token])], ["Impersonate-User", jwt_payload_email], ["Impersonate-Group", get_header_string_value(jwt_payload_groups)], ] -} else = [] +} else := [] -google_cloud_serverless_authentication_service_account = s { +google_cloud_serverless_authentication_service_account := s if { s := data.google_cloud_serverless_authentication_service_account -} else = "" +} else := "" -google_cloud_serverless_headers = h { +google_cloud_serverless_headers := h if { input.enable_google_cloud_serverless_authentication h := get_google_cloud_serverless_headers(google_cloud_serverless_authentication_service_account, input.to_audience) -} else = {} +} else := {} -routing_key_headers = h { +routing_key_headers := h if { input.enable_routing_key h := [["x-pomerium-routing-key", crypto.sha256(input.session.id)]] -} else = [] +} else := [] -session_id_token = v { +session_id_token := v if { v := session.id_token.raw -} else = "" +} else := "" -session_access_token = v { +session_access_token := v if { v := session.oauth_token.access_token -} else = "" +} else := "" -client_cert_fingerprint = v { - cert := crypto.x509.parse_certificates(trim_space(input.client_certificate.leaf))[0] - v := crypto.sha256(base64.decode(cert.Raw)) -} else = "" +client_cert_fingerprint := v if { + cert := crypto.x509.parse_certificates(trim_space(input.client_certificate.leaf))[0] + v := crypto.sha256(base64.decode(cert.Raw)) +} else := "" -set_request_headers = h { - replacements := { - "pomerium.id_token": session_id_token, - "pomerium.access_token": session_access_token, +set_request_headers := h if { + replacements := { + "pomerium.id_token": session_id_token, + "pomerium.access_token": session_access_token, "pomerium.client_cert_fingerprint": client_cert_fingerprint, - } + } h := [[header_name, header_value] | some header_name v := input.set_request_headers[header_name] header_value := pomerium.variable_substitution(v, replacements) ] -} else = [] +} else := [] identity_headers := {key: values | h1 := [["x-pomerium-jwt-assertion", signed_jwt]] @@ -251,17 +253,17 @@ identity_headers := {key: values | ] } -get_databroker_group_names(ids) = gs { +get_databroker_group_names(ids) := gs if { gs := [name | id := ids[i]; group := get_databroker_record("pomerium.io/DirectoryGroup", id); name := group.name] } -get_databroker_group_emails(ids) = gs { +get_databroker_group_emails(ids) := gs if { gs := [email | id := ids[i]; group := get_databroker_record("pomerium.io/DirectoryGroup", id); email := group.email] } -get_header_string_value(obj) = s { +get_header_string_value(obj) := s if { is_array(obj) s := concat(",", obj) -} else = s { +} else := s if { s := concat(",", [obj]) } diff --git a/config/policy_ppl_test.go b/config/policy_ppl_test.go index 13ff31209..adaaac64d 100644 --- a/config/policy_ppl_test.go +++ b/config/policy_ppl_test.go @@ -53,104 +53,106 @@ func TestPolicy_ToPPL(t *testing.T) { require.NoError(t, err) assert.Equal(t, `package pomerium.policy -default allow = [false, set()] +import rego.v1 -default deny = [false, set()] +default allow := [false, set()] -accept_0 = [true, {"accept"}] +default deny := [false, set()] -cors_preflight_0 = [true, {"cors-request"}] { +accept_0 := [true, {"accept"}] + +cors_preflight_0 := [true, {"cors-request"}] if { input.http.method == "OPTIONS" count(object.get(input.http.headers, "Access-Control-Request-Method", [])) > 0 count(object.get(input.http.headers, "Origin", [])) > 0 } -else = [false, {"non-cors-request"}] +else := [false, {"non-cors-request"}] -authenticated_user_0 = [true, {"user-ok"}] { +authenticated_user_0 := [true, {"user-ok"}] if { session := get_session(input.session.id) session.user_id != null session.user_id != "" } -else = [false, {"user-unauthorized"}] { +else := [false, {"user-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -domain_0 = [true, {"domain-ok"}] { +domain_0 := [true, {"domain-ok"}] if { session := get_session(input.session.id) user := get_user(session) domain := split(get_user_email(session, user), "@")[1] domain == "a.example.com" } -else = [false, {"domain-unauthorized"}] { +else := [false, {"domain-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -domain_1 = [true, {"domain-ok"}] { +domain_1 := [true, {"domain-ok"}] if { session := get_session(input.session.id) user := get_user(session) domain := split(get_user_email(session, user), "@")[1] domain == "b.example.com" } -else = [false, {"domain-unauthorized"}] { +else := [false, {"domain-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -domain_2 = [true, {"domain-ok"}] { +domain_2 := [true, {"domain-ok"}] if { session := get_session(input.session.id) user := get_user(session) domain := split(get_user_email(session, user), "@")[1] domain == "c.example.com" } -else = [false, {"domain-unauthorized"}] { +else := [false, {"domain-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -domain_3 = [true, {"domain-ok"}] { +domain_3 := [true, {"domain-ok"}] if { session := get_session(input.session.id) user := get_user(session) domain := split(get_user_email(session, user), "@")[1] domain == "d.example.com" } -else = [false, {"domain-unauthorized"}] { +else := [false, {"domain-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -domain_4 = [true, {"domain-ok"}] { +domain_4 := [true, {"domain-ok"}] if { session := get_session(input.session.id) user := get_user(session) domain := split(get_user_email(session, user), "@")[1] domain == "e.example.com" } -else = [false, {"domain-unauthorized"}] { +else := [false, {"domain-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -claim_0 = [true, {"claim-ok"}] { +claim_0 := [true, {"claim-ok"}] if { rule_data := "Smith" rule_path := "family_name" session := get_session(input.session.id) @@ -162,14 +164,14 @@ claim_0 = [true, {"claim-ok"}] { rule_data == values[_0] } -else = [false, {"claim-unauthorized"}] { +else := [false, {"claim-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -claim_1 = [true, {"claim-ok"}] { +claim_1 := [true, {"claim-ok"}] if { rule_data := "Jones" rule_path := "family_name" session := get_session(input.session.id) @@ -181,14 +183,14 @@ claim_1 = [true, {"claim-ok"}] { rule_data == values[_0] } -else = [false, {"claim-unauthorized"}] { +else := [false, {"claim-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -claim_2 = [true, {"claim-ok"}] { +claim_2 := [true, {"claim-ok"}] if { rule_data := "John" rule_path := "given_name" session := get_session(input.session.id) @@ -200,14 +202,14 @@ claim_2 = [true, {"claim-ok"}] { rule_data == values[_0] } -else = [false, {"claim-unauthorized"}] { +else := [false, {"claim-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -claim_3 = [true, {"claim-ok"}] { +claim_3 := [true, {"claim-ok"}] if { rule_data := "EST" rule_path := "timezone" session := get_session(input.session.id) @@ -219,204 +221,204 @@ claim_3 = [true, {"claim-ok"}] { rule_data == values[_0] } -else = [false, {"claim-unauthorized"}] { +else := [false, {"claim-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -user_0 = [true, {"user-ok"}] { +user_0 := [true, {"user-ok"}] if { session := get_session(input.session.id) user_id := session.user_id user_id == "user1" } -else = [false, {"user-unauthorized"}] { +else := [false, {"user-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -email_0 = [true, {"email-ok"}] { +email_0 := [true, {"email-ok"}] if { session := get_session(input.session.id) user := get_user(session) email := get_user_email(session, user) email == "user1" } -else = [false, {"email-unauthorized"}] { +else := [false, {"email-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -user_1 = [true, {"user-ok"}] { +user_1 := [true, {"user-ok"}] if { session := get_session(input.session.id) user_id := session.user_id user_id == "user2" } -else = [false, {"user-unauthorized"}] { +else := [false, {"user-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -email_1 = [true, {"email-ok"}] { +email_1 := [true, {"email-ok"}] if { session := get_session(input.session.id) user := get_user(session) email := get_user_email(session, user) email == "user2" } -else = [false, {"email-unauthorized"}] { +else := [false, {"email-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -user_2 = [true, {"user-ok"}] { +user_2 := [true, {"user-ok"}] if { session := get_session(input.session.id) user_id := session.user_id user_id == "user3" } -else = [false, {"user-unauthorized"}] { +else := [false, {"user-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -email_2 = [true, {"email-ok"}] { +email_2 := [true, {"email-ok"}] if { session := get_session(input.session.id) user := get_user(session) email := get_user_email(session, user) email == "user3" } -else = [false, {"email-unauthorized"}] { +else := [false, {"email-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -user_3 = [true, {"user-ok"}] { +user_3 := [true, {"user-ok"}] if { session := get_session(input.session.id) user_id := session.user_id user_id == "user4" } -else = [false, {"user-unauthorized"}] { +else := [false, {"user-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -email_3 = [true, {"email-ok"}] { +email_3 := [true, {"email-ok"}] if { session := get_session(input.session.id) user := get_user(session) email := get_user_email(session, user) email == "user4" } -else = [false, {"email-unauthorized"}] { +else := [false, {"email-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -user_4 = [true, {"user-ok"}] { +user_4 := [true, {"user-ok"}] if { session := get_session(input.session.id) user_id := session.user_id user_id == "user5" } -else = [false, {"user-unauthorized"}] { +else := [false, {"user-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -email_4 = [true, {"email-ok"}] { +email_4 := [true, {"email-ok"}] if { session := get_session(input.session.id) user := get_user(session) email := get_user_email(session, user) email == "user5" } -else = [false, {"email-unauthorized"}] { +else := [false, {"email-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -or_0 = v { +or_0 := v if { results := [accept_0, cors_preflight_0, authenticated_user_0, domain_0, domain_1, domain_2, domain_3, domain_4, claim_0, claim_1, claim_2, claim_3, user_0, email_0, user_1, email_1, user_2, email_2, user_3, email_3, user_4, email_4] normalized := [normalize_criterion_result(x) | x := results[i]] v := merge_with_or(normalized) } -user_5 = [true, {"user-ok"}] { +user_5 := [true, {"user-ok"}] if { session := get_session(input.session.id) user_id := session.user_id user_id == "user6" } -else = [false, {"user-unauthorized"}] { +else := [false, {"user-unauthorized"}] if { session := get_session(input.session.id) session.id != "" } -else = [false, {"user-unauthenticated"}] +else := [false, {"user-unauthenticated"}] -or_1 = v { +or_1 := v if { results := [user_5] normalized := [normalize_criterion_result(x) | x := results[i]] v := merge_with_or(normalized) } -allow = v { +allow := v if { results := [or_0, or_1] normalized := [normalize_criterion_result(x) | x := results[i]] v := merge_with_or(normalized) } -invert_criterion_result(in) = out { - in[0] - out = array.concat([false], array.slice(in, 1, count(in))) +invert_criterion_result(v) := out if { + v[0] + out = array.concat([false], array.slice(v, 1, count(v))) } -else = out { - not in[0] - out = array.concat([true], array.slice(in, 1, count(in))) +else := out if { + not v[0] + out = array.concat([true], array.slice(v, 1, count(v))) } -normalize_criterion_result(result) = v { +normalize_criterion_result(result) := v if { is_boolean(result) v = [result, set()] } -else = v { +else := v if { is_array(result) v = result } -else = v { +else := v if { v = [false, set()] } -object_union(xs) = merged { +object_union(xs) := merged if { merged = {k: v | some k xs[_0][k] @@ -425,38 +427,38 @@ object_union(xs) = merged { } } -merge_with_and(results) = [true, reasons, additional_data] { +merge_with_and(results) := [true, reasons, additional_data] if { true_results := [x | x := results[i]; x[0]] count(true_results) == count(results) reasons := union({x | x := true_results[i][1]}) additional_data := object_union({x | x := true_results[i][2]}) } -else = [false, reasons, additional_data] { +else := [false, reasons, additional_data] if { false_results := [x | x := results[i]; not x[0]] reasons := union({x | x := false_results[i][1]}) additional_data := object_union({x | x := false_results[i][2]}) } -merge_with_or(results) = [true, reasons, additional_data] { +merge_with_or(results) := [true, reasons, additional_data] if { true_results := [x | x := results[i]; x[0]] count(true_results) > 0 reasons := union({x | x := true_results[i][1]}) additional_data := object_union({x | x := true_results[i][2]}) } -else = [false, reasons, additional_data] { +else := [false, reasons, additional_data] if { false_results := [x | x := results[i]; not x[0]] reasons := union({x | x := false_results[i][1]}) additional_data := object_union({x | x := false_results[i][2]}) } -get_session(id) = v { +get_session(id) := v if { v = get_databroker_record("type.googleapis.com/user.ServiceAccount", id) v != null } -else = iv { +else := iv if { v = get_databroker_record("type.googleapis.com/session.Session", id) v != null object.get(v, "impersonate_session_id", "") != "" @@ -465,41 +467,41 @@ else = iv { iv != null } -else = v { +else := v if { v = get_databroker_record("type.googleapis.com/session.Session", id) v != null object.get(v, "impersonate_session_id", "") == "" } -else = {} +else := {} -get_user(session) = v { +get_user(session) := v if { v = get_databroker_record("type.googleapis.com/user.User", session.user_id) v != null } -else = {} +else := {} -get_user_email(session, user) = v { +get_user_email(session, user) := v if { v = user.email } -else = "" +else := "" -object_get(obj, key, def) = value { +object_get(obj, key, def) := value if { undefined := "10a0fd35-0f1a-4e5b-97ce-631e89e1bafa" value = object.get(obj, key, undefined) value != undefined } -else = value { +else := value if { segments := split(replace(key, ".", "/"), "/") count(segments) == 2 o1 := object.get(obj, segments[0], {}) value = object.get(o1, segments[1], def) } -else = value { +else := value if { segments := split(replace(key, ".", "/"), "/") count(segments) == 3 o1 := object.get(obj, segments[0], {}) @@ -507,7 +509,7 @@ else = value { value = object.get(o2, segments[2], def) } -else = value { +else := value if { segments := split(replace(key, ".", "/"), "/") count(segments) == 4 o1 := object.get(obj, segments[0], {}) @@ -516,7 +518,7 @@ else = value { value = object.get(o3, segments[3], def) } -else = value { +else := value if { segments := split(replace(key, ".", "/"), "/") count(segments) == 5 o1 := object.get(obj, segments[0], {}) @@ -526,7 +528,7 @@ else = value { value = object.get(o4, segments[4], def) } -else = value { +else := value if { value = object.get(obj, key, def) } `, str) diff --git a/pkg/policy/criteria/criteria.go b/pkg/policy/criteria/criteria.go index cd6924c1f..763b38bdb 100644 --- a/pkg/policy/criteria/criteria.go +++ b/pkg/policy/criteria/criteria.go @@ -66,9 +66,7 @@ func NewCriterionRule( r1.Body = body r2 := &ast.Rule{ - Head: &ast.Head{ - Value: NewCriterionTerm(false, failReason), - }, + Head: generator.NewHead("", NewCriterionTerm(false, failReason)), Body: ast.Body{ ast.NewExpr(ast.BooleanTerm(true)), }, @@ -107,9 +105,7 @@ func NewCriterionDeviceRule( // case 2: rule fails, session exists, device exists r2 := &ast.Rule{ - Head: &ast.Head{ - Value: NewCriterionTermWithAdditionalData(false, failReason, additionalData), - }, + Head: generator.NewHead("", NewCriterionTermWithAdditionalData(false, failReason, additionalData)), Body: append(sharedBody, ast.Body{ ast.MustParseExpr(`session.id != ""`), ast.MustParseExpr(`device_credential.id != ""`), @@ -120,9 +116,7 @@ func NewCriterionDeviceRule( // case 3: device not authenticated, session exists, device does not exist r3 := &ast.Rule{ - Head: &ast.Head{ - Value: NewCriterionTermWithAdditionalData(false, ReasonDeviceUnauthenticated, additionalData), - }, + Head: generator.NewHead("", NewCriterionTermWithAdditionalData(false, ReasonDeviceUnauthenticated, additionalData)), Body: append(sharedBody, ast.Body{ ast.MustParseExpr(`session.id != ""`), }...), @@ -131,9 +125,7 @@ func NewCriterionDeviceRule( // case 4: user not authenticated, session does not exist r4 := &ast.Rule{ - Head: &ast.Head{ - Value: NewCriterionTermWithAdditionalData(false, ReasonUserUnauthenticated, additionalData), - }, + Head: generator.NewHead("", NewCriterionTermWithAdditionalData(false, ReasonUserUnauthenticated, additionalData)), Body: ast.Body{ ast.NewExpr(ast.BooleanTerm(true)), }, @@ -157,9 +149,7 @@ func NewCriterionSessionRule( r1.Body = body r2 := &ast.Rule{ - Head: &ast.Head{ - Value: NewCriterionTerm(false, failReason), - }, + Head: generator.NewHead("", NewCriterionTerm(false, failReason)), Body: ast.Body{ ast.MustParseExpr(`session := get_session(input.session.id)`), ast.MustParseExpr(`session.id != ""`), @@ -168,9 +158,7 @@ func NewCriterionSessionRule( r1.Else = r2 r3 := &ast.Rule{ - Head: &ast.Head{ - Value: NewCriterionTerm(false, ReasonUserUnauthenticated), - }, + Head: generator.NewHead("", NewCriterionTerm(false, ReasonUserUnauthenticated)), Body: ast.Body{ ast.NewExpr(ast.BooleanTerm(true)), }, diff --git a/pkg/policy/criteria/criteria_test.go b/pkg/policy/criteria/criteria_test.go index decc13b3e..d4b74027c 100644 --- a/pkg/policy/criteria/criteria_test.go +++ b/pkg/policy/criteria/criteria_test.go @@ -121,6 +121,7 @@ func evaluate(t *testing.T, return nil, nil }), rego.Input(input), + rego.SetRegoVersion(ast.RegoV1), ) preparedQuery, err := r.PrepareForEval(context.Background()) if err != nil { diff --git a/pkg/policy/criteria/invalid_client_certificate.go b/pkg/policy/criteria/invalid_client_certificate.go index 8fada3693..c45e600a7 100644 --- a/pkg/policy/criteria/invalid_client_certificate.go +++ b/pkg/policy/criteria/invalid_client_certificate.go @@ -35,17 +35,13 @@ func (c invalidClientCertificateCriterion) GenerateRule(_ string, _ parser.Value r1.Body = validClientCertificateBody r2 := &ast.Rule{ - Head: &ast.Head{ - Value: NewCriterionTerm(true, ReasonClientCertificateRequired), - }, + Head: generator.NewHead("", NewCriterionTerm(true, ReasonClientCertificateRequired)), Body: noClientCertificateBody, } r1.Else = r2 r3 := &ast.Rule{ - Head: &ast.Head{ - Value: NewCriterionTerm(true, ReasonInvalidClientCertificate), - }, + Head: generator.NewHead("", NewCriterionTerm(true, ReasonInvalidClientCertificate)), } r2.Else = r3 diff --git a/pkg/policy/generator/generator.go b/pkg/policy/generator/generator.go index a68f03cfd..b96f0f5a2 100644 --- a/pkg/policy/generator/generator.go +++ b/pkg/policy/generator/generator.go @@ -49,8 +49,8 @@ func (g *Generator) GetCriterion(name string) (Criterion, bool) { // Generate generates the rego module from a policy. func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) { rs := ast.NewRuleSet() - rs.Add(ast.MustParseRule(`default allow = [false, set()]`)) - rs.Add(ast.MustParseRule(`default deny = [false, set()]`)) + rs.Add(rules.MustParse(`default allow := [false, set()]`)) + rs.Add(rules.MustParse(`default deny := [false, set()]`)) rs.Add(rules.InvertCriterionResult()) rs.Add(rules.NormalizeCriterionResult()) rs.Add(rules.ObjectUnion()) @@ -95,10 +95,7 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) { } if len(terms) > 0 { rule := &ast.Rule{ - Head: &ast.Head{ - Name: ast.Var(action), - Value: ast.VarTerm("v"), - }, + Head: NewHead(ast.Var(action), ast.VarTerm("v")), Body: append(ast.Body{ ast.Assign.Expr(ast.VarTerm("results"), ast.ArrayTerm(terms...)), }, orBody...), @@ -115,6 +112,9 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) { ast.StringTerm("policy"), }, }, + Imports: []*ast.Import{{ + Path: ast.RefTerm(ast.VarTerm("rego"), ast.StringTerm("v1")), + }}, Rules: rs, } @@ -148,8 +148,15 @@ func (g *Generator) NewRule(name string) *ast.Rule { id := g.ids[name] g.ids[name]++ return &ast.Rule{ - Head: &ast.Head{ - Name: ast.Var(fmt.Sprintf("%s_%d", name, id)), - }, + Head: NewHead(ast.Var(fmt.Sprintf("%s_%d", name, id)), nil), + } +} + +// NewHead creates a new AST Head. +func NewHead(name ast.Var, value *ast.Term) *ast.Head { + return &ast.Head{ + Name: name, + Value: value, + Assign: true, } } diff --git a/pkg/policy/generator/generator_test.go b/pkg/policy/generator/generator_test.go index e6f178422..0635d7a72 100644 --- a/pkg/policy/generator/generator_test.go +++ b/pkg/policy/generator/generator_test.go @@ -63,146 +63,118 @@ func Test(t *testing.T) { require.NoError(t, err) assert.Equal(t, `package pomerium.policy -default allow = [false, set()] +import rego.v1 -default deny = [false, set()] +default allow := [false, set()] -accept_0 { - 1 == 1 -} +default deny := [false, set()] -accept_1 { - 1 == 1 -} +accept_0 if 1 == 1 -accept_2 { - 1 == 1 -} +accept_1 if 1 == 1 -and_0 = v { +accept_2 if 1 == 1 + +and_0 := v if { results := [accept_0, accept_1, accept_2] normalized := [normalize_criterion_result(x) | x := results[i]] v := merge_with_and(normalized) } -accept_3 { - 1 == 1 -} +accept_3 if 1 == 1 -accept_4 { - 1 == 1 -} +accept_4 if 1 == 1 -accept_5 { - 1 == 1 -} +accept_5 if 1 == 1 -or_0 = v { +or_0 := v if { results := [accept_3, accept_4, accept_5] normalized := [normalize_criterion_result(x) | x := results[i]] v := merge_with_or(normalized) } -accept_6 { - 1 == 1 -} +accept_6 if 1 == 1 -accept_7 { - 1 == 1 -} +accept_7 if 1 == 1 -accept_8 { - 1 == 1 -} +accept_8 if 1 == 1 -not_0 = v { +not_0 := v if { results := [accept_6, accept_7, accept_8] normalized := [normalize_criterion_result(x) | x := results[i]] inverted := [invert_criterion_result(x) | x := results[i]] v := merge_with_and(inverted) } -accept_9 { - 1 == 1 -} +accept_9 if 1 == 1 -accept_10 { - 1 == 1 -} +accept_10 if 1 == 1 -accept_11 { - 1 == 1 -} +accept_11 if 1 == 1 -nor_0 = v { +nor_0 := v if { results := [accept_9, accept_10, accept_11] normalized := [normalize_criterion_result(x) | x := results[i]] inverted := [invert_criterion_result(x) | x := results[i]] v := merge_with_or(inverted) } -accept_12 { - 1 == 1 -} +accept_12 if 1 == 1 -and_1 = v { +and_1 := v if { results := [accept_12] normalized := [normalize_criterion_result(x) | x := results[i]] v := merge_with_and(normalized) } -allow = v { +allow := v if { results := [and_0, or_0, not_0, nor_0, and_1] normalized := [normalize_criterion_result(x) | x := results[i]] v := merge_with_or(normalized) } -accept_13 { - 1 == 1 -} +accept_13 if 1 == 1 -accept_14 { - 1 == 1 -} +accept_14 if 1 == 1 -nor_1 = v { +nor_1 := v if { results := [accept_13, accept_14] normalized := [normalize_criterion_result(x) | x := results[i]] inverted := [invert_criterion_result(x) | x := results[i]] v := merge_with_or(inverted) } -deny = v { +deny := v if { results := [nor_1] normalized := [normalize_criterion_result(x) | x := results[i]] v := merge_with_or(normalized) } -invert_criterion_result(in) = out { - in[0] - out = array.concat([false], array.slice(in, 1, count(in))) +invert_criterion_result(v) := out if { + v[0] + out = array.concat([false], array.slice(v, 1, count(v))) } -else = out { - not in[0] - out = array.concat([true], array.slice(in, 1, count(in))) +else := out if { + not v[0] + out = array.concat([true], array.slice(v, 1, count(v))) } -normalize_criterion_result(result) = v { +normalize_criterion_result(result) := v if { is_boolean(result) v = [result, set()] } -else = v { +else := v if { is_array(result) v = result } -else = v { +else := v if { v = [false, set()] } -object_union(xs) = merged { +object_union(xs) := merged if { merged = {k: v | some k xs[_][k] @@ -211,27 +183,27 @@ object_union(xs) = merged { } } -merge_with_and(results) = [true, reasons, additional_data] { +merge_with_and(results) := [true, reasons, additional_data] if { true_results := [x | x := results[i]; x[0]] count(true_results) == count(results) reasons := union({x | x := true_results[i][1]}) additional_data := object_union({x | x := true_results[i][2]}) } -else = [false, reasons, additional_data] { +else := [false, reasons, additional_data] if { false_results := [x | x := results[i]; not x[0]] reasons := union({x | x := false_results[i][1]}) additional_data := object_union({x | x := false_results[i][2]}) } -merge_with_or(results) = [true, reasons, additional_data] { +merge_with_or(results) := [true, reasons, additional_data] if { true_results := [x | x := results[i]; x[0]] count(true_results) > 0 reasons := union({x | x := true_results[i][1]}) additional_data := object_union({x | x := true_results[i][2]}) } -else = [false, reasons, additional_data] { +else := [false, reasons, additional_data] if { false_results := [x | x := results[i]; not x[0]] reasons := union({x | x := false_results[i][1]}) additional_data := object_union({x | x := false_results[i][2]}) diff --git a/pkg/policy/rules/rules.go b/pkg/policy/rules/rules.go index 63542c9b9..a43db38f7 100644 --- a/pkg/policy/rules/rules.go +++ b/pkg/policy/rules/rules.go @@ -5,84 +5,74 @@ import "github.com/open-policy-agent/opa/ast" // GetSession gets the session for the given id. func GetSession() *ast.Rule { - return ast.MustParseRule(` -get_session(id) = v { + return MustParse(` +get_session(id) := v if { v = get_databroker_record("type.googleapis.com/user.ServiceAccount", id) v != null -} else = iv { +} else := iv if { v = get_databroker_record("type.googleapis.com/session.Session", id) v != null object.get(v, "impersonate_session_id", "") != "" iv = get_databroker_record("type.googleapis.com/session.Session", v.impersonate_session_id) iv != null -} else = v { +} else := v if { v = get_databroker_record("type.googleapis.com/session.Session", id) v != null object.get(v, "impersonate_session_id", "") == "" -} else = {} { - true -} +} else := {} `) } // GetUser returns the user for the given session. func GetUser() *ast.Rule { - return ast.MustParseRule(` -get_user(session) = v { + return MustParse(` +get_user(session) := v if { v = get_databroker_record("type.googleapis.com/user.User", session.user_id) v != null -} else = {} { - true -} +} else := {} `) } // GetUserEmail gets the user email, either the impersonate email, or the user email. func GetUserEmail() *ast.Rule { - return ast.MustParseRule(` -get_user_email(session, user) = v { + return MustParse(` +get_user_email(session, user) := v if { v = user.email -} else = "" { - true -} +} else := "" `) } // GetDeviceCredential gets the device credential for the given session. func GetDeviceCredential() *ast.Rule { - return ast.MustParseRule(` -get_device_credential(session, device_type_id) = v { + return MustParse(` +get_device_credential(session, device_type_id) := v if { device_credential_id := [x.Credential.Id|x:=session.device_credentials[_];x.type_id==device_type_id][0] v = get_databroker_record("type.googleapis.com/pomerium.device.Credential", device_credential_id) v != null -} else = {} { - true -} +} else := {} `) } // GetDeviceEnrollment gets the device enrollment for the given device credential. func GetDeviceEnrollment() *ast.Rule { - return ast.MustParseRule(` -get_device_enrollment(device_credential) = v { + return MustParse(` +get_device_enrollment(device_credential) := v if { v = get_databroker_record("type.googleapis.com/pomerium.device.Enrollment", device_credential.enrollment_id) v != null -} else = {} { - true -} +} else := {} `) } // MergeWithAnd merges criterion results using `and`. func MergeWithAnd() *ast.Rule { - return ast.MustParseRule(` -merge_with_and(results) = [true, reasons, additional_data] { + return MustParse(` +merge_with_and(results) := [true, reasons, additional_data] if { true_results := [x|x:=results[i];x[0]] count(true_results) == count(results) reasons := union({x|x:=true_results[i][1]}) additional_data := object_union({x|x:=true_results[i][2]}) -} else = [false, reasons, additional_data] { +} else := [false, reasons, additional_data] if { false_results := [x|x:=results[i];not x[0]] reasons := union({x|x:=false_results[i][1]}) additional_data := object_union({x|x:=false_results[i][2]}) @@ -92,13 +82,13 @@ merge_with_and(results) = [true, reasons, additional_data] { // MergeWithOr merges criterion results using `or`. func MergeWithOr() *ast.Rule { - return ast.MustParseRule(` -merge_with_or(results) = [true, reasons, additional_data] { + return MustParse(` +merge_with_or(results) := [true, reasons, additional_data] if { true_results := [x|x:=results[i];x[0]] count(true_results) > 0 reasons := union({x|x:=true_results[i][1]}) additional_data := object_union({x|x:=true_results[i][2]}) -} else = [false, reasons, additional_data] { +} else := [false, reasons, additional_data] if { false_results := [x|x:=results[i];not x[0]] reasons := union({x|x:=false_results[i][1]}) additional_data := object_union({x|x:=false_results[i][2]}) @@ -109,27 +99,27 @@ merge_with_or(results) = [true, reasons, additional_data] { // InvertCriterionResult changes the criterion result's value from false to // true, or vice-versa. func InvertCriterionResult() *ast.Rule { - return ast.MustParseRule(` -invert_criterion_result(in) = out { - in[0] - out = array.concat([false], array.slice(in, 1, count(in))) -} else = out { - not in[0] - out = array.concat([true], array.slice(in, 1, count(in))) + return MustParse(` +invert_criterion_result(v) := out if { + v[0] + out = array.concat([false], array.slice(v, 1, count(v))) +} else := out if { + not v[0] + out = array.concat([true], array.slice(v, 1, count(v))) } `) } // NormalizeCriterionResult converts a criterion result into a standard form. func NormalizeCriterionResult() *ast.Rule { - return ast.MustParseRule(` -normalize_criterion_result(result) = v { + return MustParse(` +normalize_criterion_result(result) := v if { is_boolean(result) v = [result, set()] -} else = v { +} else := v if { is_array(result) v = result -} else = v { +} else := v if { v = [false, set()] } `) @@ -137,33 +127,33 @@ normalize_criterion_result(result) = v { // ObjectGet recursively gets a value from an object. func ObjectGet() *ast.Rule { - return ast.MustParseRule(` + return MustParse(` # object_get is like object.get, but supports converting "/" in keys to separate lookups # rego doesn't support recursion, so we hard code a limited number of /'s -object_get(obj, key, def) = value { +object_get(obj, key, def) := value if { undefined := "10a0fd35-0f1a-4e5b-97ce-631e89e1bafa" value = object.get(obj, key, undefined) value != undefined -} else = value { +} else := value if { segments := split(replace(key, ".", "/"), "/") count(segments) == 2 o1 := object.get(obj, segments[0], {}) value = object.get(o1, segments[1], def) -} else = value { +} else := value if { segments := split(replace(key, ".", "/"), "/") count(segments) == 3 o1 := object.get(obj, segments[0], {}) o2 := object.get(o1, segments[1], {}) value = object.get(o2, segments[2], def) -} else = value { +} else := value if { segments := split(replace(key, ".", "/"), "/") count(segments) == 4 o1 := object.get(obj, segments[0], {}) o2 := object.get(o1, segments[1], {}) o3 := object.get(o2, segments[2], {}) value = object.get(o3, segments[3], def) -} else = value { +} else := value if { segments := split(replace(key, ".", "/"), "/") count(segments) == 5 o1 := object.get(obj, segments[0], {}) @@ -171,7 +161,7 @@ object_get(obj, key, def) = value { o3 := object.get(o2, segments[2], {}) o4 := object.get(o3, segments[3], {}) value = object.get(o4, segments[4], def) -} else = value { +} else := value if { value = object.get(obj, key, def) } `) @@ -179,8 +169,8 @@ object_get(obj, key, def) = value { // ObjectUnion merges objects together. It expects a set of objects. func ObjectUnion() *ast.Rule { - return ast.MustParseRule(` -object_union(xs) = merged { + return MustParse(` +object_union(xs) := merged if { merged = { k: v | some k xs[_][k] @@ -190,3 +180,14 @@ object_union(xs) = merged { } `) } + +// MustParse parses an AST rule. +func MustParse(str string) *ast.Rule { + r, err := ast.ParseRuleWithOpts(str, ast.ParserOptions{ + RegoVersion: ast.RegoV1, + }) + if err != nil { + panic(err) + } + return r +}