core/opa: update for rego 1.0 (#4895)

* core/opa: update headers rego script

* core/opa: update ppl

* further updates
This commit is contained in:
Caleb Doxsey 2024-01-16 09:43:35 -07:00 committed by GitHub
parent 5e0079c649
commit 24b04bed35
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 289 additions and 319 deletions

View file

@ -66,9 +66,7 @@ func NewCriterionRule(
r1.Body = body
r2 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTerm(false, failReason),
},
Head: generator.NewHead("", NewCriterionTerm(false, failReason)),
Body: ast.Body{
ast.NewExpr(ast.BooleanTerm(true)),
},
@ -107,9 +105,7 @@ func NewCriterionDeviceRule(
// case 2: rule fails, session exists, device exists
r2 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTermWithAdditionalData(false, failReason, additionalData),
},
Head: generator.NewHead("", NewCriterionTermWithAdditionalData(false, failReason, additionalData)),
Body: append(sharedBody, ast.Body{
ast.MustParseExpr(`session.id != ""`),
ast.MustParseExpr(`device_credential.id != ""`),
@ -120,9 +116,7 @@ func NewCriterionDeviceRule(
// case 3: device not authenticated, session exists, device does not exist
r3 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTermWithAdditionalData(false, ReasonDeviceUnauthenticated, additionalData),
},
Head: generator.NewHead("", NewCriterionTermWithAdditionalData(false, ReasonDeviceUnauthenticated, additionalData)),
Body: append(sharedBody, ast.Body{
ast.MustParseExpr(`session.id != ""`),
}...),
@ -131,9 +125,7 @@ func NewCriterionDeviceRule(
// case 4: user not authenticated, session does not exist
r4 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTermWithAdditionalData(false, ReasonUserUnauthenticated, additionalData),
},
Head: generator.NewHead("", NewCriterionTermWithAdditionalData(false, ReasonUserUnauthenticated, additionalData)),
Body: ast.Body{
ast.NewExpr(ast.BooleanTerm(true)),
},
@ -157,9 +149,7 @@ func NewCriterionSessionRule(
r1.Body = body
r2 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTerm(false, failReason),
},
Head: generator.NewHead("", NewCriterionTerm(false, failReason)),
Body: ast.Body{
ast.MustParseExpr(`session := get_session(input.session.id)`),
ast.MustParseExpr(`session.id != ""`),
@ -168,9 +158,7 @@ func NewCriterionSessionRule(
r1.Else = r2
r3 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTerm(false, ReasonUserUnauthenticated),
},
Head: generator.NewHead("", NewCriterionTerm(false, ReasonUserUnauthenticated)),
Body: ast.Body{
ast.NewExpr(ast.BooleanTerm(true)),
},

View file

@ -121,6 +121,7 @@ func evaluate(t *testing.T,
return nil, nil
}),
rego.Input(input),
rego.SetRegoVersion(ast.RegoV1),
)
preparedQuery, err := r.PrepareForEval(context.Background())
if err != nil {

View file

@ -35,17 +35,13 @@ func (c invalidClientCertificateCriterion) GenerateRule(_ string, _ parser.Value
r1.Body = validClientCertificateBody
r2 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTerm(true, ReasonClientCertificateRequired),
},
Head: generator.NewHead("", NewCriterionTerm(true, ReasonClientCertificateRequired)),
Body: noClientCertificateBody,
}
r1.Else = r2
r3 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTerm(true, ReasonInvalidClientCertificate),
},
Head: generator.NewHead("", NewCriterionTerm(true, ReasonInvalidClientCertificate)),
}
r2.Else = r3

View file

