mirror of
https://github.com/alicevision/Meshroom.git
synced 2025-04-28 09:47:20 +02:00
[core] Graph: add importGraphContent API
Extract the logic of importing the content of a graph within a graph instance from the graph loading logic. Add `Graph.importGraphContent` and `Graph.importGraphContentFromFile` methods. Use the deserialization API to load the content in another temporary graph instance, to handle the renaming of nodes using the Graph API, rather than manipulating entries in a raw dictionnary.
This commit is contained in:
parent
7eab289d30
commit
4aec741a89
4 changed files with 258 additions and 26 deletions
|
@ -243,7 +243,6 @@ class Graph(BaseObject):
|
|||
self._nodes = DictModel(keyAttrName='name', parent=self)
|
||||
# Edges: use dst attribute as unique key since it can only have one input connection
|
||||
self._edges = DictModel(keyAttrName='dst', parent=self)
|
||||
self._importedNodes = DictModel(keyAttrName='name', parent=self)
|
||||
self._compatibilityNodes = DictModel(keyAttrName='name', parent=self)
|
||||
self.cacheDir = meshroom.core.defaultCacheFolder
|
||||
self._filepath = ''
|
||||
|
@ -251,15 +250,17 @@ class Graph(BaseObject):
|
|||
self.header = {}
|
||||
|
||||
def clear(self):
|
||||
self._clearGraphContent()
|
||||
self.header.clear()
|
||||
self._compatibilityNodes.clear()
|
||||
self._unsetFilepath()
|
||||
|
||||
def _clearGraphContent(self):
|
||||
self._edges.clear()
|
||||
# Tell QML nodes are going to be deleted
|
||||
for node in self._nodes:
|
||||
node.alive = False
|
||||
self._importedNodes.clear()
|
||||
self._nodes.clear()
|
||||
self._unsetFilepath()
|
||||
self._compatibilityNodes.clear()
|
||||
|
||||
@property
|
||||
def fileFeatures(self):
|
||||
|
@ -330,7 +331,7 @@ class Graph(BaseObject):
|
|||
for nodeName, nodeData in sorted(
|
||||
graphContent.items(), key=lambda x: self.getNodeIndexFromName(x[0])
|
||||
):
|
||||
self._deserializeNode(nodeData, nodeName)
|
||||
self._deserializeNode(nodeData, nodeName, self)
|
||||
|
||||
# Create graph edges by resolving attributes expressions
|
||||
self._applyExpr()
|
||||
|
@ -374,14 +375,14 @@ class Graph(BaseObject):
|
|||
|
||||
return graphContent
|
||||
|
||||
def _deserializeNode(self, nodeData: dict, nodeName: str):
|
||||
def _deserializeNode(self, nodeData: dict, nodeName: str, fromGraph: "Graph"):
|
||||
# Retrieve version from
|
||||
# 1. nodeData: node saved from a CompatibilityNode
|
||||
# 2. nodesVersion in file header: node saved from a Node
|
||||
# 3. fallback behavior: default to "0.0"
|
||||
if "version" not in nodeData:
|
||||
nodeData["version"] = self._getNodeTypeVersionFromHeader(nodeData["nodeType"], "0.0")
|
||||
inTemplate = self.header.get("template", False)
|
||||
nodeData["version"] = fromGraph._getNodeTypeVersionFromHeader(nodeData["nodeType"], "0.0")
|
||||
inTemplate = fromGraph.header.get("template", False)
|
||||
node = nodeFactory(nodeData, nodeName, inTemplate=inTemplate)
|
||||
self._addNode(node, nodeName)
|
||||
return node
|
||||
|
@ -555,6 +556,58 @@ class Graph(BaseObject):
|
|||
|
||||
return attributes
|
||||
|
||||
def importGraphContentFromFile(self, filepath: PathLike) -> list[Node]:
|
||||
"""Import the content (nodes and edges) of another Graph file into this Graph instance.
|
||||
|
||||
Args:
|
||||
filepath: The path to the Graph file to import.
|
||||
|
||||
Returns:
|
||||
The list of newly created Nodes.
|
||||
"""
|
||||
graph = loadGraph(filepath)
|
||||
return self.importGraphContent(graph)
|
||||
|
||||
@blockNodeCallbacks
|
||||
def importGraphContent(self, graph: "Graph") -> list[Node]:
|
||||
"""
|
||||
Import the content (node and edges) of another `graph` into this Graph instance.
|
||||
|
||||
Nodes are imported with their original names if possible, otherwise a new unique name is generated
|
||||
from their node type.
|
||||
|
||||
Args:
|
||||
graph: The graph to import.
|
||||
|
||||
Returns:
|
||||
The list of newly created Nodes.
|
||||
"""
|
||||
|
||||
def _renameClashingNodes():
|
||||
if not self.nodes:
|
||||
return
|
||||
unavailableNames = set(self.nodes.keys())
|
||||
for node in graph.nodes:
|
||||
if node._name in unavailableNames:
|
||||
node._name = self._createUniqueNodeName(node.nodeType, unavailableNames)
|
||||
unavailableNames.add(node._name)
|
||||
|
||||
def _importNodeAndEdges() -> list[Node]:
|
||||
importedNodes = []
|
||||
# If we import the content of the graph within itself,
|
||||
# iterate over a copy of the nodes as the graph is modified during the iteration.
|
||||
nodes = graph.nodes if graph is not self else list(graph.nodes)
|
||||
with GraphModification(self):
|
||||
for srcNode in nodes:
|
||||
node = self._deserializeNode(srcNode.toDict(), srcNode.name, graph)
|
||||
importedNodes.append(node)
|
||||
self._applyExpr()
|
||||
return importedNodes
|
||||
|
||||
_renameClashingNodes()
|
||||
importedNodes = _importNodeAndEdges()
|
||||
return importedNodes
|
||||
|
||||
@property
|
||||
def updateEnabled(self):
|
||||
return self._updateEnabled
|
||||
|
@ -766,8 +819,6 @@ class Graph(BaseObject):
|
|||
|
||||
node.alive = False
|
||||
self._nodes.remove(node)
|
||||
if node in self._importedNodes:
|
||||
self._importedNodes.remove(node)
|
||||
self.update()
|
||||
|
||||
return inEdges, outEdges, outListAttributes
|
||||
|
@ -792,13 +843,21 @@ class Graph(BaseObject):
|
|||
n.updateInternals()
|
||||
return n
|
||||
|
||||
def _createUniqueNodeName(self, inputName):
|
||||
i = 1
|
||||
while i:
|
||||
newName = "{name}_{index}".format(name=inputName, index=i)
|
||||
if newName not in self._nodes.objects:
|
||||
def _createUniqueNodeName(self, inputName: str, existingNames: Optional[set[str]] = None):
|
||||
"""Create a unique node name based on the input name.
|
||||
|
||||
Args:
|
||||
inputName: The desired node name.
|
||||
existingNames: (optional) If specified, consider this set for uniqueness check, instead of the list of nodes.
|
||||
"""
|
||||
existingNodeNames = existingNames or set(self._nodes.objects.keys())
|
||||
|
||||
idx = 1
|
||||
while idx:
|
||||
newName = f"{inputName}_{idx}"
|
||||
if newName not in existingNodeNames:
|
||||
return newName
|
||||
i += 1
|
||||
idx += 1
|
||||
|
||||
def node(self, nodeName):
|
||||
return self._nodes.get(nodeName)
|
||||
|
@ -1635,11 +1694,6 @@ class Graph(BaseObject):
|
|||
def edges(self):
|
||||
return self._edges
|
||||
|
||||
@property
|
||||
def importedNodes(self):
|
||||
"""" Return the list of nodes that were added to the graph with the latest 'Import Project' action. """
|
||||
return self._importedNodes
|
||||
|
||||
@property
|
||||
def cacheDir(self):
|
||||
return self._cacheDir
|
||||
|
|
|
@ -6,9 +6,10 @@ from PySide6.QtGui import QUndoCommand, QUndoStack
|
|||
from PySide6.QtCore import Property, Signal
|
||||
|
||||
from meshroom.core.attribute import ListAttribute, Attribute
|
||||
from meshroom.core.graph import GraphModification
|
||||
from meshroom.core.graph import Graph, GraphModification
|
||||
from meshroom.core.node import Position
|
||||
from meshroom.core.nodeFactory import nodeFactory
|
||||
from meshroom.core.typing import PathLike
|
||||
|
||||
|
||||
class UndoCommand(QUndoCommand):
|
||||
|
@ -232,7 +233,8 @@ class ImportProjectCommand(GraphCommand):
|
|||
"""
|
||||
Handle the import of a project into a Graph.
|
||||
"""
|
||||
def __init__(self, graph, filepath=None, position=None, yOffset=0, parent=None):
|
||||
|
||||
def __init__(self, graph: Graph, filepath: PathLike, position=None, yOffset=0, parent=None):
|
||||
super(ImportProjectCommand, self).__init__(graph, parent)
|
||||
self.filepath = filepath
|
||||
self.importedNames = []
|
||||
|
@ -240,9 +242,8 @@ class ImportProjectCommand(GraphCommand):
|
|||
self.yOffset = yOffset
|
||||
|
||||
def redoImpl(self):
|
||||
status = self.graph.load(self.filepath, setupProjectFile=False, importProject=True)
|
||||
importedNodes = self.graph.importedNodes
|
||||
self.setText("Import Project ({} nodes)".format(importedNodes.count))
|
||||
importedNodes = self.graph.importGraphContentFromFile(self.filepath)
|
||||
self.setText(f"Import Project ({len(importedNodes)} nodes)")
|
||||
|
||||
lowestY = 0
|
||||
for node in self.graph.nodes:
|
||||
|
|
152
tests/test_graphIO.py
Normal file
152
tests/test_graphIO.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
from meshroom.core import desc
|
||||
from meshroom.core.graph import Graph
|
||||
|
||||
from .utils import registeredNodeTypes
|
||||
|
||||
|
||||
class SimpleNode(desc.Node):
|
||||
inputs = [
|
||||
desc.File(name="input", label="Input", description="", value=""),
|
||||
]
|
||||
outputs = [
|
||||
desc.File(name="output", label="Output", description="", value=""),
|
||||
]
|
||||
|
||||
|
||||
def compareGraphsContent(graphA: Graph, graphB: Graph) -> bool:
|
||||
"""Returns whether the content (node and deges) of two graphs are considered identical.
|
||||
|
||||
Similar nodes: nodes with the same name, type and compatibility status.
|
||||
Similar edges: edges with the same source and destination attribute names.
|
||||
"""
|
||||
|
||||
def _buildNodesSet(graph: Graph):
|
||||
return set([(node.name, node.nodeType, node.isCompatibilityNode) for node in graph.nodes])
|
||||
|
||||
def _buildEdgesSet(graph: Graph):
|
||||
return set([(edge.src.fullName, edge.dst.fullName) for edge in graph.edges])
|
||||
|
||||
return _buildNodesSet(graphA) == _buildNodesSet(graphB) and _buildEdgesSet(graphA) == _buildEdgesSet(
|
||||
graphB
|
||||
)
|
||||
|
||||
|
||||
class TestImportGraphContent:
|
||||
def test_importEmptyGraph(self):
|
||||
graph = Graph("")
|
||||
|
||||
otherGraph = Graph("")
|
||||
nodes = otherGraph.importGraphContent(graph)
|
||||
|
||||
assert len(nodes) == 0
|
||||
assert len(graph.nodes) == 0
|
||||
|
||||
def test_importGraphWithSingleNode(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph.importGraphContent(graph)
|
||||
|
||||
assert compareGraphsContent(graph, otherGraph)
|
||||
|
||||
def test_importGraphWithSeveralNodes(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
graph.addNewNode(SimpleNode.__name__)
|
||||
graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph.importGraphContent(graph)
|
||||
|
||||
assert compareGraphsContent(graph, otherGraph)
|
||||
|
||||
def test_importingGraphWithNodesAndEdges(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA_1.output, nodeA_2.input)
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph.importGraphContent(graph)
|
||||
assert compareGraphsContent(graph, otherGraph)
|
||||
|
||||
def test_edgeRemappingOnImportingGraphSeveralTimes(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA_1.output, nodeA_2.input)
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph.importGraphContent(graph)
|
||||
otherGraph.importGraphContent(graph)
|
||||
|
||||
def test_importGraphWithUnknownNodeTypesCreatesCompatibilityNodes(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
otherGraph = Graph("")
|
||||
importedNode = otherGraph.importGraphContent(graph)
|
||||
|
||||
assert len(importedNode) == 1
|
||||
assert importedNode[0].isCompatibilityNode
|
||||
|
||||
def test_importGraphContentInPlace(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA_1.output, nodeA_2.input)
|
||||
|
||||
graph.importGraphContent(graph)
|
||||
|
||||
assert len(graph.nodes) == 4
|
||||
|
||||
def test_importGraphContentFromFile(self, graphSavedOnDisk):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA_1.output, nodeA_2.input)
|
||||
graph.save()
|
||||
|
||||
otherGraph = Graph("")
|
||||
nodes = otherGraph.importGraphContentFromFile(graph.filepath)
|
||||
|
||||
assert len(nodes) == 2
|
||||
|
||||
assert compareGraphsContent(graph, otherGraph)
|
||||
|
||||
def test_importGraphContentFromFileWithCompatibilityNodes(self, graphSavedOnDisk):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA_1.output, nodeA_2.input)
|
||||
graph.save()
|
||||
|
||||
otherGraph = Graph("")
|
||||
nodes = otherGraph.importGraphContentFromFile(graph.filepath)
|
||||
|
||||
assert len(nodes) == 2
|
||||
assert len(otherGraph.compatibilityNodes) == 2
|
||||
assert not compareGraphsContent(graph, otherGraph)
|
||||
|
||||
|
|
@ -431,3 +431,28 @@ class TestAttributeCallbackBehaviorWithUpstreamDynamicOutputs:
|
|||
assert nodeB.affectedInput.value == 0
|
||||
|
||||
|
||||
class TestAttributeCallbackBehaviorOnGraphImport:
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
registerNodeType(NodeWithAttributeChangedCallback)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
unregisterNodeType(NodeWithAttributeChangedCallback)
|
||||
|
||||
def test_importingGraphDoesNotTriggerAttributeChangedCallbacks(self):
|
||||
graph = Graph("")
|
||||
|
||||
nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
|
||||
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
|
||||
|
||||
graph.addEdge(nodeA.affectedInput, nodeB.input)
|
||||
|
||||
nodeA.input.value = 5
|
||||
nodeB.affectedInput.value = 2
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph.importGraphContent(graph)
|
||||
|
||||
assert otherGraph.node(nodeB.name).affectedInput.value == 2
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue