config: add support for embedded PPL policy (#2401)

This commit is contained in:
Caleb Doxsey 2021-07-27 13:44:10 -06:00 committed by GitHub
parent c34118360d
commit 0620cfdc50
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
8 changed files with 152 additions and 9 deletions

View file

@ -36,4 +36,5 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
DecodePolicyBase64Hook(),
decodeJWTClaimHeadersHookFunc(),
decodeCodecTypeHookFunc(),
decodePPLPolicyHookFunc(),
))

View file

@ -17,6 +17,8 @@ import (
"google.golang.org/protobuf/proto"
"gopkg.in/yaml.v3"
"github.com/pomerium/pomerium/pkg/policy/parser"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/urlutil"
)
@ -313,6 +315,53 @@ func (urls WeightedURLs) Flatten() ([]string, []uint32, error) {
return str, wghts, nil
}
// PPLPolicy is a policy defined using PPL.
type PPLPolicy struct {
*parser.Policy
}
// UnmarshalJSON parses JSON into a PPL policy.
func (ppl *PPLPolicy) UnmarshalJSON(data []byte) error {
var err error
ppl.Policy, err = parser.ParseJSON(bytes.NewReader(data))
if err != nil {
return err
}
return nil
}
// UnmarshalYAML parses YAML into a PPL policy.
func (ppl *PPLPolicy) UnmarshalYAML(unmarshal func(interface{}) error) error {
var i interface{}
err := unmarshal(&i)
if err != nil {
return err
}
bs, err := json.Marshal(i)
if err != nil {
return err
}
return ppl.UnmarshalJSON(bs)
}
func decodePPLPolicyHookFunc() mapstructure.DecodeHookFunc {
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
if t != reflect.TypeOf(&PPLPolicy{}) {
return data, nil
}
bs, err := json.Marshal(data)
if err != nil {
return nil, err
}
var ppl PPLPolicy
err = json.Unmarshal(bs, &ppl)
if err != nil {
return nil, err
}
return &ppl, nil
}
}
// DecodePolicyBase64Hook returns a mapstructure decode hook for base64 data.
func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
return func(f, t reflect.Type, data interface{}) (interface{}, error) {

View file

@ -10,6 +10,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v3"
"github.com/pomerium/pomerium/pkg/policy/parser"
)
func TestJWTClaimHeaders_UnmarshalJSON(t *testing.T) {
@ -190,3 +192,39 @@ func TestWeightedStringSlice(t *testing.T) {
assert.Equal(t, tc.Weights, weights, name)
}
}
func TestDecodePPLPolicyHookFunc(t *testing.T) {
var withPolicy struct {
Policy *PPLPolicy `mapstructure:"policy"`
}
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: decodePPLPolicyHookFunc(),
Result: &withPolicy,
})
require.NoError(t, err)
err = decoder.Decode(map[string]interface{}{
"policy": map[string]interface{}{
"allow": map[string]interface{}{
"or": []map[string]interface{}{
{"email": map[string]interface{}{
"is": "user1@example.com",
}},
},
},
},
})
assert.NoError(t, err)
assert.Equal(t, &PPLPolicy{
Policy: &parser.Policy{
Rules: []parser.Rule{{
Action: parser.ActionAllow,
Or: []parser.Criterion{{
Name: "email", Data: parser.Object{
"is": parser.String("user1@example.com"),
},
}},
}},
},
}, withPolicy.Policy)
}

View file

