mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-04 11:52:53 +02:00
config: add support for embedded PPL policy (#2401)
This commit is contained in:
parent
c34118360d
commit
0620cfdc50
8 changed files with 152 additions and 9 deletions
|
@ -36,4 +36,5 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
|
|||
DecodePolicyBase64Hook(),
|
||||
decodeJWTClaimHeadersHookFunc(),
|
||||
decodeCodecTypeHookFunc(),
|
||||
decodePPLPolicyHookFunc(),
|
||||
))
|
||||
|
|
|
@ -17,6 +17,8 @@ import (
|
|||
"google.golang.org/protobuf/proto"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||
|
||||
"github.com/pomerium/pomerium/internal/httputil"
|
||||
"github.com/pomerium/pomerium/internal/urlutil"
|
||||
)
|
||||
|
@ -313,6 +315,53 @@ func (urls WeightedURLs) Flatten() ([]string, []uint32, error) {
|
|||
return str, wghts, nil
|
||||
}
|
||||
|
||||
// PPLPolicy is a policy defined using PPL.
|
||||
type PPLPolicy struct {
|
||||
*parser.Policy
|
||||
}
|
||||
|
||||
// UnmarshalJSON parses JSON into a PPL policy.
|
||||
func (ppl *PPLPolicy) UnmarshalJSON(data []byte) error {
|
||||
var err error
|
||||
ppl.Policy, err = parser.ParseJSON(bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// UnmarshalYAML parses YAML into a PPL policy.
|
||||
func (ppl *PPLPolicy) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var i interface{}
|
||||
err := unmarshal(&i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
bs, err := json.Marshal(i)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ppl.UnmarshalJSON(bs)
|
||||
}
|
||||
|
||||
func decodePPLPolicyHookFunc() mapstructure.DecodeHookFunc {
|
||||
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
|
||||
if t != reflect.TypeOf(&PPLPolicy{}) {
|
||||
return data, nil
|
||||
}
|
||||
bs, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var ppl PPLPolicy
|
||||
err = json.Unmarshal(bs, &ppl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ppl, nil
|
||||
}
|
||||
}
|
||||
|
||||
// DecodePolicyBase64Hook returns a mapstructure decode hook for base64 data.
|
||||
func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
|
||||
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
|
||||
|
|
|
@ -10,6 +10,8 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||
)
|
||||
|
||||
func TestJWTClaimHeaders_UnmarshalJSON(t *testing.T) {
|
||||
|
@ -190,3 +192,39 @@ func TestWeightedStringSlice(t *testing.T) {
|
|||
assert.Equal(t, tc.Weights, weights, name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDecodePPLPolicyHookFunc(t *testing.T) {
|
||||
var withPolicy struct {
|
||||
Policy *PPLPolicy `mapstructure:"policy"`
|
||||
}
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: decodePPLPolicyHookFunc(),
|
||||
Result: &withPolicy,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = decoder.Decode(map[string]interface{}{
|
||||
"policy": map[string]interface{}{
|
||||
"allow": map[string]interface{}{
|
||||
"or": []map[string]interface{}{
|
||||
{"email": map[string]interface{}{
|
||||
"is": "user1@example.com",
|
||||
}},
|
||||
},
|
||||
},
|
||||
},
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, &PPLPolicy{
|
||||
Policy: &parser.Policy{
|
||||
Rules: []parser.Rule{{
|
||||
Action: parser.ActionAllow,
|
||||
Or: []parser.Criterion{{
|
||||
Name: "email", Data: parser.Object{
|
||||
"is": parser.String("user1@example.com"),
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
}, withPolicy.Policy)
|
||||
}
|
||||
|
|
|
@ -114,6 +114,7 @@ type Options struct {
|
|||
// Policies define per-route configuration and access control policies.
|
||||
Policies []Policy `mapstructure:"policy"`
|
||||
PolicyFile string `mapstructure:"policy_file" yaml:"policy_file,omitempty"`
|
||||
Routes []Policy `mapstructure:"routes"`
|
||||
|
||||
// AdditionalPolicies are any additional policies added to the options.
|
||||
AdditionalPolicies []Policy `yaml:"-"`
|
||||
|
@ -428,6 +429,15 @@ func (o *Options) parsePolicy() error {
|
|||
if len(policies) != 0 {
|
||||
o.Policies = policies
|
||||
}
|
||||
|
||||
var routes []Policy
|
||||
if err := o.viper.UnmarshalKey("routes", &routes, ViperPolicyHooks); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(routes) != 0 {
|
||||
o.Routes = routes
|
||||
}
|
||||
|
||||
// Finish initializing policies
|
||||
for i := range o.Policies {
|
||||
p := &o.Policies[i]
|
||||
|
@ -435,6 +445,12 @@ func (o *Options) parsePolicy() error {
|
|||
return err
|
||||
}
|
||||
}
|
||||
for i := range o.Routes {
|
||||
p := &o.Routes[i]
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
for i := range o.AdditionalPolicies {
|
||||
p := &o.AdditionalPolicies[i]
|
||||
if err := p.Validate(); err != nil {
|
||||
|
@ -861,8 +877,9 @@ func (o *Options) GetAllPolicies() []Policy {
|
|||
if o == nil {
|
||||
return nil
|
||||
}
|
||||
policies := make([]Policy, 0, len(o.Policies)+len(o.AdditionalPolicies))
|
||||
policies := make([]Policy, 0, len(o.Policies)+len(o.Routes)+len(o.AdditionalPolicies))
|
||||
policies = append(policies, o.Policies...)
|
||||
policies = append(policies, o.Routes...)
|
||||
policies = append(policies, o.AdditionalPolicies...)
|
||||
return policies
|
||||
}
|
||||
|
|
|
@ -161,6 +161,8 @@ type Policy struct {
|
|||
|
||||
// SetResponseHeaders sets response headers.
|
||||
SetResponseHeaders map[string]string `mapstructure:"set_response_headers" yaml:"set_response_headers,omitempty"`
|
||||
|
||||
Policy *PPLPolicy `mapstructure:"policy" yaml:"policy,omitempty" json:"policy,omitempty"`
|
||||
}
|
||||
|
||||
// RewriteHeader is a policy configuration option to rewrite an HTTP header.
|
||||
|
|
|
@ -99,5 +99,10 @@ func (p *Policy) ToPPL() *parser.Policy {
|
|||
})
|
||||
ppl.Rules = append(ppl.Rules, denyRule)
|
||||
|
||||
// append embedded PPL policy rules
|
||||
if p.Policy != nil && p.Policy.Policy != nil {
|
||||
ppl.Rules = append(ppl.Rules, p.Policy.Policy.Rules...)
|
||||
}
|
||||
|
||||
return ppl
|
||||
}
|
||||
|
|
|
@ -7,6 +7,7 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/pomerium/pomerium/pkg/policy"
|
||||
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||
)
|
||||
|
||||
func TestPolicy_ToPPL(t *testing.T) {
|
||||
|
@ -38,6 +39,19 @@ func TestPolicy_ToPPL(t *testing.T) {
|
|||
},
|
||||
},
|
||||
},
|
||||
Policy: &PPLPolicy{
|
||||
Policy: &parser.Policy{
|
||||
Rules: []parser.Rule{{
|
||||
Action: parser.ActionAllow,
|
||||
Or: []parser.Criterion{{
|
||||
Name: "user",
|
||||
Data: parser.Object{
|
||||
"is": parser.String("user6"),
|
||||
},
|
||||
}},
|
||||
}},
|
||||
},
|
||||
},
|
||||
}).ToPPL())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, `package pomerium.policy
|
||||
|
@ -469,24 +483,41 @@ else = v28 {
|
|||
v28
|
||||
}
|
||||
|
||||
users_5 {
|
||||
session := get_session(input.session.id)
|
||||
user := get_user(session)
|
||||
user_id := user.id
|
||||
user_id == "user6"
|
||||
}
|
||||
|
||||
or_1 = v1 {
|
||||
v1 := users_5
|
||||
v1
|
||||
}
|
||||
|
||||
allow = v1 {
|
||||
v1 := or_0
|
||||
v1
|
||||
}
|
||||
|
||||
else = v2 {
|
||||
v2 := or_1
|
||||
v2
|
||||
}
|
||||
|
||||
invalid_client_certificate_0 = reason {
|
||||
reason = [495, "invalid client certificate"]
|
||||
is_boolean(input.is_valid_client_certificate)
|
||||
not input.is_valid_client_certificate
|
||||
}
|
||||
|
||||
or_1 = v1 {
|
||||
or_2 = v1 {
|
||||
v1 := invalid_client_certificate_0
|
||||
v1
|
||||
}
|
||||
|
||||
deny = v1 {
|
||||
v1 := or_1
|
||||
v1 := or_2
|
||||
v1
|
||||
}
|
||||
|
||||
|
|
|
@ -29,13 +29,13 @@ type Handler struct {
|
|||
mu sync.RWMutex
|
||||
key []byte
|
||||
options *config.Options
|
||||
policies map[uint64]*config.Policy
|
||||
policies map[uint64]config.Policy
|
||||
}
|
||||
|
||||
// New creates a new Handler.
|
||||
func New() *Handler {
|
||||
h := new(Handler)
|
||||
h.policies = make(map[uint64]*config.Policy)
|
||||
h.policies = make(map[uint64]config.Policy)
|
||||
return h
|
||||
}
|
||||
|
||||
|
@ -120,7 +120,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler {
|
|||
|
||||
h := stdhttputil.NewSingleHostReverseProxy(&dst)
|
||||
h.ErrorLog = stdlog.New(log.Logger(), "", 0)
|
||||
h.Transport = config.NewPolicyHTTPTransport(options, policy, disableHTTP2)
|
||||
h.Transport = config.NewPolicyHTTPTransport(options, &policy, disableHTTP2)
|
||||
h.ServeHTTP(w, r)
|
||||
return nil
|
||||
})
|
||||
|
@ -133,14 +133,14 @@ func (h *Handler) Update(ctx context.Context, cfg *config.Config) {
|
|||
|
||||
h.key, _ = cfg.Options.GetSharedKey()
|
||||
h.options = cfg.Options
|
||||
h.policies = make(map[uint64]*config.Policy)
|
||||
for i, p := range cfg.Options.Policies {
|
||||
h.policies = make(map[uint64]config.Policy)
|
||||
for _, p := range cfg.Options.GetAllPolicies() {
|
||||
id, err := p.RouteID()
|
||||
if err != nil {
|
||||
log.Warn(ctx).Err(err).Msg("reproxy: error getting route id")
|
||||
continue
|
||||
}
|
||||
h.policies[id] = &cfg.Options.Policies[i]
|
||||
h.policies[id] = p
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue