From 61847b1fdcb396d2878780acced84b5bfa360801 Mon Sep 17 00:00:00 2001 From: "backport-actions-token[bot]" <87506591+backport-actions-token[bot]@users.noreply.github.com> Date: Mon, 23 Sep 2024 13:52:34 -0600 Subject: [PATCH] core/proxy: fix is-enterprise check (#5297) * core/proxy: fix is-enterprise check (#5295) * add testutil --------- Co-authored-by: Caleb Doxsey --- internal/testutil/grpc.go | 44 ++++++++++++++++++ pkg/grpc/config/config.go | 4 ++ proxy/data.go | 5 ++- proxy/data_test.go | 94 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 146 insertions(+), 1 deletion(-) create mode 100644 internal/testutil/grpc.go create mode 100644 proxy/data_test.go diff --git a/internal/testutil/grpc.go b/internal/testutil/grpc.go new file mode 100644 index 000000000..3b34f4575 --- /dev/null +++ b/internal/testutil/grpc.go @@ -0,0 +1,44 @@ +package testutil + +import ( + "context" + "errors" + "net" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" +) + +// NewGRPCServer starts a gRPC server and returns a client connection to it. +func NewGRPCServer(t testing.TB, register func(s *grpc.Server)) *grpc.ClientConn { + t.Helper() + + li := bufconn.Listen(1024 * 1024) + s := grpc.NewServer() + register(s) + go func() { + err := s.Serve(li) + if errors.Is(err, grpc.ErrServerStopped) { + err = nil + } + require.NoError(t, err) + }() + t.Cleanup(func() { + s.Stop() + }) + + cc, err := grpc.NewClient("passthrough://bufnet", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + return li.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + t.Cleanup(func() { + cc.Close() + }) + + return cc +} diff --git a/pkg/grpc/config/config.go b/pkg/grpc/config/config.go index 4e4c11bc5..f3a15dd7f 100644 --- a/pkg/grpc/config/config.go +++ b/pkg/grpc/config/config.go @@ -16,3 +16,7 @@ func (rr *RouteRedirect) IsSet() bool { rr.SchemeRedirect != nil || rr.HttpsRedirect != nil } + +func (x *Config) GetId() string { //nolint + return x.Name +} diff --git a/proxy/data.go b/proxy/data.go index 288647a58..8abfc7205 100644 --- a/proxy/data.go +++ b/proxy/data.go @@ -64,7 +64,10 @@ func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData { func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) { client := p.state.Load().dataBrokerClient - res, _ := client.Get(ctx, &databroker.GetRequest{Type: "type.googleapis.com/pomerium.config.Config", Id: "dashboard"}) + res, _ := client.Get(ctx, &databroker.GetRequest{ + Type: "type.googleapis.com/pomerium.config.Config", + Id: "dashboard-settings", + }) data.IsEnterprise = res.GetRecord() != nil if !data.IsEnterprise { return diff --git a/proxy/data_test.go b/proxy/data_test.go new file mode 100644 index 000000000..7b2e090f4 --- /dev/null +++ b/proxy/data_test.go @@ -0,0 +1,94 @@ +package proxy + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/datasource/pkg/directory" + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/testutil" + configpb "github.com/pomerium/pomerium/pkg/grpc/config" + databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +func Test_getUserInfoData(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { + databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New()) + }) + t.Cleanup(func() { cc.Close() }) + + client := databrokerpb.NewDataBrokerServiceClient(cc) + + opts := testOptions(t) + proxy, err := New(&config.Config{Options: opts}) + require.NoError(t, err) + proxy.state.Load().dataBrokerClient = client + + require.NoError(t, databrokerpb.PutMulti(ctx, client, + makeRecord(&session.Session{ + Id: "S1", + UserId: "U1", + }), + makeRecord(&user.User{ + Id: "U1", + }), + makeRecord(&configpb.Config{ + Name: "dashboard-settings", + }), + makeStructRecord(directory.UserRecordType, "U1", map[string]any{ + "group_ids": []any{"G1", "G2", "G3"}, + }))) + + r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) + r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{ + ID: "S1", + })) + data := proxy.getUserInfoData(r) + assert.Equal(t, "S1", data.Session.Id) + assert.Equal(t, "U1", data.User.Id) + assert.True(t, data.IsEnterprise) + assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs) +} + +func makeRecord(object interface { + proto.Message + GetId() string +}, +) *databrokerpb.Record { + a := protoutil.NewAny(object) + return &databrokerpb.Record{ + Type: a.GetTypeUrl(), + Id: object.GetId(), + Data: a, + ModifiedAt: timestamppb.Now(), + } +} + +func makeStructRecord(recordType, recordID string, object any) *databrokerpb.Record { + s := protoutil.ToStruct(object).GetStructValue() + return &databrokerpb.Record{ + Type: recordType, + Id: recordID, + Data: protoutil.NewAny(s), + ModifiedAt: timestamppb.Now(), + } +}