diff --git a/pkg/policy/criteria/domains.go b/pkg/policy/criteria/domains.go index 915529372..25d0589d9 100644 --- a/pkg/policy/criteria/domains.go +++ b/pkg/policy/criteria/domains.go @@ -36,7 +36,7 @@ func (c domainsCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, switch data.(type) { case parser.String: - r.Body = append(r.Body, ast.MustParseExpr(`domain = rule_data`)) + r.Body = append(r.Body, ast.MustParseExpr(`domain == rule_data`)) default: return nil, nil, fmt.Errorf("unsupported value type: %T", data) } diff --git a/pkg/policy/criteria/emails.go b/pkg/policy/criteria/emails.go index a345c6c5a..d35b840c8 100644 --- a/pkg/policy/criteria/emails.go +++ b/pkg/policy/criteria/emails.go @@ -36,7 +36,7 @@ func (c emailsCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, [ switch data.(type) { case parser.String: - r.Body = append(r.Body, ast.MustParseExpr(`email = rule_data`)) + r.Body = append(r.Body, ast.MustParseExpr(`email == rule_data`)) default: return nil, nil, fmt.Errorf("unsupported value type: %T", data) } diff --git a/pkg/policy/criteria/groups.go b/pkg/policy/criteria/groups.go index 821d9047d..ede9e36d7 100644 --- a/pkg/policy/criteria/groups.go +++ b/pkg/policy/criteria/groups.go @@ -61,7 +61,7 @@ func (c groupsCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, [ switch data.(type) { case parser.String: - r.Body = append(r.Body, ast.MustParseExpr(`group = rule_data`)) + r.Body = append(r.Body, ast.MustParseExpr(`group == rule_data`)) default: return nil, nil, fmt.Errorf("unsupported value type: %T", data) } diff --git a/pkg/policy/criteria/users.go b/pkg/policy/criteria/users.go index 495f649ff..5217e516c 100644 --- a/pkg/policy/criteria/users.go +++ b/pkg/policy/criteria/users.go @@ -33,7 +33,7 @@ func (c usersCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, [] switch data.(type) { case parser.String: - r.Body = append(r.Body, ast.MustParseExpr(`user_id = rule_data`)) + r.Body = append(r.Body, ast.MustParseExpr(`user.id == rule_data`)) default: return nil, nil, fmt.Errorf("unsupported value type: %T", data) } diff --git a/pkg/policy/generator/conditionals.go b/pkg/policy/generator/conditionals.go index fdfb65a46..ddf63eea3 100644 --- a/pkg/policy/generator/conditionals.go +++ b/pkg/policy/generator/conditionals.go @@ -8,8 +8,6 @@ import ( "github.com/pomerium/pomerium/pkg/policy/parser" ) -type conditionalGenerator func(dst *ast.RuleSet, policyCriteria []parser.Criterion) (*ast.Rule, error) - func (g *Generator) generateAndRule(dst *ast.RuleSet, policyCriteria []parser.Criterion) (*ast.Rule, error) { rule := g.NewRule("and") @@ -22,7 +20,7 @@ func (g *Generator) generateAndRule(dst *ast.RuleSet, policyCriteria []parser.Cr return nil, err } - g.fillViaAnd(rule, expressions) + g.fillViaAnd(rule, false, expressions) dst.Add(rule) return rule, nil @@ -37,15 +35,12 @@ func (g *Generator) generateNotRule(dst *ast.RuleSet, policyCriteria []parser.Cr // NOT => (NOT A) AND (NOT B) - expressions, err := g.generateCriterionRules(dst, policyCriteria) + terms, err := g.generateCriterionRules(dst, policyCriteria) if err != nil { return nil, err } - for _, expr := range expressions { - expr.Negated = true - } - g.fillViaAnd(rule, expressions) + g.fillViaAnd(rule, true, terms) dst.Add(rule) return rule, nil @@ -58,12 +53,12 @@ func (g *Generator) generateOrRule(dst *ast.RuleSet, policyCriteria []parser.Cri return rule, nil } - expressions, err := g.generateCriterionRules(dst, policyCriteria) + terms, err := g.generateCriterionRules(dst, policyCriteria) if err != nil { return nil, err } - g.fillViaOr(rule, expressions) + g.fillViaOr(rule, false, terms) dst.Add(rule) return rule, nil @@ -78,22 +73,19 @@ func (g *Generator) generateNorRule(dst *ast.RuleSet, policyCriteria []parser.Cr // NOR => (NOT A) OR (NOT B) - expressions, err := g.generateCriterionRules(dst, policyCriteria) + terms, err := g.generateCriterionRules(dst, policyCriteria) if err != nil { return nil, err } - for _, expr := range expressions { - expr.Negated = true - } - g.fillViaOr(rule, expressions) + g.fillViaOr(rule, true, terms) dst.Add(rule) return rule, nil } -func (g *Generator) generateCriterionRules(dst *ast.RuleSet, policyCriteria []parser.Criterion) ([]*ast.Expr, error) { - var expressions []*ast.Expr +func (g *Generator) generateCriterionRules(dst *ast.RuleSet, policyCriteria []parser.Criterion) ([]*ast.Term, error) { + var terms []*ast.Term for _, policyCriterion := range policyCriteria { criterion, ok := g.criteria[policyCriterion.Name] if !ok { @@ -106,27 +98,40 @@ func (g *Generator) generateCriterionRules(dst *ast.RuleSet, policyCriteria []pa *dst = dst.Merge(additionalRules) dst.Add(mainRule) - expr := ast.NewExpr(ast.VarTerm(string(mainRule.Head.Name))) - expressions = append(expressions, expr) + terms = append(terms, ast.VarTerm(string(mainRule.Head.Name))) } - return expressions, nil + return terms, nil } -func (g *Generator) fillViaAnd(rule *ast.Rule, expressions []*ast.Expr) { - for _, expr := range expressions { - rule.Body = append(rule.Body, expr) - } -} - -func (g *Generator) fillViaOr(rule *ast.Rule, expressions []*ast.Expr) { +func (g *Generator) fillViaAnd(rule *ast.Rule, negated bool, terms []*ast.Term) { currentRule := rule - for i, expr := range expressions { + currentRule.Head.Value = ast.VarTerm("v1") + for i, term := range terms { + nm := fmt.Sprintf("v%d", i+1) + currentRule.Body = append(currentRule.Body, ast.Assign.Expr(ast.VarTerm(nm), term)) + expr := ast.NewExpr(ast.VarTerm(nm)) + if negated { + expr.Negated = true + } + currentRule.Body = append(currentRule.Body, expr) + } +} + +func (g *Generator) fillViaOr(rule *ast.Rule, negated bool, terms []*ast.Term) { + currentRule := rule + for i, term := range terms { if i > 0 { - currentRule.Else = &ast.Rule{ - Head: &ast.Head{}, - } + currentRule.Else = &ast.Rule{Head: &ast.Head{}} currentRule = currentRule.Else } - currentRule.Body = ast.Body{expr} + nm := fmt.Sprintf("v%d", i+1) + currentRule.Head.Value = ast.VarTerm(nm) + + currentRule.Body = append(currentRule.Body, ast.Assign.Expr(ast.VarTerm(nm), term)) + expr := ast.NewExpr(ast.VarTerm(nm)) + if negated { + expr.Negated = true + } + currentRule.Body = append(currentRule.Body, expr) } } diff --git a/pkg/policy/generator/generator.go b/pkg/policy/generator/generator.go index 7ee65d7fa..240f80ff7 100644 --- a/pkg/policy/generator/generator.go +++ b/pkg/policy/generator/generator.go @@ -3,6 +3,7 @@ package generator import ( "fmt" + "sort" "github.com/open-policy-agent/opa/ast" @@ -52,35 +53,55 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) { rules.Add(ast.MustParseRule(`default allow = false`)) rules.Add(ast.MustParseRule(`default deny = false`)) - for _, policyRule := range policy.Rules { - rule := &ast.Rule{ - Head: &ast.Head{Name: ast.Var(policyRule.Action)}, - } - - fields := []struct { - criteria []parser.Criterion - generator conditionalGenerator - }{ - {policyRule.And, g.generateAndRule}, - {policyRule.Or, g.generateOrRule}, - {policyRule.Not, g.generateNotRule}, - {policyRule.Nor, g.generateNorRule}, - } - for _, field := range fields { - if len(field.criteria) == 0 { + for _, action := range []parser.Action{parser.ActionAllow, parser.ActionDeny} { + var terms []*ast.Term + for _, policyRule := range policy.Rules { + if policyRule.Action != action { continue } - subRule, err := field.generator(&rules, field.criteria) - if err != nil { - return nil, err - } - rule.Body = append(rule.Body, ast.NewExpr(ast.VarTerm(string(subRule.Head.Name)))) - } - rules.Add(rule) + if len(policyRule.And) > 0 { + subRule, err := g.generateAndRule(&rules, policyRule.And) + if err != nil { + return nil, err + } + terms = append(terms, ast.VarTerm(string(subRule.Head.Name))) + } + if len(policyRule.Or) > 0 { + subRule, err := g.generateOrRule(&rules, policyRule.Or) + if err != nil { + return nil, err + } + terms = append(terms, ast.VarTerm(string(subRule.Head.Name))) + } + if len(policyRule.Not) > 0 { + subRule, err := g.generateNotRule(&rules, policyRule.Not) + if err != nil { + return nil, err + } + terms = append(terms, ast.VarTerm(string(subRule.Head.Name))) + } + if len(policyRule.Nor) > 0 { + subRule, err := g.generateNorRule(&rules, policyRule.Nor) + if err != nil { + return nil, err + } + terms = append(terms, ast.VarTerm(string(subRule.Head.Name))) + } + } + if len(terms) > 0 { + rule := &ast.Rule{ + Head: &ast.Head{ + Name: ast.Var(action), + Value: ast.VarTerm("v1"), + }, + } + g.fillViaOr(rule, false, terms) + rules.Add(rule) + } } - return &ast.Module{ + mod := &ast.Module{ Package: &ast.Package{ Path: ast.Ref{ ast.StringTerm("policy.rego"), @@ -89,7 +110,21 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) { }, }, Rules: rules, - }, nil + } + + // move functions to the end + sort.SliceStable(mod.Rules, func(i, j int) bool { + return len(mod.Rules[i].Head.Args) < len(mod.Rules[j].Head.Args) + }) + + i := 1 + ast.WalkRules(mod, func(r *ast.Rule) bool { + r.SetLoc(ast.NewLocation([]byte(r.String()), "", i, 1)) + i++ + return false + }) + + return mod, nil } // NewRule creates a new rule with a dynamically generated name. diff --git a/pkg/policy/generator/generator_test.go b/pkg/policy/generator/generator_test.go index a4aa48da9..e62634d01 100644 --- a/pkg/policy/generator/generator_test.go +++ b/pkg/policy/generator/generator_test.go @@ -44,6 +44,12 @@ func Test(t *testing.T) { {Name: "accept"}, }, }, + { + Action: parser.ActionAllow, + And: []parser.Criterion{ + {Name: "accept"}, + }, + }, }, }) assert.NoError(t, err) @@ -65,10 +71,13 @@ accept_2 { 1 == 1 } -and_0 { - accept_0 - accept_1 - accept_2 +and_0 = v1 { + v1 := accept_0 + v1 + v2 := accept_1 + v2 + v3 := accept_2 + v3 } accept_3 { @@ -83,16 +92,19 @@ accept_5 { 1 == 1 } -or_0 { - accept_3 +or_0 = v1 { + v1 := accept_3 + v1 } -else { - accept_4 +else = v2 { + v2 := accept_4 + v2 } -else { - accept_5 +else = v3 { + v3 := accept_5 + v3 } accept_6 { @@ -107,10 +119,13 @@ accept_8 { 1 == 1 } -not_0 { - not accept_6 - not accept_7 - not accept_8 +not_0 = v1 { + v1 := accept_6 + not v1 + v2 := accept_7 + not v2 + v3 := accept_8 + not v3 } accept_9 { @@ -125,23 +140,53 @@ accept_11 { 1 == 1 } -nor_0 { - not accept_9 +nor_0 = v1 { + v1 := accept_9 + not v1 } -else { - not accept_10 +else = v2 { + v2 := accept_10 + not v2 } -else { - not accept_11 +else = v3 { + v3 := accept_11 + not v3 } -allow { - and_0 - or_0 - not_0 - nor_0 +accept_12 { + 1 == 1 +} + +and_1 = v1 { + v1 := accept_12 + v1 +} + +allow = v1 { + v1 := and_0 + v1 +} + +else = v2 { + v2 := or_0 + v2 +} + +else = v3 { + v3 := not_0 + v3 +} + +else = v4 { + v4 := nor_0 + v4 +} + +else = v5 { + v5 := and_1 + v5 } `, string(format.MustAst(mod))) } diff --git a/pkg/policy/parser/json.go b/pkg/policy/parser/json.go index c141a0c2d..eb666972d 100644 --- a/pkg/policy/parser/json.go +++ b/pkg/policy/parser/json.go @@ -141,6 +141,9 @@ func (o Object) Clone() Value { func (o Object) RegoValue() ast.Value { kvps := make([][2]*ast.Term, 0, len(o)) for k, v := range o { + if v == nil { + v = Null{} + } kvps = append(kvps, [2]*ast.Term{ ast.StringTerm(k), ast.NewTerm(v.RegoValue()), diff --git a/pkg/policy/policy.go b/pkg/policy/policy.go index c5b47415e..0354eab9c 100644 --- a/pkg/policy/policy.go +++ b/pkg/policy/policy.go @@ -19,21 +19,24 @@ type ( CriterionConstructor = generator.CriterionConstructor ) -// GenerateRegoFromPPL generates a rego script from raw Pomerium Policy Language. -func GenerateRegoFromPPL(r io.Reader) (string, error) { - p := parser.New() +// GenerateRegoFromReader generates a rego script from raw Pomerium Policy Language. +func GenerateRegoFromReader(r io.Reader) (string, error) { + ppl, err := parser.ParseYAML(r) + if err != nil { + return "", err + } + return GenerateRegoFromPolicy(ppl) +} + +// GenerateRegoFromPolicy generates a rego script from a Pomerium Policy Language policy. +func GenerateRegoFromPolicy(p *parser.Policy) (string, error) { var gOpts []generator.Option for _, ctor := range criteria.All() { gOpts = append(gOpts, generator.WithCriterion(ctor)) } g := generator.New(gOpts...) - ppl, err := p.ParseYAML(r) - if err != nil { - return "", err - } - - mod, err := g.Generate(ppl) + mod, err := g.Generate(p) if err != nil { return "", err } diff --git a/pkg/policy/rules/rules.go b/pkg/policy/rules/rules.go index 5b845573c..c3e6a8d0d 100644 --- a/pkg/policy/rules/rules.go +++ b/pkg/policy/rules/rules.go @@ -7,11 +7,13 @@ import "github.com/open-policy-agent/opa/ast" func GetSession() *ast.Rule { return ast.MustParseRule(` get_session(id) = v { - v := get_databroker_record("type.googleapis.com/user.ServiceAccount", id) + v = get_databroker_record("type.googleapis.com/user.ServiceAccount", id) + v != null } else = v { - v := get_databroker_record("type.googleapis.com/session.Session", id) -} else = v { - v := {} + v = get_databroker_record("type.googleapis.com/session.Session", id) + v != null +} else = {} { + true } `) } @@ -20,11 +22,13 @@ get_session(id) = v { func GetUser() *ast.Rule { return ast.MustParseRule(` get_user(session) = v { - v := get_databroker_record("type.googleapis.com/user.User", session.impersonate_user_id) + v = get_databroker_record("type.googleapis.com/user.User", session.impersonate_user_id) + v != null } else = v { - v := get_databroker_record("type.googleapis.com/user.User", session.user_id) -} else = v { - v := {} + v = get_databroker_record("type.googleapis.com/user.User", session.user_id) + v != null +} else = {} { + true } `) } @@ -33,11 +37,11 @@ get_user(session) = v { func GetUserEmail() *ast.Rule { return ast.MustParseRule(` get_user_email(session, user) = v { - v := session.impersonate_email + v = session.impersonate_email } else = v { - v := user.email -} else = v { - v := "" + v = user.email +} else = "" { + true } `) } @@ -46,11 +50,13 @@ get_user_email(session, user) = v { func GetDirectoryUser() *ast.Rule { return ast.MustParseRule(` get_directory_user(session) = v { - v := get_databroker_record("type.googleapis.com/directory.User", session.impersonate_user_id) + v = get_databroker_record("type.googleapis.com/directory.User", session.impersonate_user_id) + v != null } else = v { - v := get_databroker_record("type.googleapis.com/directory.User", session.user_id) -} else = v { - v := {} + v = get_databroker_record("type.googleapis.com/directory.User", session.user_id) + v != null +} else = "" { + true } `) } @@ -59,9 +65,10 @@ get_directory_user(session) = v { func GetDirectoryGroup() *ast.Rule { return ast.MustParseRule(` get_directory_group(id) = v { - v := get_databroker_record("type.googleapis.com/directory.Group", id) -} else = v { - v := {} + v = get_databroker_record("type.googleapis.com/directory.Group", id) + v != null +} else = {} { + true } `) } @@ -70,11 +77,13 @@ get_directory_group(id) = v { func GetGroupIDs() *ast.Rule { return ast.MustParseRule(` get_group_ids(session, directory_user) = v { - v := session.impersonate_groups + v = session.impersonate_groups + v != null } else = v { - v := directory_user.group_ids -} else = v { - v := [] + v = directory_user.group_ids + v != null +} else = [] { + true } `) }