[core] Graph: add importGraphContent API

Extract the logic of importing the content of a graph within a graph instance from
the graph loading logic.
Add `Graph.importGraphContent` and `Graph.importGraphContentFromFile`
methods.
Use the deserialization API to load the content in another temporary graph instance,
to handle the renaming of nodes using the Graph API, rather than manipulating
entries in a raw dictionnary.
This commit is contained in:
Yann Lanthony 2025-02-06 16:46:04 +01:00
parent 7eab289d30
commit 4aec741a89
4 changed files with 258 additions and 26 deletions

View file

@ -243,7 +243,6 @@ class Graph(BaseObject):
self._nodes = DictModel(keyAttrName='name', parent=self)
# Edges: use dst attribute as unique key since it can only have one input connection
self._edges = DictModel(keyAttrName='dst', parent=self)
self._importedNodes = DictModel(keyAttrName='name', parent=self)
self._compatibilityNodes = DictModel(keyAttrName='name', parent=self)
self.cacheDir = meshroom.core.defaultCacheFolder
self._filepath = ''
@ -251,15 +250,17 @@ class Graph(BaseObject):
self.header = {}
def clear(self):
self._clearGraphContent()
self.header.clear()
self._compatibilityNodes.clear()
self._unsetFilepath()
def _clearGraphContent(self):
self._edges.clear()
# Tell QML nodes are going to be deleted
for node in self._nodes:
node.alive = False
self._importedNodes.clear()
self._nodes.clear()
self._unsetFilepath()
self._compatibilityNodes.clear()
@property
def fileFeatures(self):
@ -330,7 +331,7 @@ class Graph(BaseObject):
for nodeName, nodeData in sorted(
graphContent.items(), key=lambda x: self.getNodeIndexFromName(x[0])
):
self._deserializeNode(nodeData, nodeName)
self._deserializeNode(nodeData, nodeName, self)
# Create graph edges by resolving attributes expressions
self._applyExpr()
@ -374,14 +375,14 @@ class Graph(BaseObject):
return graphContent
def _deserializeNode(self, nodeData: dict, nodeName: str):
def _deserializeNode(self, nodeData: dict, nodeName: str, fromGraph: "Graph"):
# Retrieve version from
# 1. nodeData: node saved from a CompatibilityNode
# 2. nodesVersion in file header: node saved from a Node
# 3. fallback behavior: default to "0.0"
if "version" not in nodeData:
nodeData["version"] = self._getNodeTypeVersionFromHeader(nodeData["nodeType"], "0.0")
inTemplate = self.header.get("template", False)
nodeData["version"] = fromGraph._getNodeTypeVersionFromHeader(nodeData["nodeType"], "0.0")
inTemplate = fromGraph.header.get("template", False)
node = nodeFactory(nodeData, nodeName, inTemplate=inTemplate)
self._addNode(node, nodeName)
return node
@ -555,6 +556,58 @@ class Graph(BaseObject):
return attributes
def importGraphContentFromFile(self, filepath: PathLike) -> list[Node]:
"""Import the content (nodes and edges) of another Graph file into this Graph instance.
Args:
filepath: The path to the Graph file to import.
Returns:
The list of newly created Nodes.
"""
graph = loadGraph(filepath)
return self.importGraphContent(graph)
@blockNodeCallbacks
def importGraphContent(self, graph: "Graph") -> list[Node]:
"""
Import the content (node and edges) of another `graph` into this Graph instance.
Nodes are imported with their original names if possible, otherwise a new unique name is generated
from their node type.
Args:
graph: The graph to import.
Returns:
The list of newly created Nodes.
"""
def _renameClashingNodes():
if not self.nodes:
return
unavailableNames = set(self.nodes.keys())
for node in graph.nodes:
if node._name in unavailableNames:
node._name = self._createUniqueNodeName(node.nodeType, unavailableNames)
unavailableNames.add(node._name)
def _importNodeAndEdges() -> list[Node]:
importedNodes = []
# If we import the content of the graph within itself,
# iterate over a copy of the nodes as the graph is modified during the iteration.
nodes = graph.nodes if graph is not self else list(graph.nodes)
with GraphModification(self):
for srcNode in nodes:
node = self._deserializeNode(srcNode.toDict(), srcNode.name, graph)
importedNodes.append(node)
self._applyExpr()
return importedNodes
_renameClashingNodes()
importedNodes = _importNodeAndEdges()
return importedNodes
@property
def updateEnabled(self):
return self._updateEnabled
@ -766,8 +819,6 @@ class Graph(BaseObject):
node.alive = False
self._nodes.remove(node)
if node in self._importedNodes:
self._importedNodes.remove(node)
self.update()
return inEdges, outEdges, outListAttributes
@ -792,13 +843,21 @@ class Graph(BaseObject):
n.updateInternals()
return n
def _createUniqueNodeName(self, inputName):
i = 1
while i:
newName = "{name}_{index}".format(name=inputName, index=i)
if newName not in self._nodes.objects:
def _createUniqueNodeName(self, inputName: str, existingNames: Optional[set[str]] = None):
"""Create a unique node name based on the input name.
Args:
inputName: The desired node name.
existingNames: (optional) If specified, consider this set for uniqueness check, instead of the list of nodes.
"""
existingNodeNames = existingNames or set(self._nodes.objects.keys())
idx = 1
while idx:
newName = f"{inputName}_{idx}"
if newName not in existingNodeNames:
return newName
i += 1
idx += 1
def node(self, nodeName):
return self._nodes.get(nodeName)
@ -1635,11 +1694,6 @@ class Graph(BaseObject):
def edges(self):
return self._edges
@property
def importedNodes(self):
"""" Return the list of nodes that were added to the graph with the latest 'Import Project' action. """
return self._importedNodes
@property
def cacheDir(self):
return self._cacheDir

