diff --git a/internal/mcp/handler_authorization.go b/internal/mcp/handler_authorization.go index 0021d904e..de4613ce0 100644 --- a/internal/mcp/handler_authorization.go +++ b/internal/mcp/handler_authorization.go @@ -41,7 +41,7 @@ func (srv *Handler) Authorize(w http.ResponseWriter, r *http.Request) { return } - client, err := srv.storage.GetClientByID(ctx, v.ClientId) + client, err := srv.storage.GetClient(ctx, v.ClientId) if err != nil && status.Code(err) == codes.NotFound { log.Ctx(ctx).Error().Err(err).Str("id", v.ClientId).Msg("client id not found") oauth21.ErrorResponse(w, http.StatusUnauthorized, oauth21.InvalidClient) diff --git a/internal/mcp/storage.go b/internal/mcp/storage.go index 7d988a845..3fd3d0e4f 100644 --- a/internal/mcp/storage.go +++ b/internal/mcp/storage.go @@ -46,7 +46,7 @@ func (storage *Storage) RegisterClient( return id, nil } -func (storage *Storage) GetClientByID( +func (storage *Storage) GetClient( ctx context.Context, id string, ) (*rfc7591v1.ClientMetadata, error) { @@ -85,3 +85,24 @@ func (storage *Storage) CreateAuthorizationRequest( } return id, nil } + +func (storage *Storage) GetAuthorizationRequest( + ctx context.Context, + id string, +) (*oauth21proto.AuthorizationRequest, error) { + v := new(oauth21proto.AuthorizationRequest) + rec, err := storage.client.Get(ctx, &databroker.GetRequest{ + Type: protoutil.GetTypeURL(v), + Id: id, + }) + if err != nil { + return nil, fmt.Errorf("failed to get authorization request by ID: %w", err) + } + + err = anypb.UnmarshalTo(rec.Record.Data, v, proto.UnmarshalOptions{}) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal authorization request: %w", err) + } + + return v, nil +} diff --git a/internal/mcp/storage_test.go b/internal/mcp/storage_test.go index 28b709165..f036bdd6e 100644 --- a/internal/mcp/storage_test.go +++ b/internal/mcp/storage_test.go @@ -14,6 +14,7 @@ import ( "github.com/pomerium/pomerium/internal/databroker" "github.com/pomerium/pomerium/internal/mcp" + oauth21proto "github.com/pomerium/pomerium/internal/oauth21/gen" rfc7591v1 "github.com/pomerium/pomerium/internal/rfc7591" "github.com/pomerium/pomerium/internal/testutil" databroker_grpc "github.com/pomerium/pomerium/pkg/grpc/databroker" @@ -50,15 +51,26 @@ func TestStorage(t *testing.T) { require.NoError(t, err) client := databroker_grpc.NewDataBrokerServiceClient(conn) + storage := mcp.NewStorage(client) t.Run("client registration", func(t *testing.T) { - storage := mcp.NewStorage(client) + t.Parallel() id, err := storage.RegisterClient(ctx, &rfc7591v1.ClientMetadata{}) require.NoError(t, err) require.NotEmpty(t, id) - _, err = storage.GetClientByID(ctx, id) + _, err = storage.GetClient(ctx, id) + require.NoError(t, err) + }) + + t.Run("authorization request", func(t *testing.T) { + t.Parallel() + + id, err := storage.CreateAuthorizationRequest(ctx, &oauth21proto.AuthorizationRequest{}) + require.NoError(t, err) + + _, err = storage.GetAuthorizationRequest(ctx, id) require.NoError(t, err) }) }