diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index 5cbc1a5b..0d12b597 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -13,7 +13,7 @@ from enum import Enum import meshroom import meshroom.core from meshroom.common import BaseObject, DictModel, Slot, Signal, Property -from meshroom.core.attribute import Attribute +from meshroom.core.attribute import Attribute, ListAttribute from meshroom.core.exception import UnknownNodeTypeError from meshroom.core.node import node_factory, Status, Node, CompatibilityNode @@ -264,6 +264,86 @@ class Graph(BaseObject): node._applyExpr() return node + def copyNode(self, srcNode, withEdges=False): + """ + Get a copy instance of a node outside the graph. + + Args: + srcNode (Node): the node to copy + withEdges (bool): whether to copy edges + + Returns: + Node, dict: the created node instance, + a dictionary of linked attributes with their original value (empty if withEdges is True) + """ + with GraphModification(self): + # create a new node of the same type and with the same attributes values + # keep links as-is so that CompatibilityNodes attributes can be created with correct automatic description + # (File params for link expressions) + node = node_factory(srcNode.toDict()) + # skip edges: filter out attributes which are links by resetting default values + skippedEdges = {} + if not withEdges: + for n, attr in node.attributes.items(): + # find top-level links + if Attribute.isLinkExpression(attr.value): + skippedEdges[attr] = attr.value + attr.resetValue() + # find links in ListAttribute children + elif isinstance(attr, ListAttribute): + for child in attr.value: + if Attribute.isLinkExpression(child.value): + skippedEdges[child] = child.value + child.resetValue() + return node, skippedEdges + + def duplicateNode(self, srcNode): + """ Duplicate a node in the graph with its connections. + + Args: + srcNode: the node to duplicate + + Returns: + Node: the created node + """ + node, edges = self.copyNode(srcNode, withEdges=True) + return self.addNode(node) + + def duplicateNodesFromNode(self, fromNode): + """ + Duplicate 'fromNode' and all the following nodes towards graph's leaves. + + Args: + fromNode (Node): the node to start the duplication from + + Returns: + Dict[Node, Node]: the source->duplicate map + """ + srcNodes, srcEdges = self.nodesFromNode(fromNode) + duplicates = {} + + with GraphModification(self): + duplicateEdges = {} + # first, duplicate all nodes without edges and keep a 'source=>duplicate' map + # keeps tracks of non-created edges for later remap + for srcNode in srcNodes: + node, edges = self.copyNode(srcNode, withEdges=False) + duplicate = self.addNode(node) + duplicateEdges.update(edges) + duplicates[srcNode] = duplicate # original node to duplicate map + + # re-create edges taking into account what has been duplicated + for attr, linkExpression in duplicateEdges.items(): + link = linkExpression[1:-1] # remove starting '{' and trailing '}' + # get source node and attribute name + edgeSrcNodeName, edgeSrcAttrName = link.split(".", 1) + edgeSrcNode = self.node(edgeSrcNodeName) + # if the edge's source node has been duplicated, use the duplicate; otherwise use the original node + edgeSrcNode = duplicates.get(edgeSrcNode, edgeSrcNode) + self.addEdge(edgeSrcNode.attribute(edgeSrcAttrName), attr) + + return duplicates + def outEdges(self, attribute): """ Return the list of edges starting from the given attribute """ # type: (Attribute,) -> [Edge] diff --git a/meshroom/core/node.py b/meshroom/core/node.py index 8a012e13..a0d12f94 100644 --- a/meshroom/core/node.py +++ b/meshroom/core/node.py @@ -1,6 +1,7 @@ #!/usr/bin/env python # coding:utf-8 import atexit +import copy import datetime import json import logging @@ -677,7 +678,9 @@ class CompatibilityNode(BaseNode): super(CompatibilityNode, self).__init__(nodeType, parent) self.issue = issue - self.nodeDict = nodeDict + # make a deepcopy of nodeDict to handle CompatibilityNode duplication + # and be able to change modified inputs (see CompatibilityNode.toDict) + self.nodeDict = copy.deepcopy(nodeDict) self.inputs = nodeDict.get("inputs", {}) self.outputs = nodeDict.get("outputs", {}) @@ -804,10 +807,21 @@ class CompatibilityNode(BaseNode): self._attributes.add(attribute) return matchDesc + @property + def actualInputs(self): + """ Get actual node inputs, where links could differ from original serialized node data + (i.e after node duplication) """ + return {k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isInput} + def toDict(self): """ Return the original serialized node that generated a compatibility issue. + + Serialized inputs are updated to handle instances that have been duplicated + and might be connected to different nodes. """ + # update inputs to get up-to-date connections + self.nodeDict.update({"inputs": self.actualInputs}) return self.nodeDict @property @@ -823,7 +837,8 @@ class CompatibilityNode(BaseNode): if not self.canUpgrade: raise NodeUpgradeError(self.name, "no matching node type") # TODO: use upgrade method of node description if available - return Node(self.nodeType, **{key: value for key, value in self.inputs.items() if key in self._commonInputs}) + return Node(self.nodeType, **{key: value for key, value in self.actualInputs.items() + if key in self._commonInputs}) def node_factory(nodeDict, name=None): diff --git a/meshroom/ui/commands.py b/meshroom/ui/commands.py index be76ff8e..a0a4e453 100755 --- a/meshroom/ui/commands.py +++ b/meshroom/ui/commands.py @@ -135,6 +135,35 @@ class RemoveNodeCommand(GraphCommand): self.graph.attribute(dstAttr)) +class DuplicateNodeCommand(GraphCommand): + """ + Handle node duplication in a Graph. + """ + def __init__(self, graph, srcNode, duplicateFollowingNodes, parent=None): + super(DuplicateNodeCommand, self).__init__(graph, parent) + self.srcNodeName = srcNode.name + self.duplicateFollowingNodes = duplicateFollowingNodes + self.duplicates = [] + + def redoImpl(self): + srcNode = self.graph.node(self.srcNodeName) + + if self.duplicateFollowingNodes: + duplicates = self.graph.duplicateNodesFromNode(srcNode) + self.duplicates = [n.name for n in duplicates.values()] + self.setText("Duplicate {} nodes from {}".format(len(duplicates), self.srcNodeName)) + else: + self.duplicates = [self.graph.duplicateNode(srcNode).name] + self.setText("Duplicate {}".format(self.srcNodeName)) + + return self.duplicates + + def undoImpl(self): + # delete all the duplicated nodes + for nodeName in self.duplicates: + self.graph.removeNode(nodeName) + + class SetAttributeCommand(GraphCommand): def __init__(self, graph, attribute, value, parent=None): super(SetAttributeCommand, self).__init__(graph, parent) diff --git a/meshroom/ui/graph.py b/meshroom/ui/graph.py index f6f52bf4..998e17b2 100644 --- a/meshroom/ui/graph.py +++ b/meshroom/ui/graph.py @@ -283,75 +283,19 @@ class UIGraph(QObject): """ Reset 'attribute' to its default value """ self.push(commands.SetAttributeCommand(self._graph, attribute, attribute.defaultValue())) - @Slot(Node) - def duplicateNode(self, srcNode, createEdges=True): + @Slot(Node, bool, result="QVariantList") + def duplicateNode(self, srcNode, duplicateFollowingNodes=False): """ - Duplicate 'srcNode'. + Duplicate a node an optionally all the following nodes to graph leaves. Args: - srcNode (Node): the node to duplicate - createEdges (bool): whether to replicate 'srcNode' edges on the duplicated node - - Returns: - Node: the duplicated node - """ - serialized = srcNode.toDict() - with self.groupedGraphModification("Duplicate Node {}".format(srcNode.name)): - # skip edges: filter out attributes which are links - if not createEdges: - serialized["inputs"] = {k: v for k, v in serialized["inputs"].items() if not Attribute.isLinkExpression(v)} - # create a new node of the same type and with the same attributes values - node = self.addNewNode(serialized["nodeType"], **serialized["inputs"]) - return node - - def duplicateNodesFromNode(self, fromNode): - """ - Duplicate 'fromNode' and all the following nodes towards graph's leaves. - - Args: - fromNode (Node): the node to start the duplication from - - Returns: - {Nodes: Node}: the source->duplicate nodes map - """ - srcNodes, srcEdges = self._graph.nodesFromNode(fromNode) - duplicates = {} - - with self.groupedGraphModification("Duplicate {} Nodes".format(len(srcNodes))): - # duplicate all nodes without edges and keep a 'source=>duplicate' map - for srcNode in srcNodes: - duplicate = self.duplicateNode(srcNode, createEdges=False) - duplicates[srcNode] = duplicate # original node to duplicate map - - # re-create edges taking into account what has been duplicated - for srcNode, duplicate in duplicates.items(): - # get link attributes - links = {k: v for k, v in srcNode.toDict()["inputs"].items() if Attribute.isLinkExpression(v)} - for attr, link in links.items(): - link = link[1:-1] # remove starting '{' and trailing '}' - # get source node and attribute name - edgeSrcNodeName, edgeSrcAttrName = link.split(".", 1) - edgeSrcNode = self._graph.node(edgeSrcNodeName) - # if the edge's source node has been duplicated, use the duplicate; otherwise use the original node - edgeSrcNode = duplicates.get(edgeSrcNode, edgeSrcNode) - self.addEdge(edgeSrcNode.attribute(edgeSrcAttrName), duplicate.attribute(attr)) - - return duplicates - - @Slot(Node, result="QVariantList") - def duplicateNodes(self, fromNode): - """ - Slot accessor to 'duplicateNodesFromNode'. Returns the list of created nodes, usable from QML. - - Args: - fromNode (Node): node to start the duplication from - - See Also: duplicateNodesFromNode + srcNode (Node): node to start the duplication from + duplicateFollowingNodes (bool): whether to duplicate all the following nodes to graph leaves Returns: [Nodes]: the list of duplicated nodes """ - return self.duplicateNodesFromNode(fromNode).values() + return self.push(commands.DuplicateNodeCommand(self._graph, srcNode, duplicateFollowingNodes)) @Slot(CompatibilityNode) def upgradeNode(self, node): diff --git a/meshroom/ui/qml/GraphEditor/GraphEditor.qml b/meshroom/ui/qml/GraphEditor/GraphEditor.qml index 5347dfb9..d841ba38 100755 --- a/meshroom/ui/qml/GraphEditor/GraphEditor.qml +++ b/meshroom/ui/qml/GraphEditor/GraphEditor.qml @@ -240,7 +240,7 @@ Item { } function duplicate(duplicateFollowingNodes) { - var nodes = duplicateFollowingNodes ? uigraph.duplicateNodes(node) : [uigraph.duplicateNode(node)] + var nodes = uigraph.duplicateNode(node, duplicateFollowingNodes) var delegates = [] var from = nodeRepeater.count - nodes.length var to = nodeRepeater.count - 1 diff --git a/tests/test_graph.py b/tests/test_graph.py index 62c5d2e1..8c0c5e41 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -204,3 +204,33 @@ def test_graph_nodes_sorting(): ls1 = graph.addNewNode('Ls', name='Ls_1') assert graph.nodesByType('Ls', sortedByIndex=True) == [ls0, ls1, ls2] + + +def test_duplicate_nodes(): + """ + Test nodes duplication. + """ + + # n0 -- n1 -- n2 + # \ \ + # ---------- n3 + + g = Graph('') + n0 = g.addNewNode('Ls', input='/tmp') + n1 = g.addNewNode('Ls', input=n0.output) + n2 = g.addNewNode('Ls', input=n1.output) + n3 = g.addNewNode('AppendFiles', input=n1.output, input2=n2.output) + + # duplicate from n1 + nMap = g.duplicateNodesFromNode(fromNode=n1) + for s, d in nMap.items(): + assert s.nodeType == d.nodeType + + # check number of duplicated nodes + assert len(nMap) == 3 + + # check connections + assert nMap[n1].input.getLinkParam() == n0.output + assert nMap[n2].input.getLinkParam() == nMap[n1].output + assert nMap[n3].input.getLinkParam() == nMap[n1].output + assert nMap[n3].input2.getLinkParam() == nMap[n2].output diff --git a/tests/test_ui_graph.py b/tests/test_ui_graph.py deleted file mode 100644 index 64062dc4..00000000 --- a/tests/test_ui_graph.py +++ /dev/null @@ -1,37 +0,0 @@ -#!/usr/bin/env python -# coding:utf-8 - -from meshroom.ui.graph import UIGraph - - -def test_duplicate_nodes(): - """ - Test nodes duplication. - """ - - # n0 -- n1 -- n2 - # \ \ - # ---------- n3 - - g = UIGraph() - n0 = g.addNewNode('Ls', input='/tmp') - n1 = g.addNewNode('Ls', input=n0.output) - n2 = g.addNewNode('Ls', input=n1.output) - n3 = g.addNewNode('AppendFiles', input=n1.output, input2=n2.output) - - # duplicate from n1 - nMap = g.duplicateNodesFromNode(fromNode=n1) - for s, d in nMap.items(): - assert s.nodeType == d.nodeType - - # check number of duplicated nodes - assert len(nMap) == 3 - - # check connections - assert nMap[n1].input.getLinkParam() == n0.output - assert nMap[n2].input.getLinkParam() == nMap[n1].output - assert nMap[n3].input.getLinkParam() == nMap[n1].output - assert nMap[n3].input2.getLinkParam() == nMap[n2].output - - # ensure de-allocation order for un-parented UIGraph (QObject) with no QApplication instance - g.deleteLater()