[core] Graph: add importGraphContent API

Extract the logic of importing the content of a graph within a graph instance from
the graph loading logic.
Add `Graph.importGraphContent` and `Graph.importGraphContentFromFile`
methods.
Use the deserialization API to load the content in another temporary graph instance,
to handle the renaming of nodes using the Graph API, rather than manipulating
entries in a raw dictionnary.
This commit is contained in:
Yann Lanthony 2025-02-06 16:46:04 +01:00
parent 7eab289d30
commit 4aec741a89
4 changed files with 258 additions and 26 deletions

View file

@ -243,7 +243,6 @@ class Graph(BaseObject):
self._nodes = DictModel(keyAttrName='name', parent=self) self._nodes = DictModel(keyAttrName='name', parent=self)
# Edges: use dst attribute as unique key since it can only have one input connection # Edges: use dst attribute as unique key since it can only have one input connection
self._edges = DictModel(keyAttrName='dst', parent=self) self._edges = DictModel(keyAttrName='dst', parent=self)
self._importedNodes = DictModel(keyAttrName='name', parent=self)
self._compatibilityNodes = DictModel(keyAttrName='name', parent=self) self._compatibilityNodes = DictModel(keyAttrName='name', parent=self)
self.cacheDir = meshroom.core.defaultCacheFolder self.cacheDir = meshroom.core.defaultCacheFolder
self._filepath = '' self._filepath = ''
@ -251,15 +250,17 @@ class Graph(BaseObject):
self.header = {} self.header = {}
def clear(self): def clear(self):
self._clearGraphContent()
self.header.clear() self.header.clear()
self._compatibilityNodes.clear() self._unsetFilepath()
def _clearGraphContent(self):
self._edges.clear() self._edges.clear()
# Tell QML nodes are going to be deleted # Tell QML nodes are going to be deleted
for node in self._nodes: for node in self._nodes:
node.alive = False node.alive = False
self._importedNodes.clear()
self._nodes.clear() self._nodes.clear()
self._unsetFilepath() self._compatibilityNodes.clear()
@property @property
def fileFeatures(self): def fileFeatures(self):
@ -330,7 +331,7 @@ class Graph(BaseObject):
for nodeName, nodeData in sorted( for nodeName, nodeData in sorted(
graphContent.items(), key=lambda x: self.getNodeIndexFromName(x[0]) graphContent.items(), key=lambda x: self.getNodeIndexFromName(x[0])
): ):
self._deserializeNode(nodeData, nodeName) self._deserializeNode(nodeData, nodeName, self)
# Create graph edges by resolving attributes expressions # Create graph edges by resolving attributes expressions
self._applyExpr() self._applyExpr()
@ -374,14 +375,14 @@ class Graph(BaseObject):
return graphContent return graphContent
def _deserializeNode(self, nodeData: dict, nodeName: str): def _deserializeNode(self, nodeData: dict, nodeName: str, fromGraph: "Graph"):
# Retrieve version from # Retrieve version from
# 1. nodeData: node saved from a CompatibilityNode # 1. nodeData: node saved from a CompatibilityNode
# 2. nodesVersion in file header: node saved from a Node # 2. nodesVersion in file header: node saved from a Node
# 3. fallback behavior: default to "0.0" # 3. fallback behavior: default to "0.0"
if "version" not in nodeData: if "version" not in nodeData:
nodeData["version"] = self._getNodeTypeVersionFromHeader(nodeData["nodeType"], "0.0") nodeData["version"] = fromGraph._getNodeTypeVersionFromHeader(nodeData["nodeType"], "0.0")
inTemplate = self.header.get("template", False) inTemplate = fromGraph.header.get("template", False)
node = nodeFactory(nodeData, nodeName, inTemplate=inTemplate) node = nodeFactory(nodeData, nodeName, inTemplate=inTemplate)
self._addNode(node, nodeName) self._addNode(node, nodeName)
return node return node
@ -555,6 +556,58 @@ class Graph(BaseObject):
return attributes return attributes
def importGraphContentFromFile(self, filepath: PathLike) -> list[Node]:
"""Import the content (nodes and edges) of another Graph file into this Graph instance.
Args:
filepath: The path to the Graph file to import.
Returns:
The list of newly created Nodes.
"""
graph = loadGraph(filepath)
return self.importGraphContent(graph)
@blockNodeCallbacks
def importGraphContent(self, graph: "Graph") -> list[Node]:
"""
Import the content (node and edges) of another `graph` into this Graph instance.
Nodes are imported with their original names if possible, otherwise a new unique name is generated
from their node type.
Args:
graph: The graph to import.
Returns:
The list of newly created Nodes.
"""
def _renameClashingNodes():
if not self.nodes:
return
unavailableNames = set(self.nodes.keys())
for node in graph.nodes:
if node._name in unavailableNames:
node._name = self._createUniqueNodeName(node.nodeType, unavailableNames)
unavailableNames.add(node._name)
def _importNodeAndEdges() -> list[Node]:
importedNodes = []
# If we import the content of the graph within itself,
# iterate over a copy of the nodes as the graph is modified during the iteration.
nodes = graph.nodes if graph is not self else list(graph.nodes)
with GraphModification(self):
for srcNode in nodes:
node = self._deserializeNode(srcNode.toDict(), srcNode.name, graph)
importedNodes.append(node)
self._applyExpr()
return importedNodes
_renameClashingNodes()
importedNodes = _importNodeAndEdges()
return importedNodes
@property @property
def updateEnabled(self): def updateEnabled(self):
return self._updateEnabled return self._updateEnabled
@ -766,8 +819,6 @@ class Graph(BaseObject):
node.alive = False node.alive = False
self._nodes.remove(node) self._nodes.remove(node)
if node in self._importedNodes:
self._importedNodes.remove(node)
self.update() self.update()
return inEdges, outEdges, outListAttributes return inEdges, outEdges, outListAttributes
@ -792,13 +843,21 @@ class Graph(BaseObject):
n.updateInternals() n.updateInternals()
return n return n
def _createUniqueNodeName(self, inputName): def _createUniqueNodeName(self, inputName: str, existingNames: Optional[set[str]] = None):
i = 1 """Create a unique node name based on the input name.
while i:
newName = "{name}_{index}".format(name=inputName, index=i) Args:
if newName not in self._nodes.objects: inputName: The desired node name.
existingNames: (optional) If specified, consider this set for uniqueness check, instead of the list of nodes.
"""
existingNodeNames = existingNames or set(self._nodes.objects.keys())
idx = 1
while idx:
newName = f"{inputName}_{idx}"
if newName not in existingNodeNames:
return newName return newName
i += 1 idx += 1
def node(self, nodeName): def node(self, nodeName):
return self._nodes.get(nodeName) return self._nodes.get(nodeName)
@ -1635,11 +1694,6 @@ class Graph(BaseObject):
def edges(self): def edges(self):
return self._edges return self._edges
@property
def importedNodes(self):
"""" Return the list of nodes that were added to the graph with the latest 'Import Project' action. """
return self._importedNodes
@property @property
def cacheDir(self): def cacheDir(self):
return self._cacheDir return self._cacheDir