@ -114,6 +114,7 @@ type Options struct {
// Policies define per-route configuration and access control policies.
Policies []Policy `mapstructure:"policy"`
PolicyFile string `mapstructure:"policy_file" yaml:"policy_file,omitempty"`
Routes []Policy `mapstructure:"routes"`
// AdditionalPolicies are any additional policies added to the options.
AdditionalPolicies []Policy `yaml:"-"`
@ -428,6 +429,15 @@ func (o *Options) parsePolicy() error {
if len(policies) != 0 {
o.Policies = policies
}
var routes []Policy
if err := o.viper.UnmarshalKey("routes", &routes, ViperPolicyHooks); err != nil {
return err
}
if len(routes) != 0 {
o.Routes = routes
}
// Finish initializing policies
for i := range o.Policies {
p := &o.Policies[i]
@ -435,6 +445,12 @@ func (o *Options) parsePolicy() error {
return err
}
}
for i := range o.Routes {
p := &o.Routes[i]
if err := p.Validate(); err != nil {
return err
}
}
for i := range o.AdditionalPolicies {
p := &o.AdditionalPolicies[i]
if err := p.Validate(); err != nil {
@ -861,8 +877,9 @@ func (o *Options) GetAllPolicies() []Policy {
if o == nil {
return nil
}
policies := make([]Policy, 0, len(o.Policies)+len(o.AdditionalPolicies))
policies := make([]Policy, 0, len(o.Policies)+len(o.Routes)+len(o.AdditionalPolicies))
policies = append(policies, o.Policies...)
policies = append(policies, o.Routes...)
policies = append(policies, o.AdditionalPolicies...)
return policies
}

View file

@ -161,6 +161,8 @@ type Policy struct {
// SetResponseHeaders sets response headers.
SetResponseHeaders map[string]string `mapstructure:"set_response_headers" yaml:"set_response_headers,omitempty"`
Policy *PPLPolicy `mapstructure:"policy" yaml:"policy,omitempty" json:"policy,omitempty"`
}
// RewriteHeader is a policy configuration option to rewrite an HTTP header.

View file

@ -99,5 +99,10 @@ func (p *Policy) ToPPL() *parser.Policy {
})
ppl.Rules = append(ppl.Rules, denyRule)
// append embedded PPL policy rules
if p.Policy != nil && p.Policy.Policy != nil {
ppl.Rules = append(ppl.Rules, p.Policy.Policy.Rules...)
}
return ppl
}

View file

@ -7,6 +7,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/pomerium/pomerium/pkg/policy"
"github.com/pomerium/pomerium/pkg/policy/parser"
)
func TestPolicy_ToPPL(t *testing.T) {
@ -38,6 +39,19 @@ func TestPolicy_ToPPL(t *testing.T) {
},
},
},
Policy: &PPLPolicy{
Policy: &parser.Policy{
Rules: []parser.Rule{{
Action: parser.ActionAllow,
Or: []parser.Criterion{{
Name: "user",
Data: parser.Object{
"is": parser.String("user6"),
},
}},
}},
},
},
}).ToPPL())
require.NoError(t, err)
assert.Equal(t, `package pomerium.policy
@ -469,24 +483,41 @@ else = v28 {
v28
}
users_5 {
session := get_session(input.session.id)
user := get_user(session)
user_id := user.id
user_id == "user6"
}
or_1 = v1 {
v1 := users_5
v1
}
allow = v1 {
v1 := or_0
v1
}
else = v2 {
v2 := or_1
v2
}
invalid_client_certificate_0 = reason {
reason = [495, "invalid client certificate"]
is_boolean(input.is_valid_client_certificate)
not input.is_valid_client_certificate
}
or_1 = v1 {
or_2 = v1 {
v1 := invalid_client_certificate_0
v1
}
deny = v1 {
v1 := or_1
v1 := or_2
v1
}

View file

@ -29,13 +29,13 @@ type Handler struct {
mu sync.RWMutex
key []byte
options *config.Options
policies map[uint64]*config.Policy
policies map[uint64]config.Policy
}
// New creates a new Handler.
func New() *Handler {
h := new(Handler)
h.policies = make(map[uint64]*config.Policy)
h.policies = make(map[uint64]config.Policy)
return h
}
@ -120,7 +120,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler {
h := stdhttputil.NewSingleHostReverseProxy(&dst)
h.ErrorLog = stdlog.New(log.Logger(), "", 0)
h.Transport = config.NewPolicyHTTPTransport(options, policy, disableHTTP2)
h.Transport = config.NewPolicyHTTPTransport(options, &policy, disableHTTP2)
h.ServeHTTP(w, r)
return nil
})
@ -133,14 +133,14 @@ func (h *Handler) Update(ctx context.Context, cfg *config.Config) {
h.key, _ = cfg.Options.GetSharedKey()
h.options = cfg.Options
h.policies = make(map[uint64]*config.Policy)
for i, p := range cfg.Options.Policies {
h.policies = make(map[uint64]config.Policy)
for _, p := range cfg.Options.GetAllPolicies() {
id, err := p.RouteID()
if err != nil {
log.Warn(ctx).Err(err).Msg("reproxy: error getting route id")
continue
}
h.policies[id] = &cfg.Options.Policies[i]
h.policies[id] = p
}
}