ppl: add support for additional data (#2696)

* ppl: add support for additional data

* remove unused NewCriterionDeviceRule
This commit is contained in:
Caleb Doxsey 2021-10-22 12:32:20 -06:00 committed by GitHub
parent 0638b07f4d
commit 6e48627b4d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
17 changed files with 205 additions and 59 deletions

View file

@ -28,22 +28,34 @@ type PolicyResponse struct {
Allow, Deny RuleResult
}
// NewPolicyResponse creates a new PolicyResponse.
func NewPolicyResponse() *PolicyResponse {
return &PolicyResponse{
Allow: NewRuleResult(false),
Deny: NewRuleResult(false),
}
}
// A RuleResult is the result of evaluating a rule.
type RuleResult struct {
Value bool
Reasons criteria.Reasons
Value bool
Reasons criteria.Reasons
AdditionalData map[string]interface{}
}
// NewRuleResult creates a new RuleResult.
func NewRuleResult(value bool, reasons ...criteria.Reason) RuleResult {
return RuleResult{
Value: value,
Reasons: criteria.NewReasons(reasons...),
Value: value,
Reasons: criteria.NewReasons(reasons...),
AdditionalData: map[string]interface{}{},
}
}
// MergeRuleResultsWithOr merges all the results using `or`.
func MergeRuleResultsWithOr(results ...RuleResult) (merged RuleResult) {
func MergeRuleResultsWithOr(results ...RuleResult) RuleResult {
merged := NewRuleResult(false)
var trueResults, falseResults []RuleResult
for _, result := range results {
if result.Value {
@ -57,11 +69,17 @@ func MergeRuleResultsWithOr(results ...RuleResult) (merged RuleResult) {
merged.Value = true
for _, result := range trueResults {
merged.Reasons = merged.Reasons.Union(result.Reasons)
for k, v := range result.AdditionalData {
merged.AdditionalData[k] = v
}
}
} else {
merged.Value = false
for _, result := range falseResults {
merged.Reasons = merged.Reasons.Union(result.Reasons)
for k, v := range result.AdditionalData {
merged.AdditionalData[k] = v
}
}
}
@ -145,7 +163,7 @@ func NewPolicyEvaluator(ctx context.Context, store *Store, configPolicy *config.
// Evaluate evaluates the policy rego scripts.
func (e *PolicyEvaluator) Evaluate(ctx context.Context, req *PolicyRequest) (*PolicyResponse, error) {
res := new(PolicyResponse)
res := NewPolicyResponse()
// run each query and merge the results
for _, query := range e.queries {
o, err := e.evaluateQuery(ctx, req, query)
@ -179,7 +197,7 @@ func (e *PolicyEvaluator) evaluateQuery(ctx context.Context, req *PolicyRequest,
return res, nil
}
// getRuleResult gets the rule result var. It expects a boolean or [boolean, []string].
// getRuleResult gets the rule result var. It expects a boolean, [boolean, []string] or [boolean, []string, object].
func (e *PolicyEvaluator) getRuleResult(name string, vars rego.Vars) (result RuleResult) {
result = NewRuleResult(false)
@ -193,14 +211,21 @@ func (e *PolicyEvaluator) getRuleResult(name string, vars rego.Vars) (result Rul
result.Value = t
case []interface{}:
switch len(t) {
case 3:
v, ok := t[2].(map[string]interface{})
if ok {
for k, vv := range v {
result.AdditionalData[k] = vv
}
}
fallthrough
case 2:
// fill in the reasons
v, ok := t[1].([]interface{})
if !ok {
return result
}
for _, vv := range v {
result.Reasons.Add(criteria.Reason(fmt.Sprint(vv)))
if ok {
for _, vv := range v {
result.Reasons.Add(criteria.Reason(fmt.Sprint(vv)))
}
}
fallthrough
case 1:

View file

@ -623,12 +623,14 @@ deny = v {
v := merge_with_or(normalized)
}
invert_criterion_result(result) = [false, result[1]] {
result[0]
invert_criterion_result(in) = out {
in[0]
out = array.concat([false], array.slice(in, 1, count(in)))
}
else = [true, result[1]] {
not result[0]
else = out {
not in[0]
out = array.concat([true], array.slice(in, 1, count(in)))
}
normalize_criterion_result(result) = v {
@ -645,26 +647,39 @@ else = v {
v = [false, set()]
}
merge_with_and(results) = [true, reasons] {
object_union(xs) = merged {
merged = {k: v |
some k
xs[_0][k]
vs := [xv | xv := xs[_][k]]
v := vs[minus(count(vs), 1)]
}
}
merge_with_and(results) = [true, reasons, additional_data] {
true_results := [x | x := results[i]; x[0]]
count(true_results) == count(results)
reasons := union({x | x := true_results[i][1]})
additional_data := object_union({x | x := true_results[i][2]})
}
else = [false, reasons] {
else = [false, reasons, additional_data] {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
additional_data := object_union({x | x := false_results[i][2]})
}
merge_with_or(results) = [true, reasons] {
merge_with_or(results) = [true, reasons, additional_data] {
true_results := [x | x := results[i]; x[0]]
count(true_results) > 0
reasons := union({x | x := true_results[i][1]})
additional_data := object_union({x | x := true_results[i][2]})
}
else = [false, reasons] {
else = [false, reasons, additional_data] {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
additional_data := object_union({x | x := false_results[i][2]})
}
get_session(id) = v {

View file

@ -13,6 +13,6 @@ allow:
- accept: 1
`, []dataBrokerRecord{}, Input{})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonAccept}}, res["allow"])
require.Equal(t, A{true, A{ReasonAccept}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
}

View file

@ -16,7 +16,7 @@ allow:
- authenticated_user: 1
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{false, A{ReasonUserUnauthenticated}}, res["allow"])
require.Equal(t, A{false, A{ReasonUserUnauthenticated}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by domain", func(t *testing.T) {
@ -33,7 +33,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonUserOK}}, res["allow"])
require.Equal(t, A{true, A{ReasonUserOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -18,7 +18,7 @@ allow:
- claim/family_name: Smith
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{false, A{ReasonUserUnauthenticated}}, res["allow"])
require.Equal(t, A{false, A{ReasonUserUnauthenticated}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("no claim", func(t *testing.T) {
@ -35,7 +35,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{false, A{ReasonClaimUnauthorized}}, res["allow"])
require.Equal(t, A{false, A{ReasonClaimUnauthorized}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by session claim", func(t *testing.T) {
@ -59,7 +59,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonClaimOK}}, res["allow"])
require.Equal(t, A{true, A{ReasonClaimOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by user claim", func(t *testing.T) {
@ -83,7 +83,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonClaimOK}}, res["allow"])
require.Equal(t, A{true, A{ReasonClaimOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -20,7 +20,7 @@ allow:
},
}})
require.NoError(t, err)
require.Equal(t, A{true, A{"cors-request"}}, res["allow"])
require.Equal(t, A{true, A{"cors-request"}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("false", func(t *testing.T) {
@ -32,7 +32,7 @@ allow:
Method: "OPTIONS",
}})
require.NoError(t, err)
require.Equal(t, A{false, A{"non-cors-request"}}, res["allow"])
require.Equal(t, A{false, A{"non-cors-request"}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -126,3 +126,24 @@ func NewCriterionTerm(value bool, reasons ...Reason) *ast.Term {
ast.SetTerm(terms...),
)
}
// NewCriterionTermWithAdditionalData creates a new rego term for a criterion with additional data:
//
// [true, {"reason"}, {"key": "value"}]
//
func NewCriterionTermWithAdditionalData(value bool, reason Reason, additionalData map[string]interface{}) *ast.Term {
var kvs [][2]*ast.Term
for k, v := range additionalData {
kvs = append(kvs, [2]*ast.Term{
ast.StringTerm(k),
ast.NewTerm(ast.MustInterfaceToValue(v)),
})
}
var terms []*ast.Term
terms = append(terms, ast.StringTerm(string(reason)))
return ast.ArrayTerm(
ast.BooleanTerm(value),
ast.SetTerm(terms...),
ast.ObjectTerm(kvs...),
)
}

View file

@ -18,7 +18,7 @@ allow:
is: example.com
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{false, A{ReasonUserUnauthenticated}}, res["allow"])
require.Equal(t, A{false, A{ReasonUserUnauthenticated}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by domain", func(t *testing.T) {
@ -40,7 +40,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonDomainOK}}, res["allow"])
require.Equal(t, A{true, A{ReasonDomainOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by impersonate email", func(t *testing.T) {
@ -62,7 +62,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonDomainOK}}, res["allow"])
require.Equal(t, A{true, A{ReasonDomainOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -19,7 +19,7 @@ allow:
is: test@example.com
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{false, A{ReasonUserUnauthenticated}}, res["allow"])
require.Equal(t, A{false, A{ReasonUserUnauthenticated}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by email", func(t *testing.T) {
@ -41,7 +41,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonEmailOK}}, res["allow"])
require.Equal(t, A{true, A{ReasonEmailOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by impersonate session id", func(t *testing.T) {
@ -72,7 +72,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION1"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonEmailOK}}, res["allow"])
require.Equal(t, A{true, A{ReasonEmailOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -20,7 +20,7 @@ allow:
has: group2
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{false, A{ReasonUserUnauthenticated}}, res["allow"])
require.Equal(t, A{false, A{ReasonUserUnauthenticated}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by id", func(t *testing.T) {
@ -42,7 +42,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonGroupsOK}}, res["allow"])
require.Equal(t, A{true, A{ReasonGroupsOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by email", func(t *testing.T) {
@ -68,7 +68,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonGroupsOK}}, res["allow"])
require.Equal(t, A{true, A{ReasonGroupsOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by name", func(t *testing.T) {
@ -94,7 +94,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonGroupsOK}}, res["allow"])
require.Equal(t, A{true, A{ReasonGroupsOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -11,6 +11,9 @@ const (
ReasonClaimOK = "claim-ok"
ReasonClaimUnauthorized = "claim-unauthorized"
ReasonCORSRequest = "cors-request"
ReasonDeviceOK = "device-ok"
ReasonDeviceUnauthenticated = "device-unauthenticated"
ReasonDeviceUnauthorized = "device-unauthorized"
ReasonDomainOK = "domain-ok"
ReasonDomainUnauthorized = "domain-unauthorized"
ReasonEmailOK = "email-ok"

View file

@ -13,6 +13,6 @@ allow:
- reject: 1
`, []dataBrokerRecord{}, Input{})
require.NoError(t, err)
require.Equal(t, A{false, A{ReasonReject}}, res["allow"])
require.Equal(t, A{false, A{ReasonReject}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
}

View file

@ -18,7 +18,7 @@ allow:
is: USER_ID
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{false, A{ReasonUserUnauthenticated}}, res["allow"])
require.Equal(t, A{false, A{ReasonUserUnauthenticated}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by user id", func(t *testing.T) {
@ -36,7 +36,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonUserOK}}, res["allow"])
require.Equal(t, A{true, A{ReasonUserOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by impersonate session id", func(t *testing.T) {
@ -59,7 +59,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION1"}})
require.NoError(t, err)
require.Equal(t, A{true, A{ReasonUserOK}}, res["allow"])
require.Equal(t, A{true, A{ReasonUserOK}, M{}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -53,6 +53,7 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) {
rs.Add(ast.MustParseRule(`default deny = [false, set()]`))
rs.Add(rules.InvertCriterionResult())
rs.Add(rules.NormalizeCriterionResult())
rs.Add(rules.ObjectUnion())
rs.Add(rules.MergeWithAnd())
rs.Add(rules.MergeWithOr())

View file

@ -178,12 +178,14 @@ deny = v {
v := merge_with_or(normalized)
}
invert_criterion_result(result) = [false, result[1]] {
result[0]
invert_criterion_result(in) = out {
in[0]
out = array.concat([false], array.slice(in, 1, count(in)))
}
else = [true, result[1]] {
not result[0]
else = out {
not in[0]
out = array.concat([true], array.slice(in, 1, count(in)))
}
normalize_criterion_result(result) = v {
@ -200,26 +202,39 @@ else = v {
v = [false, set()]
}
merge_with_and(results) = [true, reasons] {
object_union(xs) = merged {
merged = {k: v |
some k
xs[_][k]
vs := [xv | xv := xs[_][k]]
v := vs[minus(count(vs), 1)]
}
}
merge_with_and(results) = [true, reasons, additional_data] {
true_results := [x | x := results[i]; x[0]]
count(true_results) == count(results)
reasons := union({x | x := true_results[i][1]})
additional_data := object_union({x | x := true_results[i][2]})
}
else = [false, reasons] {
else = [false, reasons, additional_data] {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
additional_data := object_union({x | x := false_results[i][2]})
}
merge_with_or(results) = [true, reasons] {
merge_with_or(results) = [true, reasons, additional_data] {
true_results := [x | x := results[i]; x[0]]
count(true_results) > 0
reasons := union({x | x := true_results[i][1]})
additional_data := object_union({x | x := true_results[i][2]})
}
else = [false, reasons] {
else = [false, reasons, additional_data] {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
additional_data := object_union({x | x := false_results[i][2]})
}
`, string(format.MustAst(mod)))
}

View file

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"math"
"github.com/open-policy-agent/opa/ast"
)
@ -137,6 +138,29 @@ func (o Object) Clone() Value {
return no
}
// Falsy returns true if the value is considered Javascript falsy:
// https://developer.mozilla.org/en-US/docs/Glossary/Falsy.
// If the field is not found in the object it is *not* falsy.
func (o Object) Falsy(field string) bool {
v, ok := o[field]
if !ok {
return false
}
switch v := v.(type) {
case Boolean:
return !bool(v)
case Number:
return v.Float64() == 0 || math.IsNaN(v.Float64())
case String:
return v == ""
case Null:
return true
default:
return false
}
}
// RegoValue returns the Object as a rego Value.
func (o Object) RegoValue() ast.Value {
kvps := make([][2]*ast.Term, 0, len(o))
@ -158,6 +182,16 @@ func (o Object) String() string {
return string(bs)
}
// Truthy returns the opposite of Falsy, however if the field is not found in the object it is neither truthy nor falsy.
func (o Object) Truthy(field string) bool {
_, ok := o[field]
if !ok {
return false
}
return !o.Falsy(field)
}
// An Array is a slice of values.
type Array []Value
@ -216,6 +250,18 @@ func (n Number) Clone() Value {
return n
}
// Float64 returns the number as a float64.
func (n Number) Float64() float64 {
v, _ := json.Number(n).Float64()
return v
}
// Int64 returns the number as an int64.
func (n Number) Int64() int64 {
v, _ := json.Number(n).Int64()
return v
}
// RegoValue returns the Number as a rego Value.
func (n Number) RegoValue() ast.Value {
return ast.Number(n)

View file

@ -88,13 +88,15 @@ get_group_ids(session, directory_user) = v {
// MergeWithAnd merges criterion results using `and`.
func MergeWithAnd() *ast.Rule {
return ast.MustParseRule(`
merge_with_and(results) = [true, reasons] {
merge_with_and(results) = [true, reasons, additional_data] {
true_results := [x|x:=results[i];x[0]]
count(true_results) == count(results)
reasons := union({x|x:=true_results[i][1]})
} else = [false, reasons] {
additional_data := object_union({x|x:=true_results[i][2]})
} else = [false, reasons, additional_data] {
false_results := [x|x:=results[i];not x[0]]
reasons := union({x|x:=false_results[i][1]})
additional_data := object_union({x|x:=false_results[i][2]})
}
`)
}
@ -102,13 +104,15 @@ merge_with_and(results) = [true, reasons] {
// MergeWithOr merges criterion results using `or`.
func MergeWithOr() *ast.Rule {
return ast.MustParseRule(`
merge_with_or(results) = [true, reasons] {
merge_with_or(results) = [true, reasons, additional_data] {
true_results := [x|x:=results[i];x[0]]
count(true_results) > 0
reasons := union({x|x:=true_results[i][1]})
} else = [false, reasons] {
additional_data := object_union({x|x:=true_results[i][2]})
} else = [false, reasons, additional_data] {
false_results := [x|x:=results[i];not x[0]]
reasons := union({x|x:=false_results[i][1]})
additional_data := object_union({x|x:=false_results[i][2]})
}
`)
}
@ -117,10 +121,12 @@ merge_with_or(results) = [true, reasons] {
// true, or vice-versa.
func InvertCriterionResult() *ast.Rule {
return ast.MustParseRule(`
invert_criterion_result(result) = [false, result[1]] {
result[0]
} else = [true, result[1]] {
not result[0]
invert_criterion_result(in) = out {
in[0]
out = array.concat([false], array.slice(in, 1, count(in)))
} else = out {
not in[0]
out = array.concat([true], array.slice(in, 1, count(in)))
}
`)
}
@ -176,3 +182,17 @@ object_get(obj, key, def) = value {
}
`)
}
// ObjectUnion merges objects together. It expects a set of objects.
func ObjectUnion() *ast.Rule {
return ast.MustParseRule(`
object_union(xs) = merged {
merged = { k: v |
some k
xs[_][k]
vs := [ xv | xv := xs[_][k] ]
v := vs[count(vs)-1]
}
}
`)
}