@ -49,8 +49,8 @@ func (g *Generator) GetCriterion(name string) (Criterion, bool) {
// Generate generates the rego module from a policy.
func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) {
rs := ast.NewRuleSet()
rs.Add(ast.MustParseRule(`default allow = [false, set()]`))
rs.Add(ast.MustParseRule(`default deny = [false, set()]`))
rs.Add(rules.MustParse(`default allow := [false, set()]`))
rs.Add(rules.MustParse(`default deny := [false, set()]`))
rs.Add(rules.InvertCriterionResult())
rs.Add(rules.NormalizeCriterionResult())
rs.Add(rules.ObjectUnion())
@ -95,10 +95,7 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) {
}
if len(terms) > 0 {
rule := &ast.Rule{
Head: &ast.Head{
Name: ast.Var(action),
Value: ast.VarTerm("v"),
},
Head: NewHead(ast.Var(action), ast.VarTerm("v")),
Body: append(ast.Body{
ast.Assign.Expr(ast.VarTerm("results"), ast.ArrayTerm(terms...)),
}, orBody...),
@ -115,6 +112,9 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) {
ast.StringTerm("policy"),
},
},
Imports: []*ast.Import{{
Path: ast.RefTerm(ast.VarTerm("rego"), ast.StringTerm("v1")),
}},
Rules: rs,
}
@ -148,8 +148,15 @@ func (g *Generator) NewRule(name string) *ast.Rule {
id := g.ids[name]
g.ids[name]++
return &ast.Rule{
Head: &ast.Head{
Name: ast.Var(fmt.Sprintf("%s_%d", name, id)),
},
Head: NewHead(ast.Var(fmt.Sprintf("%s_%d", name, id)), nil),
}
}
// NewHead creates a new AST Head.
func NewHead(name ast.Var, value *ast.Term) *ast.Head {
return &ast.Head{
Name: name,
Value: value,
Assign: true,
}
}

View file

