mirror of
https://github.com/pomerium/pomerium.git
synced 2025-05-29 17:07:24 +02:00
authorize: fix custom rego panic (#2226)
* fix custom rego panic * fix type cast
This commit is contained in:
parent
cc4c400140
commit
897e7202bb
3 changed files with 48 additions and 20 deletions
|
@ -53,7 +53,7 @@ func (ce *CustomEvaluator) Evaluate(ctx context.Context, req *CustomEvaluatorReq
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resultSet, err := q.Eval(ctx, rego.EvalInput(struct {
|
resultSet, err := safeEval(ctx, q, rego.EvalInput(struct {
|
||||||
HTTP RequestHTTP `json:"http"`
|
HTTP RequestHTTP `json:"http"`
|
||||||
Session RequestSession `json:"session"`
|
Session RequestSession `json:"session"`
|
||||||
}{HTTP: req.HTTP, Session: req.Session}))
|
}{HTTP: req.HTTP, Session: req.Session}))
|
||||||
|
@ -61,27 +61,25 @@ func (ce *CustomEvaluator) Evaluate(ctx context.Context, req *CustomEvaluatorReq
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
vars, ok := resultSet[0].Bindings.WithoutWildcards()["result"].(map[string]interface{})
|
vars := ce.getVars(resultSet)
|
||||||
if !ok {
|
|
||||||
vars = make(map[string]interface{})
|
|
||||||
}
|
|
||||||
|
|
||||||
res := &CustomEvaluatorResponse{
|
res := &CustomEvaluatorResponse{
|
||||||
Headers: getHeadersVar(resultSet[0].Bindings.WithoutWildcards()),
|
Headers: getHeadersVar(vars),
|
||||||
}
|
}
|
||||||
res.Allowed, _ = vars["allow"].(bool)
|
if result, ok := vars["result"].(map[string]interface{}); ok {
|
||||||
if v, ok := vars["deny"]; ok {
|
res.Allowed, _ = result["allow"].(bool)
|
||||||
// support `deny = true`
|
if v, ok := result["deny"]; ok {
|
||||||
if b, ok := v.(bool); ok {
|
// support `deny = true`
|
||||||
res.Denied = b
|
if b, ok := v.(bool); ok {
|
||||||
}
|
res.Denied = b
|
||||||
|
}
|
||||||
|
|
||||||
// support `deny[reason] = true`
|
// support `deny[reason] = true`
|
||||||
if m, ok := v.(map[string]interface{}); ok {
|
if m, ok := v.(map[string]interface{}); ok {
|
||||||
for mk, mv := range m {
|
for mk, mv := range m {
|
||||||
if b, ok := mv.(bool); ok {
|
if b, ok := mv.(bool); ok {
|
||||||
res.Denied = b
|
res.Denied = b
|
||||||
res.Reason = mk
|
res.Reason = mk
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -122,3 +120,10 @@ func (ce *CustomEvaluator) getPreparedEvalQuery(ctx context.Context, src string)
|
||||||
ce.queries[src] = q
|
ce.queries[src] = q
|
||||||
return q, nil
|
return q, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (ce *CustomEvaluator) getVars(resultSet rego.ResultSet) rego.Vars {
|
||||||
|
if len(resultSet) == 0 {
|
||||||
|
return make(rego.Vars)
|
||||||
|
}
|
||||||
|
return resultSet[0].Bindings.WithoutWildcards()
|
||||||
|
}
|
||||||
|
|
|
@ -53,4 +53,17 @@ func TestCustomEvaluator(t *testing.T) {
|
||||||
}
|
}
|
||||||
assert.NotNil(t, res)
|
assert.NotNil(t, res)
|
||||||
})
|
})
|
||||||
|
t.Run("invalid package", func(t *testing.T) {
|
||||||
|
ce := NewCustomEvaluator(store)
|
||||||
|
res, err := ce.Evaluate(ctx, &CustomEvaluatorRequest{
|
||||||
|
RegoPolicy: `package custom_ext_authz.rego
|
||||||
|
allow {
|
||||||
|
true
|
||||||
|
}`,
|
||||||
|
})
|
||||||
|
if !assert.NoError(t, err) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.NotNil(t, res)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -78,7 +78,7 @@ func (e *Evaluator) Evaluate(ctx context.Context, req *Request) (*Result, error)
|
||||||
return nil, fmt.Errorf("error validating client certificate: %w", err)
|
return nil, fmt.Errorf("error validating client certificate: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := e.query.Eval(ctx, rego.EvalInput(e.newInput(req, isValid)))
|
res, err := safeEval(ctx, e.query, rego.EvalInput(e.newInput(req, isValid)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error evaluating rego policy: %w", err)
|
return nil, fmt.Errorf("error evaluating rego policy: %w", err)
|
||||||
}
|
}
|
||||||
|
@ -185,3 +185,13 @@ func (e *Evaluator) newInput(req *Request, isValidClientCertificate bool) *input
|
||||||
i.IsValidClientCertificate = isValidClientCertificate
|
i.IsValidClientCertificate = isValidClientCertificate
|
||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func safeEval(ctx context.Context, q rego.PreparedEvalQuery, options ...rego.EvalOption) (resultSet rego.ResultSet, err error) {
|
||||||
|
defer func() {
|
||||||
|
if e := recover(); e != nil {
|
||||||
|
err = fmt.Errorf("%v", e)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
resultSet, err = q.Eval(ctx, options...)
|
||||||
|
return resultSet, err
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue