From 4aec741a89a50c421b25a11fb748225f187c0d47 Mon Sep 17 00:00:00 2001 From: Yann Lanthony Date: Thu, 6 Feb 2025 16:46:04 +0100 Subject: [PATCH] [core] Graph: add importGraphContent API Extract the logic of importing the content of a graph within a graph instance from the graph loading logic. Add `Graph.importGraphContent` and `Graph.importGraphContentFromFile` methods. Use the deserialization API to load the content in another temporary graph instance, to handle the renaming of nodes using the Graph API, rather than manipulating entries in a raw dictionnary. --- meshroom/core/graph.py | 96 ++++++++++--- meshroom/ui/commands.py | 11 +- tests/test_graphIO.py | 152 +++++++++++++++++++++ tests/test_nodeAttributeChangedCallback.py | 25 ++++ 4 files changed, 258 insertions(+), 26 deletions(-) create mode 100644 tests/test_graphIO.py diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index 1f454b20..305b6415 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -243,7 +243,6 @@ class Graph(BaseObject): self._nodes = DictModel(keyAttrName='name', parent=self) # Edges: use dst attribute as unique key since it can only have one input connection self._edges = DictModel(keyAttrName='dst', parent=self) - self._importedNodes = DictModel(keyAttrName='name', parent=self) self._compatibilityNodes = DictModel(keyAttrName='name', parent=self) self.cacheDir = meshroom.core.defaultCacheFolder self._filepath = '' @@ -251,15 +250,17 @@ class Graph(BaseObject): self.header = {} def clear(self): + self._clearGraphContent() self.header.clear() - self._compatibilityNodes.clear() + self._unsetFilepath() + + def _clearGraphContent(self): self._edges.clear() # Tell QML nodes are going to be deleted for node in self._nodes: node.alive = False - self._importedNodes.clear() self._nodes.clear() - self._unsetFilepath() + self._compatibilityNodes.clear() @property def fileFeatures(self): @@ -330,7 +331,7 @@ class Graph(BaseObject): for nodeName, nodeData in sorted( graphContent.items(), key=lambda x: self.getNodeIndexFromName(x[0]) ): - self._deserializeNode(nodeData, nodeName) + self._deserializeNode(nodeData, nodeName, self) # Create graph edges by resolving attributes expressions self._applyExpr() @@ -374,14 +375,14 @@ class Graph(BaseObject): return graphContent - def _deserializeNode(self, nodeData: dict, nodeName: str): + def _deserializeNode(self, nodeData: dict, nodeName: str, fromGraph: "Graph"): # Retrieve version from # 1. nodeData: node saved from a CompatibilityNode # 2. nodesVersion in file header: node saved from a Node # 3. fallback behavior: default to "0.0" if "version" not in nodeData: - nodeData["version"] = self._getNodeTypeVersionFromHeader(nodeData["nodeType"], "0.0") - inTemplate = self.header.get("template", False) + nodeData["version"] = fromGraph._getNodeTypeVersionFromHeader(nodeData["nodeType"], "0.0") + inTemplate = fromGraph.header.get("template", False) node = nodeFactory(nodeData, nodeName, inTemplate=inTemplate) self._addNode(node, nodeName) return node @@ -555,6 +556,58 @@ class Graph(BaseObject): return attributes + def importGraphContentFromFile(self, filepath: PathLike) -> list[Node]: + """Import the content (nodes and edges) of another Graph file into this Graph instance. + + Args: + filepath: The path to the Graph file to import. + + Returns: + The list of newly created Nodes. + """ + graph = loadGraph(filepath) + return self.importGraphContent(graph) + + @blockNodeCallbacks + def importGraphContent(self, graph: "Graph") -> list[Node]: + """ + Import the content (node and edges) of another `graph` into this Graph instance. + + Nodes are imported with their original names if possible, otherwise a new unique name is generated + from their node type. + + Args: + graph: The graph to import. + + Returns: + The list of newly created Nodes. + """ + + def _renameClashingNodes(): + if not self.nodes: + return + unavailableNames = set(self.nodes.keys()) + for node in graph.nodes: + if node._name in unavailableNames: + node._name = self._createUniqueNodeName(node.nodeType, unavailableNames) + unavailableNames.add(node._name) + + def _importNodeAndEdges() -> list[Node]: + importedNodes = [] + # If we import the content of the graph within itself, + # iterate over a copy of the nodes as the graph is modified during the iteration. + nodes = graph.nodes if graph is not self else list(graph.nodes) + with GraphModification(self): + for srcNode in nodes: + node = self._deserializeNode(srcNode.toDict(), srcNode.name, graph) + importedNodes.append(node) + self._applyExpr() + return importedNodes + + _renameClashingNodes() + importedNodes = _importNodeAndEdges() + return importedNodes + @property def updateEnabled(self): return self._updateEnabled @@ -766,8 +819,6 @@ class Graph(BaseObject): node.alive = False self._nodes.remove(node) - if node in self._importedNodes: - self._importedNodes.remove(node) self.update() return inEdges, outEdges, outListAttributes @@ -792,13 +843,21 @@ class Graph(BaseObject): n.updateInternals() return n - def _createUniqueNodeName(self, inputName): - i = 1 - while i: - newName = "{name}_{index}".format(name=inputName, index=i) - if newName not in self._nodes.objects: + def _createUniqueNodeName(self, inputName: str, existingNames: Optional[set[str]] = None): + """Create a unique node name based on the input name. + + Args: + inputName: The desired node name. + existingNames: (optional) If specified, consider this set for uniqueness check, instead of the list of nodes. + """ + existingNodeNames = existingNames or set(self._nodes.objects.keys()) + + idx = 1 + while idx: + newName = f"{inputName}_{idx}" + if newName not in existingNodeNames: return newName - i += 1 + idx += 1 def node(self, nodeName): return self._nodes.get(nodeName) @@ -1635,11 +1694,6 @@ class Graph(BaseObject): def edges(self): return self._edges - @property - def importedNodes(self): - """" Return the list of nodes that were added to the graph with the latest 'Import Project' action. """ - return self._importedNodes - @property def cacheDir(self): return self._cacheDir diff --git a/meshroom/ui/commands.py b/meshroom/ui/commands.py index 52e3151e..7d8ccc1f 100755 --- a/meshroom/ui/commands.py +++ b/meshroom/ui/commands.py @@ -6,9 +6,10 @@ from PySide6.QtGui import QUndoCommand, QUndoStack from PySide6.QtCore import Property, Signal from meshroom.core.attribute import ListAttribute, Attribute -from meshroom.core.graph import GraphModification +from meshroom.core.graph import Graph, GraphModification from meshroom.core.node import Position from meshroom.core.nodeFactory import nodeFactory +from meshroom.core.typing import PathLike class UndoCommand(QUndoCommand): @@ -232,7 +233,8 @@ class ImportProjectCommand(GraphCommand): """ Handle the import of a project into a Graph. """ - def __init__(self, graph, filepath=None, position=None, yOffset=0, parent=None): + + def __init__(self, graph: Graph, filepath: PathLike, position=None, yOffset=0, parent=None): super(ImportProjectCommand, self).__init__(graph, parent) self.filepath = filepath self.importedNames = [] @@ -240,9 +242,8 @@ class ImportProjectCommand(GraphCommand): self.yOffset = yOffset def redoImpl(self): - status = self.graph.load(self.filepath, setupProjectFile=False, importProject=True) - importedNodes = self.graph.importedNodes - self.setText("Import Project ({} nodes)".format(importedNodes.count)) + importedNodes = self.graph.importGraphContentFromFile(self.filepath) + self.setText(f"Import Project ({len(importedNodes)} nodes)") lowestY = 0 for node in self.graph.nodes: diff --git a/tests/test_graphIO.py b/tests/test_graphIO.py new file mode 100644 index 00000000..01d3a89e --- /dev/null +++ b/tests/test_graphIO.py @@ -0,0 +1,152 @@ +from meshroom.core import desc +from meshroom.core.graph import Graph + +from .utils import registeredNodeTypes + + +class SimpleNode(desc.Node): + inputs = [ + desc.File(name="input", label="Input", description="", value=""), + ] + outputs = [ + desc.File(name="output", label="Output", description="", value=""), + ] + + +def compareGraphsContent(graphA: Graph, graphB: Graph) -> bool: + """Returns whether the content (node and deges) of two graphs are considered identical. + + Similar nodes: nodes with the same name, type and compatibility status. + Similar edges: edges with the same source and destination attribute names. + """ + + def _buildNodesSet(graph: Graph): + return set([(node.name, node.nodeType, node.isCompatibilityNode) for node in graph.nodes]) + + def _buildEdgesSet(graph: Graph): + return set([(edge.src.fullName, edge.dst.fullName) for edge in graph.edges]) + + return _buildNodesSet(graphA) == _buildNodesSet(graphB) and _buildEdgesSet(graphA) == _buildEdgesSet( + graphB + ) + + +class TestImportGraphContent: + def test_importEmptyGraph(self): + graph = Graph("") + + otherGraph = Graph("") + nodes = otherGraph.importGraphContent(graph) + + assert len(nodes) == 0 + assert len(graph.nodes) == 0 + + def test_importGraphWithSingleNode(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + graph.addNewNode(SimpleNode.__name__) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + + assert compareGraphsContent(graph, otherGraph) + + def test_importGraphWithSeveralNodes(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + graph.addNewNode(SimpleNode.__name__) + graph.addNewNode(SimpleNode.__name__) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + + assert compareGraphsContent(graph, otherGraph) + + def test_importingGraphWithNodesAndEdges(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + assert compareGraphsContent(graph, otherGraph) + + def test_edgeRemappingOnImportingGraphSeveralTimes(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + otherGraph.importGraphContent(graph) + + def test_importGraphWithUnknownNodeTypesCreatesCompatibilityNodes(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + graph.addNewNode(SimpleNode.__name__) + + otherGraph = Graph("") + importedNode = otherGraph.importGraphContent(graph) + + assert len(importedNode) == 1 + assert importedNode[0].isCompatibilityNode + + def test_importGraphContentInPlace(self): + graph = Graph("") + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + + graph.importGraphContent(graph) + + assert len(graph.nodes) == 4 + + def test_importGraphContentFromFile(self, graphSavedOnDisk): + graph: Graph = graphSavedOnDisk + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + graph.save() + + otherGraph = Graph("") + nodes = otherGraph.importGraphContentFromFile(graph.filepath) + + assert len(nodes) == 2 + + assert compareGraphsContent(graph, otherGraph) + + def test_importGraphContentFromFileWithCompatibilityNodes(self, graphSavedOnDisk): + graph: Graph = graphSavedOnDisk + + with registeredNodeTypes([SimpleNode]): + nodeA_1 = graph.addNewNode(SimpleNode.__name__) + nodeA_2 = graph.addNewNode(SimpleNode.__name__) + + graph.addEdge(nodeA_1.output, nodeA_2.input) + graph.save() + + otherGraph = Graph("") + nodes = otherGraph.importGraphContentFromFile(graph.filepath) + + assert len(nodes) == 2 + assert len(otherGraph.compatibilityNodes) == 2 + assert not compareGraphsContent(graph, otherGraph) + + diff --git a/tests/test_nodeAttributeChangedCallback.py b/tests/test_nodeAttributeChangedCallback.py index edd14bc8..faee0e00 100644 --- a/tests/test_nodeAttributeChangedCallback.py +++ b/tests/test_nodeAttributeChangedCallback.py @@ -431,3 +431,28 @@ class TestAttributeCallbackBehaviorWithUpstreamDynamicOutputs: assert nodeB.affectedInput.value == 0 +class TestAttributeCallbackBehaviorOnGraphImport: + @classmethod + def setup_class(cls): + registerNodeType(NodeWithAttributeChangedCallback) + + @classmethod + def teardown_class(cls): + unregisterNodeType(NodeWithAttributeChangedCallback) + + def test_importingGraphDoesNotTriggerAttributeChangedCallbacks(self): + graph = Graph("") + + nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__) + nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__) + + graph.addEdge(nodeA.affectedInput, nodeB.input) + + nodeA.input.value = 5 + nodeB.affectedInput.value = 2 + + otherGraph = Graph("") + otherGraph.importGraphContent(graph) + + assert otherGraph.node(nodeB.name).affectedInput.value == 2 +