mirror of
https://github.com/alicevision/Meshroom.git
synced 2025-06-06 12:51:57 +02:00
[core] Graph: add replaceNode
method
Factorize the logic of replacing a node with another one and re-creating output edges into `Graph.replaceNode` and `Graph._restoreOutEdges`.
This commit is contained in:
parent
1cf0fc95ba
commit
45ef4b592d
3 changed files with 50 additions and 57 deletions
|
@ -649,7 +649,7 @@ class Graph(BaseObject):
|
|||
def node(self, nodeName):
|
||||
return self._nodes.get(nodeName)
|
||||
|
||||
def upgradeNode(self, nodeName):
|
||||
def upgradeNode(self, nodeName) -> Node:
|
||||
"""
|
||||
Upgrade the CompatibilityNode identified as 'nodeName'
|
||||
Args:
|
||||
|
@ -669,25 +669,49 @@ class Graph(BaseObject):
|
|||
if not isinstance(node, CompatibilityNode):
|
||||
raise ValueError("Upgrade is only available on CompatibilityNode instances.")
|
||||
upgradedNode = node.upgrade()
|
||||
with GraphModification(self):
|
||||
inEdges, outEdges, outListAttributes = self.removeNode(nodeName)
|
||||
self.addNode(upgradedNode, nodeName)
|
||||
for dst, src in outEdges.items():
|
||||
# Re-create the entries in ListAttributes that were completely removed during the call to "removeNode"
|
||||
# If they are not re-created first, adding their edges will lead to errors
|
||||
# 0 = attribute name, 1 = attribute index, 2 = attribute value
|
||||
if dst in outListAttributes.keys():
|
||||
listAttr = self.attribute(outListAttributes[dst][0])
|
||||
if isinstance(outListAttributes[dst][2], list):
|
||||
listAttr[outListAttributes[dst][1]:outListAttributes[dst][1]] = outListAttributes[dst][2]
|
||||
else:
|
||||
listAttr.insert(outListAttributes[dst][1], outListAttributes[dst][2])
|
||||
try:
|
||||
self.addEdge(self.attribute(src), self.attribute(dst))
|
||||
except (KeyError, ValueError) as e:
|
||||
logging.warning("Failed to restore edge {} -> {}: {}".format(src, dst, str(e)))
|
||||
self.replaceNode(nodeName, upgradedNode)
|
||||
return upgradedNode
|
||||
|
||||
return upgradedNode, inEdges, outEdges, outListAttributes
|
||||
@changeTopology
|
||||
def replaceNode(self, nodeName: str, newNode: BaseNode):
|
||||
"""Replace the node idenfitied by `nodeName` with `newNode`, while restoring compatible edges.
|
||||
|
||||
Args:
|
||||
nodeName: The name of the Node to replace.
|
||||
newNode: The Node instance to replace it with.
|
||||
"""
|
||||
with GraphModification(self):
|
||||
_, outEdges, outListAttributes = self.removeNode(nodeName)
|
||||
self.addNode(newNode, nodeName)
|
||||
self._restoreOutEdges(outEdges, outListAttributes)
|
||||
|
||||
def _restoreOutEdges(self, outEdges: dict[str, str], outListAttributes):
|
||||
"""Restore output edges that were removed during a call to "removeNode".
|
||||
|
||||
Args:
|
||||
outEdges: a dictionary containing the outgoing edges removed by a call to "removeNode".
|
||||
{dstAttr.getFullNameToNode(), srcAttr.getFullNameToNode()}
|
||||
outListAttributes: a dictionary containing the values, indices and keys of attributes that were connected
|
||||
to a ListAttribute prior to the removal of all edges.
|
||||
{dstAttr.getFullNameToNode(), (dstAttr.root.getFullNameToNode(), dstAttr.index, dstAttr.value)}
|
||||
"""
|
||||
def _recreateTargetListAttributeChildren(listAttrName: str, index: int, value: Any):
|
||||
listAttr = self.attribute(listAttrName)
|
||||
if not isinstance(listAttr, ListAttribute):
|
||||
return
|
||||
if isinstance(value, list):
|
||||
listAttr[index:index] = value
|
||||
else:
|
||||
listAttr.insert(index, value)
|
||||
|
||||
for dstName, srcName in outEdges.items():
|
||||
# Re-create the entries in ListAttributes that were completely removed during the call to "removeNode"
|
||||
if dstName in outListAttributes:
|
||||
_recreateTargetListAttributeChildren(*outListAttributes[dstName])
|
||||
try:
|
||||
self.addEdge(self.attribute(srcName), self.attribute(dstName))
|
||||
except (KeyError, ValueError) as e:
|
||||
logging.warning(f"Failed to restore edge {srcName} -> {dstName}: {str(e)}")
|
||||
|
||||
def upgradeAllNodes(self):
|
||||
""" Upgrade all upgradable CompatibilityNode instances in the graph. """
|
||||
|
|
|
@ -170,19 +170,7 @@ class RemoveNodeCommand(GraphCommand):
|
|||
node = nodeFactory(self.nodeDict, self.nodeName)
|
||||
self.graph.addNode(node, self.nodeName)
|
||||
assert (node.getName() == self.nodeName)
|
||||
# recreate out edges deleted on node removal
|
||||
for dstAttr, srcAttr in self.outEdges.items():
|
||||
# if edges were connected to ListAttributes, recreate their corresponding entry in said ListAttribute
|
||||
# 0 = attribute name, 1 = attribute index, 2 = attribute value
|
||||
if dstAttr in self.outListAttributes.keys():
|
||||
listAttr = self.graph.attribute(self.outListAttributes[dstAttr][0])
|
||||
if isinstance(self.outListAttributes[dstAttr][2], list):
|
||||
listAttr[self.outListAttributes[dstAttr][1]:self.outListAttributes[dstAttr][1]] = self.outListAttributes[dstAttr][2]
|
||||
else:
|
||||
listAttr.insert(self.outListAttributes[dstAttr][1], self.outListAttributes[dstAttr][2])
|
||||
|
||||
self.graph.addEdge(self.graph.attribute(srcAttr),
|
||||
self.graph.attribute(dstAttr))
|
||||
self.graph._restoreOutEdges(self.outEdges, self.outListAttributes)
|
||||
|
||||
|
||||
class DuplicateNodesCommand(GraphCommand):
|
||||
|
@ -451,38 +439,19 @@ class UpgradeNodeCommand(GraphCommand):
|
|||
super(UpgradeNodeCommand, self).__init__(graph, parent)
|
||||
self.nodeDict = node.toDict()
|
||||
self.nodeName = node.getName()
|
||||
self.outEdges = {}
|
||||
self.outListAttributes = {}
|
||||
self.setText("Upgrade Node {}".format(self.nodeName))
|
||||
|
||||
def redoImpl(self):
|
||||
if not self.graph.node(self.nodeName).canUpgrade:
|
||||
return False
|
||||
upgradedNode, _, self.outEdges, self.outListAttributes = self.graph.upgradeNode(self.nodeName)
|
||||
return upgradedNode
|
||||
return self.graph.upgradeNode(self.nodeName)
|
||||
|
||||
def undoImpl(self):
|
||||
# delete upgraded node
|
||||
expectedUid = self.graph.node(self.nodeName)._uid
|
||||
self.graph.removeNode(self.nodeName)
|
||||
# recreate compatibility node
|
||||
with GraphModification(self.graph):
|
||||
# We come back from an upgrade, so we enforce uidConflict=True as there was a uid conflict before
|
||||
node = nodeFactory(self.nodeDict, name=self.nodeName, expectedUid=expectedUid)
|
||||
self.graph.addNode(node, self.nodeName)
|
||||
# recreate out edges
|
||||
for dstAttr, srcAttr in self.outEdges.items():
|
||||
# if edges were connected to ListAttributes, recreate their corresponding entry in said ListAttribute
|
||||
# 0 = attribute name, 1 = attribute index, 2 = attribute value
|
||||
if dstAttr in self.outListAttributes.keys():
|
||||
listAttr = self.graph.attribute(self.outListAttributes[dstAttr][0])
|
||||
if isinstance(self.outListAttributes[dstAttr][2], list):
|
||||
listAttr[self.outListAttributes[dstAttr][1]:self.outListAttributes[dstAttr][1]] = self.outListAttributes[dstAttr][2]
|
||||
else:
|
||||
listAttr.insert(self.outListAttributes[dstAttr][1], self.outListAttributes[dstAttr][2])
|
||||
|
||||
self.graph.addEdge(self.graph.attribute(srcAttr),
|
||||
self.graph.attribute(dstAttr))
|
||||
self.graph.replaceNode(self.nodeName, node)
|
||||
|
||||
|
||||
class EnableGraphUpdateCommand(GraphCommand):
|
||||
|
|
|
@ -255,7 +255,7 @@ def test_description_conflict():
|
|||
assert not hasattr(compatNode, "in")
|
||||
|
||||
# perform upgrade
|
||||
upgradedNode = g.upgradeNode(nodeName)[0]
|
||||
upgradedNode = g.upgradeNode(nodeName)
|
||||
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV2)
|
||||
|
||||
assert list(upgradedNode.attributes.keys()) == ["in", "paramA", "output"]
|
||||
|
@ -270,7 +270,7 @@ def test_description_conflict():
|
|||
assert hasattr(compatNode, "paramA")
|
||||
|
||||
# perform upgrade
|
||||
upgradedNode = g.upgradeNode(nodeName)[0]
|
||||
upgradedNode = g.upgradeNode(nodeName)
|
||||
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV3)
|
||||
|
||||
assert not hasattr(upgradedNode, "paramA")
|
||||
|
@ -283,7 +283,7 @@ def test_description_conflict():
|
|||
assert not hasattr(compatNode, "paramA")
|
||||
|
||||
# perform upgrade
|
||||
upgradedNode = g.upgradeNode(nodeName)[0]
|
||||
upgradedNode = g.upgradeNode(nodeName)
|
||||
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV4)
|
||||
|
||||
assert hasattr(upgradedNode, "paramA")
|
||||
|
@ -303,7 +303,7 @@ def test_description_conflict():
|
|||
assert isinstance(elt, next(a for a in SampleGroupV1 if a.name == elt.name).__class__)
|
||||
|
||||
# perform upgrade
|
||||
upgradedNode = g.upgradeNode(nodeName)[0]
|
||||
upgradedNode = g.upgradeNode(nodeName)
|
||||
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV5)
|
||||
|
||||
assert hasattr(upgradedNode, "paramA")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue