Store all duplicates of a node correctly upon their creation

Duplicates used to be stored in a dictionary with an entry being
"parent node": "duplicated node". On occasions where a single
parent node was duplicated more than once, the latest duplicated
 node erased the previous one(s), and these older ones were
"lost": after being created, there was no trace left of their
existence in the duplication operation. Undoing that duplication
operation was thus leaving these duplicated nodes out and not
removing them.

Duplicated nodes are now stored as "parent node": [list of
duplicated nodes] to keep track of all the created nodes,
effectively removing them upon an "undo".
This commit is contained in:
Candice Bentéjac 2022-07-20 16:11:37 +02:00
parent 09fc117c65
commit b77274a027
3 changed files with 17 additions and 12 deletions

View file

@ -382,7 +382,7 @@ class Graph(BaseObject):
node, edges = self.copyNode(srcNode, withEdges=False) node, edges = self.copyNode(srcNode, withEdges=False)
duplicate = self.addNode(node) duplicate = self.addNode(node)
duplicateEdges.update(edges) duplicateEdges.update(edges)
duplicates[srcNode] = duplicate # original node to duplicate map duplicates.setdefault(srcNode, []).append(duplicate)
# re-create edges taking into account what has been duplicated # re-create edges taking into account what has been duplicated
for attr, linkExpression in duplicateEdges.items(): for attr, linkExpression in duplicateEdges.items():
@ -390,8 +390,10 @@ class Graph(BaseObject):
# get source node and attribute name # get source node and attribute name
edgeSrcNodeName, edgeSrcAttrName = link.split(".", 1) edgeSrcNodeName, edgeSrcAttrName = link.split(".", 1)
edgeSrcNode = self.node(edgeSrcNodeName) edgeSrcNode = self.node(edgeSrcNodeName)
# if the edge's source node has been duplicated, use the duplicate; otherwise use the original node # if the edge's source node has been duplicated (the key exists in the dictionary),
edgeSrcNode = duplicates.get(edgeSrcNode, edgeSrcNode) # use the duplicate; otherwise use the original node
if edgeSrcNode in duplicates:
edgeSrcNode = duplicates.get(edgeSrcNode)[0]
self.addEdge(edgeSrcNode.attribute(edgeSrcAttrName), attr) self.addEdge(edgeSrcNode.attribute(edgeSrcAttrName), attr)
return duplicates return duplicates

View file

@ -184,7 +184,8 @@ class DuplicateNodesCommand(GraphCommand):
def redoImpl(self): def redoImpl(self):
srcNodes = [ self.graph.node(i) for i in self.srcNodeNames ] srcNodes = [ self.graph.node(i) for i in self.srcNodeNames ]
duplicates = list(self.graph.duplicateNodes(srcNodes).values()) # flatten the list of duplicated nodes to avoid lists within the list
duplicates = [ n for nodes in list(self.graph.duplicateNodes(srcNodes).values()) for n in nodes ]
self.duplicates = [ n.name for n in duplicates ] self.duplicates = [ n.name for n in duplicates ]
return duplicates return duplicates

View file

@ -266,14 +266,16 @@ def test_duplicate_nodes():
# duplicate from n1 # duplicate from n1
nodes_to_duplicate, _ = g.dfsOnDiscover(startNodes=[n1], reverse=True, dependenciesOnly=True) nodes_to_duplicate, _ = g.dfsOnDiscover(startNodes=[n1], reverse=True, dependenciesOnly=True)
nMap = g.duplicateNodes(srcNodes=nodes_to_duplicate) nMap = g.duplicateNodes(srcNodes=nodes_to_duplicate)
for s, d in nMap.items(): for s, duplicated in nMap.items():
assert s.nodeType == d.nodeType for d in duplicated:
assert s.nodeType == d.nodeType
# check number of duplicated nodes # check number of duplicated nodes and that every parent node has been duplicated once
assert len(nMap) == 3 assert len(nMap) == 3 and all([len(nMap[i]) == 1 for i in nMap.keys()])
# check connections # check connections
assert nMap[n1].input.getLinkParam() == n0.output # access directly index 0 because we know there is a single duplicate for each parent node
assert nMap[n2].input.getLinkParam() == nMap[n1].output assert nMap[n1][0].input.getLinkParam() == n0.output
assert nMap[n3].input.getLinkParam() == nMap[n1].output assert nMap[n2][0].input.getLinkParam() == nMap[n1][0].output
assert nMap[n3].input2.getLinkParam() == nMap[n2].output assert nMap[n3][0].input.getLinkParam() == nMap[n1][0].output
assert nMap[n3][0].input2.getLinkParam() == nMap[n2][0].output