mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-29 17:07:24 +02:00
authorize: fix custom rego panic (#2226)
* fix custom rego panic * fix type cast
This commit is contained in:
parent
cc4c400140
commit
897e7202bb
3 changed files with 48 additions and 20 deletions
|
@ -53,7 +53,7 @@ func (ce *CustomEvaluator) Evaluate(ctx context.Context, req *CustomEvaluatorReq
|
|||
return nil, err
|
||||
}
|
||||
|
||||
resultSet, err := q.Eval(ctx, rego.EvalInput(struct {
|
||||
resultSet, err := safeEval(ctx, q, rego.EvalInput(struct {
|
||||
HTTP RequestHTTP `json:"http"`
|
||||
Session RequestSession `json:"session"`
|
||||
}{HTTP: req.HTTP, Session: req.Session}))
|
||||
|
@ -61,27 +61,25 @@ func (ce *CustomEvaluator) Evaluate(ctx context.Context, req *CustomEvaluatorReq
|
|||
return nil, err
|
||||
}
|
||||
|
||||
vars, ok := resultSet[0].Bindings.WithoutWildcards()["result"].(map[string]interface{})
|
||||
if !ok {
|
||||
vars = make(map[string]interface{})
|
||||
}
|
||||
|
||||
vars := ce.getVars(resultSet)
|
||||
res := &CustomEvaluatorResponse{
|
||||
Headers: getHeadersVar(resultSet[0].Bindings.WithoutWildcards()),
|
||||
Headers: getHeadersVar(vars),
|
||||
}
|
||||
res.Allowed, _ = vars["allow"].(bool)
|
||||
if v, ok := vars["deny"]; ok {
|
||||
// support `deny = true`
|
||||
if b, ok := v.(bool); ok {
|
||||
res.Denied = b
|
||||
}
|
||||
if result, ok := vars["result"].(map[string]interface{}); ok {
|
||||
res.Allowed, _ = result["allow"].(bool)
|
||||
if v, ok := result["deny"]; ok {
|
||||
// support `deny = true`
|
||||
if b, ok := v.(bool); ok {
|
||||
res.Denied = b
|
||||
}
|
||||
|
||||
// support `deny[reason] = true`
|
||||
if m, ok := v.(map[string]interface{}); ok {
|
||||
for mk, mv := range m {
|
||||
if b, ok := mv.(bool); ok {
|
||||
res.Denied = b
|
||||
res.Reason = mk
|
||||
// support `deny[reason] = true`
|
||||
if m, ok := v.(map[string]interface{}); ok {
|
||||
for mk, mv := range m {
|
||||
if b, ok := mv.(bool); ok {
|
||||
res.Denied = b
|
||||
res.Reason = mk
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -122,3 +120,10 @@ func (ce *CustomEvaluator) getPreparedEvalQuery(ctx context.Context, src string)
|
|||
ce.queries[src] = q
|
||||
return q, nil
|
||||
}
|
||||
|
||||
func (ce *CustomEvaluator) getVars(resultSet rego.ResultSet) rego.Vars {
|
||||
if len(resultSet) == 0 {
|
||||
return make(rego.Vars)
|
||||
}
|
||||
return resultSet[0].Bindings.WithoutWildcards()
|
||||
}
|
||||
|
|
|
@ -53,4 +53,17 @@ func TestCustomEvaluator(t *testing.T) {
|
|||
}
|
||||
assert.NotNil(t, res)
|
||||
})
|
||||
t.Run("invalid package", func(t *testing.T) {
|
||||
ce := NewCustomEvaluator(store)
|
||||
res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{
|
||||
RegoPolicy: `package custom_ext_authz.rego
|
||||
allow {
|
||||
true
|
||||
}`,
|
||||
})
|
||||
if !assert.NoError(t, err) {
|
||||
return
|
||||
}
|
||||
assert.NotNil(t, res)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
|
|||
return nil, fmt.Errorf("error validating client certificate: %w", err)
|
||||
}
|
||||
|
||||
res, err := e.query.Eval(ctx, rego.EvalInput(e.newInput(req, isValid)))
|
||||
res, err := safeEval(ctx, e.query, rego.EvalInput(e.newInput(req, isValid)))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error evaluating rego policy: %w", err)
|
||||
}
|
||||
|
@ -185,3 +185,13 @@ func (e *Evaluator) newInput(req *Request, isValidClientCertificate bool) *input
|
|||
i.IsValidClientCertificate = isValidClientCertificate
|
||||
return i
|
||||
}
|
||||
|
||||
func safeEval(ctx context.Context, q rego.PreparedEvalQuery, options ...rego.EvalOption) (resultSet rego.ResultSet, err error) {
|
||||
defer func() {
|
||||
if e := recover(); e != nil {
|
||||
err = fmt.Errorf("%v", e)
|
||||
}
|
||||
}()
|
||||
resultSet, err = q.Eval(ctx, options...)
|
||||
return resultSet, err
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue