authorize: filter only by group ID (#5452)

authorize: filter only by group ID (#5437)

Change the JWT groups filtering behavior:
- to filter only by group ID (not group name)
- and only for groups sourced from directory sync (groups from a 
  "groups" claim will not be filtered)

This avoids the need to fetch all group names up front, which should 
improve performance in specific circumstances.

Co-authored-by: Kenneth Jenkins <51246568+kenjenkins@users.noreply.github.com>
This commit is contained in:
backport-actions-token[bot] 2025-01-28 12:16:45 -08:00 committed by GitHub
parent 69d9bc0145
commit 1815dea9f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 38 additions and 27 deletions

View file

@ -272,8 +272,7 @@ func (e *headersEvaluatorEvaluation) getJWTPayloadEmail(ctx context.Context) str
} }
func (e *headersEvaluatorEvaluation) getJWTPayloadGroups(ctx context.Context) []string { func (e *headersEvaluatorEvaluation) getJWTPayloadGroups(ctx context.Context) []string {
groups := e.getFilteredGroups(ctx) groups := e.getGroups(ctx)
if groups == nil { if groups == nil {
// If there are no groups, marshal this claim as an empty list rather than a JSON null, // If there are no groups, marshal this claim as an empty list rather than a JSON null,
// for better compatibility with third-party libraries. // for better compatibility with third-party libraries.
@ -283,9 +282,22 @@ func (e *headersEvaluatorEvaluation) getJWTPayloadGroups(ctx context.Context) []
return groups return groups
} }
func (e *headersEvaluatorEvaluation) getFilteredGroups(ctx context.Context) []string { func (e *headersEvaluatorEvaluation) getGroups(ctx context.Context) []string {
groups := e.getAllGroups(ctx) groupIDs := e.getGroupIDs(ctx)
if len(groupIDs) > 0 {
groupIDs = e.filterGroups(groupIDs)
groups := make([]string, 0, len(groupIDs)*2)
groups = append(groups, groupIDs...)
groups = append(groups, e.getDataBrokerGroupNames(ctx, groupIDs)...)
return groups
}
s, _ := e.getSessionOrServiceAccount(ctx)
groups, _ := getClaimStringSlice(s, "groups")
return groups
}
func (e *headersEvaluatorEvaluation) filterGroups(groups []string) []string {
// Apply the global groups filter or the per-route groups filter, if either is enabled. // Apply the global groups filter or the per-route groups filter, if either is enabled.
filters := make([]config.JWTGroupsFilter, 0, 2) filters := make([]config.JWTGroupsFilter, 0, 2)
if f := e.evaluator.store.GetJWTGroupsFilter(); f.Enabled() { if f := e.evaluator.store.GetJWTGroupsFilter(); f.Enabled() {
@ -308,21 +320,6 @@ func (e *headersEvaluatorEvaluation) getFilteredGroups(ctx context.Context) []st
}) })
} }
// getAllGroups returns the full group names/IDs list (without any filtering).
func (e *headersEvaluatorEvaluation) getAllGroups(ctx context.Context) []string {
groupIDs := e.getGroupIDs(ctx)
if len(groupIDs) > 0 {
groups := make([]string, 0, len(groupIDs)*2)
groups = append(groups, groupIDs...)
groups = append(groups, e.getDataBrokerGroupNames(ctx, groupIDs)...)
return groups
}
s, _ := e.getSessionOrServiceAccount(ctx)
groups, _ := getClaimStringSlice(s, "groups")
return groups
}
func (e *headersEvaluatorEvaluation) getJWTPayloadSID() string { func (e *headersEvaluatorEvaluation) getJWTPayloadSID() string {
return e.request.Session.ID return e.request.Session.ID
} }

View file

