From ef12fda55c446a981adf1be97c722a05bcd7f704 Mon Sep 17 00:00:00 2001 From: Caleb Doxsey Date: Thu, 7 Nov 2024 09:22:35 -0700 Subject: [PATCH] authorize: additional header evaluator tests (#5363) * authorize: additional header evaluator tests * add groups to jwt test --- authorize/evaluator/headers_evaluator_test.go | 206 +++++++++++++++++- pkg/storage/querier.go | 56 +++-- 2 files changed, 245 insertions(+), 17 deletions(-) diff --git a/authorize/evaluator/headers_evaluator_test.go b/authorize/evaluator/headers_evaluator_test.go index 08d4fe189..285e4910c 100644 --- a/authorize/evaluator/headers_evaluator_test.go +++ b/authorize/evaluator/headers_evaluator_test.go @@ -20,14 +20,74 @@ import ( "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/pomerium/datasource/pkg/directory" "github.com/pomerium/pomerium/authorize/internal/store" "github.com/pomerium/pomerium/config" "github.com/pomerium/pomerium/pkg/cryptutil" + "github.com/pomerium/pomerium/pkg/grpc/databroker" "github.com/pomerium/pomerium/pkg/grpc/session" "github.com/pomerium/pomerium/pkg/grpc/user" "github.com/pomerium/pomerium/pkg/storage" ) +func BenchmarkHeadersEvaluator(b *testing.B) { + ctx := context.Background() + + signingKey, err := cryptutil.NewSigningKey() + require.NoError(b, err) + encodedSigningKey, err := cryptutil.EncodePrivateKey(signingKey) + require.NoError(b, err) + privateJWK, err := cryptutil.PrivateJWKFromBytes(encodedSigningKey) + require.NoError(b, err) + iat := time.Unix(1686870680, 0) + + ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier([]proto.Message{ + &session.Session{Id: "s1", ImpersonateSessionId: proto.String("s2"), UserId: "u1"}, + &session.Session{Id: "s2", UserId: "u2", Claims: map[string]*structpb.ListValue{ + "name": {Values: []*structpb.Value{ + structpb.NewStringValue("n1"), + }}, + }, IssuedAt: timestamppb.New(iat)}, + &user.User{Id: "u2", Name: "USER#2"}, + newDirectoryUserRecord(directory.User{ID: "u2", GroupIDs: []string{"g1", "g2", "g3", "g4"}}), + newDirectoryGroupRecord(directory.Group{ID: "g1", Name: "GROUP1"}), + newDirectoryGroupRecord(directory.Group{ID: "g2", Name: "GROUP2"}), + newDirectoryGroupRecord(directory.Group{ID: "g3", Name: "GROUP3"}), + newDirectoryGroupRecord(directory.Group{ID: "g4", Name: "GROUP4"}), + }...)) + + s := store.New() + s.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY")) + s.UpdateSigningKey(privateJWK) + + e, err := NewHeadersEvaluator(ctx, s, rego.Time(iat)) + require.NoError(b, err) + + req := &HeadersRequest{ + EnableRoutingKey: true, + Issuer: "from.example.com", + Audience: "from.example.com", + KubernetesServiceAccountToken: "KUBERNETES_SERVICE_ACCOUNT_TOKEN", + ToAudience: "to.example.com", + Session: RequestSession{ + ID: "s1", + }, + SetRequestHeaders: map[string]string{ + "X-Custom-Header": "CUSTOM_VALUE", + "X-ID-Token": "${pomerium.id_token}", + "X-Access-Token": "${pomerium.access_token}", + "Client-Cert-Fingerprint": "${pomerium.client_cert_fingerprint}", + "Authorization": "Bearer ${pomerium.jwt}", + }, + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + res, err := e.Evaluate(ctx, req, rego.EvalTime(iat)) + require.NoError(b, err) + _ = res + } +} + func TestNewHeadersRequestFromPolicy(t *testing.T) { req, _ := NewHeadersRequestFromPolicy(&config.Policy{ EnableGoogleCloudServerlessAuthentication: true, @@ -142,7 +202,7 @@ func TestHeadersEvaluator(t *testing.T) { ctx := context.Background() ctx = storage.WithQuerier(ctx, storage.NewStaticQuerier(data...)) store := store.New() - store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("email", "groups", "user", "CUSTOM_KEY")) + store.UpdateJWTClaimHeaders(config.NewJWTClaimHeaders("name", "email", "groups", "user", "CUSTOM_KEY")) store.UpdateSigningKey(privateJWK) e, err := NewHeadersEvaluator(ctx, store, rego.Time(iat)) require.NoError(t, err) @@ -157,7 +217,22 @@ func TestHeadersEvaluator(t *testing.T) { "name": {Values: []*structpb.Value{ structpb.NewStringValue("n1"), }}, + "CUSTOM_KEY": {Values: []*structpb.Value{ + structpb.NewStringValue("v1"), + structpb.NewStringValue("v2"), + structpb.NewStringValue("v3"), + }}, }, IssuedAt: timestamppb.New(iat)}, + &user.User{Id: "u2", Claims: map[string]*structpb.ListValue{ + "name": {Values: []*structpb.Value{ + structpb.NewStringValue("n1"), + }}, + }}, + newDirectoryUserRecord(directory.User{ID: "u2", GroupIDs: []string{"g1", "g2", "g3", "g4"}}), + newDirectoryGroupRecord(directory.Group{ID: "g1", Name: "GROUP1", Email: "g1@example.com"}), + newDirectoryGroupRecord(directory.Group{ID: "g2", Name: "GROUP2", Email: "g2@example.com"}), + newDirectoryGroupRecord(directory.Group{ID: "g3", Name: "GROUP3", Email: "g3@example.com"}), + newDirectoryGroupRecord(directory.Group{ID: "g4", Name: "GROUP4", Email: "g4@example.com"}), }, &HeadersRequest{ Issuer: "from.example.com", @@ -204,6 +279,8 @@ func TestHeadersEvaluator(t *testing.T) { assert.Equal(t, "u2", claims["sub"], "should set subject to user id") assert.Equal(t, "u2", claims["user"], "should set user to user id") assert.Equal(t, "n1", claims["name"], "should set name") + assert.Equal(t, "v1,v2,v3", claims["CUSTOM_KEY"], "should set CUSTOM_KEY") + assert.Equal(t, []any{"g1", "g2", "g3", "g4", "GROUP1", "GROUP2", "GROUP3", "GROUP4"}, claims["groups"]) }) t.Run("set_request_headers", func(t *testing.T) { @@ -317,6 +394,18 @@ func TestHeadersEvaluator(t *testing.T) { []protoreflect.ProtoMessage{ &session.Session{Id: "s1", UserId: "u1"}, &user.User{Id: "u1", Email: "u1@example.com"}, + newDirectoryUserRecord(directory.User{ + ID: "u1", + GroupIDs: []string{"g1", "g2", "g3"}, + }), + newDirectoryGroupRecord(directory.Group{ + ID: "g1", + Name: "GROUP1", + }), + newDirectoryGroupRecord(directory.Group{ + ID: "g2", + Name: "GROUP2", + }), }, &HeadersRequest{ Issuer: "from.example.com", @@ -328,7 +417,102 @@ func TestHeadersEvaluator(t *testing.T) { require.NoError(t, err) assert.Equal(t, "Bearer TOKEN", output.Headers.Get("Authorization")) assert.Equal(t, "u1@example.com", output.Headers.Get("Impersonate-User")) - assert.Empty(t, output.Headers["Impersonate-Group"]) + assert.Equal(t, "g1,g2,g3,GROUP1,GROUP2", output.Headers.Get("Impersonate-Group")) + }) + + t.Run("routing key", func(t *testing.T) { + t.Parallel() + + output, err := eval(t, + []protoreflect.ProtoMessage{}, + &HeadersRequest{ + EnableRoutingKey: false, + Session: RequestSession{ID: "s1"}, + }) + require.NoError(t, err) + assert.Empty(t, output.Headers.Get("X-Pomerium-Routing-Key")) + + output, err = eval(t, + []protoreflect.ProtoMessage{}, + &HeadersRequest{ + EnableRoutingKey: true, + Session: RequestSession{ID: "s1"}, + }) + require.NoError(t, err) + assert.Equal(t, "e8bc163c82eee18733288c7d4ac636db3a6deb013ef2d37b68322be20edc45cc", output.Headers.Get("X-Pomerium-Routing-Key")) + }) + + t.Run("jwt payload email", func(t *testing.T) { + t.Parallel() + + output, err := eval(t, + []protoreflect.ProtoMessage{ + &session.Session{Id: "s1", UserId: "u1"}, + &user.User{Id: "u1", Email: "user@example.com"}, + }, + &HeadersRequest{ + Session: RequestSession{ID: "s1"}, + }) + require.NoError(t, err) + assert.Equal(t, "user@example.com", output.Headers.Get("X-Pomerium-Claim-Email")) + + output, err = eval(t, + []protoreflect.ProtoMessage{ + &session.Session{Id: "s1", UserId: "u1"}, + newDirectoryUserRecord(directory.User{ID: "u1", Email: "directory-user@example.com"}), + }, + &HeadersRequest{ + Session: RequestSession{ID: "s1"}, + }) + require.NoError(t, err) + assert.Equal(t, "directory-user@example.com", output.Headers.Get("X-Pomerium-Claim-Email")) + }) + t.Run("jwt payload name", func(t *testing.T) { + t.Parallel() + + output, err := eval(t, + []protoreflect.ProtoMessage{ + &session.Session{Id: "s1", UserId: "u1", Claims: map[string]*structpb.ListValue{ + "name": {Values: []*structpb.Value{ + structpb.NewStringValue("NAME_FROM_SESSION"), + }}, + }}, + }, + &HeadersRequest{ + Session: RequestSession{ID: "s1"}, + }) + require.NoError(t, err) + assert.Equal(t, "NAME_FROM_SESSION", output.Headers.Get("X-Pomerium-Claim-Name")) + + output, err = eval(t, + []protoreflect.ProtoMessage{ + &session.Session{Id: "s1", UserId: "u1"}, + &user.User{Id: "u1", Claims: map[string]*structpb.ListValue{ + "name": {Values: []*structpb.Value{ + structpb.NewStringValue("NAME_FROM_USER"), + }}, + }}, + }, + &HeadersRequest{ + Session: RequestSession{ID: "s1"}, + }) + require.NoError(t, err) + assert.Equal(t, "NAME_FROM_USER", output.Headers.Get("X-Pomerium-Claim-Name")) + }) + + t.Run("service account", func(t *testing.T) { + t.Parallel() + + output, err := eval(t, + []protoreflect.ProtoMessage{ + &user.ServiceAccount{Id: "sa1", UserId: "u1"}, + &user.User{Id: "u1", Email: "u1@example.com"}, + }, + &HeadersRequest{ + Session: RequestSession{ID: "sa1"}, + }) + require.NoError(t, err) + assert.Equal(t, "u1@example.com", output.Headers.Get("X-Pomerium-Claim-Email")) }) } @@ -339,8 +523,24 @@ func decodeJWSPayload(t *testing.T, jws string) []byte { // separated by a '.' character. The payload is the middle one of these. // cf. https://www.rfc-editor.org/rfc/rfc7515#section-7.1 parts := strings.Split(jws, ".") - require.Equal(t, 3, len(parts)) + require.Equal(t, 3, len(parts), "jws should have 3 parts: %s", jws) payload, err := base64.RawURLEncoding.DecodeString(parts[1]) require.NoError(t, err) return payload } + +func newDirectoryGroupRecord(directoryGroup directory.Group) *databroker.Record { + m := map[string]any{} + bs, _ := json.Marshal(directoryGroup) + _ = json.Unmarshal(bs, &m) + s, _ := structpb.NewStruct(m) + return storage.NewStaticRecord(directory.GroupRecordType, s) +} + +func newDirectoryUserRecord(directoryUser directory.User) *databroker.Record { + m := map[string]any{} + bs, _ := json.Marshal(directoryUser) + _ = json.Unmarshal(bs, &m) + s, _ := structpb.NewStruct(m) + return storage.NewStaticRecord(directory.UserRecordType, s) +} diff --git a/pkg/storage/querier.go b/pkg/storage/querier.go index 8e05c44f8..1de0b195d 100644 --- a/pkg/storage/querier.go +++ b/pkg/storage/querier.go @@ -2,6 +2,7 @@ package storage import ( "context" + "encoding/json" "strconv" "sync" @@ -9,6 +10,7 @@ import ( grpc "google.golang.org/grpc" "google.golang.org/grpc/codes" status "google.golang.org/grpc/status" + "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" timestamppb "google.golang.org/protobuf/types/known/timestamppb" @@ -57,26 +59,52 @@ type staticQuerier struct { func NewStaticQuerier(msgs ...proto.Message) Querier { getter := &staticQuerier{} for _, msg := range msgs { - data := protoutil.NewAny(msg) - record := new(databroker.Record) - record.ModifiedAt = timestamppb.Now() - record.Version = cryptutil.NewRandomUInt64() - record.Id = uuid.New().String() - record.Data = data - record.Type = data.TypeUrl - if hasID, ok := msg.(interface{ GetId() string }); ok { - record.Id = hasID.GetId() - } - if hasVersion, ok := msg.(interface{ GetVersion() string }); ok { - if v, err := strconv.ParseUint(hasVersion.GetVersion(), 10, 64); err == nil { - record.Version = v - } + record, ok := msg.(*databroker.Record) + if !ok { + record = NewStaticRecord(protoutil.NewAny(msg).TypeUrl, msg) } getter.records = append(getter.records, record) } return getter } +// NewStaticRecord creates a new databroker Record from a protobuf message. +func NewStaticRecord(typeURL string, msg proto.Message) *databroker.Record { + data := protoutil.NewAny(msg) + record := new(databroker.Record) + record.ModifiedAt = timestamppb.Now() + record.Version = cryptutil.NewRandomUInt64() + record.Id = uuid.New().String() + record.Data = data + record.Type = typeURL + if hasID, ok := msg.(interface{ GetId() string }); ok { + record.Id = hasID.GetId() + } + if hasVersion, ok := msg.(interface{ GetVersion() string }); ok { + if v, err := strconv.ParseUint(hasVersion.GetVersion(), 10, 64); err == nil { + record.Version = v + } + } + + var jsonData struct { + ID string `json:"id"` + Version string `json:"version"` + } + bs, _ := protojson.Marshal(msg) + _ = json.Unmarshal(bs, &jsonData) + + if jsonData.ID != "" { + record.Id = jsonData.ID + } + if jsonData.Version != "" { + if v, err := strconv.ParseUint(jsonData.Version, 10, 64); err == nil { + record.Version = v + } + } + + return record +} + func (q *staticQuerier) InvalidateCache(_ context.Context, _ *databroker.QueryRequest) {} // Query queries for records.