[core] Graph: add replaceNode method

Factorize the logic of replacing a node with another one and re-creating
output edges into `Graph.replaceNode`  and `Graph._restoreOutEdges`.
This commit is contained in:
Yann Lanthony 2025-02-06 16:46:04 +01:00
parent 1cf0fc95ba
commit 45ef4b592d
3 changed files with 50 additions and 57 deletions

View file

@ -649,7 +649,7 @@ class Graph(BaseObject):
def node(self, nodeName):
return self._nodes.get(nodeName)
def upgradeNode(self, nodeName):
def upgradeNode(self, nodeName) -> Node:
"""
Upgrade the CompatibilityNode identified as 'nodeName'
Args:
@ -669,25 +669,49 @@ class Graph(BaseObject):
if not isinstance(node, CompatibilityNode):
raise ValueError("Upgrade is only available on CompatibilityNode instances.")
upgradedNode = node.upgrade()
with GraphModification(self):
inEdges, outEdges, outListAttributes = self.removeNode(nodeName)
self.addNode(upgradedNode, nodeName)
for dst, src in outEdges.items():
# Re-create the entries in ListAttributes that were completely removed during the call to "removeNode"
# If they are not re-created first, adding their edges will lead to errors
# 0 = attribute name, 1 = attribute index, 2 = attribute value
if dst in outListAttributes.keys():
listAttr = self.attribute(outListAttributes[dst][0])
if isinstance(outListAttributes[dst][2], list):
listAttr[outListAttributes[dst][1]:outListAttributes[dst][1]] = outListAttributes[dst][2]
else:
listAttr.insert(outListAttributes[dst][1], outListAttributes[dst][2])
try:
self.addEdge(self.attribute(src), self.attribute(dst))
except (KeyError, ValueError) as e:
logging.warning("Failed to restore edge {} -> {}: {}".format(src, dst, str(e)))
self.replaceNode(nodeName, upgradedNode)
return upgradedNode
return upgradedNode, inEdges, outEdges, outListAttributes
@changeTopology
def replaceNode(self, nodeName: str, newNode: BaseNode):
"""Replace the node idenfitied by `nodeName` with `newNode`, while restoring compatible edges.
Args:
nodeName: The name of the Node to replace.
newNode: The Node instance to replace it with.
"""
with GraphModification(self):
_, outEdges, outListAttributes = self.removeNode(nodeName)
self.addNode(newNode, nodeName)
self._restoreOutEdges(outEdges, outListAttributes)
def _restoreOutEdges(self, outEdges: dict[str, str], outListAttributes):
"""Restore output edges that were removed during a call to "removeNode".
Args:
outEdges: a dictionary containing the outgoing edges removed by a call to "removeNode".
{dstAttr.getFullNameToNode(), srcAttr.getFullNameToNode()}
outListAttributes: a dictionary containing the values, indices and keys of attributes that were connected
to a ListAttribute prior to the removal of all edges.
{dstAttr.getFullNameToNode(), (dstAttr.root.getFullNameToNode(), dstAttr.index, dstAttr.value)}
"""
def _recreateTargetListAttributeChildren(listAttrName: str, index: int, value: Any):
listAttr = self.attribute(listAttrName)
if not isinstance(listAttr, ListAttribute):
return
if isinstance(value, list):
listAttr[index:index] = value
else:
listAttr.insert(index, value)
for dstName, srcName in outEdges.items():
# Re-create the entries in ListAttributes that were completely removed during the call to "removeNode"
if dstName in outListAttributes:
_recreateTargetListAttributeChildren(*outListAttributes[dstName])
try:
self.addEdge(self.attribute(srcName), self.attribute(dstName))
except (KeyError, ValueError) as e:
logging.warning(f"Failed to restore edge {srcName} -> {dstName}: {str(e)}")
def upgradeAllNodes(self):
""" Upgrade all upgradable CompatibilityNode instances in the graph. """

View file

@ -170,19 +170,7 @@ class RemoveNodeCommand(GraphCommand):
node = nodeFactory(self.nodeDict, self.nodeName)
self.graph.addNode(node, self.nodeName)
assert (node.getName() == self.nodeName)
# recreate out edges deleted on node removal
for dstAttr, srcAttr in self.outEdges.items():
# if edges were connected to ListAttributes, recreate their corresponding entry in said ListAttribute
# 0 = attribute name, 1 = attribute index, 2 = attribute value
if dstAttr in self.outListAttributes.keys():
listAttr = self.graph.attribute(self.outListAttributes[dstAttr][0])
if isinstance(self.outListAttributes[dstAttr][2], list):
listAttr[self.outListAttributes[dstAttr][1]:self.outListAttributes[dstAttr][1]] = self.outListAttributes[dstAttr][2]
else:
listAttr.insert(self.outListAttributes[dstAttr][1], self.outListAttributes[dstAttr][2])
self.graph.addEdge(self.graph.attribute(srcAttr),
self.graph.attribute(dstAttr))
self.graph._restoreOutEdges(self.outEdges, self.outListAttributes)
class DuplicateNodesCommand(GraphCommand):
@ -451,38 +439,19 @@ class UpgradeNodeCommand(GraphCommand):
super(UpgradeNodeCommand, self).__init__(graph, parent)
self.nodeDict = node.toDict()
self.nodeName = node.getName()
self.outEdges = {}
self.outListAttributes = {}
self.setText("Upgrade Node {}".format(self.nodeName))
def redoImpl(self):
if not self.graph.node(self.nodeName).canUpgrade:
return False
upgradedNode, _, self.outEdges, self.outListAttributes = self.graph.upgradeNode(self.nodeName)
return upgradedNode
return self.graph.upgradeNode(self.nodeName)
def undoImpl(self):
# delete upgraded node
expectedUid = self.graph.node(self.nodeName)._uid
self.graph.removeNode(self.nodeName)
# recreate compatibility node
with GraphModification(self.graph):
# We come back from an upgrade, so we enforce uidConflict=True as there was a uid conflict before
node = nodeFactory(self.nodeDict, name=self.nodeName, expectedUid=expectedUid)
self.graph.addNode(node, self.nodeName)
# recreate out edges
for dstAttr, srcAttr in self.outEdges.items():
# if edges were connected to ListAttributes, recreate their corresponding entry in said ListAttribute
# 0 = attribute name, 1 = attribute index, 2 = attribute value
if dstAttr in self.outListAttributes.keys():
listAttr = self.graph.attribute(self.outListAttributes[dstAttr][0])
if isinstance(self.outListAttributes[dstAttr][2], list):
listAttr[self.outListAttributes[dstAttr][1]:self.outListAttributes[dstAttr][1]] = self.outListAttributes[dstAttr][2]
else:
listAttr.insert(self.outListAttributes[dstAttr][1], self.outListAttributes[dstAttr][2])
self.graph.addEdge(self.graph.attribute(srcAttr),
self.graph.attribute(dstAttr))
self.graph.replaceNode(self.nodeName, node)
class EnableGraphUpdateCommand(GraphCommand):

View file

@ -255,7 +255,7 @@ def test_description_conflict():
assert not hasattr(compatNode, "in")
# perform upgrade
upgradedNode = g.upgradeNode(nodeName)[0]
upgradedNode = g.upgradeNode(nodeName)
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV2)
assert list(upgradedNode.attributes.keys()) == ["in", "paramA", "output"]
@ -270,7 +270,7 @@ def test_description_conflict():
assert hasattr(compatNode, "paramA")
# perform upgrade
upgradedNode = g.upgradeNode(nodeName)[0]
upgradedNode = g.upgradeNode(nodeName)
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV3)
assert not hasattr(upgradedNode, "paramA")
@ -283,7 +283,7 @@ def test_description_conflict():
assert not hasattr(compatNode, "paramA")
# perform upgrade
upgradedNode = g.upgradeNode(nodeName)[0]
upgradedNode = g.upgradeNode(nodeName)
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV4)
assert hasattr(upgradedNode, "paramA")
@ -303,7 +303,7 @@ def test_description_conflict():
assert isinstance(elt, next(a for a in SampleGroupV1 if a.name == elt.name).__class__)
# perform upgrade
upgradedNode = g.upgradeNode(nodeName)[0]
upgradedNode = g.upgradeNode(nodeName)
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV5)
assert hasattr(upgradedNode, "paramA")