ppl: pass contextual information through policy (#2612)

* ppl: pass contextual information through policy

* maybe fix nginx

* fix nginx

* pr comments

* go mod tidy
This commit is contained in:
Caleb Doxsey 2021-09-20 16:02:26 -06:00 committed by GitHub
parent 5340f55c20
commit efffe57bf0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
40 changed files with 1144 additions and 703 deletions

View file

@ -23,11 +23,56 @@ import (
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/internal/telemetry/requestid"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/policy/criteria"
)
func (a *Authorize) okResponse(reply *evaluator.Result) *envoy_service_auth_v3.CheckResponse {
func (a *Authorize) handleResultAllowed(
ctx context.Context,
in *envoy_service_auth_v3.CheckRequest,
result *evaluator.Result,
) (*envoy_service_auth_v3.CheckResponse, error) {
return a.okResponse(result.Headers), nil
}
func (a *Authorize) handleResultDenied(
ctx context.Context,
in *envoy_service_auth_v3.CheckRequest,
result *evaluator.Result,
) (*envoy_service_auth_v3.CheckResponse, error) {
denyStatusCode := int32(http.StatusForbidden)
denyStatusText := http.StatusText(http.StatusForbidden)
switch {
case result.Deny.Reasons.Has(criteria.ReasonRouteNotFound):
denyStatusCode = http.StatusNotFound
denyStatusText = http.StatusText(http.StatusNotFound)
case result.Deny.Reasons.Has(criteria.ReasonInvalidClientCertificate):
denyStatusCode = httputil.StatusInvalidClientCertificate
denyStatusText = httputil.StatusText(httputil.StatusInvalidClientCertificate)
}
return a.deniedResponse(ctx, in, denyStatusCode, denyStatusText, nil)
}
func (a *Authorize) handleResultNotAllowed(
ctx context.Context,
in *envoy_service_auth_v3.CheckRequest,
result *evaluator.Result,
isForwardAuthVerify bool,
) (*envoy_service_auth_v3.CheckResponse, error) {
switch {
case result.Allow.Reasons.Has(criteria.ReasonUserUnauthenticated):
// when the user is unauthenticated it means they haven't
// logged in yet, so redirect to authenticate
return a.requireLoginResponse(ctx, in, isForwardAuthVerify)
}
return a.deniedResponse(ctx, in, http.StatusForbidden, http.StatusText(http.StatusForbidden), nil)
}
func (a *Authorize) okResponse(headers http.Header) *envoy_service_auth_v3.CheckResponse {
var requestHeaders []*envoy_config_core_v3.HeaderValueOption
for k, vs := range reply.Headers {
for k, vs := range headers {
requestHeaders = append(requestHeaders, mkHeader(k, strings.Join(vs, ","), false))
}
// ensure request headers are sorted by key for deterministic output
@ -105,7 +150,11 @@ func (a *Authorize) deniedResponse(
}, nil
}
func (a *Authorize) requireLoginResponse(ctx context.Context, in *envoy_service_auth_v3.CheckRequest) (*envoy_service_auth_v3.CheckResponse, error) {
func (a *Authorize) requireLoginResponse(
ctx context.Context,
in *envoy_service_auth_v3.CheckRequest,
isForwardAuthVerify bool,
) (*envoy_service_auth_v3.CheckResponse, error) {
opts := a.currentOptions.Load()
state := a.state.Load()
authenticateURL, err := opts.GetAuthenticateURL()
@ -113,7 +162,7 @@ func (a *Authorize) requireLoginResponse(ctx context.Context, in *envoy_service_
return nil, err
}
if !shouldRedirect(in) {
if !a.shouldRedirect(in) || isForwardAuthVerify {
return a.deniedResponse(ctx, in, http.StatusUnauthorized, http.StatusText(http.StatusUnauthorized), nil)
}
@ -186,7 +235,7 @@ func (a *Authorize) userInfoEndpointURL(in *envoy_service_auth_v3.CheckRequest)
return urlutil.NewSignedURL(a.state.Load().sharedKey, debugEndpoint).Sign(), nil
}
func shouldRedirect(in *envoy_service_auth_v3.CheckRequest) bool {
func (a *Authorize) shouldRedirect(in *envoy_service_auth_v3.CheckRequest) bool {
requestHeaders := in.GetAttributes().GetRequest().GetHttp().GetHeaders()
if requestHeaders == nil {
return true
@ -196,12 +245,12 @@ func shouldRedirect(in *envoy_service_auth_v3.CheckRequest) bool {
return false
}
a, err := rfc7231.ParseAccept(requestHeaders["accept"])
accept, err := rfc7231.ParseAccept(requestHeaders["accept"])
if err != nil {
return true
}
mediaType, ok := a.MostAcceptable([]string{
mediaType, ok := accept.MostAcceptable([]string{
"text/html",
"application/json",
"text/plain",

View file

@ -63,28 +63,28 @@ func TestAuthorize_okResponse(t *testing.T) {
}{
{
"ok reply",
&evaluator.Result{Allow: true},
&evaluator.Result{Allow: evaluator.NewRuleResult(true)},
&envoy_service_auth_v3.CheckResponse{
Status: &status.Status{Code: 0, Message: "OK"},
},
},
{
"ok reply with k8s svc",
&evaluator.Result{Allow: true},
&evaluator.Result{Allow: evaluator.NewRuleResult(true)},
&envoy_service_auth_v3.CheckResponse{
Status: &status.Status{Code: 0, Message: "OK"},
},
},
{
"ok reply with k8s svc impersonate",
&evaluator.Result{Allow: true},
&evaluator.Result{Allow: evaluator.NewRuleResult(true)},
&envoy_service_auth_v3.CheckResponse{
Status: &status.Status{Code: 0, Message: "OK"},
},
},
{
"ok reply with jwt claims header",
&evaluator.Result{Allow: true},
&evaluator.Result{Allow: evaluator.NewRuleResult(true)},
&envoy_service_auth_v3.CheckResponse{
Status: &status.Status{Code: 0, Message: "OK"},
},
@ -93,7 +93,7 @@ func TestAuthorize_okResponse(t *testing.T) {
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := a.okResponse(tc.reply)
got := a.okResponse(tc.reply.Headers)
assert.Equal(t, tc.want.Status.Code, got.Status.Code)
assert.Equal(t, tc.want.Status.Message, got.Status.Message)
want, _ := protojson.Marshal(tc.want.GetOkResponse())
@ -175,7 +175,7 @@ func TestRequireLogin(t *testing.T) {
require.NoError(t, err)
t.Run("accept empty", func(t *testing.T) {
res, err := a.requireLoginResponse(context.Background(), &envoy_service_auth_v3.CheckRequest{})
res, err := a.requireLoginResponse(context.Background(), &envoy_service_auth_v3.CheckRequest{}, false)
require.NoError(t, err)
assert.Equal(t, http.StatusFound, int(res.GetDeniedResponse().GetStatus().GetCode()))
})
@ -190,7 +190,7 @@ func TestRequireLogin(t *testing.T) {
},
},
},
})
}, false)
require.NoError(t, err)
assert.Equal(t, http.StatusFound, int(res.GetDeniedResponse().GetStatus().GetCode()))
})
@ -205,7 +205,7 @@ func TestRequireLogin(t *testing.T) {
},
},
},
})
}, false)
require.NoError(t, err)
assert.Equal(t, http.StatusUnauthorized, int(res.GetDeniedResponse().GetStatus().GetCode()))
})

View file

@ -16,15 +16,12 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/policy/criteria"
)
// notFoundOutput is what's returned if a route isn't found for a policy.
var notFoundOutput = &Result{
Allow: false,
Deny: &Denial{
Status: http.StatusNotFound,
Message: "route not found",
},
Deny: NewRuleResult(true, criteria.ReasonRouteNotFound),
Headers: make(http.Header),
}
@ -50,8 +47,8 @@ type RequestSession struct {
// Result is the result of evaluation.
type Result struct {
Allow bool
Deny *Denial
Allow RuleResult
Deny RuleResult
Headers http.Header
DataBrokerServerVersion, DataBrokerRecordVersion uint64

View file

@ -21,6 +21,7 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/directory"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/policy/criteria"
"github.com/pomerium/pomerium/pkg/protoutil"
)
@ -98,7 +99,7 @@ func TestEvaluator(t *testing.T) {
Policy: &policies[0],
})
require.NoError(t, err)
assert.Equal(t, &Denial{Status: 495, Message: "invalid client certificate"}, res.Deny)
assert.Equal(t, NewRuleResult(true, criteria.ReasonInvalidClientCertificate), res.Deny)
})
t.Run("valid", func(t *testing.T) {
res, err := eval(t, options, nil, &Request{
@ -108,7 +109,7 @@ func TestEvaluator(t *testing.T) {
},
})
require.NoError(t, err)
assert.Nil(t, res.Deny)
assert.False(t, res.Deny.Value)
})
})
t.Run("identity_headers", func(t *testing.T) {
@ -186,7 +187,7 @@ func TestEvaluator(t *testing.T) {
},
})
require.NoError(t, err)
assert.True(t, res.Allow)
assert.True(t, res.Allow.Value)
})
t.Run("allowed sub", func(t *testing.T) {
res, err := eval(t, options, []proto.Message{
@ -210,7 +211,7 @@ func TestEvaluator(t *testing.T) {
},
})
require.NoError(t, err)
assert.True(t, res.Allow)
assert.True(t, res.Allow.Value)
})
t.Run("denied", func(t *testing.T) {
res, err := eval(t, options, []proto.Message{
@ -234,7 +235,7 @@ func TestEvaluator(t *testing.T) {
},
})
require.NoError(t, err)
assert.False(t, res.Allow)
assert.False(t, res.Allow.Value)
})
})
t.Run("impersonate email", func(t *testing.T) {
@ -265,7 +266,7 @@ func TestEvaluator(t *testing.T) {
},
})
require.NoError(t, err)
assert.True(t, res.Allow)
assert.True(t, res.Allow.Value)
})
})
t.Run("user_id", func(t *testing.T) {
@ -290,7 +291,7 @@ func TestEvaluator(t *testing.T) {
},
})
require.NoError(t, err)
assert.True(t, res.Allow)
assert.True(t, res.Allow.Value)
})
t.Run("domain", func(t *testing.T) {
res, err := eval(t, options, []proto.Message{
@ -314,7 +315,7 @@ func TestEvaluator(t *testing.T) {
},
})
require.NoError(t, err)
assert.True(t, res.Allow)
assert.True(t, res.Allow.Value)
})
t.Run("impersonate domain", func(t *testing.T) {
res, err := eval(t, options, []proto.Message{
@ -343,7 +344,7 @@ func TestEvaluator(t *testing.T) {
},
})
require.NoError(t, err)
assert.True(t, res.Allow)
assert.True(t, res.Allow.Value)
})
t.Run("groups", func(t *testing.T) {
res, err := eval(t, options, []proto.Message{
@ -376,7 +377,7 @@ func TestEvaluator(t *testing.T) {
},
})
require.NoError(t, err)
assert.True(t, res.Allow)
assert.True(t, res.Allow.Value)
})
t.Run("any authenticated user", func(t *testing.T) {
res, err := eval(t, options, []proto.Message{
@ -399,7 +400,7 @@ func TestEvaluator(t *testing.T) {
},
})
require.NoError(t, err)
assert.True(t, res.Allow)
assert.True(t, res.Allow.Value)
})
t.Run("carry over assertion header", func(t *testing.T) {
tcs := []struct {

View file

@ -3,8 +3,6 @@ package evaluator
import (
"context"
"fmt"
"net/http"
"strconv"
"strings"
"github.com/open-policy-agent/opa/rego"
@ -15,6 +13,7 @@ import (
"github.com/pomerium/pomerium/internal/telemetry/trace"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/policy"
"github.com/pomerium/pomerium/pkg/policy/criteria"
)
// PolicyRequest is the input to policy evaluation.
@ -26,29 +25,49 @@ type PolicyRequest struct {
// PolicyResponse is the result of evaluating a policy.
type PolicyResponse struct {
Allow bool
Deny *Denial
Allow, Deny RuleResult
}
// Merge merges another PolicyResponse into this PolicyResponse. Access is allowed if either is allowed. Access is denied if
// either is denied. (and denials take precedence)
func (res *PolicyResponse) Merge(other *PolicyResponse) *PolicyResponse {
merged := &PolicyResponse{
Allow: res.Allow || other.Allow,
Deny: res.Deny,
// A RuleResult is the result of evaluating a rule.
type RuleResult struct {
Value bool
Reasons criteria.Reasons
}
// NewRuleResult creates a new RuleResult.
func NewRuleResult(value bool, reasons ...criteria.Reason) RuleResult {
return RuleResult{
Value: value,
Reasons: criteria.NewReasons(reasons...),
}
if other.Deny != nil {
merged.Deny = other.Deny
}
// MergeRuleResultsWithOr merges all the results using `or`.
func MergeRuleResultsWithOr(results ...RuleResult) (merged RuleResult) {
var trueResults, falseResults []RuleResult
for _, result := range results {
if result.Value {
trueResults = append(trueResults, result)
} else {
falseResults = append(falseResults, result)
}
}
if len(trueResults) > 0 {
merged.Value = true
for _, result := range trueResults {
merged.Reasons = merged.Reasons.Union(result.Reasons)
}
} else {
merged.Value = false
for _, result := range falseResults {
merged.Reasons = merged.Reasons.Union(result.Reasons)
}
}
return merged
}
// A Denial indicates the request should be denied (even if otherwise allowed).
type Denial struct {
Status int
Message string
}
type policyQuery struct {
rego.PreparedEvalQuery
checksum string
@ -133,7 +152,8 @@ func (e *PolicyEvaluator) Evaluate(ctx context.Context, req *PolicyRequest) (*Po
if err != nil {
return nil, err
}
res = res.Merge(o)
res.Allow = MergeRuleResultsWithOr(res.Allow, o.Allow)
res.Deny = MergeRuleResultsWithOr(res.Deny, o.Deny)
}
return res, nil
}
@ -153,67 +173,45 @@ func (e *PolicyEvaluator) evaluateQuery(ctx context.Context, req *PolicyRequest,
}
res := &PolicyResponse{
Allow: e.getAllow(rs[0].Bindings),
Deny: e.getDeny(ctx, rs[0].Bindings),
Allow: e.getRuleResult("allow", rs[0].Bindings),
Deny: e.getRuleResult("deny", rs[0].Bindings),
}
return res, nil
}
// getAllow gets the allow var. It expects a boolean.
func (e *PolicyEvaluator) getAllow(vars rego.Vars) bool {
// getRuleResult gets the rule result var. It expects a boolean or [boolean, []string].
func (e *PolicyEvaluator) getRuleResult(name string, vars rego.Vars) (result RuleResult) {
result = NewRuleResult(false)
m, ok := vars["result"].(map[string]interface{})
if !ok {
return false
return result
}
allow, ok := m["allow"].(bool)
if !ok {
return false
}
return allow
}
// getDeny gets the deny var. It expects an (http status code, message) pair.
func (e *PolicyEvaluator) getDeny(ctx context.Context, vars rego.Vars) *Denial {
m, ok := vars["result"].(map[string]interface{})
if !ok {
return nil
}
var status int
var reason string
switch t := m["deny"].(type) {
switch t := m[name].(type) {
case bool:
if t {
status = http.StatusForbidden
reason = ""
} else {
return nil
}
result.Value = t
case []interface{}:
switch len(t) {
case 0:
return nil
case 2:
var err error
status, err = strconv.Atoi(fmt.Sprint(t[0]))
if err != nil {
log.Error(ctx).Err(err).Msg("invalid type in deny")
return nil
// fill in the reasons
v, ok := t[1].([]interface{})
if !ok {
return result
}
reason = fmt.Sprint(t[1])
default:
log.Error(ctx).Interface("deny", t).Msg("invalid size in deny")
return nil
for _, vv := range v {
result.Reasons.Add(criteria.Reason(fmt.Sprint(vv)))
}
fallthrough
case 1:
// fill in the value
v, ok := t[0].(bool)
if !ok {
return result
}
result.Value = v
}
default:
return nil
}
return &Denial{
Status: status,
Message: reason,
}
return result
}

View file

@ -3,7 +3,6 @@ package evaluator
import (
"context"
"math"
"net/http"
"strings"
"testing"
@ -16,6 +15,7 @@ import (
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/policy"
"github.com/pomerium/pomerium/pkg/policy/criteria"
)
func TestPolicyEvaluator(t *testing.T) {
@ -70,7 +70,8 @@ func TestPolicyEvaluator(t *testing.T) {
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: true,
Allow: NewRuleResult(true, criteria.ReasonEmailOK),
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificateOrNoneRequired),
}, output)
})
t.Run("invalid cert", func(t *testing.T) {
@ -85,11 +86,8 @@ func TestPolicyEvaluator(t *testing.T) {
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: true,
Deny: &Denial{
Status: 495,
Message: "invalid client certificate",
},
Allow: NewRuleResult(true, criteria.ReasonEmailOK),
Deny: NewRuleResult(true, criteria.ReasonInvalidClientCertificate),
}, output)
})
t.Run("forbidden", func(t *testing.T) {
@ -104,7 +102,8 @@ func TestPolicyEvaluator(t *testing.T) {
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: false,
Allow: NewRuleResult(false, criteria.ReasonEmailUnauthorized, criteria.ReasonNonPomeriumRoute, criteria.ReasonUserUnauthorized),
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificateOrNoneRequired),
}, output)
})
t.Run("ppl", func(t *testing.T) {
@ -133,7 +132,8 @@ func TestPolicyEvaluator(t *testing.T) {
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Allow: true,
Allow: NewRuleResult(true, criteria.ReasonAccept),
Deny: NewRuleResult(false, criteria.ReasonValidClientCertificateOrNoneRequired),
}, output)
})
t.Run("deny", func(t *testing.T) {
@ -161,9 +161,8 @@ func TestPolicyEvaluator(t *testing.T) {
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Deny: &Denial{
Status: http.StatusForbidden,
},
Allow: NewRuleResult(false, criteria.ReasonNonPomeriumRoute),
Deny: NewRuleResult(true, criteria.ReasonAccept),
}, output)
})
t.Run("client certificate", func(t *testing.T) {
@ -192,10 +191,8 @@ func TestPolicyEvaluator(t *testing.T) {
})
require.NoError(t, err)
assert.Equal(t, &PolicyResponse{
Deny: &Denial{
Status: 495,
Message: "invalid client certificate",
},
Allow: NewRuleResult(false, criteria.ReasonNonPomeriumRoute),
Deny: NewRuleResult(true, criteria.ReasonAccept, criteria.ReasonInvalidClientCertificate),
}, output)
})
})

