authorize: add filter options for JWT groups (#5417)

Add a new option for filtering to a subset of directory groups in the
Pomerium JWT and Impersonate-Group headers. Add a JWTGroupsFilter field
to both the Options struct (for a global filter) and to the Policy
struct (for per-route filter). These will be populated only from the
config protos, and not from a config file.

If either filter is set, then for each of a user's groups, the group
name or group ID will be added to the JWT groups claim only if it is an
exact string match with one of the elements of either filter.
This commit is contained in:
Kenneth Jenkins 2025-01-08 13:57:57 -08:00 committed by GitHub
parent 95d4a24271
commit 21b9e7890c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 834 additions and 620 deletions

View file

@ -134,6 +134,7 @@ func newPolicyEvaluator(
evaluator.WithAuthenticateURL(authenticateURL.String()),
evaluator.WithGoogleCloudServerlessAuthenticationServiceAccount(opts.GetGoogleCloudServerlessAuthenticationServiceAccount()),
evaluator.WithJWTClaimsHeaders(opts.JWTClaimsHeaders),
evaluator.WithJWTGroupsFilter(opts.JWTGroupsFilter),
)
}

View file

@ -15,6 +15,7 @@ type evaluatorConfig struct {
AuthenticateURL string
GoogleCloudServerlessAuthenticationServiceAccount string
JWTClaimsHeaders config.JWTClaimHeaders
JWTGroupsFilter config.JWTGroupsFilter
}
// cacheKey() returns a hash over the configuration, except for the policies.
@ -97,3 +98,10 @@ func WithJWTClaimsHeaders(headers config.JWTClaimHeaders) Option {
cfg.JWTClaimsHeaders = headers
}
}
// WithJWTGroupsFilter sets the JWT groups filter in the config.
func WithJWTGroupsFilter(groups config.JWTGroupsFilter) Option {
return func(cfg *evaluatorConfig) {
cfg.JWTGroupsFilter = groups
}
}

View file

@ -327,6 +327,7 @@ func updateStore(ctx context.Context, store *store.Store, cfg *evaluatorConfig)
cfg.GoogleCloudServerlessAuthenticationServiceAccount,
)
store.UpdateJWTClaimHeaders(cfg.JWTClaimsHeaders)
store.UpdateJWTGroupsFilter(cfg.JWTGroupsFilter)
store.UpdateRoutePolicies(cfg.Policies)
store.UpdateSigningKey(jwk)

View file

