core/opa: update for rego 1.0 (#4895)

* core/opa: update headers rego script

* core/opa: update ppl

* further updates
This commit is contained in:
Caleb Doxsey 2024-01-16 09:43:35 -07:00 committed by GitHub
parent 5e0079c649
commit 24b04bed35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 289 additions and 319 deletions

View file

@ -104,10 +104,11 @@ func NewHeadersEvaluator(ctx context.Context, store *store.Store) (*HeadersEvalu
r := rego.New(
rego.Store(store),
rego.Module("pomerium.headers", opa.HeadersRego),
rego.Query("result = data.pomerium.headers"),
rego.Query("result := data.pomerium.headers"),
getGoogleCloudServerlessHeadersRegoOption,
variableSubstitutionFunctionRegoOption,
store.GetDataBrokerRecordOption(),
rego.SetRegoVersion(ast.RegoV1),
)
q, err := r.PrepareForEval(ctx)

View file

@ -1,5 +1,7 @@
package pomerium.headers
import rego.v1
# input:
# enable_google_cloud_serverless_authentication: boolean
# enable_routing_key: boolean
@ -30,11 +32,11 @@ package pomerium.headers
five_minutes := round((time.now_ns() / 1e9) + (60 * 5))
# get the session
session = v {
session := v if {
# try a service account
v = get_databroker_record("type.googleapis.com/user.ServiceAccount", input.session.id)
v != null
} else = iv {
} else := iv if {
# try an impersonated session
v = get_databroker_record("type.googleapis.com/session.Session", input.session.id)
v != null
@ -42,89 +44,89 @@ session = v {
iv = get_databroker_record("type.googleapis.com/session.Session", v.impersonate_session_id)
iv != null
} else = v {
} else := v if {
# try a normal session
v = get_databroker_record("type.googleapis.com/session.Session", input.session.id)
v != null
object.get(v, "impersonate_session_id", "") == ""
} else = {}
} else := {}
user = u {
user := u if {
u = get_databroker_record("type.googleapis.com/user.User", session.user_id)
u != null
} else = {}
} else := {}
directory_user = du {
directory_user := du if {
du = get_databroker_record("pomerium.io/DirectoryUser", session.user_id)
du != null
} else = {}
} else := {}
group_ids = gs {
group_ids := gs if {
gs = directory_user.group_ids
gs != null
} else = []
} else := []
groups := array.concat(group_ids, array.concat(get_databroker_group_names(group_ids), get_databroker_group_emails(group_ids)))
jwt_headers = {
jwt_headers := {
"typ": "JWT",
"alg": data.signing_key.alg,
"kid": data.signing_key.kid,
}
jwt_payload_aud = v {
jwt_payload_aud := v if {
v := input.issuer
} else = ""
} else := ""
jwt_payload_iss = v {
jwt_payload_iss := v if {
v := input.issuer
} else = ""
} else := ""
jwt_payload_jti = v {
jwt_payload_jti := v if {
v = session.id
} else = ""
} else := ""
jwt_payload_exp = v {
jwt_payload_exp := v if {
v = min([five_minutes, round(session.expires_at.seconds)])
} else = v {
} else := v if {
v = five_minutes
} else = null
} else := null
jwt_payload_iat = v {
jwt_payload_iat := v if {
# sessions store the issued_at on the id_token
v = round(session.id_token.issued_at.seconds)
} else = v {
} else := v if {
# service accounts store the issued at directly
v = round(session.issued_at.seconds)
} else = null
} else := null
jwt_payload_sub = v {
jwt_payload_sub := v if {
v = session.user_id
} else = ""
} else := ""
jwt_payload_user = v {
jwt_payload_user := v if {
v = session.user_id
} else = ""
} else := ""
jwt_payload_email = v {
jwt_payload_email := v if {
v = directory_user.email
} else = v {
} else := v if {
v = user.email
} else = ""
} else := ""
jwt_payload_groups = v {
jwt_payload_groups := v if {
v = array.concat(group_ids, get_databroker_group_names(group_ids))
v != []
} else = v {
} else := v if {
v = session.claims.groups
v != null
} else = []
} else := []
jwt_payload_name = v {
jwt_payload_name := v if {
v = get_header_string_value(session.claims.name)
} else = v {
} else := v if {
v = get_header_string_value(user.claims.name)
} else = ""
} else := ""
# the session id is always set to the input session id, even if impersonating
jwt_payload_sid := input.session.id
@ -162,51 +164,51 @@ additional_jwt_claims := [[k, v] |
jwt_claims := array.concat(base_jwt_claims, additional_jwt_claims)
jwt_payload = {key: value |
jwt_payload := {key: value |
# use a comprehension over an array to remove nil values
[key, value] := jwt_claims[_]
value != null
}
signed_jwt = io.jwt.encode_sign(jwt_headers, jwt_payload, data.signing_key)
signed_jwt := io.jwt.encode_sign(jwt_headers, jwt_payload, data.signing_key)
kubernetes_headers = h {
kubernetes_headers := h if {
input.kubernetes_service_account_token != ""
h := [
["Authorization", concat(" ", ["Bearer", input.kubernetes_service_account_token])],
["Impersonate-User", jwt_payload_email],
["Impersonate-Group", get_header_string_value(jwt_payload_groups)],
]
} else = []
} else := []
google_cloud_serverless_authentication_service_account = s {
google_cloud_serverless_authentication_service_account := s if {
s := data.google_cloud_serverless_authentication_service_account
} else = ""
} else := ""
google_cloud_serverless_headers = h {
google_cloud_serverless_headers := h if {
input.enable_google_cloud_serverless_authentication
h := get_google_cloud_serverless_headers(google_cloud_serverless_authentication_service_account, input.to_audience)
} else = {}
} else := {}
routing_key_headers = h {
routing_key_headers := h if {
input.enable_routing_key
h := [["x-pomerium-routing-key", crypto.sha256(input.session.id)]]
} else = []
} else := []
session_id_token = v {
session_id_token := v if {
v := session.id_token.raw
} else = ""
} else := ""
session_access_token = v {
session_access_token := v if {
v := session.oauth_token.access_token
} else = ""
} else := ""
client_cert_fingerprint = v {
client_cert_fingerprint := v if {
cert := crypto.x509.parse_certificates(trim_space(input.client_certificate.leaf))[0]
v := crypto.sha256(base64.decode(cert.Raw))
} else = ""
} else := ""
set_request_headers = h {
set_request_headers := h if {
replacements := {
"pomerium.id_token": session_id_token,
"pomerium.access_token": session_access_token,
@ -217,7 +219,7 @@ set_request_headers = h {
v := input.set_request_headers[header_name]
header_value := pomerium.variable_substitution(v, replacements)
]
} else = []
} else := []
identity_headers := {key: values |
h1 := [["x-pomerium-jwt-assertion", signed_jwt]]
@ -251,17 +253,17 @@ identity_headers := {key: values |
]
}
get_databroker_group_names(ids) = gs {
get_databroker_group_names(ids) := gs if {
gs := [name | id := ids[i]; group := get_databroker_record("pomerium.io/DirectoryGroup", id); name := group.name]
}
get_databroker_group_emails(ids) = gs {
get_databroker_group_emails(ids) := gs if {
gs := [email | id := ids[i]; group := get_databroker_record("pomerium.io/DirectoryGroup", id); email := group.email]
}
get_header_string_value(obj) = s {
get_header_string_value(obj) := s if {
is_array(obj)
s := concat(",", obj)
} else = s {
} else := s if {
s := concat(",", [obj])
}

View file

@ -53,104 +53,106 @@ func TestPolicy_ToPPL(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, `package pomerium.policy
default allow = [false, set()]
import rego.v1
default deny = [false, set()]
default allow := [false, set()]
accept_0 = [true, {"accept"}]
default deny := [false, set()]
cors_preflight_0 = [true, {"cors-request"}] {
accept_0 := [true, {"accept"}]
cors_preflight_0 := [true, {"cors-request"}] if {
input.http.method == "OPTIONS"
count(object.get(input.http.headers, "Access-Control-Request-Method", [])) > 0
count(object.get(input.http.headers, "Origin", [])) > 0
}
else = [false, {"non-cors-request"}]
else := [false, {"non-cors-request"}]
authenticated_user_0 = [true, {"user-ok"}] {
authenticated_user_0 := [true, {"user-ok"}] if {
session := get_session(input.session.id)
session.user_id != null
session.user_id != ""
}
else = [false, {"user-unauthorized"}] {
else := [false, {"user-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
domain_0 = [true, {"domain-ok"}] {
domain_0 := [true, {"domain-ok"}] if {
session := get_session(input.session.id)
user := get_user(session)
domain := split(get_user_email(session, user), "@")[1]
domain == "a.example.com"
}
else = [false, {"domain-unauthorized"}] {
else := [false, {"domain-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
domain_1 = [true, {"domain-ok"}] {
domain_1 := [true, {"domain-ok"}] if {
session := get_session(input.session.id)
user := get_user(session)
domain := split(get_user_email(session, user), "@")[1]
domain == "b.example.com"
}
else = [false, {"domain-unauthorized"}] {
else := [false, {"domain-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
domain_2 = [true, {"domain-ok"}] {
domain_2 := [true, {"domain-ok"}] if {
session := get_session(input.session.id)
user := get_user(session)
domain := split(get_user_email(session, user), "@")[1]
domain == "c.example.com"
}
else = [false, {"domain-unauthorized"}] {
else := [false, {"domain-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
domain_3 = [true, {"domain-ok"}] {
domain_3 := [true, {"domain-ok"}] if {
session := get_session(input.session.id)
user := get_user(session)
domain := split(get_user_email(session, user), "@")[1]
domain == "d.example.com"
}
else = [false, {"domain-unauthorized"}] {
else := [false, {"domain-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
domain_4 = [true, {"domain-ok"}] {
domain_4 := [true, {"domain-ok"}] if {
session := get_session(input.session.id)
user := get_user(session)
domain := split(get_user_email(session, user), "@")[1]
domain == "e.example.com"
}
else = [false, {"domain-unauthorized"}] {
else := [false, {"domain-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
claim_0 = [true, {"claim-ok"}] {
claim_0 := [true, {"claim-ok"}] if {
rule_data := "Smith"
rule_path := "family_name"
session := get_session(input.session.id)
@ -162,14 +164,14 @@ claim_0 = [true, {"claim-ok"}] {
rule_data == values[_0]
}
else = [false, {"claim-unauthorized"}] {
else := [false, {"claim-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
claim_1 = [true, {"claim-ok"}] {
claim_1 := [true, {"claim-ok"}] if {
rule_data := "Jones"
rule_path := "family_name"
session := get_session(input.session.id)
@ -181,14 +183,14 @@ claim_1 = [true, {"claim-ok"}] {
rule_data == values[_0]
}
else = [false, {"claim-unauthorized"}] {
else := [false, {"claim-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
claim_2 = [true, {"claim-ok"}] {
claim_2 := [true, {"claim-ok"}] if {
rule_data := "John"
rule_path := "given_name"
session := get_session(input.session.id)
@ -200,14 +202,14 @@ claim_2 = [true, {"claim-ok"}] {
rule_data == values[_0]
}
else = [false, {"claim-unauthorized"}] {
else := [false, {"claim-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
claim_3 = [true, {"claim-ok"}] {
claim_3 := [true, {"claim-ok"}] if {
rule_data := "EST"
rule_path := "timezone"
session := get_session(input.session.id)
@ -219,204 +221,204 @@ claim_3 = [true, {"claim-ok"}] {
rule_data == values[_0]
}
else = [false, {"claim-unauthorized"}] {
else := [false, {"claim-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
user_0 = [true, {"user-ok"}] {
user_0 := [true, {"user-ok"}] if {
session := get_session(input.session.id)
user_id := session.user_id
user_id == "user1"
}
else = [false, {"user-unauthorized"}] {
else := [false, {"user-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
email_0 = [true, {"email-ok"}] {
email_0 := [true, {"email-ok"}] if {
session := get_session(input.session.id)
user := get_user(session)
email := get_user_email(session, user)
email == "user1"
}
else = [false, {"email-unauthorized"}] {
else := [false, {"email-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
user_1 = [true, {"user-ok"}] {
user_1 := [true, {"user-ok"}] if {
session := get_session(input.session.id)
user_id := session.user_id
user_id == "user2"
}
else = [false, {"user-unauthorized"}] {
else := [false, {"user-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
email_1 = [true, {"email-ok"}] {
email_1 := [true, {"email-ok"}] if {
session := get_session(input.session.id)
user := get_user(session)
email := get_user_email(session, user)
email == "user2"
}
else = [false, {"email-unauthorized"}] {
else := [false, {"email-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
user_2 = [true, {"user-ok"}] {
user_2 := [true, {"user-ok"}] if {
session := get_session(input.session.id)
user_id := session.user_id
user_id == "user3"
}
else = [false, {"user-unauthorized"}] {
else := [false, {"user-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
email_2 = [true, {"email-ok"}] {
email_2 := [true, {"email-ok"}] if {
session := get_session(input.session.id)
user := get_user(session)
email := get_user_email(session, user)
email == "user3"
}
else = [false, {"email-unauthorized"}] {
else := [false, {"email-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
user_3 = [true, {"user-ok"}] {
user_3 := [true, {"user-ok"}] if {
session := get_session(input.session.id)
user_id := session.user_id
user_id == "user4"
}
else = [false, {"user-unauthorized"}] {
else := [false, {"user-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
email_3 = [true, {"email-ok"}] {
email_3 := [true, {"email-ok"}] if {
session := get_session(input.session.id)
user := get_user(session)
email := get_user_email(session, user)
email == "user4"
}
else = [false, {"email-unauthorized"}] {
else := [false, {"email-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
user_4 = [true, {"user-ok"}] {
user_4 := [true, {"user-ok"}] if {
session := get_session(input.session.id)
user_id := session.user_id
user_id == "user5"
}
else = [false, {"user-unauthorized"}] {
else := [false, {"user-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
email_4 = [true, {"email-ok"}] {
email_4 := [true, {"email-ok"}] if {
session := get_session(input.session.id)
user := get_user(session)
email := get_user_email(session, user)
email == "user5"
}
else = [false, {"email-unauthorized"}] {
else := [false, {"email-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
or_0 = v {
or_0 := v if {
results := [accept_0, cors_preflight_0, authenticated_user_0, domain_0, domain_1, domain_2, domain_3, domain_4, claim_0, claim_1, claim_2, claim_3, user_0, email_0, user_1, email_1, user_2, email_2, user_3, email_3, user_4, email_4]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
user_5 = [true, {"user-ok"}] {
user_5 := [true, {"user-ok"}] if {
session := get_session(input.session.id)
user_id := session.user_id
user_id == "user6"
}
else = [false, {"user-unauthorized"}] {
else := [false, {"user-unauthorized"}] if {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}]
else := [false, {"user-unauthenticated"}]
or_1 = v {
or_1 := v if {
results := [user_5]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
allow = v {
allow := v if {
results := [or_0, or_1]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
invert_criterion_result(in) = out {
in[0]
out = array.concat([false], array.slice(in, 1, count(in)))
invert_criterion_result(v) := out if {
v[0]
out = array.concat([false], array.slice(v, 1, count(v)))
}
else = out {
not in[0]
out = array.concat([true], array.slice(in, 1, count(in)))
else := out if {
not v[0]
out = array.concat([true], array.slice(v, 1, count(v)))
}
normalize_criterion_result(result) = v {
normalize_criterion_result(result) := v if {
is_boolean(result)
v = [result, set()]
}
else = v {
else := v if {
is_array(result)
v = result
}
else = v {
else := v if {
v = [false, set()]
}
object_union(xs) = merged {
object_union(xs) := merged if {
merged = {k: v |
some k
xs[_0][k]
@ -425,38 +427,38 @@ object_union(xs) = merged {
}
}
merge_with_and(results) = [true, reasons, additional_data] {
merge_with_and(results) := [true, reasons, additional_data] if {
true_results := [x | x := results[i]; x[0]]
count(true_results) == count(results)
reasons := union({x | x := true_results[i][1]})
additional_data := object_union({x | x := true_results[i][2]})
}
else = [false, reasons, additional_data] {
else := [false, reasons, additional_data] if {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
additional_data := object_union({x | x := false_results[i][2]})
}
merge_with_or(results) = [true, reasons, additional_data] {
merge_with_or(results) := [true, reasons, additional_data] if {
true_results := [x | x := results[i]; x[0]]
count(true_results) > 0
reasons := union({x | x := true_results[i][1]})
additional_data := object_union({x | x := true_results[i][2]})
}
else = [false, reasons, additional_data] {
else := [false, reasons, additional_data] if {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
additional_data := object_union({x | x := false_results[i][2]})
}
get_session(id) = v {
get_session(id) := v if {
v = get_databroker_record("type.googleapis.com/user.ServiceAccount", id)
v != null
}
else = iv {
else := iv if {
v = get_databroker_record("type.googleapis.com/session.Session", id)
v != null
object.get(v, "impersonate_session_id", "") != ""
@ -465,41 +467,41 @@ else = iv {
iv != null
}
else = v {
else := v if {
v = get_databroker_record("type.googleapis.com/session.Session", id)
v != null
object.get(v, "impersonate_session_id", "") == ""
}
else = {}
else := {}
get_user(session) = v {
get_user(session) := v if {
v = get_databroker_record("type.googleapis.com/user.User", session.user_id)
v != null
}
else = {}
else := {}
get_user_email(session, user) = v {
get_user_email(session, user) := v if {
v = user.email
}
else = ""
else := ""
object_get(obj, key, def) = value {
object_get(obj, key, def) := value if {
undefined := "10a0fd35-0f1a-4e5b-97ce-631e89e1bafa"
value = object.get(obj, key, undefined)
value != undefined
}
else = value {
else := value if {
segments := split(replace(key, ".", "/"), "/")
count(segments) == 2
o1 := object.get(obj, segments[0], {})
value = object.get(o1, segments[1], def)
}
else = value {
else := value if {
segments := split(replace(key, ".", "/"), "/")
count(segments) == 3
o1 := object.get(obj, segments[0], {})
@ -507,7 +509,7 @@ else = value {
value = object.get(o2, segments[2], def)
}
else = value {
else := value if {
segments := split(replace(key, ".", "/"), "/")
count(segments) == 4
o1 := object.get(obj, segments[0], {})
@ -516,7 +518,7 @@ else = value {
value = object.get(o3, segments[3], def)
}
else = value {
else := value if {
segments := split(replace(key, ".", "/"), "/")
count(segments) == 5
o1 := object.get(obj, segments[0], {})
@ -526,7 +528,7 @@ else = value {
value = object.get(o4, segments[4], def)
}
else = value {
else := value if {
value = object.get(obj, key, def)
}
`, str)

View file

@ -66,9 +66,7 @@ func NewCriterionRule(
r1.Body = body
r2 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTerm(false, failReason),
},
Head: generator.NewHead("", NewCriterionTerm(false, failReason)),
Body: ast.Body{
ast.NewExpr(ast.BooleanTerm(true)),
},
@ -107,9 +105,7 @@ func NewCriterionDeviceRule(
// case 2: rule fails, session exists, device exists
r2 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTermWithAdditionalData(false, failReason, additionalData),
},
Head: generator.NewHead("", NewCriterionTermWithAdditionalData(false, failReason, additionalData)),
Body: append(sharedBody, ast.Body{
ast.MustParseExpr(`session.id != ""`),
ast.MustParseExpr(`device_credential.id != ""`),
@ -120,9 +116,7 @@ func NewCriterionDeviceRule(
// case 3: device not authenticated, session exists, device does not exist
r3 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTermWithAdditionalData(false, ReasonDeviceUnauthenticated, additionalData),
},
Head: generator.NewHead("", NewCriterionTermWithAdditionalData(false, ReasonDeviceUnauthenticated, additionalData)),
Body: append(sharedBody, ast.Body{
ast.MustParseExpr(`session.id != ""`),
}...),
@ -131,9 +125,7 @@ func NewCriterionDeviceRule(
// case 4: user not authenticated, session does not exist
r4 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTermWithAdditionalData(false, ReasonUserUnauthenticated, additionalData),
},
Head: generator.NewHead("", NewCriterionTermWithAdditionalData(false, ReasonUserUnauthenticated, additionalData)),
Body: ast.Body{
ast.NewExpr(ast.BooleanTerm(true)),
},
@ -157,9 +149,7 @@ func NewCriterionSessionRule(
r1.Body = body
r2 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTerm(false, failReason),
},
Head: generator.NewHead("", NewCriterionTerm(false, failReason)),
Body: ast.Body{
ast.MustParseExpr(`session := get_session(input.session.id)`),
ast.MustParseExpr(`session.id != ""`),
@ -168,9 +158,7 @@ func NewCriterionSessionRule(
r1.Else = r2
r3 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTerm(false, ReasonUserUnauthenticated),
},
Head: generator.NewHead("", NewCriterionTerm(false, ReasonUserUnauthenticated)),
Body: ast.Body{
ast.NewExpr(ast.BooleanTerm(true)),
},

View file

@ -121,6 +121,7 @@ func evaluate(t *testing.T,
return nil, nil
}),
rego.Input(input),
rego.SetRegoVersion(ast.RegoV1),
)
preparedQuery, err := r.PrepareForEval(context.Background())
if err != nil {

View file

@ -35,17 +35,13 @@ func (c invalidClientCertificateCriterion) GenerateRule(_ string, _ parser.Value
r1.Body = validClientCertificateBody
r2 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTerm(true, ReasonClientCertificateRequired),
},
Head: generator.NewHead("", NewCriterionTerm(true, ReasonClientCertificateRequired)),
Body: noClientCertificateBody,
}
r1.Else = r2
r3 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTerm(true, ReasonInvalidClientCertificate),
},
Head: generator.NewHead("", NewCriterionTerm(true, ReasonInvalidClientCertificate)),
}
r2.Else = r3

View file

@ -49,8 +49,8 @@ func (g *Generator) GetCriterion(name string) (Criterion, bool) {
// Generate generates the rego module from a policy.
func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) {
rs := ast.NewRuleSet()
rs.Add(ast.MustParseRule(`default allow = [false, set()]`))
rs.Add(ast.MustParseRule(`default deny = [false, set()]`))
rs.Add(rules.MustParse(`default allow := [false, set()]`))
rs.Add(rules.MustParse(`default deny := [false, set()]`))
rs.Add(rules.InvertCriterionResult())
rs.Add(rules.NormalizeCriterionResult())
rs.Add(rules.ObjectUnion())
@ -95,10 +95,7 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) {
}
if len(terms) > 0 {
rule := &ast.Rule{
Head: &ast.Head{
Name: ast.Var(action),
Value: ast.VarTerm("v"),
},
Head: NewHead(ast.Var(action), ast.VarTerm("v")),
Body: append(ast.Body{
ast.Assign.Expr(ast.VarTerm("results"), ast.ArrayTerm(terms...)),
}, orBody...),
@ -115,6 +112,9 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) {
ast.StringTerm("policy"),
},
},
Imports: []*ast.Import{{
Path: ast.RefTerm(ast.VarTerm("rego"), ast.StringTerm("v1")),
}},
Rules: rs,
}
@ -148,8 +148,15 @@ func (g *Generator) NewRule(name string) *ast.Rule {
id := g.ids[name]
g.ids[name]++
return &ast.Rule{
Head: &ast.Head{
Name: ast.Var(fmt.Sprintf("%s_%d", name, id)),
},
Head: NewHead(ast.Var(fmt.Sprintf("%s_%d", name, id)), nil),
}
}
// NewHead creates a new AST Head.
func NewHead(name ast.Var, value *ast.Term) *ast.Head {
return &ast.Head{
Name: name,
Value: value,
Assign: true,
}
}

View file

@ -63,146 +63,118 @@ func Test(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, `package pomerium.policy
default allow = [false, set()]
import rego.v1
default deny = [false, set()]
default allow := [false, set()]
accept_0 {
1 == 1
}
default deny := [false, set()]
accept_1 {
1 == 1
}
accept_0 if 1 == 1
accept_2 {
1 == 1
}
accept_1 if 1 == 1
and_0 = v {
accept_2 if 1 == 1
and_0 := v if {
results := [accept_0, accept_1, accept_2]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_and(normalized)
}
accept_3 {
1 == 1
}
accept_3 if 1 == 1
accept_4 {
1 == 1
}
accept_4 if 1 == 1
accept_5 {
1 == 1
}
accept_5 if 1 == 1
or_0 = v {
or_0 := v if {
results := [accept_3, accept_4, accept_5]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
accept_6 {
1 == 1
}
accept_6 if 1 == 1
accept_7 {
1 == 1
}
accept_7 if 1 == 1
accept_8 {
1 == 1
}
accept_8 if 1 == 1
not_0 = v {
not_0 := v if {
results := [accept_6, accept_7, accept_8]
normalized := [normalize_criterion_result(x) | x := results[i]]
inverted := [invert_criterion_result(x) | x := results[i]]
v := merge_with_and(inverted)
}
accept_9 {
1 == 1
}
accept_9 if 1 == 1
accept_10 {
1 == 1
}
accept_10 if 1 == 1
accept_11 {
1 == 1
}
accept_11 if 1 == 1
nor_0 = v {
nor_0 := v if {
results := [accept_9, accept_10, accept_11]
normalized := [normalize_criterion_result(x) | x := results[i]]
inverted := [invert_criterion_result(x) | x := results[i]]
v := merge_with_or(inverted)
}
accept_12 {
1 == 1
}
accept_12 if 1 == 1
and_1 = v {
and_1 := v if {
results := [accept_12]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_and(normalized)
}
allow = v {
allow := v if {
results := [and_0, or_0, not_0, nor_0, and_1]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
accept_13 {
1 == 1
}
accept_13 if 1 == 1
accept_14 {
1 == 1
}
accept_14 if 1 == 1
nor_1 = v {
nor_1 := v if {
results := [accept_13, accept_14]
normalized := [normalize_criterion_result(x) | x := results[i]]
inverted := [invert_criterion_result(x) | x := results[i]]
v := merge_with_or(inverted)
}
deny = v {
deny := v if {
results := [nor_1]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
invert_criterion_result(in) = out {
in[0]
out = array.concat([false], array.slice(in, 1, count(in)))
invert_criterion_result(v) := out if {
v[0]
out = array.concat([false], array.slice(v, 1, count(v)))
}
else = out {
not in[0]
out = array.concat([true], array.slice(in, 1, count(in)))
else := out if {
not v[0]
out = array.concat([true], array.slice(v, 1, count(v)))
}
normalize_criterion_result(result) = v {
normalize_criterion_result(result) := v if {
is_boolean(result)
v = [result, set()]
}
else = v {
else := v if {
is_array(result)
v = result
}
else = v {
else := v if {
v = [false, set()]
}
object_union(xs) = merged {
object_union(xs) := merged if {
merged = {k: v |
some k
xs[_][k]
@ -211,27 +183,27 @@ object_union(xs) = merged {
}
}
merge_with_and(results) = [true, reasons, additional_data] {
merge_with_and(results) := [true, reasons, additional_data] if {
true_results := [x | x := results[i]; x[0]]
count(true_results) == count(results)
reasons := union({x | x := true_results[i][1]})
additional_data := object_union({x | x := true_results[i][2]})
}
else = [false, reasons, additional_data] {
else := [false, reasons, additional_data] if {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
additional_data := object_union({x | x := false_results[i][2]})
}
merge_with_or(results) = [true, reasons, additional_data] {
merge_with_or(results) := [true, reasons, additional_data] if {
true_results := [x | x := results[i]; x[0]]
count(true_results) > 0
reasons := union({x | x := true_results[i][1]})
additional_data := object_union({x | x := true_results[i][2]})
}
else = [false, reasons, additional_data] {
else := [false, reasons, additional_data] if {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
additional_data := object_union({x | x := false_results[i][2]})

View file

@ -5,84 +5,74 @@ import "github.com/open-policy-agent/opa/ast"
// GetSession gets the session for the given id.
func GetSession() *ast.Rule {
return ast.MustParseRule(`
get_session(id) = v {
return MustParse(`
get_session(id) := v if {
v = get_databroker_record("type.googleapis.com/user.ServiceAccount", id)
v != null
} else = iv {
} else := iv if {
v = get_databroker_record("type.googleapis.com/session.Session", id)
v != null
object.get(v, "impersonate_session_id", "") != ""
iv = get_databroker_record("type.googleapis.com/session.Session", v.impersonate_session_id)
iv != null
} else = v {
} else := v if {
v = get_databroker_record("type.googleapis.com/session.Session", id)
v != null
object.get(v, "impersonate_session_id", "") == ""
} else = {} {
true
}
} else := {}
`)
}
// GetUser returns the user for the given session.
func GetUser() *ast.Rule {
return ast.MustParseRule(`
get_user(session) = v {
return MustParse(`
get_user(session) := v if {
v = get_databroker_record("type.googleapis.com/user.User", session.user_id)
v != null
} else = {} {
true
}
} else := {}
`)
}
// GetUserEmail gets the user email, either the impersonate email, or the user email.
func GetUserEmail() *ast.Rule {
return ast.MustParseRule(`
get_user_email(session, user) = v {
return MustParse(`
get_user_email(session, user) := v if {
v = user.email
} else = "" {
true
}
} else := ""
`)
}
// GetDeviceCredential gets the device credential for the given session.
func GetDeviceCredential() *ast.Rule {
return ast.MustParseRule(`
get_device_credential(session, device_type_id) = v {
return MustParse(`
get_device_credential(session, device_type_id) := v if {
device_credential_id := [x.Credential.Id|x:=session.device_credentials[_];x.type_id==device_type_id][0]
v = get_databroker_record("type.googleapis.com/pomerium.device.Credential", device_credential_id)
v != null
} else = {} {
true
}
} else := {}
`)
}
// GetDeviceEnrollment gets the device enrollment for the given device credential.
func GetDeviceEnrollment() *ast.Rule {
return ast.MustParseRule(`
get_device_enrollment(device_credential) = v {
return MustParse(`
get_device_enrollment(device_credential) := v if {
v = get_databroker_record("type.googleapis.com/pomerium.device.Enrollment", device_credential.enrollment_id)
v != null
} else = {} {
true
}
} else := {}
`)
}
// MergeWithAnd merges criterion results using `and`.
func MergeWithAnd() *ast.Rule {
return ast.MustParseRule(`
merge_with_and(results) = [true, reasons, additional_data] {
return MustParse(`
merge_with_and(results) := [true, reasons, additional_data] if {
true_results := [x|x:=results[i];x[0]]
count(true_results) == count(results)
reasons := union({x|x:=true_results[i][1]})
additional_data := object_union({x|x:=true_results[i][2]})
} else = [false, reasons, additional_data] {
} else := [false, reasons, additional_data] if {
false_results := [x|x:=results[i];not x[0]]
reasons := union({x|x:=false_results[i][1]})
additional_data := object_union({x|x:=false_results[i][2]})
@ -92,13 +82,13 @@ merge_with_and(results) = [true, reasons, additional_data] {
// MergeWithOr merges criterion results using `or`.
func MergeWithOr() *ast.Rule {
return ast.MustParseRule(`
merge_with_or(results) = [true, reasons, additional_data] {
return MustParse(`
merge_with_or(results) := [true, reasons, additional_data] if {
true_results := [x|x:=results[i];x[0]]
count(true_results) > 0
reasons := union({x|x:=true_results[i][1]})
additional_data := object_union({x|x:=true_results[i][2]})
} else = [false, reasons, additional_data] {
} else := [false, reasons, additional_data] if {
false_results := [x|x:=results[i];not x[0]]
reasons := union({x|x:=false_results[i][1]})
additional_data := object_union({x|x:=false_results[i][2]})
@ -109,27 +99,27 @@ merge_with_or(results) = [true, reasons, additional_data] {
// InvertCriterionResult changes the criterion result's value from false to
// true, or vice-versa.
func InvertCriterionResult() *ast.Rule {
return ast.MustParseRule(`
invert_criterion_result(in) = out {
in[0]
out = array.concat([false], array.slice(in, 1, count(in)))
} else = out {
not in[0]
out = array.concat([true], array.slice(in, 1, count(in)))
return MustParse(`
invert_criterion_result(v) := out if {
v[0]
out = array.concat([false], array.slice(v, 1, count(v)))
} else := out if {
not v[0]
out = array.concat([true], array.slice(v, 1, count(v)))
}
`)
}
// NormalizeCriterionResult converts a criterion result into a standard form.
func NormalizeCriterionResult() *ast.Rule {
return ast.MustParseRule(`
normalize_criterion_result(result) = v {
return MustParse(`
normalize_criterion_result(result) := v if {
is_boolean(result)
v = [result, set()]
} else = v {
} else := v if {
is_array(result)
v = result
} else = v {
} else := v if {
v = [false, set()]
}
`)
@ -137,33 +127,33 @@ normalize_criterion_result(result) = v {
// ObjectGet recursively gets a value from an object.
func ObjectGet() *ast.Rule {
return ast.MustParseRule(`
return MustParse(`
# object_get is like object.get, but supports converting "/" in keys to separate lookups
# rego doesn't support recursion, so we hard code a limited number of /'s
object_get(obj, key, def) = value {
object_get(obj, key, def) := value if {
undefined := "10a0fd35-0f1a-4e5b-97ce-631e89e1bafa"
value = object.get(obj, key, undefined)
value != undefined
} else = value {
} else := value if {
segments := split(replace(key, ".", "/"), "/")
count(segments) == 2
o1 := object.get(obj, segments[0], {})
value = object.get(o1, segments[1], def)
} else = value {
} else := value if {
segments := split(replace(key, ".", "/"), "/")
count(segments) == 3
o1 := object.get(obj, segments[0], {})
o2 := object.get(o1, segments[1], {})
value = object.get(o2, segments[2], def)
} else = value {
} else := value if {
segments := split(replace(key, ".", "/"), "/")
count(segments) == 4
o1 := object.get(obj, segments[0], {})
o2 := object.get(o1, segments[1], {})
o3 := object.get(o2, segments[2], {})
value = object.get(o3, segments[3], def)
} else = value {
} else := value if {
segments := split(replace(key, ".", "/"), "/")
count(segments) == 5
o1 := object.get(obj, segments[0], {})
@ -171,7 +161,7 @@ object_get(obj, key, def) = value {
o3 := object.get(o2, segments[2], {})
o4 := object.get(o3, segments[3], {})
value = object.get(o4, segments[4], def)
} else = value {
} else := value if {
value = object.get(obj, key, def)
}
`)
@ -179,8 +169,8 @@ object_get(obj, key, def) = value {
// ObjectUnion merges objects together. It expects a set of objects.
func ObjectUnion() *ast.Rule {
return ast.MustParseRule(`
object_union(xs) = merged {
return MustParse(`
object_union(xs) := merged if {
merged = { k: v |
some k
xs[_][k]
@ -190,3 +180,14 @@ object_union(xs) = merged {
}
`)
}
// MustParse parses an AST rule.
func MustParse(str string) *ast.Rule {
r, err := ast.ParseRuleWithOpts(str, ast.ParserOptions{
RegoVersion: ast.RegoV1,
})
if err != nil {
panic(err)
}
return r
}