diff --git a/internal/mcp/storage.go b/internal/mcp/storage.go index 6fb3b30b4..7d988a845 100644 --- a/internal/mcp/storage.go +++ b/internal/mcp/storage.go @@ -50,14 +50,15 @@ func (storage *Storage) GetClientByID( ctx context.Context, id string, ) (*rfc7591v1.ClientMetadata, error) { + v := new(rfc7591v1.ClientMetadata) rec, err := storage.client.Get(ctx, &databroker.GetRequest{ - Id: id, + Type: protoutil.GetTypeURL(v), + Id: id, }) if err != nil { return nil, fmt.Errorf("failed to get client by ID: %w", err) } - v := new(rfc7591v1.ClientMetadata) err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{}) if err != nil { return nil, fmt.Errorf("failed to unmarshal client registration request: %w", err) diff --git a/internal/mcp/storage_test.go b/internal/mcp/storage_test.go new file mode 100644 index 000000000..28b709165 --- /dev/null +++ b/internal/mcp/storage_test.go @@ -0,0 +1,64 @@ +package mcp_test + +import ( + "context" + "net" + "testing" + "time" + + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel/trace/noop" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" + "google.golang.org/grpc/test/bufconn" + + "github.com/pomerium/pomerium/internal/databroker" + "github.com/pomerium/pomerium/internal/mcp" + rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591" + "github.com/pomerium/pomerium/internal/testutil" + databroker_grpc "github.com/pomerium/pomerium/pkg/grpc/databroker" +) + +func TestStorage(t *testing.T) { + t.Parallel() + + ctx := testutil.GetContext(t, time.Minute*5) + + list := bufconn.Listen(1024 * 1024) + t.Cleanup(func() { + list.Close() + }) + + srv := databroker.New(ctx, noop.NewTracerProvider()) + grpcServer := grpc.NewServer() + databroker_grpc.RegisterDataBrokerServiceServer(grpcServer, srv) + + go func() { + if err := grpcServer.Serve(list); err != nil { + t.Errorf("failed to serve: %v", err) + } + }() + t.Cleanup(func() { + grpcServer.Stop() + }) + + conn, err := grpc.DialContext(ctx, "bufnet", + grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) { + return list.Dial() + }), + grpc.WithTransportCredentials(insecure.NewCredentials())) + require.NoError(t, err) + + client := databroker_grpc.NewDataBrokerServiceClient(conn) + + t.Run("client registration", func(t *testing.T) { + storage := mcp.NewStorage(client) + + id, err := storage.RegisterClient(ctx, &rfc7591v1.ClientMetadata{}) + require.NoError(t, err) + require.NotEmpty(t, id) + + _, err = storage.GetClientByID(ctx, id) + require.NoError(t, err) + }) +}