@ -708,6 +708,9 @@ func TestPolicyEvaluatorReuse(t *testing.T) {
t.Run("JWTClaimsHeaders changed", func(t *testing.T) {
assertNoneReused(t, WithJWTClaimsHeaders(config.JWTClaimHeaders{"dummy": "header"}))
})
t.Run("JWTGroupsFilter changed", func(t *testing.T) {
assertNoneReused(t, WithJWTGroupsFilter(config.NewJWTGroupsFilter([]string{"group1", "group2"})))
})
// If some policies have changed, but the evaluatorConfig is otherwise
// identical, only evaluators for the changed policies should be updated.

View file

@ -18,10 +18,12 @@ import (
"google.golang.org/protobuf/types/known/structpb"
"github.com/pomerium/datasource/pkg/directory"
"github.com/pomerium/pomerium/config"
"github.com/pomerium/pomerium/internal/log"
"github.com/pomerium/pomerium/pkg/cryptutil"
"github.com/pomerium/pomerium/pkg/grpc/session"
"github.com/pomerium/pomerium/pkg/grpc/user"
"github.com/pomerium/pomerium/pkg/slices"
)
// A headersEvaluatorEvaluation is a single evaluation of the headers evaluator.
@ -310,6 +312,44 @@ func (e *headersEvaluatorEvaluation) getJWTPayloadEmail(ctx context.Context) str
}
func (e *headersEvaluatorEvaluation) getJWTPayloadGroups(ctx context.Context) []string {
groups := e.getFilteredGroups(ctx)
if groups == nil {
// If there are no groups, marshal this claim as an empty list rather than a JSON null,
// for better compatibility with third-party libraries.
// See https://github.com/pomerium/pomerium/issues/5393 for one example.
groups = []string{}
}
return groups
}
func (e *headersEvaluatorEvaluation) getFilteredGroups(ctx context.Context) []string {
groups := e.getAllGroups(ctx)
// Apply the global groups filter or the per-route groups filter, if either is enabled.
filters := make([]config.JWTGroupsFilter, 0, 2)
if f := e.evaluator.store.GetJWTGroupsFilter(); f.Enabled() {
filters = append(filters, f)
}
if e.request.Policy != nil && e.request.Policy.JWTGroupsFilter.Enabled() {
filters = append(filters, e.request.Policy.JWTGroupsFilter)
}
if len(filters) == 0 {
return groups
}
return slices.Filter(groups, func(g string) bool {
// A group should be included if it appears in either the global or the route-level filter list.
for _, f := range filters {
if f.IsAllowed(g) {
return true
}
}
return false
})
}
// getAllGroups returns the full group names/IDs list (without any filtering).
func (e *headersEvaluatorEvaluation) getAllGroups(ctx context.Context) []string {
groupIDs := e.getGroupIDs(ctx)
if len(groupIDs) > 0 {
groups := make([]string, 0, len(groupIDs)*2)
@ -320,12 +360,6 @@ func (e *headersEvaluatorEvaluation) getJWTPayloadGroups(ctx context.Context) []
s, _ := e.getSessionOrServiceAccount(ctx)
groups, _ := getClaimStringSlice(s, "groups")
if groups == nil {
// If there are no groups, marshal this claim as an empty list rather than a JSON null,
// for better compatibility with third-party libraries.
// See https://github.com/pomerium/pomerium/issues/5393 for one example.
groups = []string{}
}
return groups
}

View file

@ -13,6 +13,7 @@ import (
"time"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
"github.com/go-jose/go-jose/v3"
"github.com/go-jose/go-jose/v3/jwt"
"github.com/open-policy-agent/opa/rego"
"github.com/stretchr/testify/assert"
@ -35,12 +36,8 @@ import (
func BenchmarkHeadersEvaluator(b *testing.B) {
ctx := context.Background()
signingKey, err := cryptutil.NewSigningKey()
require.NoError(b, err)
encodedSigningKey, err := cryptutil.EncodePrivateKey(signingKey)
require.NoError(b, err)
privateJWK, err := cryptutil.PrivateJWKFromBytes(encodedSigningKey)
require.NoError(b, err)
privateJWK, _ := newJWK(b)
iat := time.Unix(1686870680, 0)
ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier([]proto.Message{
@ -96,14 +93,7 @@ func TestHeadersEvaluator(t *testing.T) {
type A = []any
type M = map[string]any
signingKey, err := cryptutil.NewSigningKey()
require.NoError(t, err)
encodedSigningKey, err := cryptutil.EncodePrivateKey(signingKey)
require.NoError(t, err)
privateJWK, err := cryptutil.PrivateJWKFromBytes(encodedSigningKey)
require.NoError(t, err)
publicJWK, err := cryptutil.PublicJWKFromBytes(encodedSigningKey)
require.NoError(t, err)
privateJWK, publicJWK := newJWK(t)
iat := time.Unix(1686870680, 0)
@ -476,6 +466,86 @@ func TestHeadersEvaluator(t *testing.T) {
})
}
func TestHeadersEvaluator_JWTGroupsFilter(t *testing.T) {
t.Parallel()
privateJWK, _ := newJWK(t)
// Create some user and groups data.
var records []proto.Message
groupsCount := 50
for i := 1; i <= groupsCount; i++ {
id := fmt.Sprint(i)
records = append(records, newDirectoryGroupRecord(directory.Group{ID: id, Name: "GROUP-" + id}))
}
for i := 1; i <= 10; i++ {
id := fmt.Sprintf("USER-%d", i)
// User 1 will be in every group, user 2 in every other group, user 3 in every third group, etc.
var groups []string
for j := i; j <= groupsCount; j += i {
groups = append(groups, fmt.Sprint(j))
}
records = append(records,
&session.Session{Id: fmt.Sprintf("SESSION-%d", i), UserId: id},
newDirectoryUserRecord(directory.User{ID: id, GroupIDs: groups}),
)
}
cases := []struct {
name string
globalFilter []string
routeFilter []string
sessionID string
expected []any
}{
{"global filter 1", []string{"42", "1", "GROUP-12"}, nil, "SESSION-1", []any{"1", "42", "GROUP-12"}},
{"global filter 2", []string{"42", "1", "GROUP-12"}, nil, "SESSION-2", []any{"42", "GROUP-12"}},
{"route filter 1", nil, []string{"42", "1", "GROUP-12"}, "SESSION-1", []any{"1", "42", "GROUP-12"}},
{"route filter 2", nil, []string{"42", "1", "GROUP-12"}, "SESSION-2", []any{"42", "GROUP-12"}},
{"both filters 1", []string{"1"}, []string{"42", "GROUP-12"}, "SESSION-1", []any{"1", "42", "GROUP-12"}},
{"both filters 2", []string{"1"}, []string{"42", "GROUP-12"}, "SESSION-2", []any{"42", "GROUP-12"}},
{"overlapping", []string{"1"}, []string{"1"}, "SESSION-1", []any{"1"}},
{"empty route filter", []string{"1", "2", "3"}, []string{}, "SESSION-1", []any{"1", "2", "3"}},
{
"no filtering", nil, nil, "SESSION-10",
[]any{"10", "20", "30", "40", "50", "GROUP-10", "GROUP-20", "GROUP-30", "GROUP-40", "GROUP-50"},
},
}
ctx := storage.WithQuerier(context.Background(), storage.NewStaticQuerier(records...))
for _, c := range cases {
t.Run(c.name, func(t *testing.T) {
store := store.New()
store.UpdateSigningKey(privateJWK)
store.UpdateJWTGroupsFilter(config.NewJWTGroupsFilter(c.globalFilter))
req := &Request{Session: RequestSession{ID: c.sessionID}}
if c.routeFilter != nil {
req.Policy = &config.Policy{
JWTGroupsFilter: config.NewJWTGroupsFilter(c.routeFilter),
}
}
e := NewHeadersEvaluator(store)
resp, err := e.Evaluate(ctx, req)
require.NoError(t, err)
decoded := decodeJWTAssertion(t, resp.Headers)
assert.Equal(t, c.expected, decoded["groups"])
})
}
}
func newJWK(t testing.TB) (privateJWK, publicJWK *jose.JSONWebKey) {
t.Helper()
signingKey, err := cryptutil.NewSigningKey()
require.NoError(t, err)
encodedSigningKey, err := cryptutil.EncodePrivateKey(signingKey)
require.NoError(t, err)
privateJWK, err = cryptutil.PrivateJWKFromBytes(encodedSigningKey)
require.NoError(t, err)
publicJWK, err = cryptutil.PublicJWKFromBytes(encodedSigningKey)
require.NoError(t, err)
return
}
func decodeJWTAssertion(t *testing.T, headers http.Header) map[string]any {
jwtHeader := headers.Get("X-Pomerium-Jwt-Assertion")
// Make sure the 'iat' and 'exp' claims can be parsed as an integer. We

View file

@ -32,6 +32,7 @@ type Store struct {
googleCloudServerlessAuthenticationServiceAccount atomic.Pointer[string]
jwtClaimHeaders atomic.Pointer[map[string]string]
jwtGroupsFilter atomic.Pointer[config.JWTGroupsFilter]
signingKey atomic.Pointer[jose.JSONWebKey]
}
@ -58,6 +59,13 @@ func (s *Store) GetJWTClaimHeaders() map[string]string {
return *m
}
func (s *Store) GetJWTGroupsFilter() config.JWTGroupsFilter {
if f := s.jwtGroupsFilter.Load(); f != nil {
return *f
}
return config.JWTGroupsFilter{}
}
func (s *Store) GetSigningKey() *jose.JSONWebKey {
return s.signingKey.Load()
}
@ -75,6 +83,12 @@ func (s *Store) UpdateJWTClaimHeaders(jwtClaimHeaders map[string]string) {
s.jwtClaimHeaders.Store(&jwtClaimHeaders)
}
// UpdateJWTGroupsFilter updates the JWT groups filter in the store.
func (s *Store) UpdateJWTGroupsFilter(groups config.JWTGroupsFilter) {
// This isn't used by the Rego code, so we don't need to write it to the opastorage.Store instance.
s.jwtGroupsFilter.Store(&groups)
}
// UpdateRoutePolicies updates the route policies in the store.
func (s *Store) UpdateRoutePolicies(routePolicies []*config.Policy) {
s.write("/route_policies", routePolicies)

View file

@ -7,17 +7,20 @@ import (
"fmt"
"net/url"
"reflect"
"slices"
"strconv"
"strings"
"unicode"
envoy_config_cluster_v3 "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3"
goset "github.com/hashicorp/go-set/v3"
"github.com/mitchellh/mapstructure"
"github.com/volatiletech/null/v9"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"gopkg.in/yaml.v3"
"github.com/pomerium/pomerium/internal/hashutil"
"github.com/pomerium/pomerium/internal/httputil"
"github.com/pomerium/pomerium/internal/urlutil"
"github.com/pomerium/pomerium/pkg/policy/parser"
@ -574,3 +577,43 @@ func serializable(in any) (any, error) {
return in, nil
}
}
type JWTGroupsFilter struct {
set *goset.Set[string]
}
func NewJWTGroupsFilter(groups []string) JWTGroupsFilter {
var s *goset.Set[string]
if len(groups) > 0 {
s = goset.From(groups)
}
return JWTGroupsFilter{s}
}
func (f JWTGroupsFilter) Enabled() bool {
return f.set != nil
}
func (f JWTGroupsFilter) IsAllowed(group string) bool {
return f.set == nil || f.set.Contains(group)
}
func (f JWTGroupsFilter) ToSlice() []string {
if f.set == nil {
return nil
}
return slices.Sorted(f.set.Items())
}
func (f JWTGroupsFilter) Hash() (uint64, error) {
return hashutil.Hash(f.ToSlice())
}
func (f JWTGroupsFilter) Equal(other JWTGroupsFilter) bool {
if f.set == nil && other.set == nil {
return true
} else if f.set == nil || other.set == nil {
return false
}
return f.set.Equal(other.set)
}

View file

@ -193,6 +193,9 @@ type Options struct {
// List of JWT claims to insert as x-pomerium-claim-* headers on proxied requests
JWTClaimsHeaders JWTClaimHeaders `mapstructure:"jwt_claims_headers" yaml:"jwt_claims_headers,omitempty"`
// Allowlist of group names/IDs to include in the Pomerium JWT.
JWTGroupsFilter JWTGroupsFilter
DefaultUpstreamTimeout time.Duration `mapstructure:"default_upstream_timeout" yaml:"default_upstream_timeout,omitempty"`
// Address/Port to bind to for prometheus metrics
@ -1510,6 +1513,7 @@ func (o *Options) ApplySettings(ctx context.Context, certsIndex *cryptutil.Certi
set(&o.SigningKey, settings.SigningKey)
setMap(&o.SetResponseHeaders, settings.SetResponseHeaders)
setMap(&o.JWTClaimsHeaders, settings.JwtClaimsHeaders)
o.JWTGroupsFilter = NewJWTGroupsFilter(settings.JwtGroupsFilter)
setDuration(&o.DefaultUpstreamTimeout, settings.DefaultUpstreamTimeout)
set(&o.MetricsAddr, settings.MetricsAddress)
set(&o.MetricsBasicAuth, settings.MetricsBasicAuth)
@ -1599,6 +1603,7 @@ func (o *Options) ToProto() *config.Config {
copySrcToOptionalDest(&settings.SigningKey, valueOrFromFileBase64(o.SigningKey, o.SigningKeyFile))
settings.SetResponseHeaders = o.SetResponseHeaders
settings.JwtClaimsHeaders = o.JWTClaimsHeaders
settings.JwtGroupsFilter = o.JWTGroupsFilter.ToSlice()
copyOptionalDuration(&settings.DefaultUpstreamTimeout, o.DefaultUpstreamTimeout)
copySrcToOptionalDest(&settings.MetricsAddress, &o.MetricsAddr)
copySrcToOptionalDest(&settings.MetricsBasicAuth, &o.MetricsBasicAuth)

View file

@ -1436,6 +1436,9 @@ func TestRoute_FromToProto(t *testing.T) {
pb.From = "https://" + randomDomain()
// EnvoyOpts is set to an empty non-nil message during conversion, if nil
pb.EnvoyOpts = &envoy_config_cluster_v3.Cluster{}
// JWT groups filter order is not significant. Upon conversion back to
// a protobuf the JWT groups will be sorted.
slices.Sort(pb.JwtGroupsFilter)
switch mathrand.IntN(3) {
case 0:
@ -1536,6 +1539,9 @@ func TestOptions_FromToProto(t *testing.T) {
unsetFalseOptionalBoolFields(settings)
fixZeroValuedEnums(settings)
generateCertificates(t, settings)
// JWT groups filter order is not significant. Upon conversion back to
// a protobuf the JWT groups will be sorted.
slices.Sort(settings.JwtGroupsFilter)
return settings
}

View file

@ -165,6 +165,10 @@ type Policy struct {
// - "uri": Issuer strings will be a complete URI, including the scheme and ending with a trailing slash.
JWTIssuerFormat string `mapstructure:"jwt_issuer_format" yaml:"jwt_issuer_format,omitempty"`
// Allowlist of group names/IDs to include in the Pomerium JWT.
// This expands on any global allowlist set in the main Options.
JWTGroupsFilter JWTGroupsFilter
SubPolicies []SubPolicy `mapstructure:"sub_policies" yaml:"sub_policies,omitempty" json:"sub_policies,omitempty"`
EnvoyOpts *envoy_config_cluster_v3.Cluster `mapstructure:"_envoy_opts" yaml:"-" json:"-"`
@ -290,6 +294,7 @@ func NewPolicyFromProto(pb *configpb.Route) (*Policy, error) {
IdleTimeout: idleTimeout,
IDPClientID: pb.GetIdpClientId(),
IDPClientSecret: pb.GetIdpClientSecret(),
JWTGroupsFilter: NewJWTGroupsFilter(pb.JwtGroupsFilter),
KubernetesServiceAccountToken: pb.GetKubernetesServiceAccountToken(),
KubernetesServiceAccountTokenFile: pb.GetKubernetesServiceAccountTokenFile(),
PassIdentityHeaders: pb.PassIdentityHeaders,
@ -432,6 +437,7 @@ func (p *Policy) ToProto() (*configpb.Route, error) {
From: p.From,
Id: p.ID,
IdleTimeout: idleTimeout,
JwtGroupsFilter: p.JWTGroupsFilter.ToSlice(),
KubernetesServiceAccountToken: p.KubernetesServiceAccountToken,
KubernetesServiceAccountTokenFile: p.KubernetesServiceAccountTokenFile,
Name: fmt.Sprint(p.RouteID()),

File diff suppressed because it is too large Load diff

View file

@ -45,7 +45,7 @@ enum IssuerFormat {
IssuerURI = 1;
}
// Next ID: 66.
// Next ID: 67.
message Route {
string name = 1;
@ -113,6 +113,7 @@ message Route {
string kubernetes_service_account_token_file = 64;
bool enable_google_cloud_serverless_authentication = 42;
IssuerFormat jwt_issuer_format = 65;
repeated string jwt_groups_filter = 66;
envoy.config.cluster.v3.Cluster envoy_opts = 36;
@ -145,7 +146,7 @@ message Policy {
string remediation = 9;
}
// Next ID: 119.
// Next ID: 120.
message Settings {
message Certificate {
bytes cert_bytes = 3;
@ -198,6 +199,7 @@ message Settings {
map<string, string> set_response_headers = 69;
// repeated string jwt_claims_headers = 37;
map<string, string> jwt_claims_headers = 63;
repeated string jwt_groups_filter = 119;
optional google.protobuf.Duration default_upstream_timeout = 39;
optional string metrics_address = 40;
optional string metrics_basic_auth = 64;