diff --git a/config/constants.go b/config/constants.go index 2b7293e78..06e6202ee 100644 --- a/config/constants.go +++ b/config/constants.go @@ -36,4 +36,5 @@ var ViperPolicyHooks = viper.DecodeHook(mapstructure.ComposeDecodeHookFunc( DecodePolicyBase64Hook(), decodeJWTClaimHeadersHookFunc(), decodeCodecTypeHookFunc(), + decodePPLPolicyHookFunc(), )) diff --git a/config/custom.go b/config/custom.go index 788c68f4f..204adc61f 100644 --- a/config/custom.go +++ b/config/custom.go @@ -17,6 +17,8 @@ import ( "google.golang.org/protobuf/proto" "gopkg.in/yaml.v3" + "github.com/pomerium/pomerium/pkg/policy/parser" + "github.com/pomerium/pomerium/internal/httputil" "github.com/pomerium/pomerium/internal/urlutil" ) @@ -313,6 +315,53 @@ func (urls WeightedURLs) Flatten() ([]string, []uint32, error) { return str, wghts, nil } +// PPLPolicy is a policy defined using PPL. +type PPLPolicy struct { + *parser.Policy +} + +// UnmarshalJSON parses JSON into a PPL policy. +func (ppl *PPLPolicy) UnmarshalJSON(data []byte) error { + var err error + ppl.Policy, err = parser.ParseJSON(bytes.NewReader(data)) + if err != nil { + return err + } + return nil +} + +// UnmarshalYAML parses YAML into a PPL policy. +func (ppl *PPLPolicy) UnmarshalYAML(unmarshal func(interface{}) error) error { + var i interface{} + err := unmarshal(&i) + if err != nil { + return err + } + bs, err := json.Marshal(i) + if err != nil { + return err + } + return ppl.UnmarshalJSON(bs) +} + +func decodePPLPolicyHookFunc() mapstructure.DecodeHookFunc { + return func(f, t reflect.Type, data interface{}) (interface{}, error) { + if t != reflect.TypeOf(&PPLPolicy{}) { + return data, nil + } + bs, err := json.Marshal(data) + if err != nil { + return nil, err + } + var ppl PPLPolicy + err = json.Unmarshal(bs, &ppl) + if err != nil { + return nil, err + } + return &ppl, nil + } +} + // DecodePolicyBase64Hook returns a mapstructure decode hook for base64 data. func DecodePolicyBase64Hook() mapstructure.DecodeHookFunc { return func(f, t reflect.Type, data interface{}) (interface{}, error) { diff --git a/config/custom_test.go b/config/custom_test.go index 6b5767693..a85af5cd0 100644 --- a/config/custom_test.go +++ b/config/custom_test.go @@ -10,6 +10,8 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gopkg.in/yaml.v3" + + "github.com/pomerium/pomerium/pkg/policy/parser" ) func TestJWTClaimHeaders_UnmarshalJSON(t *testing.T) { @@ -190,3 +192,39 @@ func TestWeightedStringSlice(t *testing.T) { assert.Equal(t, tc.Weights, weights, name) } } + +func TestDecodePPLPolicyHookFunc(t *testing.T) { + var withPolicy struct { + Policy *PPLPolicy `mapstructure:"policy"` + } + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: decodePPLPolicyHookFunc(), + Result: &withPolicy, + }) + require.NoError(t, err) + + err = decoder.Decode(map[string]interface{}{ + "policy": map[string]interface{}{ + "allow": map[string]interface{}{ + "or": []map[string]interface{}{ + {"email": map[string]interface{}{ + "is": "user1@example.com", + }}, + }, + }, + }, + }) + assert.NoError(t, err) + assert.Equal(t, &PPLPolicy{ + Policy: &parser.Policy{ + Rules: []parser.Rule{{ + Action: parser.ActionAllow, + Or: []parser.Criterion{{ + Name: "email", Data: parser.Object{ + "is": parser.String("user1@example.com"), + }, + }}, + }}, + }, + }, withPolicy.Policy) +} diff --git a/config/options.go b/config/options.go index 55fdb8aee..5d4c6f030 100644 --- a/config/options.go +++ b/config/options.go @@ -114,6 +114,7 @@ type Options struct { // Policies define per-route configuration and access control policies. Policies []Policy `mapstructure:"policy"` PolicyFile string `mapstructure:"policy_file" yaml:"policy_file,omitempty"` + Routes []Policy `mapstructure:"routes"` // AdditionalPolicies are any additional policies added to the options. AdditionalPolicies []Policy `yaml:"-"` @@ -428,6 +429,15 @@ func (o *Options) parsePolicy() error { if len(policies) != 0 { o.Policies = policies } + + var routes []Policy + if err := o.viper.UnmarshalKey("routes", &routes, ViperPolicyHooks); err != nil { + return err + } + if len(routes) != 0 { + o.Routes = routes + } + // Finish initializing policies for i := range o.Policies { p := &o.Policies[i] @@ -435,6 +445,12 @@ func (o *Options) parsePolicy() error { return err } } + for i := range o.Routes { + p := &o.Routes[i] + if err := p.Validate(); err != nil { + return err + } + } for i := range o.AdditionalPolicies { p := &o.AdditionalPolicies[i] if err := p.Validate(); err != nil { @@ -861,8 +877,9 @@ func (o *Options) GetAllPolicies() []Policy { if o == nil { return nil } - policies := make([]Policy, 0, len(o.Policies)+len(o.AdditionalPolicies)) + policies := make([]Policy, 0, len(o.Policies)+len(o.Routes)+len(o.AdditionalPolicies)) policies = append(policies, o.Policies...) + policies = append(policies, o.Routes...) policies = append(policies, o.AdditionalPolicies...) return policies } diff --git a/config/policy.go b/config/policy.go index 6f0794525..7c2998e3d 100644 --- a/config/policy.go +++ b/config/policy.go @@ -161,6 +161,8 @@ type Policy struct { // SetResponseHeaders sets response headers. SetResponseHeaders map[string]string `mapstructure:"set_response_headers" yaml:"set_response_headers,omitempty"` + + Policy *PPLPolicy `mapstructure:"policy" yaml:"policy,omitempty" json:"policy,omitempty"` } // RewriteHeader is a policy configuration option to rewrite an HTTP header. diff --git a/config/policy_ppl.go b/config/policy_ppl.go index dbb49c6ad..0ffb4607e 100644 --- a/config/policy_ppl.go +++ b/config/policy_ppl.go @@ -99,5 +99,10 @@ func (p *Policy) ToPPL() *parser.Policy { }) ppl.Rules = append(ppl.Rules, denyRule) + // append embedded PPL policy rules + if p.Policy != nil && p.Policy.Policy != nil { + ppl.Rules = append(ppl.Rules, p.Policy.Policy.Rules...) + } + return ppl } diff --git a/config/policy_ppl_test.go b/config/policy_ppl_test.go index 2a76a9f26..915254aaa 100644 --- a/config/policy_ppl_test.go +++ b/config/policy_ppl_test.go @@ -7,6 +7,7 @@ import ( "github.com/stretchr/testify/require" "github.com/pomerium/pomerium/pkg/policy" + "github.com/pomerium/pomerium/pkg/policy/parser" ) func TestPolicy_ToPPL(t *testing.T) { @@ -38,6 +39,19 @@ func TestPolicy_ToPPL(t *testing.T) { }, }, }, + Policy: &PPLPolicy{ + Policy: &parser.Policy{ + Rules: []parser.Rule{{ + Action: parser.ActionAllow, + Or: []parser.Criterion{{ + Name: "user", + Data: parser.Object{ + "is": parser.String("user6"), + }, + }}, + }}, + }, + }, }).ToPPL()) require.NoError(t, err) assert.Equal(t, `package pomerium.policy @@ -469,24 +483,41 @@ else = v28 { v28 } +users_5 { + session := get_session(input.session.id) + user := get_user(session) + user_id := user.id + user_id == "user6" +} + +or_1 = v1 { + v1 := users_5 + v1 +} + allow = v1 { v1 := or_0 v1 } +else = v2 { + v2 := or_1 + v2 +} + invalid_client_certificate_0 = reason { reason = [495, "invalid client certificate"] is_boolean(input.is_valid_client_certificate) not input.is_valid_client_certificate } -or_1 = v1 { +or_2 = v1 { v1 := invalid_client_certificate_0 v1 } deny = v1 { - v1 := or_1 + v1 := or_2 v1 } diff --git a/internal/httputil/reproxy/reproxy.go b/internal/httputil/reproxy/reproxy.go index 2e8ba8401..965bfd537 100644 --- a/internal/httputil/reproxy/reproxy.go +++ b/internal/httputil/reproxy/reproxy.go @@ -29,13 +29,13 @@ type Handler struct { mu sync.RWMutex key []byte options *config.Options - policies map[uint64]*config.Policy + policies map[uint64]config.Policy } // New creates a new Handler. func New() *Handler { h := new(Handler) - h.policies = make(map[uint64]*config.Policy) + h.policies = make(map[uint64]config.Policy) return h } @@ -120,7 +120,7 @@ func (h *Handler) Middleware(next http.Handler) http.Handler { h := stdhttputil.NewSingleHostReverseProxy(&dst) h.ErrorLog = stdlog.New(log.Logger(), "", 0) - h.Transport = config.NewPolicyHTTPTransport(options, policy, disableHTTP2) + h.Transport = config.NewPolicyHTTPTransport(options, &policy, disableHTTP2) h.ServeHTTP(w, r) return nil }) @@ -133,14 +133,14 @@ func (h *Handler) Update(ctx context.Context, cfg *config.Config) { h.key, _ = cfg.Options.GetSharedKey() h.options = cfg.Options - h.policies = make(map[uint64]*config.Policy) - for i, p := range cfg.Options.Policies { + h.policies = make(map[uint64]config.Policy) + for _, p := range cfg.Options.GetAllPolicies() { id, err := p.RouteID() if err != nil { log.Warn(ctx).Err(err).Msg("reproxy: error getting route id") continue } - h.policies[id] = &cfg.Options.Policies[i] + h.policies[id] = p } }