Merge pull request #1738 from alicevision/fix/undoDuplicatedNodes

Fix node duplication/removal behaviour
This commit is contained in:
Fabien Castan 2022-08-01 15:56:48 +02:00 committed by GitHub
commit dd5aadd875
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 36 additions and 15 deletions

View file

@ -380,7 +380,7 @@ class Graph(BaseObject):
node, edges = self.copyNode(srcNode, withEdges=False)
duplicate = self.addNode(node)
duplicateEdges.update(edges)
duplicates[srcNode] = duplicate # original node to duplicate map
duplicates.setdefault(srcNode, []).append(duplicate)
# re-create edges taking into account what has been duplicated
for attr, linkExpression in duplicateEdges.items():
@ -388,8 +388,10 @@ class Graph(BaseObject):
# get source node and attribute name
edgeSrcNodeName, edgeSrcAttrName = link.split(".", 1)
edgeSrcNode = self.node(edgeSrcNodeName)
# if the edge's source node has been duplicated, use the duplicate; otherwise use the original node
edgeSrcNode = duplicates.get(edgeSrcNode, edgeSrcNode)
# if the edge's source node has been duplicated (the key exists in the dictionary),
# use the duplicate; otherwise use the original node
if edgeSrcNode in duplicates:
edgeSrcNode = duplicates.get(edgeSrcNode)[0]
self.addEdge(edgeSrcNode.attribute(edgeSrcAttrName), attr)
return duplicates

View file

@ -184,7 +184,8 @@ class DuplicateNodesCommand(GraphCommand):
def redoImpl(self):
srcNodes = [ self.graph.node(i) for i in self.srcNodeNames ]
duplicates = list(self.graph.duplicateNodes(srcNodes).values())
# flatten the list of duplicated nodes to avoid lists within the list
duplicates = [ n for nodes in list(self.graph.duplicateNodes(srcNodes).values()) for n in nodes ]
self.duplicates = [ n.name for n in duplicates ]
return duplicates

View file

@ -559,9 +559,11 @@ class UIGraph(QObject):
"""
with self.groupedGraphModification("Remove Nodes From Selected Nodes"):
nodesToRemove, _ = self._graph.dfsOnDiscover(startNodes=nodes, reverse=True, dependenciesOnly=True)
# filter out nodes that will be removed more than once
uniqueNodesToRemove = list(dict.fromkeys(nodesToRemove))
# Perform nodes removal from leaves to start node so that edges
# can be re-created in correct order on redo.
self.removeNodes(list(reversed(nodesToRemove)))
self.removeNodes(list(reversed(uniqueNodesToRemove)))
@Slot(QObject, result="QVariantList")
def duplicateNodes(self, nodes):
@ -574,6 +576,7 @@ class UIGraph(QObject):
list[Node]: the list of duplicated nodes
"""
nodes = self.filterNodes(nodes)
nPositions = []
# enable updates between duplication and layout to get correct depths during layout
with self.groupedGraphModification("Duplicate Selected Nodes", disableUpdates=False):
# disable graph updates during duplication
@ -581,8 +584,19 @@ class UIGraph(QObject):
duplicates = self.push(commands.DuplicateNodesCommand(self._graph, nodes))
# move nodes below the bounding box formed by the duplicated node(s)
bbox = self._layout.boundingBox(duplicates)
for n in duplicates:
self.moveNode(n, Position(n.x, bbox[3] + self.layout.gridSpacing + n.y))
idx = duplicates.index(n)
yPos = n.y + self.layout.gridSpacing + bbox[3]
if idx > 0 and (n.x, yPos) in nPositions:
# make sure the node will not be moved on top of another node
while (n.x, yPos) in nPositions:
yPos = yPos + self.layout.gridSpacing + self.layout.nodeHeight
self.moveNode(n, Position(n.x, yPos))
else:
self.moveNode(n, Position(n.x, bbox[3] + self.layout.gridSpacing + n.y))
nPositions.append((n.x, n.y))
return duplicates
@Slot(QObject, result="QVariantList")
@ -597,7 +611,9 @@ class UIGraph(QObject):
"""
with self.groupedGraphModification("Duplicate Nodes From Selected Nodes"):
nodesToDuplicate, _ = self._graph.dfsOnDiscover(startNodes=nodes, reverse=True, dependenciesOnly=True)
duplicates = self.duplicateNodes(nodesToDuplicate)
# filter out nodes that will be duplicated more than once
uniqueNodesToDuplicate = list(dict.fromkeys(nodesToDuplicate))
duplicates = self.duplicateNodes(uniqueNodesToDuplicate)
return duplicates
@Slot(QObject)

View file

@ -266,14 +266,16 @@ def test_duplicate_nodes():
# duplicate from n1
nodes_to_duplicate, _ = g.dfsOnDiscover(startNodes=[n1], reverse=True, dependenciesOnly=True)
nMap = g.duplicateNodes(srcNodes=nodes_to_duplicate)
for s, d in nMap.items():
assert s.nodeType == d.nodeType
for s, duplicated in nMap.items():
for d in duplicated:
assert s.nodeType == d.nodeType
# check number of duplicated nodes
assert len(nMap) == 3
# check number of duplicated nodes and that every parent node has been duplicated once
assert len(nMap) == 3 and all([len(nMap[i]) == 1 for i in nMap.keys()])
# check connections
assert nMap[n1].input.getLinkParam() == n0.output
assert nMap[n2].input.getLinkParam() == nMap[n1].output
assert nMap[n3].input.getLinkParam() == nMap[n1].output
assert nMap[n3].input2.getLinkParam() == nMap[n2].output
# access directly index 0 because we know there is a single duplicate for each parent node
assert nMap[n1][0].input.getLinkParam() == n0.output
assert nMap[n2][0].input.getLinkParam() == nMap[n1][0].output
assert nMap[n3][0].input.getLinkParam() == nMap[n1][0].output
assert nMap[n3][0].input2.getLinkParam() == nMap[n2][0].output