diff --git a/pkg/grpc/config/config.go b/pkg/grpc/config/config.go index 4e4c11bc5..f3a15dd7f 100644 --- a/pkg/grpc/config/config.go +++ b/pkg/grpc/config/config.go @@ -16,3 +16,7 @@ func (rr *RouteRedirect) IsSet() bool { rr.SchemeRedirect != nil || rr.HttpsRedirect != nil } + +func (x *Config) GetId() string { //nolint + return x.Name +} diff --git a/proxy/data.go b/proxy/data.go index 288647a58..8abfc7205 100644 --- a/proxy/data.go +++ b/proxy/data.go @@ -64,7 +64,10 @@ func (p *Proxy) getUserInfoData(r *http.Request) handlers.UserInfoData { func (p *Proxy) fillEnterpriseUserInfoData(ctx context.Context, data *handlers.UserInfoData) { client := p.state.Load().dataBrokerClient - res, _ := client.Get(ctx, &databroker.GetRequest{Type: "type.googleapis.com/pomerium.config.Config", Id: "dashboard"}) + res, _ := client.Get(ctx, &databroker.GetRequest{ + Type: "type.googleapis.com/pomerium.config.Config", + Id: "dashboard-settings", + }) data.IsEnterprise = res.GetRecord() != nil if !data.IsEnterprise { return diff --git a/proxy/data_test.go b/proxy/data_test.go new file mode 100644 index 000000000..7b2e090f4 --- /dev/null +++ b/proxy/data_test.go @@ -0,0 +1,94 @@ +package proxy + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/pomerium/datasource/pkg/directory" + "github.com/pomerium/pomerium/config" + "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/internal/sessions" + "github.com/pomerium/pomerium/internal/testutil" + configpb "github.com/pomerium/pomerium/pkg/grpc/config" + databrokerpb "github.com/pomerium/pomerium/pkg/grpc/databroker" + "github.com/pomerium/pomerium/pkg/grpc/session" + "github.com/pomerium/pomerium/pkg/grpc/user" + "github.com/pomerium/pomerium/pkg/protoutil" +) + +func Test_getUserInfoData(t *testing.T) { + t.Parallel() + + ctx, clearTimeout := context.WithTimeout(context.Background(), time.Second*10) + defer clearTimeout() + + cc := testutil.NewGRPCServer(t, func(srv *grpc.Server) { + databrokerpb.RegisterDataBrokerServiceServer(srv, databroker.New()) + }) + t.Cleanup(func() { cc.Close() }) + + client := databrokerpb.NewDataBrokerServiceClient(cc) + + opts := testOptions(t) + proxy, err := New(&config.Config{Options: opts}) + require.NoError(t, err) + proxy.state.Load().dataBrokerClient = client + + require.NoError(t, databrokerpb.PutMulti(ctx, client, + makeRecord(&session.Session{ + Id: "S1", + UserId: "U1", + }), + makeRecord(&user.User{ + Id: "U1", + }), + makeRecord(&configpb.Config{ + Name: "dashboard-settings", + }), + makeStructRecord(directory.UserRecordType, "U1", map[string]any{ + "group_ids": []any{"G1", "G2", "G3"}, + }))) + + r := httptest.NewRequest(http.MethodGet, "/.pomerium/", nil) + r.Header.Set("Authorization", "Bearer Pomerium-"+encodeSession(t, opts, &sessions.State{ + ID: "S1", + })) + data := proxy.getUserInfoData(r) + assert.Equal(t, "S1", data.Session.Id) + assert.Equal(t, "U1", data.User.Id) + assert.True(t, data.IsEnterprise) + assert.Equal(t, []string{"G1", "G2", "G3"}, data.DirectoryUser.GroupIDs) +} + +func makeRecord(object interface { + proto.Message + GetId() string +}, +) *databrokerpb.Record { + a := protoutil.NewAny(object) + return &databrokerpb.Record{ + Type: a.GetTypeUrl(), + Id: object.GetId(), + Data: a, + ModifiedAt: timestamppb.Now(), + } +} + +func makeStructRecord(recordType, recordID string, object any) *databrokerpb.Record { + s := protoutil.ToStruct(object).GetStructValue() + return &databrokerpb.Record{ + Type: recordType, + Id: recordID, + Data: protoutil.NewAny(s), + ModifiedAt: timestamppb.Now(), + } +}