diff --git a/authorize/evaluator/headers_evaluator_evaluation.go b/authorize/evaluator/headers_evaluator_evaluation.go index 8f099f88f..dfb531af7 100644 --- a/authorize/evaluator/headers_evaluator_evaluation.go +++ b/authorize/evaluator/headers_evaluator_evaluation.go @@ -272,8 +272,7 @@ func (e *headersEvaluatorEvaluation) getJWTPayloadEmail(ctx context.Context) str } func (e *headersEvaluatorEvaluation) getJWTPayloadGroups(ctx context.Context) []string { - groups := e.getFilteredGroups(ctx) - + groups := e.getGroups(ctx) if groups == nil { // If there are no groups, marshal this claim as an empty list rather than a JSON null, // for better compatibility with third-party libraries. @@ -283,9 +282,22 @@ func (e *headersEvaluatorEvaluation) getJWTPayloadGroups(ctx context.Context) [] return groups } -func (e *headersEvaluatorEvaluation) getFilteredGroups(ctx context.Context) []string { - groups := e.getAllGroups(ctx) +func (e *headersEvaluatorEvaluation) getGroups(ctx context.Context) []string { + groupIDs := e.getGroupIDs(ctx) + if len(groupIDs) > 0 { + groupIDs = e.filterGroups(groupIDs) + groups := make([]string, 0, len(groupIDs)*2) + groups = append(groups, groupIDs...) + groups = append(groups, e.getDataBrokerGroupNames(ctx, groupIDs)...) + return groups + } + s, _ := e.getSessionOrServiceAccount(ctx) + groups, _ := getClaimStringSlice(s, "groups") + return groups +} + +func (e *headersEvaluatorEvaluation) filterGroups(groups []string) []string { // Apply the global groups filter or the per-route groups filter, if either is enabled. filters := make([]config.JWTGroupsFilter, 0, 2) if f := e.evaluator.store.GetJWTGroupsFilter(); f.Enabled() { @@ -308,21 +320,6 @@ func (e *headersEvaluatorEvaluation) getFilteredGroups(ctx context.Context) []st }) } -// getAllGroups returns the full group names/IDs list (without any filtering). -func (e *headersEvaluatorEvaluation) getAllGroups(ctx context.Context) []string { - groupIDs := e.getGroupIDs(ctx) - if len(groupIDs) > 0 { - groups := make([]string, 0, len(groupIDs)*2) - groups = append(groups, groupIDs...) - groups = append(groups, e.getDataBrokerGroupNames(ctx, groupIDs)...) - return groups - } - - s, _ := e.getSessionOrServiceAccount(ctx) - groups, _ := getClaimStringSlice(s, "groups") - return groups -} - func (e *headersEvaluatorEvaluation) getJWTPayloadSID() string { return e.request.Session.ID } diff --git a/authorize/evaluator/headers_evaluator_test.go b/authorize/evaluator/headers_evaluator_test.go index 979111d1b..5108b34d3 100644 --- a/authorize/evaluator/headers_evaluator_test.go +++ b/authorize/evaluator/headers_evaluator_test.go @@ -551,6 +551,12 @@ func TestHeadersEvaluator_JWTGroupsFilter(t *testing.T) { newDirectoryUserRecord(directory.User{ID: id, GroupIDs: groups}), ) } + // Also add a user session with an upstream "groups" claim from the IdP. + records = append(records, + &session.Session{Id: "SESSION-11", UserId: "USER-11", Claims: map[string]*structpb.ListValue{ + "groups": newList("foo", "bar", "baz"), + }}, + ) cases := []struct { name string @@ -559,18 +565,21 @@ func TestHeadersEvaluator_JWTGroupsFilter(t *testing.T) { sessionID string expected []any }{ - {"global filter 1", []string{"42", "1", "GROUP-12"}, nil, "SESSION-1", []any{"1", "42", "GROUP-12"}}, - {"global filter 2", []string{"42", "1", "GROUP-12"}, nil, "SESSION-2", []any{"42", "GROUP-12"}}, - {"route filter 1", nil, []string{"42", "1", "GROUP-12"}, "SESSION-1", []any{"1", "42", "GROUP-12"}}, - {"route filter 2", nil, []string{"42", "1", "GROUP-12"}, "SESSION-2", []any{"42", "GROUP-12"}}, - {"both filters 1", []string{"1"}, []string{"42", "GROUP-12"}, "SESSION-1", []any{"1", "42", "GROUP-12"}}, - {"both filters 2", []string{"1"}, []string{"42", "GROUP-12"}, "SESSION-2", []any{"42", "GROUP-12"}}, - {"overlapping", []string{"1"}, []string{"1"}, "SESSION-1", []any{"1"}}, - {"empty route filter", []string{"1", "2", "3"}, []string{}, "SESSION-1", []any{"1", "2", "3"}}, + {"global filter 1", []string{"42", "1"}, nil, "SESSION-1", []any{"1", "42", "GROUP-1", "GROUP-42"}}, + {"global filter 2", []string{"42", "1"}, nil, "SESSION-2", []any{"42", "GROUP-42"}}, + {"route filter 1", nil, []string{"42", "1"}, "SESSION-1", []any{"1", "42", "GROUP-1", "GROUP-42"}}, + {"route filter 2", nil, []string{"42", "1"}, "SESSION-2", []any{"42", "GROUP-42"}}, + {"both filters 1", []string{"1"}, []string{"42"}, "SESSION-1", []any{"1", "42", "GROUP-1", "GROUP-42"}}, + {"both filters 2", []string{"1"}, []string{"42"}, "SESSION-2", []any{"42", "GROUP-42"}}, + {"cannot filter by name", []string{"GROUP-1"}, nil, "SESSION-1", []any{}}, + {"overlapping", []string{"1"}, []string{"1"}, "SESSION-1", []any{"1", "GROUP-1"}}, + {"empty route filter", []string{"1", "2", "3"}, []string{}, "SESSION-1", []any{"1", "2", "3", "GROUP-1", "GROUP-2", "GROUP-3"}}, { "no filtering", nil, nil, "SESSION-10", []any{"10", "20", "30", "40", "50", "GROUP-10", "GROUP-20", "GROUP-30", "GROUP-40", "GROUP-50"}, }, + // filtering has no effect on groups from an IdP "groups" claim + {"groups claim", []string{"foo", "quux"}, nil, "SESSION-11", []any{"foo", "bar", "baz"}}, } ctx := storage.WithQuerier(context.Background(), storage.NewStaticQuerier(records...)) @@ -647,3 +656,8 @@ func newDirectoryUserRecord(directoryUser directory.User) *databroker.Record { s, _ := structpb.NewStruct(m) return storage.NewStaticRecord(directory.UserRecordType, s) } + +func newList(v ...any) *structpb.ListValue { + lv, _ := structpb.NewList(v) + return lv +}