ppl: fix not/nor rules (#2313)

* ppl: fix not/nor rules

* use set comprehension with count
This commit is contained in:
Caleb Doxsey 2021-06-25 05:41:24 -06:00 committed by GitHub
parent 41a2622736
commit 11a619390a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 54 additions and 33 deletions

View file

@ -20,7 +20,7 @@ func (g *Generator) generateAndRule(dst *ast.RuleSet, policyCriteria []parser.Cr
return nil, err
}
g.fillViaAnd(rule, false, expressions)
g.fillViaAnd(rule, expressions)
dst.Add(rule)
return rule, nil
@ -40,7 +40,7 @@ func (g *Generator) generateNotRule(dst *ast.RuleSet, policyCriteria []parser.Cr
return nil, err
}
g.fillViaAnd(rule, true, terms)
g.fillViaSetComprehension(rule, terms, true, true)
dst.Add(rule)
return rule, nil
@ -58,7 +58,7 @@ func (g *Generator) generateOrRule(dst *ast.RuleSet, policyCriteria []parser.Cri
return nil, err
}
g.fillViaOr(rule, false, terms)
g.fillViaOr(rule, terms)
dst.Add(rule)
return rule, nil
@ -78,7 +78,7 @@ func (g *Generator) generateNorRule(dst *ast.RuleSet, policyCriteria []parser.Cr
return nil, err
}
g.fillViaOr(rule, true, terms)
g.fillViaSetComprehension(rule, terms, false, true)
dst.Add(rule)
return rule, nil
@ -103,21 +103,18 @@ func (g *Generator) generateCriterionRules(dst *ast.RuleSet, policyCriteria []pa
return terms, nil
}
func (g *Generator) fillViaAnd(rule *ast.Rule, negated bool, terms []*ast.Term) {
func (g *Generator) fillViaAnd(rule *ast.Rule, terms []*ast.Term) {
currentRule := rule
currentRule.Head.Value = ast.VarTerm("v1")
for i, term := range terms {
nm := fmt.Sprintf("v%d", i+1)
currentRule.Body = append(currentRule.Body, ast.Assign.Expr(ast.VarTerm(nm), term))
expr := ast.NewExpr(ast.VarTerm(nm))
if negated {
expr.Negated = true
}
currentRule.Body = append(currentRule.Body, expr)
}
}
func (g *Generator) fillViaOr(rule *ast.Rule, negated bool, terms []*ast.Term) {
func (g *Generator) fillViaOr(rule *ast.Rule, terms []*ast.Term) {
currentRule := rule
for i, term := range terms {
if i > 0 {
@ -129,9 +126,49 @@ func (g *Generator) fillViaOr(rule *ast.Rule, negated bool, terms []*ast.Term) {
currentRule.Body = append(currentRule.Body, ast.Assign.Expr(ast.VarTerm(nm), term))
expr := ast.NewExpr(ast.VarTerm(nm))
if negated {
expr.Negated = true
}
currentRule.Body = append(currentRule.Body, expr)
}
}
func (g *Generator) fillViaSetComprehension(rule *ast.Rule, terms []*ast.Term, useIntersection, negated bool) {
sets := make([]*ast.Term, len(terms))
for i, term := range terms {
e := ast.NewExpr(term)
e.Negated = negated
sets[i] = ast.SetComprehensionTerm(ast.NumberTerm("1"), ast.NewBody(e))
}
var builtIn *ast.Builtin
if useIntersection {
builtIn = ast.And
} else {
builtIn = ast.Or
}
rule.Head.Value = ast.VarTerm("v")
rule.Body = ast.NewBody(
ast.Assign.Expr(
ast.VarTerm("v"),
ast.Equal.Call(
ast.Count.Call(
mergeTerms(builtIn, sets...),
),
ast.NumberTerm("1"),
),
),
)
}
func mergeTerms(builtIn *ast.Builtin, terms ...*ast.Term) *ast.Term {
// mergeTerms(AND, A, B, C, D) => AND(AND(A, B), AND(C, D))
switch len(terms) {
case 0:
return ast.NullTerm()
case 1:
return terms[0]
default:
return builtIn.Call(
mergeTerms(builtIn, terms[:len(terms)/2]...),
mergeTerms(builtIn, terms[len(terms)/2:]...),
)
}
}