View file

@ -72,25 +72,16 @@ func (a *Authorize) Check(ctx context.Context, in *envoy_service_auth_v3.CheckRe
a.logAuthorizeCheck(ctx, in, out, res, s, u)
}()
denyStatusCode := int32(http.StatusForbidden)
denyStatusText := http.StatusText(http.StatusForbidden)
if res.Deny != nil {
denyStatusCode = int32(res.Deny.Status)
denyStatusText = res.Deny.Message
} else if res.Allow {
return a.okResponse(res), nil
if res.Deny.Value {
return a.handleResultDenied(ctx, in, res)
}
// if we're logged in, don't redirect, deny with forbidden
if req.Session.ID != "" {
return a.deniedResponse(ctx, in, denyStatusCode, denyStatusText, nil)
if res.Allow.Value {
return a.handleResultAllowed(ctx, in, res)
}
if isForwardAuth && hreq.URL.Path == "/verify" {
return a.deniedResponse(ctx, in, http.StatusUnauthorized, "Unauthenticated", nil)
}
return a.requireLoginResponse(ctx, in)
isForwardAuthVerify := isForwardAuth && hreq.URL.Path == "/verify"
return a.handleResultNotAllowed(ctx, in, res, isForwardAuthVerify)
}
func getForwardAuthURL(r *http.Request) *url.URL {

View file

@ -46,8 +46,18 @@ func (a *Authorize) logAuthorizeCheck(
// result
if res != nil {
evt = evt.Bool("allow", res.Allow)
evt = evt.Interface("deny", res.Deny)
evt = evt.Bool("allow", res.Allow.Value)
if res.Allow.Value {
evt = evt.Strs("allow-why-true", res.Allow.Reasons.Strings())
} else {
evt = evt.Strs("allow-why-false", res.Allow.Reasons.Strings())
}
evt = evt.Bool("deny", res.Deny.Value)
if res.Deny.Value {
evt = evt.Strs("deny-why-true", res.Deny.Reasons.Strings())
} else {
evt = evt.Strs("deny-why-false", res.Deny.Reasons.Strings())
}
evt = evt.Str("user", u.GetId())
evt = evt.Str("email", u.GetEmail())
evt = evt.Uint64("databroker_server_version", res.DataBrokerServerVersion)

View file

@ -56,66 +56,126 @@ func TestPolicy_ToPPL(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, `package pomerium.policy
default allow = false
default allow = [false, set()]
default deny = false
default deny = [false, set()]
pomerium_routes_0 {
pomerium_routes_0 = [true, {"pomerium-route"}] {
contains(input.http.url, "/.pomerium/")
}
accept_0 = v {
v := true
else = [false, {"non-pomerium-route"}] {
true
}
cors_preflight_0 {
accept_0 = [true, {"accept"}]
cors_preflight_0 = [true, {"cors-request"}] {
input.http.method == "OPTIONS"
count(object.get(input.http.headers, "Access-Control-Request-Method", [])) > 0
count(object.get(input.http.headers, "Origin", [])) > 0
}
authenticated_user_0 {
else = [false, {"non-cors-request"}] {
true
}
authenticated_user_0 = [true, {"user-ok"}] {
session := get_session(input.session.id)
session.user_id != null
session.user_id != ""
}
domains_0 {
else = [false, {"user-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
domain_0 = [true, {"domain-ok"}] {
session := get_session(input.session.id)
user := get_user(session)
domain := split(get_user_email(session, user), "@")[1]
domain == "a.example.com"
}
domains_1 {
else = [false, {"domain-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
domain_1 = [true, {"domain-ok"}] {
session := get_session(input.session.id)
user := get_user(session)
domain := split(get_user_email(session, user), "@")[1]
domain == "b.example.com"
}
domains_2 {
else = [false, {"domain-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
domain_2 = [true, {"domain-ok"}] {
session := get_session(input.session.id)
user := get_user(session)
domain := split(get_user_email(session, user), "@")[1]
domain == "c.example.com"
}
domains_3 {
else = [false, {"domain-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
domain_3 = [true, {"domain-ok"}] {
session := get_session(input.session.id)
user := get_user(session)
domain := split(get_user_email(session, user), "@")[1]
domain == "d.example.com"
}
domains_4 {
else = [false, {"domain-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
domain_4 = [true, {"domain-ok"}] {
session := get_session(input.session.id)
user := get_user(session)
domain := split(get_user_email(session, user), "@")[1]
domain == "e.example.com"
}
groups_0 {
else = [false, {"domain-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
groups_0 = [true, {"groups-ok"}] {
session := get_session(input.session.id)
directory_user := get_directory_user(session)
group_ids := get_group_ids(session, directory_user)
@ -137,7 +197,16 @@ groups_0 {
count([true | some v; v = groups[_0]; v == "group1"]) > 0
}
groups_1 {
else = [false, {"groups-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
groups_1 = [true, {"groups-ok"}] {
session := get_session(input.session.id)
directory_user := get_directory_user(session)
group_ids := get_group_ids(session, directory_user)
@ -159,7 +228,16 @@ groups_1 {
count([true | some v; v = groups[_0]; v == "group2"]) > 0
}
groups_2 {
else = [false, {"groups-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
groups_2 = [true, {"groups-ok"}] {
session := get_session(input.session.id)
directory_user := get_directory_user(session)
group_ids := get_group_ids(session, directory_user)
@ -181,7 +259,16 @@ groups_2 {
count([true | some v; v = groups[_0]; v == "group3"]) > 0
}
groups_3 {
else = [false, {"groups-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
groups_3 = [true, {"groups-ok"}] {
session := get_session(input.session.id)
directory_user := get_directory_user(session)
group_ids := get_group_ids(session, directory_user)
@ -203,7 +290,16 @@ groups_3 {
count([true | some v; v = groups[_0]; v == "group4"]) > 0
}
groups_4 {
else = [false, {"groups-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
groups_4 = [true, {"groups-ok"}] {
session := get_session(input.session.id)
directory_user := get_directory_user(session)
group_ids := get_group_ids(session, directory_user)
@ -225,7 +321,16 @@ groups_4 {
count([true | some v; v = groups[_0]; v == "group5"]) > 0
}
claims_0 {
else = [false, {"groups-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
claim_0 = [true, {"claim-ok"}] {
rule_data := "Smith"
rule_path := "family_name"
session := get_session(input.session.id)
@ -237,7 +342,16 @@ claims_0 {
rule_data == values[_0]
}
claims_1 {
else = [false, {"claim-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
claim_1 = [true, {"claim-ok"}] {
rule_data := "Jones"
rule_path := "family_name"
session := get_session(input.session.id)
@ -249,7 +363,16 @@ claims_1 {
rule_data == values[_0]
}
claims_2 {
else = [false, {"claim-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
claim_2 = [true, {"claim-ok"}] {
rule_data := "John"
rule_path := "given_name"
session := get_session(input.session.id)
@ -261,7 +384,16 @@ claims_2 {
rule_data == values[_0]
}
claims_3 {
else = [false, {"claim-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
claim_3 = [true, {"claim-ok"}] {
rule_data := "EST"
rule_path := "timezone"
session := get_session(input.session.id)
@ -273,246 +405,266 @@ claims_3 {
rule_data == values[_0]
}
users_0 {
else = [false, {"claim-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
user_0 = [true, {"user-ok"}] {
session := get_session(input.session.id)
user_id := session.user_id
user_id == "user1"
}
emails_0 {
else = [false, {"user-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
email_0 = [true, {"email-ok"}] {
session := get_session(input.session.id)
user := get_user(session)
email := get_user_email(session, user)
email == "user1"
}
users_1 {
else = [false, {"email-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
user_1 = [true, {"user-ok"}] {
session := get_session(input.session.id)
user_id := session.user_id
user_id == "user2"
}
emails_1 {
else = [false, {"user-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
email_1 = [true, {"email-ok"}] {
session := get_session(input.session.id)
user := get_user(session)
email := get_user_email(session, user)
email == "user2"
}
users_2 {
else = [false, {"email-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
user_2 = [true, {"user-ok"}] {
session := get_session(input.session.id)
user_id := session.user_id
user_id == "user3"
}
emails_2 {
else = [false, {"user-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
email_2 = [true, {"email-ok"}] {
session := get_session(input.session.id)
user := get_user(session)
email := get_user_email(session, user)
email == "user3"
}
users_3 {
else = [false, {"email-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
user_3 = [true, {"user-ok"}] {
session := get_session(input.session.id)
user_id := session.user_id
user_id == "user4"
}
emails_3 {
else = [false, {"user-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
email_3 = [true, {"email-ok"}] {
session := get_session(input.session.id)
user := get_user(session)
email := get_user_email(session, user)
email == "user4"
}
users_4 {
else = [false, {"email-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
user_4 = [true, {"user-ok"}] {
session := get_session(input.session.id)
user_id := session.user_id
user_id == "user5"
}
emails_4 {
else = [false, {"user-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = [false, {"user-unauthenticated"}] {
true
}
email_4 = [true, {"email-ok"}] {
session := get_session(input.session.id)
user := get_user(session)
email := get_user_email(session, user)
email == "user5"
}
or_0 = v1 {
v1 := pomerium_routes_0
v1
else = [false, {"email-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
else = v2 {
v2 := accept_0
v2
else = [false, {"user-unauthenticated"}] {
true
}
else = v3 {
v3 := cors_preflight_0
v3
or_0 = v {
results := [pomerium_routes_0, accept_0, cors_preflight_0, authenticated_user_0, domain_0, domain_1, domain_2, domain_3, domain_4, groups_0, groups_1, groups_2, groups_3, groups_4, claim_0, claim_1, claim_2, claim_3, user_0, email_0, user_1, email_1, user_2, email_2, user_3, email_3, user_4, email_4]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
else = v4 {
v4 := authenticated_user_0
v4
}
else = v5 {
v5 := domains_0
v5
}
else = v6 {
v6 := domains_1
v6
}
else = v7 {
v7 := domains_2
v7
}
else = v8 {
v8 := domains_3
v8
}
else = v9 {
v9 := domains_4
v9
}
else = v10 {
v10 := groups_0
v10
}
else = v11 {
v11 := groups_1
v11
}
else = v12 {
v12 := groups_2
v12
}
else = v13 {
v13 := groups_3
v13
}
else = v14 {
v14 := groups_4
v14
}
else = v15 {
v15 := claims_0
v15
}
else = v16 {
v16 := claims_1
v16
}
else = v17 {
v17 := claims_2
v17
}
else = v18 {
v18 := claims_3
v18
}
else = v19 {
v19 := users_0
v19
}
else = v20 {
v20 := emails_0
v20
}
else = v21 {
v21 := users_1
v21
}
else = v22 {
v22 := emails_1
v22
}
else = v23 {
v23 := users_2
v23
}
else = v24 {
v24 := emails_2
v24
}
else = v25 {
v25 := users_3
v25
}
else = v26 {
v26 := emails_3
v26
}
else = v27 {
v27 := users_4
v27
}
else = v28 {
v28 := emails_4
v28
}
users_5 {
user_5 = [true, {"user-ok"}] {
session := get_session(input.session.id)
user_id := session.user_id
user_id == "user6"
}
or_1 = v1 {
v1 := users_5
v1
else = [false, {"user-unauthorized"}] {
session := get_session(input.session.id)
session.id != ""
}
allow = v1 {
v1 := or_0
v1
else = [false, {"user-unauthenticated"}] {
true
}
else = v2 {
v2 := or_1
v2
or_1 = v {
results := [user_5]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
invalid_client_certificate_0 = reason {
reason = [495, "invalid client certificate"]
allow = v {
results := [or_0, or_1]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
invalid_client_certificate_0 = [true, {"invalid-client-certificate"}] {
is_boolean(input.is_valid_client_certificate)
not input.is_valid_client_certificate
}
or_2 = v1 {
v1 := invalid_client_certificate_0
v1
else = [false, {"valid-client-certificate-or-none-required"}] {
true
}
deny = v1 {
v1 := or_2
v1
or_2 = v {
results := [invalid_client_certificate_0]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
deny = v {
results := [or_2]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
invert_criterion_result(result) = [false, result[1]] {
result[0]
}
else = [true, result[1]] {
not result[0]
}
normalize_criterion_result(result) = v {
is_boolean(result)
v = [result, set()]
}
else = v {
is_array(result)
v = result
}
else = v {
v = [false, set()]
}
merge_with_and(results) = [true, reasons] {
true_results := [x | x := results[i]; x[0]]
count(true_results) == count(results)
reasons := union({x | x := true_results[i][1]})
}
else = [false, reasons] {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
}
merge_with_or(results) = [true, reasons] {
true_results := [x | x := results[i]; x[0]]
count(true_results) > 0
reasons := union({x | x := true_results[i][1]})
}
else = [false, reasons] {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
}
get_session(id) = v {

1
go.sum
View file

@ -1754,6 +1754,7 @@ google.golang.org/api v0.50.0/go.mod h1:4bNT5pAuq5ji4SRZm+5QIkjny9JAyVD/3gaSihNe
google.golang.org/api v0.51.0/go.mod h1:t4HdrdoNgyN5cbEfm7Lum0lcLDLiise1F8qDKX00sOU=
google.golang.org/api v0.54.0/go.mod h1:7C4bFFOvVDGXjfDTAsgGwDgAxRDeQ4X8NvUedIt6z3k=
google.golang.org/api v0.55.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE=
google.golang.org/api v0.56.0/go.mod h1:38yMfeP1kfjsl8isn0tliTjIb1rJXcQi4UXlbqivdVE=
google.golang.org/api v0.57.0 h1:4t9zuDlHLcIx0ZEhmXEeFVCRsiOgpgn2QOH9N0MNjPI=
google.golang.org/api v0.57.0/go.mod h1:dVPlbZyBo2/OjBpmvNdpn2GRm6rPy75jyU7bmhdrMgI=
google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM=

View file

@ -7,10 +7,6 @@ import (
"github.com/pomerium/pomerium/pkg/policy/parser"
)
var acceptBody = ast.Body{
ast.MustParseExpr(`v := true`),
}
type acceptCriterion struct {
g *Generator
}
@ -24,9 +20,9 @@ func (acceptCriterion) Name() string {
}
func (c acceptCriterion) GenerateRule(_ string, _ parser.Value) (*ast.Rule, []*ast.Rule, error) {
rule := c.g.NewRule("accept")
rule.Head.Value = ast.VarTerm("v")
rule.Body = acceptBody
rule := c.g.NewRule(c.Name())
rule.Head.Value = NewCriterionTerm(true, ReasonAccept)
rule.Body = ast.Body{ast.NewExpr(ast.BooleanTerm(true))}
return rule, nil, nil
}

View file

@ -13,6 +13,6 @@ allow:
- accept: 1
`, []dataBrokerRecord{}, Input{})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{ReasonAccept}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
}

View file

@ -27,8 +27,9 @@ func (authenticatedUserCriterion) Name() string {
}
func (c authenticatedUserCriterion) GenerateRule(_ string, _ parser.Value) (*ast.Rule, []*ast.Rule, error) {
rule := c.g.NewRule("authenticated_user")
rule.Body = authenticatedUserBody
rule := NewCriterionSessionRule(c.g, c.Name(),
ReasonUserOK, ReasonUserUnauthorized,
authenticatedUserBody)
return rule, []*ast.Rule{rules.GetSession()}, nil
}

View file

@ -16,8 +16,8 @@ allow:
- authenticated_user: 1
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, false, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{false, A{ReasonUserUnauthenticated}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by domain", func(t *testing.T) {
res, err := evaluate(t, `
@ -33,7 +33,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{ReasonUserOK}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -45,14 +45,13 @@ func (claimsCriterion) Name() string {
}
func (c claimsCriterion) GenerateRule(subPath string, data parser.Value) (*ast.Rule, []*ast.Rule, error) {
r := c.g.NewRule("claims")
r.Body = append(r.Body,
rule := NewCriterionSessionRule(c.g, c.Name(),
ReasonClaimOK, ReasonClaimUnauthorized,
append(ast.Body{
ast.Assign.Expr(ast.VarTerm("rule_data"), ast.NewTerm(data.RegoValue())),
ast.Assign.Expr(ast.VarTerm("rule_path"), ast.NewTerm(ast.MustInterfaceToValue(subPath))),
)
r.Body = append(r.Body, claimsBody...)
return r, []*ast.Rule{
}, claimsBody...))
return rule, []*ast.Rule{
rules.GetSession(),
rules.GetUser(),
rules.ObjectGet(),

View file

@ -18,8 +18,25 @@ allow:
- claim/family_name: Smith
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, false, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{false, A{ReasonUserUnauthenticated}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("no claim", func(t *testing.T) {
res, err := evaluate(t, `
allow:
and:
- claim/family_name: Smith
`,
[]dataBrokerRecord{
&session.Session{
Id: "SESSION_ID",
UserId: "USER_ID",
},
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, A{false, A{ReasonClaimUnauthorized}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by session claim", func(t *testing.T) {
res, err := evaluate(t, `
@ -42,8 +59,8 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{ReasonClaimOK}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by user claim", func(t *testing.T) {
res, err := evaluate(t, `
@ -66,7 +83,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{ReasonClaimOK}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -26,8 +26,9 @@ func (corsPreflightCriterion) Name() string {
}
func (c corsPreflightCriterion) GenerateRule(_ string, _ parser.Value) (*ast.Rule, []*ast.Rule, error) {
rule := c.g.NewRule("cors_preflight")
rule.Body = corsPreflightBody
rule := NewCriterionRule(c.g, c.Name(),
ReasonCORSRequest, ReasonNonCORSRequest,
corsPreflightBody)
return rule, nil, nil
}

View file

@ -20,8 +20,8 @@ allow:
},
}})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{"cors-request"}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("false", func(t *testing.T) {
res, err := evaluate(t, `
@ -32,7 +32,7 @@ allow:
Method: "OPTIONS",
}})
require.NoError(t, err)
require.Equal(t, false, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{false, A{"non-cors-request"}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -4,6 +4,8 @@ package criteria
import (
"sync"
"github.com/open-policy-agent/opa/ast"
"github.com/pomerium/pomerium/pkg/policy/generator"
)
@ -48,3 +50,79 @@ const (
// CriterionDataTypeStringMatcher indicates the expected data type is a string matcher.
CriterionDataTypeStringMatcher CriterionDataType = "string_matcher"
)
// NewCriterionRule generates a new rule for a criterion.
func NewCriterionRule(
g *generator.Generator,
name string,
passReason, failReason Reason,
body ast.Body,
) *ast.Rule {
r1 := g.NewRule(name)
r1.Head.Value = NewCriterionTerm(true, passReason)
r1.Body = body
r2 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTerm(false, failReason),
},
Body: ast.Body{
ast.NewExpr(ast.BooleanTerm(true)),
},
}
r1.Else = r2
return r1
}
// NewCriterionSessionRule generates a new rule for a criterion which
// requires a session. If there is no session "user-unauthenticated"
// is returned.
func NewCriterionSessionRule(
g *generator.Generator,
name string,
passReason, failReason Reason,
body ast.Body,
) *ast.Rule {
r1 := g.NewRule(name)
r1.Head.Value = NewCriterionTerm(true, passReason)
r1.Body = body
r2 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTerm(false, failReason),
},
Body: ast.Body{
ast.MustParseExpr(`session := get_session(input.session.id)`),
ast.MustParseExpr(`session.id != ""`),
},
}
r1.Else = r2
r3 := &ast.Rule{
Head: &ast.Head{
Value: NewCriterionTerm(false, ReasonUserUnauthenticated),
},
Body: ast.Body{
ast.NewExpr(ast.BooleanTerm(true)),
},
}
r2.Else = r3
return r1
}
// NewCriterionTerm creates a new rego term for a criterion:
//
// [true, {"reason"}]
//
func NewCriterionTerm(value bool, reasons ...Reason) *ast.Term {
var terms []*ast.Term
for _, r := range reasons {
terms = append(terms, ast.StringTerm(string(r)))
}
return ast.ArrayTerm(
ast.BooleanTerm(value),
ast.SetTerm(terms...),
)
}

View file

@ -20,6 +20,9 @@ import (
"github.com/pomerium/pomerium/pkg/protoutil"
)
type A = []interface{}
type M = map[string]interface{}
var testingNow = time.Date(2021, 5, 11, 13, 43, 0, 0, time.Local)
type (

View file

@ -0,0 +1,61 @@
package criteria
import (
"github.com/open-policy-agent/opa/ast"
"github.com/pomerium/pomerium/pkg/policy/parser"
"github.com/pomerium/pomerium/pkg/policy/rules"
)
var domainBody = ast.Body{
ast.MustParseExpr(`
session := get_session(input.session.id)
`),
ast.MustParseExpr(`
user := get_user(session)
`),
ast.MustParseExpr(`
domain := split(get_user_email(session, user), "@")[1]
`),
}
type domainCriterion struct {
g *Generator
}
func (domainCriterion) DataType() CriterionDataType {
return CriterionDataTypeStringMatcher
}
func (domainCriterion) Name() string {
return "domain"
}
func (c domainCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, []*ast.Rule, error) {
var body ast.Body
body = append(body, domainBody...)
err := matchString(&body, ast.VarTerm("domain"), data)
if err != nil {
return nil, nil, err
}
rule := NewCriterionSessionRule(c.g, c.Name(),
ReasonDomainOK, ReasonDomainUnauthorized,
body)
return rule, []*ast.Rule{
rules.GetSession(),
rules.GetUser(),
rules.GetUserEmail(),
}, nil
}
// Domain returns a Criterion on a user's email address domain.
func Domain(generator *Generator) Criterion {
return domainCriterion{g: generator}
}
func init() {
Register(Domain)
}

View file

@ -18,8 +18,8 @@ allow:
is: example.com
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, false, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{false, A{ReasonUserUnauthenticated}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by domain", func(t *testing.T) {
res, err := evaluate(t, `
@ -40,8 +40,8 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{ReasonDomainOK}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by impersonate email", func(t *testing.T) {
res, err := evaluate(t, `
@ -62,7 +62,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{ReasonDomainOK}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -1,57 +0,0 @@
package criteria
import (
"github.com/open-policy-agent/opa/ast"
"github.com/pomerium/pomerium/pkg/policy/parser"
"github.com/pomerium/pomerium/pkg/policy/rules"
)
var domainsBody = ast.Body{
ast.MustParseExpr(`
session := get_session(input.session.id)
`),
ast.MustParseExpr(`
user := get_user(session)
`),
ast.MustParseExpr(`
domain := split(get_user_email(session, user), "@")[1]
`),
}
type domainsCriterion struct {
g *Generator
}
func (domainsCriterion) DataType() CriterionDataType {
return CriterionDataTypeStringMatcher
}
func (domainsCriterion) Name() string {
return "domain"
}
func (c domainsCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, []*ast.Rule, error) {
r := c.g.NewRule("domains")
r.Body = append(r.Body, domainsBody...)
err := matchString(&r.Body, ast.VarTerm("domain"), data)
if err != nil {
return nil, nil, err
}
return r, []*ast.Rule{
rules.GetSession(),
rules.GetUser(),
rules.GetUserEmail(),
}, nil
}
// Domains returns a Criterion on a user's email address domain.
func Domains(generator *Generator) Criterion {
return domainsCriterion{g: generator}
}
func init() {
Register(Domains)
}

View file

@ -0,0 +1,62 @@
package criteria
import (
"github.com/open-policy-agent/opa/ast"
"github.com/pomerium/pomerium/pkg/policy/generator"
"github.com/pomerium/pomerium/pkg/policy/parser"
"github.com/pomerium/pomerium/pkg/policy/rules"
)
var emailBody = ast.Body{
ast.MustParseExpr(`
session := get_session(input.session.id)
`),
ast.MustParseExpr(`
user := get_user(session)
`),
ast.MustParseExpr(`
email := get_user_email(session, user)
`),
}
type emailCriterion struct {
g *Generator
}
func (emailCriterion) DataType() generator.CriterionDataType {
return CriterionDataTypeStringMatcher
}
func (emailCriterion) Name() string {
return "email"
}
func (c emailCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, []*ast.Rule, error) {
var body ast.Body
body = append(body, emailBody...)
err := matchString(&body, ast.VarTerm("email"), data)
if err != nil {
return nil, nil, err
}
rule := NewCriterionSessionRule(c.g, c.Name(),
ReasonEmailOK, ReasonEmailUnauthorized,
body)
return rule, []*ast.Rule{
rules.GetSession(),
rules.GetUser(),
rules.GetUserEmail(),
}, nil
}
// Email returns a Criterion on a user's email address.
func Email(generator *Generator) Criterion {
return emailCriterion{g: generator}
}
func init() {
Register(Email)
}

View file

@ -19,8 +19,8 @@ allow:
is: test@example.com
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, false, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{false, A{ReasonUserUnauthenticated}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by email", func(t *testing.T) {
res, err := evaluate(t, `
@ -41,8 +41,8 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{ReasonEmailOK}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by impersonate session id", func(t *testing.T) {
res, err := evaluate(t, `
@ -72,7 +72,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION1"}})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{ReasonEmailOK}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -1,58 +0,0 @@
package criteria
import (
"github.com/open-policy-agent/opa/ast"
"github.com/pomerium/pomerium/pkg/policy/generator"
"github.com/pomerium/pomerium/pkg/policy/parser"
"github.com/pomerium/pomerium/pkg/policy/rules"
)
var emailsBody = ast.Body{
ast.MustParseExpr(`
session := get_session(input.session.id)
`),
ast.MustParseExpr(`
user := get_user(session)
`),
ast.MustParseExpr(`
email := get_user_email(session, user)
`),
}
type emailsCriterion struct {
g *Generator
}
func (emailsCriterion) DataType() generator.CriterionDataType {
return CriterionDataTypeStringMatcher
}
func (emailsCriterion) Name() string {
return "email"
}
func (c emailsCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, []*ast.Rule, error) {
r := c.g.NewRule("emails")
r.Body = append(r.Body, emailsBody...)
err := matchString(&r.Body, ast.VarTerm("email"), data)
if err != nil {
return nil, nil, err
}
return r, []*ast.Rule{
rules.GetSession(),
rules.GetUser(),
rules.GetUserEmail(),
}, nil
}
// Emails returns a Criterion on a user's email address.
func Emails(generator *Generator) Criterion {
return emailsCriterion{g: generator}
}
func init() {
Register(Emails)
}

View file

@ -52,15 +52,19 @@ func (groupsCriterion) Name() string {
}
func (c groupsCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, []*ast.Rule, error) {
r := c.g.NewRule("groups")
r.Body = append(r.Body, groupsBody...)
var body ast.Body
body = append(body, groupsBody...)
err := matchStringList(&r.Body, ast.VarTerm("groups"), data)
err := matchStringList(&body, ast.VarTerm("groups"), data)
if err != nil {
return nil, nil, err
}
return r, []*ast.Rule{
rule := NewCriterionSessionRule(c.g, c.Name(),
ReasonGroupsOK, ReasonGroupsUnauthorized,
body)
return rule, []*ast.Rule{
rules.GetSession(),
rules.GetDirectoryUser(),
rules.GetDirectoryGroup(),

View file

@ -20,8 +20,8 @@ allow:
has: group2
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, false, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{false, A{ReasonUserUnauthenticated}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by id", func(t *testing.T) {
res, err := evaluate(t, `
@ -42,8 +42,8 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{ReasonGroupsOK}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by email", func(t *testing.T) {
res, err := evaluate(t, `
@ -68,8 +68,8 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{ReasonGroupsOK}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by name", func(t *testing.T) {
res, err := evaluate(t, `
@ -94,7 +94,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{ReasonGroupsOK}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -8,7 +8,6 @@ import (
)
var invalidClientCertificateBody = ast.Body{
ast.MustParseExpr(`reason = [495, "invalid client certificate"]`),
ast.MustParseExpr(`is_boolean(input.is_valid_client_certificate)`),
ast.MustParseExpr(`not input.is_valid_client_certificate`),
}
@ -26,9 +25,9 @@ func (invalidClientCertificateCriterion) Name() string {
}
func (c invalidClientCertificateCriterion) GenerateRule(_ string, _ parser.Value) (*ast.Rule, []*ast.Rule, error) {
rule := c.g.NewRule("invalid_client_certificate")
rule.Head.Value = ast.VarTerm("reason")
rule.Body = invalidClientCertificateBody
rule := NewCriterionRule(c.g, c.Name(),
ReasonInvalidClientCertificate, ReasonValidClientCertificateOrNoneRequired,
invalidClientCertificateBody)
return rule, nil, nil
}

View file

@ -26,10 +26,11 @@ func (pomeriumRoutesCriterion) Name() string {
}
func (c pomeriumRoutesCriterion) GenerateRule(_ string, _ parser.Value) (*ast.Rule, []*ast.Rule, error) {
r := c.g.NewRule("pomerium_routes")
r.Body = append(r.Body, pomeriumRoutesBody...)
rule := NewCriterionRule(c.g, c.Name(),
ReasonPomeriumRoute, ReasonNonPomeriumRoute,
pomeriumRoutesBody)
return r, nil, nil
return rule, nil, nil
}
// PomeriumRoutes returns a Criterion on that allows access to pomerium routes.

View file

@ -0,0 +1,75 @@
package criteria
import "sort"
// A Reason is a reason for why a policy criterion passes or fails.
type Reason string
// Well-known reasons.
const (
ReasonAccept = "accept"
ReasonClaimOK = "claim-ok"
ReasonClaimUnauthorized = "claim-unauthorized"
ReasonCORSRequest = "cors-request"
ReasonDomainOK = "domain-ok"
ReasonDomainUnauthorized = "domain-unauthorized"
ReasonEmailOK = "email-ok"
ReasonEmailUnauthorized = "email-unauthorized"
ReasonGroupsOK = "groups-ok"
ReasonGroupsUnauthorized = "groups-unauthorized"
ReasonInvalidClientCertificate = "invalid-client-certificate"
ReasonNonCORSRequest = "non-cors-request"
ReasonNonPomeriumRoute = "non-pomerium-route"
ReasonPomeriumRoute = "pomerium-route"
ReasonReject = "reject"
ReasonRouteNotFound = "route-not-found"
ReasonUserOK = "user-ok"
ReasonUserUnauthenticated = "user-unauthenticated" // user needs to log in
ReasonUserUnauthorized = "user-unauthorized" // user does not have access
ReasonValidClientCertificateOrNoneRequired = "valid-client-certificate-or-none-required"
)
// Reasons is a collection of reasons.
type Reasons map[Reason]struct{}
// NewReasons creates a new Reasons collection.
func NewReasons(reasons ...Reason) Reasons {
rs := make(Reasons)
for _, r := range reasons {
rs.Add(r)
}
return rs
}
// Add adds a reason to the collection.
func (rs Reasons) Add(r Reason) {
rs[r] = struct{}{}
}
// Has returns true if the reason is found in the collection.
func (rs Reasons) Has(r Reason) bool {
_, ok := rs[r]
return ok
}
// Strings returns the reason collection as a slice of strings.
func (rs Reasons) Strings() []string {
var arr []string
for r := range rs {
arr = append(arr, string(r))
}
sort.Strings(arr)
return arr
}
// Union merges two reason collections together.
func (rs Reasons) Union(other Reasons) Reasons {
merged := make(Reasons)
for r := range rs {
merged.Add(r)
}
for r := range other {
merged.Add(r)
}
return merged
}

View file

@ -7,10 +7,6 @@ import (
"github.com/pomerium/pomerium/pkg/policy/parser"
)
var rejectBody = ast.Body{
ast.MustParseExpr(`v := false`),
}
type rejectMatcher struct {
g *Generator
}
@ -25,8 +21,8 @@ func (rejectMatcher) Name() string {
func (m rejectMatcher) GenerateRule(_ string, _ parser.Value) (*ast.Rule, []*ast.Rule, error) {
rule := m.g.NewRule("reject")
rule.Head.Value = ast.VarTerm("v")
rule.Body = rejectBody
rule.Head.Value = NewCriterionTerm(false, ReasonReject)
rule.Body = ast.Body{ast.NewExpr(ast.BooleanTerm(true))}
return rule, nil, nil
}

View file

@ -13,6 +13,6 @@ allow:
- reject: 1
`, []dataBrokerRecord{}, Input{})
require.NoError(t, err)
require.Equal(t, false, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{false, A{ReasonReject}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
}

View file

@ -0,0 +1,57 @@
package criteria
import (
"github.com/open-policy-agent/opa/ast"
"github.com/pomerium/pomerium/pkg/policy/generator"
"github.com/pomerium/pomerium/pkg/policy/parser"
"github.com/pomerium/pomerium/pkg/policy/rules"
)
var userBody = ast.Body{
ast.MustParseExpr(`
session := get_session(input.session.id)
`),
ast.MustParseExpr(`
user_id := session.user_id
`),
}
type userCriterion struct {
g *Generator
}
func (userCriterion) DataType() generator.CriterionDataType {
return CriterionDataTypeStringMatcher
}
func (userCriterion) Name() string {
return "user"
}
func (c userCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, []*ast.Rule, error) {
var body ast.Body
body = append(body, userBody...)
err := matchString(&body, ast.VarTerm("user_id"), data)
if err != nil {
return nil, nil, err
}
rule := NewCriterionSessionRule(c.g, c.Name(),
ReasonUserOK, ReasonUserUnauthorized,
body)
return rule, []*ast.Rule{
rules.GetSession(),
}, nil
}
// UserID returns a Criterion on a user's id.
func UserID(generator *Generator) Criterion {
return userCriterion{g: generator}
}
func init() {
Register(UserID)
}

View file

@ -18,8 +18,8 @@ allow:
is: USER_ID
`, []dataBrokerRecord{}, Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, false, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{false, A{ReasonUserUnauthenticated}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by user id", func(t *testing.T) {
res, err := evaluate(t, `
@ -36,8 +36,8 @@ allow:
},
Input{Session: InputSession{ID: "SESSION_ID"}})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{ReasonUserOK}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
t.Run("by impersonate session id", func(t *testing.T) {
res, err := evaluate(t, `
@ -59,7 +59,7 @@ allow:
},
Input{Session: InputSession{ID: "SESSION1"}})
require.NoError(t, err)
require.Equal(t, true, res["allow"])
require.Equal(t, false, res["deny"])
require.Equal(t, A{true, A{ReasonUserOK}}, res["allow"])
require.Equal(t, A{false, A{}}, res["deny"])
})
}

View file

@ -1,53 +0,0 @@
package criteria
import (
"github.com/open-policy-agent/opa/ast"
"github.com/pomerium/pomerium/pkg/policy/generator"
"github.com/pomerium/pomerium/pkg/policy/parser"
"github.com/pomerium/pomerium/pkg/policy/rules"
)
var usersBody = ast.Body{
ast.MustParseExpr(`
session := get_session(input.session.id)
`),
ast.MustParseExpr(`
user_id := session.user_id
`),
}
type usersCriterion struct {
g *Generator
}
func (usersCriterion) DataType() generator.CriterionDataType {
return CriterionDataTypeStringMatcher
}
func (usersCriterion) Name() string {
return "user"
}
func (c usersCriterion) GenerateRule(_ string, data parser.Value) (*ast.Rule, []*ast.Rule, error) {
r := c.g.NewRule("users")
r.Body = append(r.Body, usersBody...)
err := matchString(&r.Body, ast.VarTerm("user_id"), data)
if err != nil {
return nil, nil, err
}
return r, []*ast.Rule{
rules.GetSession(),
}, nil
}
// UserIDs returns a Criterion on a user's id.
func UserIDs(generator *Generator) Criterion {
return usersCriterion{g: generator}
}
func init() {
Register(UserIDs)
}

View file

@ -8,6 +8,27 @@ import (
"github.com/pomerium/pomerium/pkg/policy/parser"
)
var (
andBody = ast.Body{
ast.MustParseExpr(`normalized := [normalize_criterion_result(x)|x:=results[i]]`),
ast.MustParseExpr(`v := merge_with_and(normalized)`),
}
notBody = ast.Body{
ast.MustParseExpr(`normalized := [normalize_criterion_result(x)|x:=results[i]]`),
ast.MustParseExpr(`inverted := [invert_criterion_result(x)|x:=results[i]]`),
ast.MustParseExpr(`v := merge_with_and(inverted)`),
}
orBody = ast.Body{
ast.MustParseExpr(`normalized := [normalize_criterion_result(x)|x:=results[i]]`),
ast.MustParseExpr(`v := merge_with_or(normalized)`),
}
norBody = ast.Body{
ast.MustParseExpr(`normalized := [normalize_criterion_result(x)|x:=results[i]]`),
ast.MustParseExpr(`inverted := [invert_criterion_result(x)|x:=results[i]]`),
ast.MustParseExpr(`v := merge_with_or(inverted)`),
}
)
func (g *Generator) generateAndRule(dst *ast.RuleSet, policyCriteria []parser.Criterion) (*ast.Rule, error) {
rule := g.NewRule("and")
@ -15,12 +36,16 @@ func (g *Generator) generateAndRule(dst *ast.RuleSet, policyCriteria []parser.Cr
return rule, nil
}
expressions, err := g.generateCriterionRules(dst, policyCriteria)
terms, err := g.generateCriterionRules(dst, policyCriteria)
if err != nil {
return nil, err
}
g.fillViaAnd(rule, expressions)
rule.Head.Value = ast.VarTerm("v")
rule.Body = append(ast.Body{
ast.Assign.Expr(ast.VarTerm("results"), ast.ArrayTerm(terms...)),
}, andBody...)
dst.Add(rule)
return rule, nil
@ -40,7 +65,11 @@ func (g *Generator) generateNotRule(dst *ast.RuleSet, policyCriteria []parser.Cr
return nil, err
}
g.fillViaSetComprehension(rule, terms, true, true)
rule.Head.Value = ast.VarTerm("v")
rule.Body = append(ast.Body{
ast.Assign.Expr(ast.VarTerm("results"), ast.ArrayTerm(terms...)),
}, notBody...)
dst.Add(rule)
return rule, nil
@ -58,7 +87,11 @@ func (g *Generator) generateOrRule(dst *ast.RuleSet, policyCriteria []parser.Cri
return nil, err
}
g.fillViaOr(rule, terms)
rule.Head.Value = ast.VarTerm("v")
rule.Body = append(ast.Body{
ast.Assign.Expr(ast.VarTerm("results"), ast.ArrayTerm(terms...)),
}, orBody...)
dst.Add(rule)
return rule, nil
@ -78,7 +111,11 @@ func (g *Generator) generateNorRule(dst *ast.RuleSet, policyCriteria []parser.Cr
return nil, err
}
g.fillViaSetComprehension(rule, terms, false, true)
rule.Head.Value = ast.VarTerm("v")
rule.Body = append(ast.Body{
ast.Assign.Expr(ast.VarTerm("results"), ast.ArrayTerm(terms...)),
}, norBody...)
dst.Add(rule)
return rule, nil
@ -102,73 +139,3 @@ func (g *Generator) generateCriterionRules(dst *ast.RuleSet, policyCriteria []pa
}
return terms, nil
}
func (g *Generator) fillViaAnd(rule *ast.Rule, terms []*ast.Term) {
currentRule := rule
currentRule.Head.Value = ast.VarTerm("v1")
for i, term := range terms {
nm := fmt.Sprintf("v%d", i+1)
currentRule.Body = append(currentRule.Body, ast.Assign.Expr(ast.VarTerm(nm), term))
expr := ast.NewExpr(ast.VarTerm(nm))
currentRule.Body = append(currentRule.Body, expr)
}
}
func (g *Generator) fillViaOr(rule *ast.Rule, terms []*ast.Term) {
currentRule := rule
for i, term := range terms {
if i > 0 {
currentRule.Else = &ast.Rule{Head: &ast.Head{}}
currentRule = currentRule.Else
}
nm := fmt.Sprintf("v%d", i+1)
currentRule.Head.Value = ast.VarTerm(nm)
currentRule.Body = append(currentRule.Body, ast.Assign.Expr(ast.VarTerm(nm), term))
expr := ast.NewExpr(ast.VarTerm(nm))
currentRule.Body = append(currentRule.Body, expr)
}
}
func (g *Generator) fillViaSetComprehension(rule *ast.Rule, terms []*ast.Term, useIntersection, negated bool) {
sets := make([]*ast.Term, len(terms))
for i, term := range terms {
e := ast.NewExpr(term)
e.Negated = negated
sets[i] = ast.SetComprehensionTerm(ast.NumberTerm("1"), ast.NewBody(e))
}
var builtIn *ast.Builtin
if useIntersection {
builtIn = ast.And
} else {
builtIn = ast.Or
}
rule.Head.Value = ast.VarTerm("v")
rule.Body = ast.NewBody(
ast.Assign.Expr(
ast.VarTerm("v"),
ast.Equal.Call(
ast.Count.Call(
mergeTerms(builtIn, sets...),
),
ast.NumberTerm("1"),
),
),
)
}
func mergeTerms(builtIn *ast.Builtin, terms ...*ast.Term) *ast.Term {
// mergeTerms(AND, A, B, C, D) => AND(AND(A, B), AND(C, D))
switch len(terms) {
case 0:
return ast.NullTerm()
case 1:
return terms[0]
default:
return builtIn.Call(
mergeTerms(builtIn, terms[:len(terms)/2]...),
mergeTerms(builtIn, terms[len(terms)/2:]...),
)
}
}

View file

@ -8,6 +8,7 @@ import (
"github.com/open-policy-agent/opa/ast"
"github.com/pomerium/pomerium/pkg/policy/parser"
"github.com/pomerium/pomerium/pkg/policy/rules"
)
// A Generator generates a rego script from a policy.
@ -47,9 +48,13 @@ func (g *Generator) GetCriterion(name string) (Criterion, bool) {
// Generate generates the rego module from a policy.
func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) {
rules := ast.NewRuleSet()
rules.Add(ast.MustParseRule(`default allow = false`))
rules.Add(ast.MustParseRule(`default deny = false`))
rs := ast.NewRuleSet()
rs.Add(ast.MustParseRule(`default allow = [false, set()]`))
rs.Add(ast.MustParseRule(`default deny = [false, set()]`))
rs.Add(rules.InvertCriterionResult())
rs.Add(rules.NormalizeCriterionResult())
rs.Add(rules.MergeWithAnd())
rs.Add(rules.MergeWithOr())
for _, action := range []parser.Action{parser.ActionAllow, parser.ActionDeny} {
var terms []*ast.Term
@ -59,28 +64,28 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) {
}
if len(policyRule.And) > 0 {
subRule, err := g.generateAndRule(&rules, policyRule.And)
subRule, err := g.generateAndRule(&rs, policyRule.And)
if err != nil {
return nil, err
}
terms = append(terms, ast.VarTerm(string(subRule.Head.Name)))
}
if len(policyRule.Or) > 0 {
subRule, err := g.generateOrRule(&rules, policyRule.Or)
subRule, err := g.generateOrRule(&rs, policyRule.Or)
if err != nil {
return nil, err
}
terms = append(terms, ast.VarTerm(string(subRule.Head.Name)))
}
if len(policyRule.Not) > 0 {
subRule, err := g.generateNotRule(&rules, policyRule.Not)
subRule, err := g.generateNotRule(&rs, policyRule.Not)
if err != nil {
return nil, err
}
terms = append(terms, ast.VarTerm(string(subRule.Head.Name)))
}
if len(policyRule.Nor) > 0 {
subRule, err := g.generateNorRule(&rules, policyRule.Nor)
subRule, err := g.generateNorRule(&rs, policyRule.Nor)
if err != nil {
return nil, err
}
@ -91,11 +96,13 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) {
rule := &ast.Rule{
Head: &ast.Head{
Name: ast.Var(action),
Value: ast.VarTerm("v1"),
Value: ast.VarTerm("v"),
},
Body: append(ast.Body{
ast.Assign.Expr(ast.VarTerm("results"), ast.ArrayTerm(terms...)),
}, orBody...),
}
g.fillViaOr(rule, terms)
rules.Add(rule)
rs.Add(rule)
}
}
@ -107,7 +114,7 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) {
ast.StringTerm("policy"),
},
},
Rules: rules,
Rules: rs,
}
// move functions to the end
@ -125,6 +132,16 @@ func (g *Generator) Generate(policy *parser.Policy) (*ast.Module, error) {
return mod, nil
}
// NewRuleFromTemplate creates a new rule from a template rule.
func (g *Generator) NewRuleFromTemplate(name string, template *ast.Rule) *ast.Rule {
id := g.ids[name]
g.ids[name]++
newRule := template.Copy()
newRule.Head.Name = ast.Var(fmt.Sprintf("%s_%d", name, id))
return newRule
}
// NewRule creates a new rule with a dynamically generated name.
func (g *Generator) NewRule(name string) *ast.Rule {
id := g.ids[name]

View file

@ -63,9 +63,9 @@ func Test(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, `package pomerium.policy
default allow = false
default allow = [false, set()]
default deny = false
default deny = [false, set()]
accept_0 {
1 == 1
@ -79,13 +79,10 @@ accept_2 {
1 == 1
}
and_0 = v1 {
v1 := accept_0
v1
v2 := accept_1
v2
v3 := accept_2
v3
and_0 = v {
results := [accept_0, accept_1, accept_2]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_and(normalized)
}
accept_3 {
@ -100,19 +97,10 @@ accept_5 {
1 == 1
}
or_0 = v1 {
v1 := accept_3
v1
}
else = v2 {
v2 := accept_4
v2
}
else = v3 {
v3 := accept_5
v3
or_0 = v {
results := [accept_3, accept_4, accept_5]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
accept_6 {
@ -128,7 +116,10 @@ accept_8 {
}
not_0 = v {
v := count({1 | not accept_6} & ({1 | not accept_7} & {1 | not accept_8})) == 1
results := [accept_6, accept_7, accept_8]
normalized := [normalize_criterion_result(x) | x := results[i]]
inverted := [invert_criterion_result(x) | x := results[i]]
v := merge_with_and(inverted)
}
accept_9 {
@ -144,41 +135,26 @@ accept_11 {
}
nor_0 = v {
v := count({1 | not accept_9} | ({1 | not accept_10} | {1 | not accept_11})) == 1
results := [accept_9, accept_10, accept_11]
normalized := [normalize_criterion_result(x) | x := results[i]]
inverted := [invert_criterion_result(x) | x := results[i]]
v := merge_with_or(inverted)
}
accept_12 {
1 == 1
}
and_1 = v1 {
v1 := accept_12
v1
and_1 = v {
results := [accept_12]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_and(normalized)
}
allow = v1 {
v1 := and_0
v1
}
else = v2 {
v2 := or_0
v2
}
else = v3 {
v3 := not_0
v3
}
else = v4 {
v4 := nor_0
v4
}
else = v5 {
v5 := and_1
v5
allow = v {
results := [and_0, or_0, not_0, nor_0, and_1]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
accept_13 {
@ -190,12 +166,60 @@ accept_14 {
}
nor_1 = v {
v := count({1 | not accept_13} | {1 | not accept_14}) == 1
results := [accept_13, accept_14]
normalized := [normalize_criterion_result(x) | x := results[i]]
inverted := [invert_criterion_result(x) | x := results[i]]
v := merge_with_or(inverted)
}
deny = v1 {
v1 := nor_1
v1
deny = v {
results := [nor_1]
normalized := [normalize_criterion_result(x) | x := results[i]]
v := merge_with_or(normalized)
}
invert_criterion_result(result) = [false, result[1]] {
result[0]
}
else = [true, result[1]] {
not result[0]
}
normalize_criterion_result(result) = v {
is_boolean(result)
v = [result, set()]
}
else = v {
is_array(result)
v = result
}
else = v {
v = [false, set()]
}
merge_with_and(results) = [true, reasons] {
true_results := [x | x := results[i]; x[0]]
count(true_results) == count(results)
reasons := union({x | x := true_results[i][1]})
}
else = [false, reasons] {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
}
merge_with_or(results) = [true, reasons] {
true_results := [x | x := results[i]; x[0]]
count(true_results) > 0
reasons := union({x | x := true_results[i][1]})
}
else = [false, reasons] {
false_results := [x | x := results[i]; not x[0]]
reasons := union({x | x := false_results[i][1]})
}
`, string(format.MustAst(mod)))
}

View file

@ -85,6 +85,61 @@ get_group_ids(session, directory_user) = v {
`)
}
// MergeWithAnd merges criterion results using `and`.
func MergeWithAnd() *ast.Rule {
return ast.MustParseRule(`
merge_with_and(results) = [true, reasons] {
true_results := [x|x:=results[i];x[0]]
count(true_results) == count(results)
reasons := union({x|x:=true_results[i][1]})
} else = [false, reasons] {
false_results := [x|x:=results[i];not x[0]]
reasons := union({x|x:=false_results[i][1]})
}
`)
}
// MergeWithOr merges criterion results using `or`.
func MergeWithOr() *ast.Rule {
return ast.MustParseRule(`
merge_with_or(results) = [true, reasons] {
true_results := [x|x:=results[i];x[0]]
count(true_results) > 0
reasons := union({x|x:=true_results[i][1]})
} else = [false, reasons] {
false_results := [x|x:=results[i];not x[0]]
reasons := union({x|x:=false_results[i][1]})
}
`)
}
// InvertCriterionResult changes the criterion result's value from false to
// true, or vice-versa.
func InvertCriterionResult() *ast.Rule {
return ast.MustParseRule(`
invert_criterion_result(result) = [false, result[1]] {
result[0]
} else = [true, result[1]] {
not result[0]
}
`)
}
// NormalizeCriterionResult converts a criterion result into a standard form.
func NormalizeCriterionResult() *ast.Rule {
return ast.MustParseRule(`
normalize_criterion_result(result) = v {
is_boolean(result)
v = [result, set()]
} else = v {
is_array(result)
v = result
} else = v {
v = [false, set()]
}
`)
}
// ObjectGet recursively gets a value from an object.
func ObjectGet() *ast.Rule {
return ast.MustParseRule(`