mirror of
https://github.com/pomerium/pomerium.git
synced 2025-06-06 21:04:39 +02:00
config: add support for embedded PPL policy (#2401)
This commit is contained in:
parent
c34118360d
commit
0620cfdc50
8 changed files with 152 additions and 9 deletions
|
@ -36,4 +36,5 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc(
|
||||||
DecodePolicyBase64Hook(),
|
DecodePolicyBase64Hook(),
|
||||||
decodeJWTClaimHeadersHookFunc(),
|
decodeJWTClaimHeadersHookFunc(),
|
||||||
decodeCodecTypeHookFunc(),
|
decodeCodecTypeHookFunc(),
|
||||||
|
decodePPLPolicyHookFunc(),
|
||||||
))
|
))
|
||||||
|
|
|
@ -17,6 +17,8 @@ import (
|
||||||
"google.golang.org/protobuf/proto"
|
"google.golang.org/protobuf/proto"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/internal/httputil"
|
"github.com/pomerium/pomerium/internal/httputil"
|
||||||
"github.com/pomerium/pomerium/internal/urlutil"
|
"github.com/pomerium/pomerium/internal/urlutil"
|
||||||
)
|
)
|
||||||
|
@ -313,6 +315,53 @@ func (urls WeightedURLs) Flatten() ([]string, []uint32, error) {
|
||||||
return str, wghts, nil
|
return str, wghts, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PPLPolicy is a policy defined using PPL.
|
||||||
|
type PPLPolicy struct {
|
||||||
|
*parser.Policy
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON parses JSON into a PPL policy.
|
||||||
|
func (ppl *PPLPolicy) UnmarshalJSON(data []byte) error {
|
||||||
|
var err error
|
||||||
|
ppl.Policy, err = parser.ParseJSON(bytes.NewReader(data))
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalYAML parses YAML into a PPL policy.
|
||||||
|
func (ppl *PPLPolicy) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||||
|
var i interface{}
|
||||||
|
err := unmarshal(&i)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
bs, err := json.Marshal(i)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return ppl.UnmarshalJSON(bs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodePPLPolicyHookFunc() mapstructure.DecodeHookFunc {
|
||||||
|
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
|
||||||
|
if t != reflect.TypeOf(&PPLPolicy{}) {
|
||||||
|
return data, nil
|
||||||
|
}
|
||||||
|
bs, err := json.Marshal(data)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var ppl PPLPolicy
|
||||||
|
err = json.Unmarshal(bs, &ppl)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &ppl, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// DecodePolicyBase64Hook returns a mapstructure decode hook for base64 data.
|
// DecodePolicyBase64Hook returns a mapstructure decode hook for base64 data.
|
||||||
func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
|
func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc {
|
||||||
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
|
return func(f, t reflect.Type, data interface{}) (interface{}, error) {
|
||||||
|
|
|
@ -10,6 +10,8 @@ import (
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"gopkg.in/yaml.v3"
|
"gopkg.in/yaml.v3"
|
||||||
|
|
||||||
|
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestJWTClaimHeaders_UnmarshalJSON(t *testing.T) {
|
func TestJWTClaimHeaders_UnmarshalJSON(t *testing.T) {
|
||||||
|
@ -190,3 +192,39 @@ func TestWeightedStringSlice(t *testing.T) {
|
||||||
assert.Equal(t, tc.Weights, weights, name)
|
assert.Equal(t, tc.Weights, weights, name)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestDecodePPLPolicyHookFunc(t *testing.T) {
|
||||||
|
var withPolicy struct {
|
||||||
|
Policy *PPLPolicy `mapstructure:"policy"`
|
||||||
|
}
|
||||||
|
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||||
|
DecodeHook: decodePPLPolicyHookFunc(),
|
||||||
|
Result: &withPolicy,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = decoder.Decode(map[string]interface{}{
|
||||||
|
"policy": map[string]interface{}{
|
||||||
|
"allow": map[string]interface{}{
|
||||||
|
"or": []map[string]interface{}{
|
||||||
|
{"email": map[string]interface{}{
|
||||||
|
"is": "user1@example.com",
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, &PPLPolicy{
|
||||||
|
Policy: &parser.Policy{
|
||||||
|
Rules: []parser.Rule{{
|
||||||
|
Action: parser.ActionAllow,
|
||||||
|
Or: []parser.Criterion{{
|
||||||
|
Name: "email", Data: parser.Object{
|
||||||
|
"is": parser.String("user1@example.com"),
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
}, withPolicy.Policy)
|
||||||
|
}
|
||||||
|
|
|
@ -114,6 +114,7 @@ type Options struct {
|
||||||
// Policies define per-route configuration and access control policies.
|
// Policies define per-route configuration and access control policies.
|
||||||
Policies []Policy `mapstructure:"policy"`
|
Policies []Policy `mapstructure:"policy"`
|
||||||
PolicyFile string `mapstructure:"policy_file" yaml:"policy_file,omitempty"`
|
PolicyFile string `mapstructure:"policy_file" yaml:"policy_file,omitempty"`
|
||||||
|
Routes []Policy `mapstructure:"routes"`
|
||||||
|
|
||||||
// AdditionalPolicies are any additional policies added to the options.
|
// AdditionalPolicies are any additional policies added to the options.
|
||||||
AdditionalPolicies []Policy `yaml:"-"`
|
AdditionalPolicies []Policy `yaml:"-"`
|
||||||
|
@ -428,6 +429,15 @@ func (o *Options) parsePolicy() error {
|
||||||
if len(policies) != 0 {
|
if len(policies) != 0 {
|
||||||
o.Policies = policies
|
o.Policies = policies
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var routes []Policy
|
||||||
|
if err := o.viper.UnmarshalKey("routes", &routes, ViperPolicyHooks); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if len(routes) != 0 {
|
||||||
|
o.Routes = routes
|
||||||
|
}
|
||||||
|
|
||||||
// Finish initializing policies
|
// Finish initializing policies
|
||||||
for i := range o.Policies {
|
for i := range o.Policies {
|
||||||
p := &o.Policies[i]
|
p := &o.Policies[i]
|
||||||
|
@ -435,6 +445,12 @@ func (o *Options) parsePolicy() error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
for i := range o.Routes {
|
||||||
|
p := &o.Routes[i]
|
||||||
|
if err := p.Validate(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
for i := range o.AdditionalPolicies {
|
for i := range o.AdditionalPolicies {
|
||||||
p := &o.AdditionalPolicies[i]
|
p := &o.AdditionalPolicies[i]
|
||||||
if err := p.Validate(); err != nil {
|
if err := p.Validate(); err != nil {
|
||||||
|
@ -861,8 +877,9 @@ func (o *Options) GetAllPolicies() []Policy {
|
||||||
if o == nil {
|
if o == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
policies := make([]Policy, 0, len(o.Policies)+len(o.AdditionalPolicies))
|
policies := make([]Policy, 0, len(o.Policies)+len(o.Routes)+len(o.AdditionalPolicies))
|
||||||
policies = append(policies, o.Policies...)
|
policies = append(policies, o.Policies...)
|
||||||
|
policies = append(policies, o.Routes...)
|
||||||
policies = append(policies, o.AdditionalPolicies...)
|
policies = append(policies, o.AdditionalPolicies...)
|
||||||
return policies
|
return policies
|
||||||
}
|
}
|
||||||
|
|
|
@ -161,6 +161,8 @@ type Policy struct {
|
||||||
|
|
||||||
// SetResponseHeaders sets response headers.
|
// SetResponseHeaders sets response headers.
|
||||||
SetResponseHeaders map[string]string `mapstructure:"set_response_headers" yaml:"set_response_headers,omitempty"`
|
SetResponseHeaders map[string]string `mapstructure:"set_response_headers" yaml:"set_response_headers,omitempty"`
|
||||||
|
|
||||||
|
Policy *PPLPolicy `mapstructure:"policy" yaml:"policy,omitempty" json:"policy,omitempty"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// RewriteHeader is a policy configuration option to rewrite an HTTP header.
|
// RewriteHeader is a policy configuration option to rewrite an HTTP header.
|
||||||
|
|
|
@ -99,5 +99,10 @@ func (p *Policy) ToPPL() *parser.Policy {
|
||||||
})
|
})
|
||||||
ppl.Rules = append(ppl.Rules, denyRule)
|
ppl.Rules = append(ppl.Rules, denyRule)
|
||||||
|
|
||||||
|
// append embedded PPL policy rules
|
||||||
|
if p.Policy != nil && p.Policy.Policy != nil {
|
||||||
|
ppl.Rules = append(ppl.Rules, p.Policy.Policy.Rules...)
|
||||||
|
}
|
||||||
|
|
||||||
return ppl
|
return ppl
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,6 +7,7 @@ import (
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/pomerium/pomerium/pkg/policy"
|
"github.com/pomerium/pomerium/pkg/policy"
|
||||||
|
"github.com/pomerium/pomerium/pkg/policy/parser"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPolicy_ToPPL(t *testing.T) {
|
func TestPolicy_ToPPL(t *testing.T) {
|
||||||
|
@ -38,6 +39,19 @@ func TestPolicy_ToPPL(t *testing.T) {
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
Policy: &PPLPolicy{
|
||||||
|
Policy: &parser.Policy{
|
||||||
|
Rules: []parser.Rule{{
|
||||||
|
Action: parser.ActionAllow,
|
||||||
|
Or: []parser.Criterion{{
|
||||||
|
Name: "user",
|
||||||
|
Data: parser.Object{
|
||||||
|
"is": parser.String("user6"),
|
||||||
|
},
|
||||||
|
}},
|
||||||
|
}},
|
||||||
|
},
|
||||||
|
},
|
||||||
}).ToPPL())
|
}).ToPPL())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.Equal(t, `package pomerium.policy
|
assert.Equal(t, `package pomerium.policy
|
||||||
|
@ -469,24 +483,41 @@ else = v28 {
|
||||||
v28
|
v28
|
||||||
}
|
}
|
||||||
|
|
||||||
|
users_5 {
|
||||||
|
session := get_session(input.session.id)
|
||||||
|
user := get_user(session)
|
||||||
|
user_id := user.id
|
||||||
|
user_id == "user6"
|
||||||
|
}
|
||||||
|
|
||||||
|
or_1 = v1 {
|
||||||
|
v1 := users_5
|
||||||
|
v1
|
||||||
|
}
|
||||||
|
|
||||||
allow = v1 {
|
allow = v1 {
|
||||||
v1 := or_0
|
v1 := or_0
|
||||||
v1
|
v1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
else = v2 {
|
||||||
|
v2 := or_1
|
||||||
|
v2
|
||||||
|
}
|
||||||
|
|
||||||
invalid_client_certificate_0 = reason {
|
invalid_client_certificate_0 = reason {
|
||||||
reason = [495, "invalid client certificate"]
|
reason = [495, "invalid client certificate"]
|
||||||
is_boolean(input.is_valid_client_certificate)
|
is_boolean(input.is_valid_client_certificate)
|
||||||
not input.is_valid_client_certificate
|
not input.is_valid_client_certificate
|
||||||
}
|
}
|
||||||
|
|
||||||
or_1 = v1 {
|
or_2 = v1 {
|
||||||
v1 := invalid_client_certificate_0
|
v1 := invalid_client_certificate_0
|
||||||
v1
|
v1
|
||||||
}
|
}
|
||||||
|
|
||||||
deny = v1 {
|
deny = v1 {
|
||||||
v1 := or_1
|
v1 := or_2
|
||||||
v1
|
v1
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -29,13 +29,13 @@ type Handler struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
key []byte
|
key []byte
|
||||||
options *config.Options
|
options *config.Options
|
||||||
policies map[uint64]*config.Policy
|
policies map[uint64]config.Policy
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates a new Handler.
|
// New creates a new Handler.
|
||||||
func New() *Handler {
|
func New() *Handler {
|
||||||
h := new(Handler)
|
h := new(Handler)
|
||||||
h.policies = make(map[uint64]*config.Policy)
|
h.policies = make(map[uint64]config.Policy)
|
||||||
return h
|
return h
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -120,7 +120,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler {
|
||||||
|
|
||||||
h := stdhttputil.NewSingleHostReverseProxy(&dst)
|
h := stdhttputil.NewSingleHostReverseProxy(&dst)
|
||||||
h.ErrorLog = stdlog.New(log.Logger(), "", 0)
|
h.ErrorLog = stdlog.New(log.Logger(), "", 0)
|
||||||
h.Transport = config.NewPolicyHTTPTransport(options, policy, disableHTTP2)
|
h.Transport = config.NewPolicyHTTPTransport(options, &policy, disableHTTP2)
|
||||||
h.ServeHTTP(w, r)
|
h.ServeHTTP(w, r)
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
|
@ -133,14 +133,14 @@ func (h *Handler) Update(ctx context.Context, cfg *config.Config) {
|
||||||
|
|
||||||
h.key, _ = cfg.Options.GetSharedKey()
|
h.key, _ = cfg.Options.GetSharedKey()
|
||||||
h.options = cfg.Options
|
h.options = cfg.Options
|
||||||
h.policies = make(map[uint64]*config.Policy)
|
h.policies = make(map[uint64]config.Policy)
|
||||||
for i, p := range cfg.Options.Policies {
|
for _, p := range cfg.Options.GetAllPolicies() {
|
||||||
id, err := p.RouteID()
|
id, err := p.RouteID()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(ctx).Err(err).Msg("reproxy: error getting route id")
|
log.Warn(ctx).Err(err).Msg("reproxy: error getting route id")
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
h.policies[id] = &cfg.Options.Policies[i]
|
h.policies[id] = p
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue