add tests

This commit is contained in:
Denis Mishin 2025-04-24 13:38:25 -04:00
parent f7f0769473
commit 2af3538f54
2 changed files with 67 additions and 2 deletions

View file

@ -50,14 +50,15 @@ func (storage *Storage) GetClientByID(
ctx context.Context, ctx context.Context,
id string, id string,
) (*rfc7591v1.ClientMetadata, error) { ) (*rfc7591v1.ClientMetadata, error) {
v := new(rfc7591v1.ClientMetadata)
rec, err := storage.client.Get(ctx, &databroker.GetRequest{ rec, err := storage.client.Get(ctx, &databroker.GetRequest{
Id: id, Type: protoutil.GetTypeURL(v),
Id: id,
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get client by ID: %w", err) return nil, fmt.Errorf("failed to get client by ID: %w", err)
} }
v := new(rfc7591v1.ClientMetadata)
err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{}) err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{})
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to unmarshal client registration request: %w", err) return nil, fmt.Errorf("failed to unmarshal client registration request: %w", err)

View file

@ -0,0 +1,64 @@
package mcp_test
import (
"context"
"net"
"testing"
"time"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel/trace/noop"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/test/bufconn"
"github.com/pomerium/pomerium/internal/databroker"
"github.com/pomerium/pomerium/internal/mcp"
rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591"
"github.com/pomerium/pomerium/internal/testutil"
databroker_grpc "github.com/pomerium/pomerium/pkg/grpc/databroker"
)
func TestStorage(t *testing.T) {
t.Parallel()
ctx := testutil.GetContext(t, time.Minute*5)
list := bufconn.Listen(1024 * 1024)
t.Cleanup(func() {
list.Close()
})
srv := databroker.New(ctx, noop.NewTracerProvider())
grpcServer := grpc.NewServer()
databroker_grpc.RegisterDataBrokerServiceServer(grpcServer, srv)
go func() {
if err := grpcServer.Serve(list); err != nil {
t.Errorf("failed to serve: %v", err)
}
}()
t.Cleanup(func() {
grpcServer.Stop()
})
conn, err := grpc.DialContext(ctx, "bufnet",
grpc.WithContextDialer(func(context.Context, string) (net.Conn, error) {
return list.Dial()
}),
grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
client := databroker_grpc.NewDataBrokerServiceClient(conn)
t.Run("client registration", func(t *testing.T) {
storage := mcp.NewStorage(client)
id, err := storage.RegisterClient(ctx, &rfc7591v1.ClientMetadata{})
require.NoError(t, err)
require.NotEmpty(t, id)
_, err = storage.GetClientByID(ctx, id)
require.NoError(t, err)
})
}