@ -551,6 +551,12 @@ func TestHeadersEvaluator_JWTGroupsFilter(t *testing.T) {
newDirectoryUserRecord(directory.User{ID: id, GroupIDs: groups}), newDirectoryUserRecord(directory.User{ID: id, GroupIDs: groups}),
) )
} }
// Also add a user session with an upstream "groups" claim from the IdP.
records = append(records,
&session.Session{Id: "SESSION-11", UserId: "USER-11", Claims: map[string]*structpb.ListValue{
"groups": newList("foo", "bar", "baz"),
}},
)
cases := []struct { cases := []struct {
name string name string
@ -559,18 +565,21 @@ func TestHeadersEvaluator_JWTGroupsFilter(t *testing.T) {
sessionID string sessionID string
expected []any expected []any
}{ }{
{"global filter 1", []string{"42", "1", "GROUP-12"}, nil, "SESSION-1", []any{"1", "42", "GROUP-12"}}, {"global filter 1", []string{"42", "1"}, nil, "SESSION-1", []any{"1", "42", "GROUP-1", "GROUP-42"}},
{"global filter 2", []string{"42", "1", "GROUP-12"}, nil, "SESSION-2", []any{"42", "GROUP-12"}}, {"global filter 2", []string{"42", "1"}, nil, "SESSION-2", []any{"42", "GROUP-42"}},
{"route filter 1", nil, []string{"42", "1", "GROUP-12"}, "SESSION-1", []any{"1", "42", "GROUP-12"}}, {"route filter 1", nil, []string{"42", "1"}, "SESSION-1", []any{"1", "42", "GROUP-1", "GROUP-42"}},
{"route filter 2", nil, []string{"42", "1", "GROUP-12"}, "SESSION-2", []any{"42", "GROUP-12"}}, {"route filter 2", nil, []string{"42", "1"}, "SESSION-2", []any{"42", "GROUP-42"}},
{"both filters 1", []string{"1"}, []string{"42", "GROUP-12"}, "SESSION-1", []any{"1", "42", "GROUP-12"}}, {"both filters 1", []string{"1"}, []string{"42"}, "SESSION-1", []any{"1", "42", "GROUP-1", "GROUP-42"}},
{"both filters 2", []string{"1"}, []string{"42", "GROUP-12"}, "SESSION-2", []any{"42", "GROUP-12"}}, {"both filters 2", []string{"1"}, []string{"42"}, "SESSION-2", []any{"42", "GROUP-42"}},
{"overlapping", []string{"1"}, []string{"1"}, "SESSION-1", []any{"1"}}, {"cannot filter by name", []string{"GROUP-1"}, nil, "SESSION-1", []any{}},
{"empty route filter", []string{"1", "2", "3"}, []string{}, "SESSION-1", []any{"1", "2", "3"}}, {"overlapping", []string{"1"}, []string{"1"}, "SESSION-1", []any{"1", "GROUP-1"}},
{"empty route filter", []string{"1", "2", "3"}, []string{}, "SESSION-1", []any{"1", "2", "3", "GROUP-1", "GROUP-2", "GROUP-3"}},
{ {
"no filtering", nil, nil, "SESSION-10", "no filtering", nil, nil, "SESSION-10",
[]any{"10", "20", "30", "40", "50", "GROUP-10", "GROUP-20", "GROUP-30", "GROUP-40", "GROUP-50"}, []any{"10", "20", "30", "40", "50", "GROUP-10", "GROUP-20", "GROUP-30", "GROUP-40", "GROUP-50"},
}, },
// filtering has no effect on groups from an IdP "groups" claim
{"groups claim", []string{"foo", "quux"}, nil, "SESSION-11", []any{"foo", "bar", "baz"}},
} }
ctx := storage.WithQuerier(context.Background(), storage.NewStaticQuerier(records...)) ctx := storage.WithQuerier(context.Background(), storage.NewStaticQuerier(records...))
@ -647,3 +656,8 @@ func newDirectoryUserRecord(directoryUser directory.User) *databroker.Record {
s, _ := structpb.NewStruct(m) s, _ := structpb.NewStruct(m)
return storage.NewStaticRecord(directory.UserRecordType, s) return storage.NewStaticRecord(directory.UserRecordType, s)
} }
func newList(v ...any) *structpb.ListValue {
lv, _ := structpb.NewList(v)
return lv
}