diff --git a/pkg/policy/generator/conditionals.go b/pkg/policy/generator/conditionals.go index ddf63eea3..0d9d11f8f 100644 --- a/pkg/policy/generator/conditionals.go +++ b/pkg/policy/generator/conditionals.go @@ -20,7 +20,7 @@ func (g *Generator) generateAndRule(dst *ast.RuleSet, policyCriteria []parser.Cr return nil, err } - g.fillViaAnd(rule, false, expressions) + g.fillViaAnd(rule, expressions) dst.Add(rule) return rule, nil @@ -40,7 +40,7 @@ func (g *Generator) generateNotRule(dst *ast.RuleSet, policyCriteria []parser.Cr return nil, err } - g.fillViaAnd(rule, true, terms) + g.fillViaSetComprehension(rule, terms, true, true) dst.Add(rule) return rule, nil @@ -58,7 +58,7 @@ func (g *Generator) generateOrRule(dst *ast.RuleSet, policyCriteria []parser.Cri return nil, err } - g.fillViaOr(rule, false, terms) + g.fillViaOr(rule, terms) dst.Add(rule) return rule, nil @@ -78,7 +78,7 @@ func (g *Generator) generateNorRule(dst *ast.RuleSet, policyCriteria []parser.Cr return nil, err } - g.fillViaOr(rule, true, terms) + g.fillViaSetComprehension(rule, terms, false, true) dst.Add(rule) return rule, nil @@ -103,21 +103,18 @@ func (g *Generator) generateCriterionRules(dst *ast.RuleSet, policyCriteria []pa return terms, nil } -func (g *Generator) fillViaAnd(rule *ast.Rule, negated bool, terms []*ast.Term) { +func (g *Generator) fillViaAnd(rule *ast.Rule, terms []*ast.Term) { currentRule := rule currentRule.Head.Value = ast.VarTerm("v1") for i, term := range terms { nm := fmt.Sprintf("v%d", i+1) currentRule.Body = append(currentRule.Body, ast.Assign.Expr(ast.VarTerm(nm), term)) expr := ast.NewExpr(ast.VarTerm(nm)) - if negated { - expr.Negated = true - } currentRule.Body = append(currentRule.Body, expr) } } -func (g *Generator) fillViaOr(rule *ast.Rule, negated bool, terms []*ast.Term) { +func (g *Generator) fillViaOr(rule *ast.Rule, terms []*ast.Term) { currentRule := rule for i, term := range terms { if i > 0 { @@ -129,9 +126,49 @@ func (g *Generator) fillViaOr(rule *ast.Rule, negated bool, terms []*ast.Term) { currentRule.Body = append(currentRule.Body, ast.Assign.Expr(ast.VarTerm(nm), term)) expr := ast.NewExpr(ast.VarTerm(nm)) - if negated { - expr.Negated = true - } currentRule.Body = append(currentRule.Body, expr) } } + +func (g *Generator) fillViaSetComprehension(rule *ast.Rule, terms []*ast.Term, useIntersection, negated bool) { + sets := make([]*ast.Term, len(terms)) + for i, term := range terms { + e := ast.NewExpr(term) + e.Negated = negated + sets[i] = ast.SetComprehensionTerm(ast.NumberTerm("1"), ast.NewBody(e)) + } + + var builtIn *ast.Builtin + if useIntersection { + builtIn = ast.And + } else { + builtIn = ast.Or + } + rule.Head.Value = ast.VarTerm("v") + rule.Body = ast.NewBody( + ast.Assign.Expr( + ast.VarTerm("v"), + ast.Equal.Call( + ast.Count.Call( + mergeTerms(builtIn, sets...), + ), + ast.NumberTerm("1"), + ), + ), + ) +} + +func mergeTerms(builtIn *ast.Builtin, terms ...*ast.Term) *ast.Term { + // mergeTerms(AND, A, B, C, D) => AND(AND(A, B), AND(C, D)) + switch len(terms) { + case 0: + return ast.NullTerm() + case 1: + return terms[0] + default: + return builtIn.Call( + mergeTerms(builtIn, terms[:len(terms)/2]...), + mergeTerms(builtIn, terms[len(terms)/2:]...), + ) + } +} diff --git a/pkg/policy/generator/generator.go b/pkg/policy/generator/generator.go index 240f80ff7..b0ea5aee8 100644 --- a/pkg/policy/generator/generator.go +++ b/pkg/policy/generator/generator.go @@ -96,7 +96,7 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) { Value: ast.VarTerm("v1"), }, } - g.fillViaOr(rule, false, terms) + g.fillViaOr(rule, terms) rules.Add(rule) } } diff --git a/pkg/policy/generator/generator_test.go b/pkg/policy/generator/generator_test.go index 7a499476a..0e2931540 100644 --- a/pkg/policy/generator/generator_test.go +++ b/pkg/policy/generator/generator_test.go @@ -119,13 +119,8 @@ accept_8 { 1 == 1 } -not_0 = v1 { - v1 := accept_6 - not v1 - v2 := accept_7 - not v2 - v3 := accept_8 - not v3 +not_0 = v { + v := count({1 | not accept_6} & ({1 | not accept_7} & {1 | not accept_8})) == 1 } accept_9 { @@ -140,19 +135,8 @@ accept_11 { 1 == 1 } -nor_0 = v1 { - v1 := accept_9 - not v1 -} - -else = v2 { - v2 := accept_10 - not v2 -} - -else = v3 { - v3 := accept_11 - not v3 +nor_0 = v { + v := count({1 | not accept_9} | ({1 | not accept_10} | {1 | not accept_11})) == 1 } accept_12 {