@ -63,146 +63,118 @@ func Test(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, `package pomerium.policy
default allow = [false, set()]
import rego.v1
default deny = [false, set()]
default allow := [false, set()]
accept_0 {
1 == 1
}
default deny := [false, set()]
accept_1 {
1 == 1
}
accept_0 if 1 == 1
accept_2 {
1 == 1
}
accept_1 if 1 == 1
and_0 = v {
accept_2 if 1 == 1
and_0 := v if {
results := [accept_0, accept_1, accept_2]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_and(normalized)
}
accept_3 {
1 == 1
}
accept_3 if 1 == 1
accept_4 {
1 == 1
}
accept_4 if 1 == 1
accept_5 {
1 == 1
}
accept_5 if 1 == 1
or_0 = v {
or_0 := v if {
results := [accept_3, accept_4, accept_5]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
accept_6 {
1 == 1
}
accept_6 if 1 == 1
accept_7 {
1 == 1
}
accept_7 if 1 == 1
accept_8 {
1 == 1
}
accept_8 if 1 == 1
not_0 = v {
not_0 := v if {
results := [accept_6, accept_7, accept_8]
normalized := [normalize_criterion_result(x) | x := results[i]]
inverted := [invert_criterion_result(x) | x := results[i]]
v := merge_with_and(inverted)
}
accept_9 {
1 == 1
}
accept_9 if 1 == 1
accept_10 {
1 == 1
}
accept_10 if 1 == 1
accept_11 {
1 == 1
}
accept_11 if 1 == 1
nor_0 = v {
nor_0 := v if {
results := [accept_9, accept_10, accept_11]
normalized := [normalize_criterion_result(x) | x := results[i]]
inverted := [invert_criterion_result(x) | x := results[i]]
v := merge_with_or(inverted)
}
accept_12 {
1 == 1
}
accept_12 if 1 == 1
and_1 = v {
and_1 := v if {
results := [accept_12]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_and(normalized)
}
allow = v {
allow := v if {
results := [and_0, or_0, not_0, nor_0, and_1]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
accept_13 {
1 == 1
}
accept_13 if 1 == 1
accept_14 {
1 == 1
}
accept_14 if 1 == 1
nor_1 = v {
nor_1 := v if {
results := [accept_13, accept_14]
normalized := [normalize_criterion_result(x) | x := results[i]]
inverted := [invert_criterion_result(x) | x := results[i]]
v := merge_with_or(inverted)
}
deny = v {
deny := v if {
results := [nor_1]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
invert_criterion_result(in) = out {
in[0]
out = array.concat([false], array.slice(in, 1, count(in)))
invert_criterion_result(v) := out if {
v[0]
out = array.concat([false], array.slice(v, 1, count(v)))
}
else = out {
not in[0]
out = array.concat([true], array.slice(in, 1, count(in)))
else := out if {
not v[0]
out = array.concat([true], array.slice(v, 1, count(v)))
}
normalize_criterion_result(result) = v {
normalize_criterion_result(result) := v if {
is_boolean(result)
v = [result, set()]
}
else = v {
else := v if {
is_array(result)
v = result
}
else = v {
else := v if {
v = [false, set()]
}
object_union(xs) = merged {
object_union(xs) := merged if {
merged = {k: v |
some k
xs[_][k]
@ -211,27 +183,27 @@ object_union(xs) = merged {
}
}
merge_with_and(results) = [true, reasons, additional_data] {
merge_with_and(results) := [true, reasons, additional_data] if {
true_results := [x | x := results[i]; x[0]]
count(true_results) == count(results)
reasons := union({x | x := true_results[i][1]})
additional_data := object_union({x | x := true_results[i][2]})
}
else = [false, reasons, additional_data] {
else := [false, reasons, additional_data] if {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
additional_data := object_union({x | x := false_results[i][2]})
}
merge_with_or(results) = [true, reasons, additional_data] {
merge_with_or(results) := [true, reasons, additional_data] if {
true_results := [x | x := results[i]; x[0]]
count(true_results) > 0
reasons := union({x | x := true_results[i][1]})
additional_data := object_union({x | x := true_results[i][2]})
}
else = [false, reasons, additional_data] {
else := [false, reasons, additional_data] if {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
additional_data := object_union({x | x := false_results[i][2]})

View file

@ -5,84 +5,74 @@ import "github.com/open-policy-agent/opa/ast"
// GetSession gets the session for the given id.
func GetSession() *ast.Rule {
return ast.MustParseRule(`
get_session(id) = v {
return MustParse(`
get_session(id) := v if {
v = get_databroker_record("type.googleapis.com/user.ServiceAccount", id)
v != null
} else = iv {
} else := iv if {
v = get_databroker_record("type.googleapis.com/session.Session", id)
v != null
object.get(v, "impersonate_session_id", "") != ""
iv = get_databroker_record("type.googleapis.com/session.Session", v.impersonate_session_id)
iv != null
} else = v {
} else := v if {
v = get_databroker_record("type.googleapis.com/session.Session", id)
v != null
object.get(v, "impersonate_session_id", "") == ""
} else = {} {
true
}
} else := {}
`)
}
// GetUser returns the user for the given session.
func GetUser() *ast.Rule {
return ast.MustParseRule(`
get_user(session) = v {
return MustParse(`
get_user(session) := v if {
v = get_databroker_record("type.googleapis.com/user.User", session.user_id)
v != null
} else = {} {
true
}
} else := {}
`)
}
// GetUserEmail gets the user email, either the impersonate email, or the user email.
func GetUserEmail() *ast.Rule {
return ast.MustParseRule(`
get_user_email(session, user) = v {
return MustParse(`
get_user_email(session, user) := v if {
v = user.email
} else = "" {
true
}
} else := ""
`)
}
// GetDeviceCredential gets the device credential for the given session.
func GetDeviceCredential() *ast.Rule {
return ast.MustParseRule(`
get_device_credential(session, device_type_id) = v {
return MustParse(`
get_device_credential(session, device_type_id) := v if {
device_credential_id := [x.Credential.Id|x:=session.device_credentials[_];x.type_id==device_type_id][0]
v = get_databroker_record("type.googleapis.com/pomerium.device.Credential", device_credential_id)
v != null
} else = {} {
true
}
} else := {}
`)
}
// GetDeviceEnrollment gets the device enrollment for the given device credential.
func GetDeviceEnrollment() *ast.Rule {
return ast.MustParseRule(`
get_device_enrollment(device_credential) = v {
return MustParse(`
get_device_enrollment(device_credential) := v if {
v = get_databroker_record("type.googleapis.com/pomerium.device.Enrollment", device_credential.enrollment_id)
v != null
} else = {} {
true
}
} else := {}
`)
}
// MergeWithAnd merges criterion results using `and`.
func MergeWithAnd() *ast.Rule {
return ast.MustParseRule(`
merge_with_and(results) = [true, reasons, additional_data] {
return MustParse(`
merge_with_and(results) := [true, reasons, additional_data] if {
true_results := [x|x:=results[i];x[0]]
count(true_results) == count(results)
reasons := union({x|x:=true_results[i][1]})
additional_data := object_union({x|x:=true_results[i][2]})
} else = [false, reasons, additional_data] {
} else := [false, reasons, additional_data] if {
false_results := [x|x:=results[i];not x[0]]
reasons := union({x|x:=false_results[i][1]})
additional_data := object_union({x|x:=false_results[i][2]})
@ -92,13 +82,13 @@ merge_with_and(results) = [true, reasons, additional_data] {
// MergeWithOr merges criterion results using `or`.
func MergeWithOr() *ast.Rule {
return ast.MustParseRule(`
merge_with_or(results) = [true, reasons, additional_data] {
return MustParse(`
merge_with_or(results) := [true, reasons, additional_data] if {
true_results := [x|x:=results[i];x[0]]
count(true_results) > 0
reasons := union({x|x:=true_results[i][1]})
additional_data := object_union({x|x:=true_results[i][2]})
} else = [false, reasons, additional_data] {
} else := [false, reasons, additional_data] if {
false_results := [x|x:=results[i];not x[0]]
reasons := union({x|x:=false_results[i][1]})
additional_data := object_union({x|x:=false_results[i][2]})
@ -109,27 +99,27 @@ merge_with_or(results) = [true, reasons, additional_data] {
// InvertCriterionResult changes the criterion result's value from false to
// true, or vice-versa.
func InvertCriterionResult() *ast.Rule {
return ast.MustParseRule(`
invert_criterion_result(in) = out {
in[0]
out = array.concat([false], array.slice(in, 1, count(in)))
} else = out {
not in[0]
out = array.concat([true], array.slice(in, 1, count(in)))
return MustParse(`
invert_criterion_result(v) := out if {
v[0]
out = array.concat([false], array.slice(v, 1, count(v)))
} else := out if {
not v[0]
out = array.concat([true], array.slice(v, 1, count(v)))
}
`)
}
// NormalizeCriterionResult converts a criterion result into a standard form.
func NormalizeCriterionResult() *ast.Rule {
return ast.MustParseRule(`
normalize_criterion_result(result) = v {
return MustParse(`
normalize_criterion_result(result) := v if {
is_boolean(result)
v = [result, set()]
} else = v {
} else := v if {
is_array(result)
v = result
} else = v {
} else := v if {
v = [false, set()]
}
`)
@ -137,33 +127,33 @@ normalize_criterion_result(result) = v {
// ObjectGet recursively gets a value from an object.
func ObjectGet() *ast.Rule {
return ast.MustParseRule(`
return MustParse(`
# object_get is like object.get, but supports converting "/" in keys to separate lookups
# rego doesn't support recursion, so we hard code a limited number of /'s
object_get(obj, key, def) = value {
object_get(obj, key, def) := value if {
undefined := "10a0fd35-0f1a-4e5b-97ce-631e89e1bafa"
value = object.get(obj, key, undefined)
value != undefined
} else = value {
} else := value if {
segments := split(replace(key, ".", "/"), "/")
count(segments) == 2
o1 := object.get(obj, segments[0], {})
value = object.get(o1, segments[1], def)
} else = value {
} else := value if {
segments := split(replace(key, ".", "/"), "/")
count(segments) == 3
o1 := object.get(obj, segments[0], {})
o2 := object.get(o1, segments[1], {})
value = object.get(o2, segments[2], def)
} else = value {
} else := value if {
segments := split(replace(key, ".", "/"), "/")
count(segments) == 4
o1 := object.get(obj, segments[0], {})
o2 := object.get(o1, segments[1], {})
o3 := object.get(o2, segments[2], {})
value = object.get(o3, segments[3], def)
} else = value {
} else := value if {
segments := split(replace(key, ".", "/"), "/")
count(segments) == 5
o1 := object.get(obj, segments[0], {})
@ -171,7 +161,7 @@ object_get(obj, key, def) = value {
o3 := object.get(o2, segments[2], {})
o4 := object.get(o3, segments[3], {})
value = object.get(o4, segments[4], def)
} else = value {
} else := value if {
value = object.get(obj, key, def)
}
`)
@ -179,8 +169,8 @@ object_get(obj, key, def) = value {
// ObjectUnion merges objects together. It expects a set of objects.
func ObjectUnion() *ast.Rule {
return ast.MustParseRule(`
object_union(xs) = merged {
return MustParse(`
object_union(xs) := merged if {
merged = { k: v |
some k
xs[_][k]
@ -190,3 +180,14 @@ object_union(xs) = merged {
}
`)
}
// MustParse parses an AST rule.
func MustParse(str string) *ast.Rule {
r, err := ast.ParseRuleWithOpts(str, ast.ParserOptions{
RegoVersion: ast.RegoV1,
})
if err != nil {
panic(err)
}
return r
}