mirror of
https://github.com/alicevision/Meshroom.git
synced 2025-04-29 10:17:27 +02:00
[core] move duplicateNode methods to core.graph
handle this low-level operation engine side * fix ListAttribute children links duplication * handle CompatibilityNode duplication * move corresponding unit test in test_graph.py * [ui] add DuplicateNodeCommand
This commit is contained in:
parent
1af3a16d81
commit
2952e11691
7 changed files with 164 additions and 103 deletions
|
@ -13,7 +13,7 @@ from enum import Enum
|
|||
import meshroom
|
||||
import meshroom.core
|
||||
from meshroom.common import BaseObject, DictModel, Slot, Signal, Property
|
||||
from meshroom.core.attribute import Attribute
|
||||
from meshroom.core.attribute import Attribute, ListAttribute
|
||||
from meshroom.core.exception import UnknownNodeTypeError
|
||||
from meshroom.core.node import node_factory, Status, Node, CompatibilityNode
|
||||
|
||||
|
@ -264,6 +264,86 @@ class Graph(BaseObject):
|
|||
node._applyExpr()
|
||||
return node
|
||||
|
||||
def copyNode(self, srcNode, withEdges=False):
|
||||
"""
|
||||
Get a copy instance of a node outside the graph.
|
||||
|
||||
Args:
|
||||
srcNode (Node): the node to copy
|
||||
withEdges (bool): whether to copy edges
|
||||
|
||||
Returns:
|
||||
Node, dict: the created node instance,
|
||||
a dictionary of linked attributes with their original value (empty if withEdges is True)
|
||||
"""
|
||||
with GraphModification(self):
|
||||
# create a new node of the same type and with the same attributes values
|
||||
# keep links as-is so that CompatibilityNodes attributes can be created with correct automatic description
|
||||
# (File params for link expressions)
|
||||
node = node_factory(srcNode.toDict())
|
||||
# skip edges: filter out attributes which are links by resetting default values
|
||||
skippedEdges = {}
|
||||
if not withEdges:
|
||||
for n, attr in node.attributes.items():
|
||||
# find top-level links
|
||||
if Attribute.isLinkExpression(attr.value):
|
||||
skippedEdges[attr] = attr.value
|
||||
attr.resetValue()
|
||||
# find links in ListAttribute children
|
||||
elif isinstance(attr, ListAttribute):
|
||||
for child in attr.value:
|
||||
if Attribute.isLinkExpression(child.value):
|
||||
skippedEdges[child] = child.value
|
||||
child.resetValue()
|
||||
return node, skippedEdges
|
||||
|
||||
def duplicateNode(self, srcNode):
|
||||
""" Duplicate a node in the graph with its connections.
|
||||
|
||||
Args:
|
||||
srcNode: the node to duplicate
|
||||
|
||||
Returns:
|
||||
Node: the created node
|
||||
"""
|
||||
node, edges = self.copyNode(srcNode, withEdges=True)
|
||||
return self.addNode(node)
|
||||
|
||||
def duplicateNodesFromNode(self, fromNode):
|
||||
"""
|
||||
Duplicate 'fromNode' and all the following nodes towards graph's leaves.
|
||||
|
||||
Args:
|
||||
fromNode (Node): the node to start the duplication from
|
||||
|
||||
Returns:
|
||||
Dict[Node, Node]: the source->duplicate map
|
||||
"""
|
||||
srcNodes, srcEdges = self.nodesFromNode(fromNode)
|
||||
duplicates = {}
|
||||
|
||||
with GraphModification(self):
|
||||
duplicateEdges = {}
|
||||
# first, duplicate all nodes without edges and keep a 'source=>duplicate' map
|
||||
# keeps tracks of non-created edges for later remap
|
||||
for srcNode in srcNodes:
|
||||
node, edges = self.copyNode(srcNode, withEdges=False)
|
||||
duplicate = self.addNode(node)
|
||||
duplicateEdges.update(edges)
|
||||
duplicates[srcNode] = duplicate # original node to duplicate map
|
||||
|
||||
# re-create edges taking into account what has been duplicated
|
||||
for attr, linkExpression in duplicateEdges.items():
|
||||
link = linkExpression[1:-1] # remove starting '{' and trailing '}'
|
||||
# get source node and attribute name
|
||||
edgeSrcNodeName, edgeSrcAttrName = link.split(".", 1)
|
||||
edgeSrcNode = self.node(edgeSrcNodeName)
|
||||
# if the edge's source node has been duplicated, use the duplicate; otherwise use the original node
|
||||
edgeSrcNode = duplicates.get(edgeSrcNode, edgeSrcNode)
|
||||
self.addEdge(edgeSrcNode.attribute(edgeSrcAttrName), attr)
|
||||
|
||||
return duplicates
|
||||
|
||||
def outEdges(self, attribute):
|
||||
""" Return the list of edges starting from the given attribute """
|
||||
# type: (Attribute,) -> [Edge]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# coding:utf-8
|
||||
import atexit
|
||||
import copy
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
|
@ -677,7 +678,9 @@ class CompatibilityNode(BaseNode):
|
|||
super(CompatibilityNode, self).__init__(nodeType, parent)
|
||||
|
||||
self.issue = issue
|
||||
self.nodeDict = nodeDict
|
||||
# make a deepcopy of nodeDict to handle CompatibilityNode duplication
|
||||
# and be able to change modified inputs (see CompatibilityNode.toDict)
|
||||
self.nodeDict = copy.deepcopy(nodeDict)
|
||||
|
||||
self.inputs = nodeDict.get("inputs", {})
|
||||
self.outputs = nodeDict.get("outputs", {})
|
||||
|
@ -804,10 +807,21 @@ class CompatibilityNode(BaseNode):
|
|||
self._attributes.add(attribute)
|
||||
return matchDesc
|
||||
|
||||
@property
|
||||
def actualInputs(self):
|
||||
""" Get actual node inputs, where links could differ from original serialized node data
|
||||
(i.e after node duplication) """
|
||||
return {k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isInput}
|
||||
|
||||
def toDict(self):
|
||||
"""
|
||||
Return the original serialized node that generated a compatibility issue.
|
||||
|
||||
Serialized inputs are updated to handle instances that have been duplicated
|
||||
and might be connected to different nodes.
|
||||
"""
|
||||
# update inputs to get up-to-date connections
|
||||
self.nodeDict.update({"inputs": self.actualInputs})
|
||||
return self.nodeDict
|
||||
|
||||
@property
|
||||
|
@ -823,7 +837,8 @@ class CompatibilityNode(BaseNode):
|
|||
if not self.canUpgrade:
|
||||
raise NodeUpgradeError(self.name, "no matching node type")
|
||||
# TODO: use upgrade method of node description if available
|
||||
return Node(self.nodeType, **{key: value for key, value in self.inputs.items() if key in self._commonInputs})
|
||||
return Node(self.nodeType, **{key: value for key, value in self.actualInputs.items()
|
||||
if key in self._commonInputs})
|
||||
|
||||
|
||||
def node_factory(nodeDict, name=None):
|
||||
|
|
|
@ -135,6 +135,35 @@ class RemoveNodeCommand(GraphCommand):
|
|||
self.graph.attribute(dstAttr))
|
||||
|
||||
|
||||
class DuplicateNodeCommand(GraphCommand):
|
||||
"""
|
||||
Handle node duplication in a Graph.
|
||||
"""
|
||||
def __init__(self, graph, srcNode, duplicateFollowingNodes, parent=None):
|
||||
super(DuplicateNodeCommand, self).__init__(graph, parent)
|
||||
self.srcNodeName = srcNode.name
|
||||
self.duplicateFollowingNodes = duplicateFollowingNodes
|
||||
self.duplicates = []
|
||||
|
||||
def redoImpl(self):
|
||||
srcNode = self.graph.node(self.srcNodeName)
|
||||
|
||||
if self.duplicateFollowingNodes:
|
||||
duplicates = self.graph.duplicateNodesFromNode(srcNode)
|
||||
self.duplicates = [n.name for n in duplicates.values()]
|
||||
self.setText("Duplicate {} nodes from {}".format(len(duplicates), self.srcNodeName))
|
||||
else:
|
||||
self.duplicates = [self.graph.duplicateNode(srcNode).name]
|
||||
self.setText("Duplicate {}".format(self.srcNodeName))
|
||||
|
||||
return self.duplicates
|
||||
|
||||
def undoImpl(self):
|
||||
# delete all the duplicated nodes
|
||||
for nodeName in self.duplicates:
|
||||
self.graph.removeNode(nodeName)
|
||||
|
||||
|
||||
class SetAttributeCommand(GraphCommand):
|
||||
def __init__(self, graph, attribute, value, parent=None):
|
||||
super(SetAttributeCommand, self).__init__(graph, parent)
|
||||
|
|
|
@ -283,75 +283,19 @@ class UIGraph(QObject):
|
|||
""" Reset 'attribute' to its default value """
|
||||
self.push(commands.SetAttributeCommand(self._graph, attribute, attribute.defaultValue()))
|
||||
|
||||
@Slot(Node)
|
||||
def duplicateNode(self, srcNode, createEdges=True):
|
||||
@Slot(Node, bool, result="QVariantList")
|
||||
def duplicateNode(self, srcNode, duplicateFollowingNodes=False):
|
||||
"""
|
||||
Duplicate 'srcNode'.
|
||||
Duplicate a node an optionally all the following nodes to graph leaves.
|
||||
|
||||
Args:
|
||||
srcNode (Node): the node to duplicate
|
||||
createEdges (bool): whether to replicate 'srcNode' edges on the duplicated node
|
||||
|
||||
Returns:
|
||||
Node: the duplicated node
|
||||
"""
|
||||
serialized = srcNode.toDict()
|
||||
with self.groupedGraphModification("Duplicate Node {}".format(srcNode.name)):
|
||||
# skip edges: filter out attributes which are links
|
||||
if not createEdges:
|
||||
serialized["inputs"] = {k: v for k, v in serialized["inputs"].items() if not Attribute.isLinkExpression(v)}
|
||||
# create a new node of the same type and with the same attributes values
|
||||
node = self.addNewNode(serialized["nodeType"], **serialized["inputs"])
|
||||
return node
|
||||
|
||||
def duplicateNodesFromNode(self, fromNode):
|
||||
"""
|
||||
Duplicate 'fromNode' and all the following nodes towards graph's leaves.
|
||||
|
||||
Args:
|
||||
fromNode (Node): the node to start the duplication from
|
||||
|
||||
Returns:
|
||||
{Nodes: Node}: the source->duplicate nodes map
|
||||
"""
|
||||
srcNodes, srcEdges = self._graph.nodesFromNode(fromNode)
|
||||
duplicates = {}
|
||||
|
||||
with self.groupedGraphModification("Duplicate {} Nodes".format(len(srcNodes))):
|
||||
# duplicate all nodes without edges and keep a 'source=>duplicate' map
|
||||
for srcNode in srcNodes:
|
||||
duplicate = self.duplicateNode(srcNode, createEdges=False)
|
||||
duplicates[srcNode] = duplicate # original node to duplicate map
|
||||
|
||||
# re-create edges taking into account what has been duplicated
|
||||
for srcNode, duplicate in duplicates.items():
|
||||
# get link attributes
|
||||
links = {k: v for k, v in srcNode.toDict()["inputs"].items() if Attribute.isLinkExpression(v)}
|
||||
for attr, link in links.items():
|
||||
link = link[1:-1] # remove starting '{' and trailing '}'
|
||||
# get source node and attribute name
|
||||
edgeSrcNodeName, edgeSrcAttrName = link.split(".", 1)
|
||||
edgeSrcNode = self._graph.node(edgeSrcNodeName)
|
||||
# if the edge's source node has been duplicated, use the duplicate; otherwise use the original node
|
||||
edgeSrcNode = duplicates.get(edgeSrcNode, edgeSrcNode)
|
||||
self.addEdge(edgeSrcNode.attribute(edgeSrcAttrName), duplicate.attribute(attr))
|
||||
|
||||
return duplicates
|
||||
|
||||
@Slot(Node, result="QVariantList")
|
||||
def duplicateNodes(self, fromNode):
|
||||
"""
|
||||
Slot accessor to 'duplicateNodesFromNode'. Returns the list of created nodes, usable from QML.
|
||||
|
||||
Args:
|
||||
fromNode (Node): node to start the duplication from
|
||||
|
||||
See Also: duplicateNodesFromNode
|
||||
srcNode (Node): node to start the duplication from
|
||||
duplicateFollowingNodes (bool): whether to duplicate all the following nodes to graph leaves
|
||||
|
||||
Returns:
|
||||
[Nodes]: the list of duplicated nodes
|
||||
"""
|
||||
return self.duplicateNodesFromNode(fromNode).values()
|
||||
return self.push(commands.DuplicateNodeCommand(self._graph, srcNode, duplicateFollowingNodes))
|
||||
|
||||
@Slot(CompatibilityNode)
|
||||
def upgradeNode(self, node):
|
||||
|
|
|
@ -240,7 +240,7 @@ Item {
|
|||
}
|
||||
|
||||
function duplicate(duplicateFollowingNodes) {
|
||||
var nodes = duplicateFollowingNodes ? uigraph.duplicateNodes(node) : [uigraph.duplicateNode(node)]
|
||||
var nodes = uigraph.duplicateNode(node, duplicateFollowingNodes)
|
||||
var delegates = []
|
||||
var from = nodeRepeater.count - nodes.length
|
||||
var to = nodeRepeater.count - 1
|
||||
|
|
|
@ -204,3 +204,33 @@ def test_graph_nodes_sorting():
|
|||
ls1 = graph.addNewNode('Ls', name='Ls_1')
|
||||
|
||||
assert graph.nodesByType('Ls', sortedByIndex=True) == [ls0, ls1, ls2]
|
||||
|
||||
|
||||
def test_duplicate_nodes():
|
||||
"""
|
||||
Test nodes duplication.
|
||||
"""
|
||||
|
||||
# n0 -- n1 -- n2
|
||||
# \ \
|
||||
# ---------- n3
|
||||
|
||||
g = Graph('')
|
||||
n0 = g.addNewNode('Ls', input='/tmp')
|
||||
n1 = g.addNewNode('Ls', input=n0.output)
|
||||
n2 = g.addNewNode('Ls', input=n1.output)
|
||||
n3 = g.addNewNode('AppendFiles', input=n1.output, input2=n2.output)
|
||||
|
||||
# duplicate from n1
|
||||
nMap = g.duplicateNodesFromNode(fromNode=n1)
|
||||
for s, d in nMap.items():
|
||||
assert s.nodeType == d.nodeType
|
||||
|
||||
# check number of duplicated nodes
|
||||
assert len(nMap) == 3
|
||||
|
||||
# check connections
|
||||
assert nMap[n1].input.getLinkParam() == n0.output
|
||||
assert nMap[n2].input.getLinkParam() == nMap[n1].output
|
||||
assert nMap[n3].input.getLinkParam() == nMap[n1].output
|
||||
assert nMap[n3].input2.getLinkParam() == nMap[n2].output
|
||||
|
|
|
@ -1,37 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# coding:utf-8
|
||||
|
||||
from meshroom.ui.graph import UIGraph
|
||||
|
||||
|
||||
def test_duplicate_nodes():
|
||||
"""
|
||||
Test nodes duplication.
|
||||
"""
|
||||
|
||||
# n0 -- n1 -- n2
|
||||
# \ \
|
||||
# ---------- n3
|
||||
|
||||
g = UIGraph()
|
||||
n0 = g.addNewNode('Ls', input='/tmp')
|
||||
n1 = g.addNewNode('Ls', input=n0.output)
|
||||
n2 = g.addNewNode('Ls', input=n1.output)
|
||||
n3 = g.addNewNode('AppendFiles', input=n1.output, input2=n2.output)
|
||||
|
||||
# duplicate from n1
|
||||
nMap = g.duplicateNodesFromNode(fromNode=n1)
|
||||
for s, d in nMap.items():
|
||||
assert s.nodeType == d.nodeType
|
||||
|
||||
# check number of duplicated nodes
|
||||
assert len(nMap) == 3
|
||||
|
||||
# check connections
|
||||
assert nMap[n1].input.getLinkParam() == n0.output
|
||||
assert nMap[n2].input.getLinkParam() == nMap[n1].output
|
||||
assert nMap[n3].input.getLinkParam() == nMap[n1].output
|
||||
assert nMap[n3].input2.getLinkParam() == nMap[n2].output
|
||||
|
||||
# ensure de-allocation order for un-parented UIGraph (QObject) with no QApplication instance
|
||||
g.deleteLater()
|
Loading…
Add table
Reference in a new issue