View file

@ -6,9 +6,10 @@ from PySide6.QtGui import QUndoCommand, QUndoStack
from PySide6.QtCore import Property, Signal
from meshroom.core.attribute import ListAttribute, Attribute
from meshroom.core.graph import GraphModification
from meshroom.core.graph import Graph, GraphModification
from meshroom.core.node import Position
from meshroom.core.nodeFactory import nodeFactory
from meshroom.core.typing import PathLike
class UndoCommand(QUndoCommand):
@ -232,7 +233,8 @@ class ImportProjectCommand(GraphCommand):
"""
Handle the import of a project into a Graph.
"""
def __init__(self, graph, filepath=None, position=None, yOffset=0, parent=None):
def __init__(self, graph: Graph, filepath: PathLike, position=None, yOffset=0, parent=None):
super(ImportProjectCommand, self).__init__(graph, parent)
self.filepath = filepath
self.importedNames = []
@ -240,9 +242,8 @@ class ImportProjectCommand(GraphCommand):
self.yOffset = yOffset
def redoImpl(self):
status = self.graph.load(self.filepath, setupProjectFile=False, importProject=True)
importedNodes = self.graph.importedNodes
self.setText("Import Project ({} nodes)".format(importedNodes.count))
importedNodes = self.graph.importGraphContentFromFile(self.filepath)
self.setText(f"Import Project ({len(importedNodes)} nodes)")
lowestY = 0
for node in self.graph.nodes:

152
tests/test_graphIO.py Normal file
View file

@ -0,0 +1,152 @@
from meshroom.core import desc
from meshroom.core.graph import Graph
from .utils import registeredNodeTypes
class SimpleNode(desc.Node):
inputs = [
desc.File(name="input", label="Input", description="", value=""),
]
outputs = [
desc.File(name="output", label="Output", description="", value=""),
]
def compareGraphsContent(graphA: Graph, graphB: Graph) -> bool:
"""Returns whether the content (node and deges) of two graphs are considered identical.
Similar nodes: nodes with the same name, type and compatibility status.
Similar edges: edges with the same source and destination attribute names.
"""
def _buildNodesSet(graph: Graph):
return set([(node.name, node.nodeType, node.isCompatibilityNode) for node in graph.nodes])
def _buildEdgesSet(graph: Graph):
return set([(edge.src.fullName, edge.dst.fullName) for edge in graph.edges])
return _buildNodesSet(graphA) == _buildNodesSet(graphB) and _buildEdgesSet(graphA) == _buildEdgesSet(
graphB
)
class TestImportGraphContent:
def test_importEmptyGraph(self):
graph = Graph("")
otherGraph = Graph("")
nodes = otherGraph.importGraphContent(graph)
assert len(nodes) == 0
assert len(graph.nodes) == 0
def test_importGraphWithSingleNode(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
graph.addNewNode(SimpleNode.__name__)
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
assert compareGraphsContent(graph, otherGraph)
def test_importGraphWithSeveralNodes(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
graph.addNewNode(SimpleNode.__name__)
graph.addNewNode(SimpleNode.__name__)
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
assert compareGraphsContent(graph, otherGraph)
def test_importingGraphWithNodesAndEdges(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
assert compareGraphsContent(graph, otherGraph)
def test_edgeRemappingOnImportingGraphSeveralTimes(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
otherGraph.importGraphContent(graph)
def test_importGraphWithUnknownNodeTypesCreatesCompatibilityNodes(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
graph.addNewNode(SimpleNode.__name__)
otherGraph = Graph("")
importedNode = otherGraph.importGraphContent(graph)
assert len(importedNode) == 1
assert importedNode[0].isCompatibilityNode
def test_importGraphContentInPlace(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
graph.importGraphContent(graph)
assert len(graph.nodes) == 4
def test_importGraphContentFromFile(self, graphSavedOnDisk):
graph: Graph = graphSavedOnDisk
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
graph.save()
otherGraph = Graph("")
nodes = otherGraph.importGraphContentFromFile(graph.filepath)
assert len(nodes) == 2
assert compareGraphsContent(graph, otherGraph)
def test_importGraphContentFromFileWithCompatibilityNodes(self, graphSavedOnDisk):
graph: Graph = graphSavedOnDisk
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
graph.save()
otherGraph = Graph("")
nodes = otherGraph.importGraphContentFromFile(graph.filepath)
assert len(nodes) == 2
assert len(otherGraph.compatibilityNodes) == 2
assert not compareGraphsContent(graph, otherGraph)

View file

@ -431,3 +431,28 @@ class TestAttributeCallbackBehaviorWithUpstreamDynamicOutputs:
assert nodeB.affectedInput.value == 0
class TestAttributeCallbackBehaviorOnGraphImport:
@classmethod
def setup_class(cls):
registerNodeType(NodeWithAttributeChangedCallback)
@classmethod
def teardown_class(cls):
unregisterNodeType(NodeWithAttributeChangedCallback)
def test_importingGraphDoesNotTriggerAttributeChangedCallbacks(self):
graph = Graph("")
nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
graph.addEdge(nodeA.affectedInput, nodeB.input)
nodeA.input.value = 5
nodeB.affectedInput.value = 2
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
assert otherGraph.node(nodeB.name).affectedInput.value == 2