View file

@ -6,9 +6,10 @@ from PySide6.QtGui import QUndoCommand, QUndoStack
from PySide6.QtCore import Property, Signal from PySide6.QtCore import Property, Signal
from meshroom.core.attribute import ListAttribute, Attribute from meshroom.core.attribute import ListAttribute, Attribute
from meshroom.core.graph import GraphModification from meshroom.core.graph import Graph, GraphModification
from meshroom.core.node import Position from meshroom.core.node import Position
from meshroom.core.nodeFactory import nodeFactory from meshroom.core.nodeFactory import nodeFactory
from meshroom.core.typing import PathLike
class UndoCommand(QUndoCommand): class UndoCommand(QUndoCommand):
@ -232,7 +233,8 @@ class ImportProjectCommand(GraphCommand):
""" """
Handle the import of a project into a Graph. Handle the import of a project into a Graph.
""" """
def __init__(self, graph, filepath=None, position=None, yOffset=0, parent=None):
def __init__(self, graph: Graph, filepath: PathLike, position=None, yOffset=0, parent=None):
super(ImportProjectCommand, self).__init__(graph, parent) super(ImportProjectCommand, self).__init__(graph, parent)
self.filepath = filepath self.filepath = filepath
self.importedNames = [] self.importedNames = []
@ -240,9 +242,8 @@ class ImportProjectCommand(GraphCommand):
self.yOffset = yOffset self.yOffset = yOffset
def redoImpl(self): def redoImpl(self):
status = self.graph.load(self.filepath, setupProjectFile=False, importProject=True) importedNodes = self.graph.importGraphContentFromFile(self.filepath)
importedNodes = self.graph.importedNodes self.setText(f"Import Project ({len(importedNodes)} nodes)")
self.setText("Import Project ({} nodes)".format(importedNodes.count))
lowestY = 0 lowestY = 0
for node in self.graph.nodes: for node in self.graph.nodes:

152
tests/test_graphIO.py Normal file
View file

@ -0,0 +1,152 @@
from meshroom.core import desc
from meshroom.core.graph import Graph
from .utils import registeredNodeTypes
class SimpleNode(desc.Node):
inputs = [
desc.File(name="input", label="Input", description="", value=""),
]
outputs = [
desc.File(name="output", label="Output", description="", value=""),
]
def compareGraphsContent(graphA: Graph, graphB: Graph) -> bool:
"""Returns whether the content (node and deges) of two graphs are considered identical.
Similar nodes: nodes with the same name, type and compatibility status.
Similar edges: edges with the same source and destination attribute names.
"""
def _buildNodesSet(graph: Graph):
return set([(node.name, node.nodeType, node.isCompatibilityNode) for node in graph.nodes])
def _buildEdgesSet(graph: Graph):
return set([(edge.src.fullName, edge.dst.fullName) for edge in graph.edges])
return _buildNodesSet(graphA) == _buildNodesSet(graphB) and _buildEdgesSet(graphA) == _buildEdgesSet(
graphB
)
class TestImportGraphContent:
def test_importEmptyGraph(self):
graph = Graph("")
otherGraph = Graph("")
nodes = otherGraph.importGraphContent(graph)
assert len(nodes) == 0
assert len(graph.nodes) == 0
def test_importGraphWithSingleNode(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
graph.addNewNode(SimpleNode.__name__)
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
assert compareGraphsContent(graph, otherGraph)
def test_importGraphWithSeveralNodes(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
graph.addNewNode(SimpleNode.__name__)
graph.addNewNode(SimpleNode.__name__)
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
assert compareGraphsContent(graph, otherGraph)
def test_importingGraphWithNodesAndEdges(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
assert compareGraphsContent(graph, otherGraph)
def test_edgeRemappingOnImportingGraphSeveralTimes(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
otherGraph.importGraphContent(graph)
def test_importGraphWithUnknownNodeTypesCreatesCompatibilityNodes(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
graph.addNewNode(SimpleNode.__name__)
otherGraph = Graph("")
importedNode = otherGraph.importGraphContent(graph)
assert len(importedNode) == 1
assert importedNode[0].isCompatibilityNode
def test_importGraphContentInPlace(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
graph.importGraphContent(graph)
assert len(graph.nodes) == 4
def test_importGraphContentFromFile(self, graphSavedOnDisk):
graph: Graph = graphSavedOnDisk
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
graph.save()
otherGraph = Graph("")
nodes = otherGraph.importGraphContentFromFile(graph.filepath)
assert len(nodes) == 2
assert compareGraphsContent(graph, otherGraph)
def test_importGraphContentFromFileWithCompatibilityNodes(self, graphSavedOnDisk):
graph: Graph = graphSavedOnDisk
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
graph.save()
otherGraph = Graph("")
nodes = otherGraph.importGraphContentFromFile(graph.filepath)
assert len(nodes) == 2
assert len(otherGraph.compatibilityNodes) == 2
assert not compareGraphsContent(graph, otherGraph)

View file

@ -431,3 +431,28 @@ class TestAttributeCallbackBehaviorWithUpstreamDynamicOutputs:
assert nodeB.affectedInput.value == 0 assert nodeB.affectedInput.value == 0
class TestAttributeCallbackBehaviorOnGraphImport:
@classmethod
def setup_class(cls):
registerNodeType(NodeWithAttributeChangedCallback)
@classmethod
def teardown_class(cls):
unregisterNodeType(NodeWithAttributeChangedCallback)
def test_importingGraphDoesNotTriggerAttributeChangedCallbacks(self):
graph = Graph("")
nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
graph.addEdge(nodeA.affectedInput, nodeB.input)
nodeA.input.value = 5
nodeB.affectedInput.value = 2
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
assert otherGraph.node(nodeB.name).affectedInput.value == 2