mirror of
https://github.com/alicevision/Meshroom.git
synced 2025-06-06 21:01:59 +02:00
[core] New dynamic output attributes
This commit is contained in:
parent
9a09310f07
commit
d5e356c0aa
6 changed files with 139 additions and 14 deletions
|
@ -128,6 +128,8 @@ def validateNodeDesc(nodeDesc):
|
|||
errors.append(err)
|
||||
|
||||
for param in nodeDesc.outputs:
|
||||
if param.value is None:
|
||||
continue
|
||||
err = param.checkValueTypes()
|
||||
if err:
|
||||
errors.append(err)
|
||||
|
|
|
@ -32,6 +32,7 @@ class Attribute(BaseObject):
|
|||
self._errorMessage = errorMessage
|
||||
self._visible = visible
|
||||
self._isExpression = (isinstance(self._value, str) and "{" in self._value) or isinstance(self._value, types.FunctionType)
|
||||
self._isDynamicValue = (self._value is None)
|
||||
self._valueType = None
|
||||
|
||||
name = Property(str, lambda self: self._name, constant=True)
|
||||
|
@ -39,9 +40,13 @@ class Attribute(BaseObject):
|
|||
description = Property(str, lambda self: self._description, constant=True)
|
||||
value = Property(Variant, lambda self: self._value, constant=True)
|
||||
# isExpression:
|
||||
# The value of the attribute's descriptor is a static string expression that should be evaluated at runtime.
|
||||
# The default value of the attribute's descriptor is a static string expression that should be evaluated at runtime.
|
||||
# This property only makes sense for output attributes.
|
||||
isExpression = Property(bool, lambda self: self._isExpression, constant=True)
|
||||
# isDynamicValue
|
||||
# The default value of the attribute's descriptor is None, so it's not an input value,
|
||||
# but an output value that is computed during the Node's process execution.
|
||||
isDynamicValue = Property(bool, lambda self: self._isDynamicValue, constant=True)
|
||||
uid = Property(Variant, lambda self: self._uid, constant=True)
|
||||
group = Property(str, lambda self: self._group, constant=True)
|
||||
advanced = Property(bool, lambda self: self._advanced, constant=True)
|
||||
|
@ -99,6 +104,8 @@ class ListAttribute(Attribute):
|
|||
joinChar = Property(str, lambda self: self._joinChar, constant=True)
|
||||
|
||||
def validateValue(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
if JSValue is not None and isinstance(value, JSValue):
|
||||
# Note: we could use isArray(), property("length").toInt() to retrieve all values
|
||||
raise ValueError("ListAttribute.validateValue: cannot recognize QJSValue. Please, use JSON.stringify(value) in QML.")
|
||||
|
@ -138,6 +145,8 @@ class GroupAttribute(Attribute):
|
|||
groupDesc = Property(Variant, lambda self: self._groupDesc, constant=True)
|
||||
|
||||
def validateValue(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
""" Ensure value is compatible with the group description and convert value if needed. """
|
||||
if JSValue is not None and isinstance(value, JSValue):
|
||||
# Note: we could use isArray(), property("length").toInt() to retrieve all values
|
||||
|
@ -232,6 +241,8 @@ class File(Attribute):
|
|||
self._valueType = str
|
||||
|
||||
def validateValue(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
if not isinstance(value, str):
|
||||
raise ValueError('File only supports string input (param:{}, value:{}, type:{})'.format(self.name, value, type(value)))
|
||||
return os.path.normpath(value).replace('\\', '/') if value else ''
|
||||
|
@ -252,6 +263,8 @@ class BoolParam(Param):
|
|||
self._valueType = bool
|
||||
|
||||
def validateValue(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
try:
|
||||
if isinstance(value, str):
|
||||
# use distutils.util.strtobool to handle (1/0, true/false, on/off, y/n)
|
||||
|
@ -276,6 +289,8 @@ class IntParam(Param):
|
|||
self._valueType = int
|
||||
|
||||
def validateValue(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
# handle unsigned int values that are translated to int by shiboken and may overflow
|
||||
try:
|
||||
return int(value)
|
||||
|
@ -300,6 +315,8 @@ class FloatParam(Param):
|
|||
self._valueType = float
|
||||
|
||||
def validateValue(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
try:
|
||||
return float(value)
|
||||
except:
|
||||
|
@ -320,7 +337,7 @@ class PushButtonParam(Param):
|
|||
self._valueType = None
|
||||
|
||||
def validateValue(self, value):
|
||||
pass
|
||||
return value
|
||||
def checkValueTypes(self):
|
||||
pass
|
||||
|
||||
|
@ -354,6 +371,8 @@ class ChoiceParam(Param):
|
|||
return self._valueType(value)
|
||||
|
||||
def validateValue(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
if self.exclusive:
|
||||
return self.conformValue(value)
|
||||
|
||||
|
@ -383,6 +402,8 @@ class StringParam(Param):
|
|||
self._valueType = str
|
||||
|
||||
def validateValue(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
if not isinstance(value, str):
|
||||
raise ValueError('StringParam value should be a string (param:{}, value:{}, type:{})'.format(self.name, value, type(value)))
|
||||
return value
|
||||
|
@ -401,6 +422,8 @@ class ColorParam(Param):
|
|||
self._valueType = str
|
||||
|
||||
def validateValue(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
if not isinstance(value, str) or len(value.split(" ")) > 1:
|
||||
raise ValueError('ColorParam value should be a string containing either an SVG name or an hexadecimal '
|
||||
'color code (param: {}, value: {}, type: {})'.format(self.name, value, type(value)))
|
||||
|
@ -594,7 +617,8 @@ class Node(object):
|
|||
category = 'Other'
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
super(Node, self).__init__()
|
||||
self.hasDynamicOutputAttribute = any(output.isDynamicValue for output in self.outputs)
|
||||
|
||||
def upgradeAttributeValues(self, attrValues, fromVersion):
|
||||
return attrValues
|
||||
|
@ -630,6 +654,9 @@ class InputNode(Node):
|
|||
"""
|
||||
Node that does not need to be processed, it is just a placeholder for inputs.
|
||||
"""
|
||||
def __init__(self):
|
||||
super(InputNode, self).__init__()
|
||||
|
||||
def processChunk(self, chunk):
|
||||
pass
|
||||
|
||||
|
@ -641,6 +668,9 @@ class CommandLineNode(Node):
|
|||
parallelization = None
|
||||
commandLineRange = ''
|
||||
|
||||
def __init__(self):
|
||||
super(CommandLineNode, self).__init__()
|
||||
|
||||
def buildCommandLine(self, chunk):
|
||||
|
||||
cmdPrefix = ''
|
||||
|
@ -708,6 +738,7 @@ class AVCommandLineNode(CommandLineNode):
|
|||
cmdCore = ''
|
||||
|
||||
def __init__(self):
|
||||
super(AVCommandLineNode, self).__init__()
|
||||
|
||||
if AVCommandLineNode.cgroupParsed is False:
|
||||
|
||||
|
@ -730,9 +761,9 @@ class AVCommandLineNode(CommandLineNode):
|
|||
return commandLineString + AVCommandLineNode.cmdMem + AVCommandLineNode.cmdCore
|
||||
|
||||
# Test abstract node
|
||||
class InitNode:
|
||||
class InitNode(object):
|
||||
def __init__(self):
|
||||
pass
|
||||
super(InitNode, self).__init__()
|
||||
|
||||
def initialize(self, node, inputs, recursiveInputs):
|
||||
"""
|
||||
|
|
|
@ -14,7 +14,7 @@ import meshroom
|
|||
import meshroom.core
|
||||
from meshroom.common import BaseObject, DictModel, Slot, Signal, Property
|
||||
from meshroom.core import Version
|
||||
from meshroom.core.attribute import Attribute, ListAttribute
|
||||
from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute
|
||||
from meshroom.core.exception import StopGraphVisit, StopBranchVisit
|
||||
from meshroom.core.node import nodeFactory, Status, Node, CompatibilityNode
|
||||
|
||||
|
@ -320,7 +320,10 @@ class Graph(BaseObject):
|
|||
# that were computed.
|
||||
if not isTemplate: # UIDs are not stored in templates
|
||||
self._evaluateUidConflicts(graphData)
|
||||
try:
|
||||
self._applyExpr()
|
||||
except Exception as e:
|
||||
logging.warning(e)
|
||||
|
||||
return True
|
||||
|
||||
|
@ -548,12 +551,16 @@ class Graph(BaseObject):
|
|||
skippedEdges = {}
|
||||
if not withEdges:
|
||||
for n, attr in node.attributes.items():
|
||||
if attr.isOutput:
|
||||
# edges are declared in input with an expression linking
|
||||
# to another param (which could be an output)
|
||||
continue
|
||||
# find top-level links
|
||||
if Attribute.isLinkExpression(attr.value):
|
||||
skippedEdges[attr] = attr.value
|
||||
attr.resetToDefaultValue()
|
||||
# find links in ListAttribute children
|
||||
elif isinstance(attr, ListAttribute):
|
||||
elif isinstance(attr, (ListAttribute, GroupAttribute)):
|
||||
for child in attr.value:
|
||||
if Attribute.isLinkExpression(child.value):
|
||||
skippedEdges[child] = child.value
|
||||
|
@ -584,6 +591,7 @@ class Graph(BaseObject):
|
|||
|
||||
# re-create edges taking into account what has been duplicated
|
||||
for attr, linkExpression in duplicateEdges.items():
|
||||
logging.warning("attr={} linkExpression={}".format(attr.fullName, linkExpression))
|
||||
link = linkExpression[1:-1] # remove starting '{' and trailing '}'
|
||||
# get source node and attribute name
|
||||
edgeSrcNodeName, edgeSrcAttrName = link.split(".", 1)
|
||||
|
@ -1625,6 +1633,7 @@ def executeGraph(graph, toNodes=None, forceCompute=False, forceStatus=False):
|
|||
|
||||
for n, node in enumerate(nodes):
|
||||
try:
|
||||
node.preprocess()
|
||||
multiChunks = len(node.chunks) > 1
|
||||
for c, chunk in enumerate(node.chunks):
|
||||
if multiChunks:
|
||||
|
@ -1635,6 +1644,7 @@ def executeGraph(graph, toNodes=None, forceCompute=False, forceStatus=False):
|
|||
print('\n[{node}/{nbNodes}] {nodeName}'.format(
|
||||
node=n + 1, nbNodes=len(nodes), nodeName=node.nodeType))
|
||||
chunk.process(forceCompute)
|
||||
node.postprocess()
|
||||
except Exception as e:
|
||||
logging.error("Error on node computation: {}".format(e))
|
||||
graph.clearSubmittedNodes()
|
||||
|
|
|
@ -686,6 +686,10 @@ class BaseNode(BaseObject):
|
|||
def minDepth(self):
|
||||
return self.graph.getDepth(self, minimal=True)
|
||||
|
||||
@property
|
||||
def valuesFile(self):
|
||||
return os.path.join(self.graph.cacheDir, self.internalFolder, 'values')
|
||||
|
||||
def getInputNodes(self, recursive, dependenciesOnly):
|
||||
return self.graph.getInputNodes(self, recursive=recursive, dependenciesOnly=dependenciesOnly)
|
||||
|
||||
|
@ -955,6 +959,7 @@ class BaseNode(BaseObject):
|
|||
}
|
||||
self._computeUids()
|
||||
self._buildCmdVars()
|
||||
self.updateOutputAttr()
|
||||
if self.nodeDesc:
|
||||
self.nodeDesc.postUpdate(self)
|
||||
# Notify internal folder change if needed
|
||||
|
@ -972,8 +977,11 @@ class BaseNode(BaseObject):
|
|||
"""
|
||||
Update node status based on status file content/existence.
|
||||
"""
|
||||
s = self.globalStatus
|
||||
for chunk in self._chunks:
|
||||
chunk.updateStatusFromCache()
|
||||
# logging.warning("updateStatusFromCache: {}, status: {} => {}".format(self.name, s, self.globalStatus))
|
||||
self.updateOutputAttr()
|
||||
|
||||
def submit(self, forceCompute=False):
|
||||
for chunk in self._chunks:
|
||||
|
@ -988,10 +996,71 @@ class BaseNode(BaseObject):
|
|||
def processIteration(self, iteration):
|
||||
self._chunks[iteration].process()
|
||||
|
||||
def preprocess(self):
|
||||
pass
|
||||
|
||||
def process(self, forceCompute=False):
|
||||
for chunk in self._chunks:
|
||||
chunk.process(forceCompute)
|
||||
|
||||
def postprocess(self):
|
||||
self.saveOutputAttr()
|
||||
|
||||
def updateOutputAttr(self):
|
||||
if not self.nodeDesc:
|
||||
return
|
||||
if not self.nodeDesc.hasDynamicOutputAttribute:
|
||||
return
|
||||
# logging.warning("updateOutputAttr: {}, status: {}".format(self.name, self.globalStatus))
|
||||
if self.getGlobalStatus() == Status.SUCCESS:
|
||||
self.loadOutputAttr()
|
||||
else:
|
||||
self.resetOutputAttr()
|
||||
|
||||
def resetOutputAttr(self):
|
||||
if not self.nodeDesc.hasDynamicOutputAttribute:
|
||||
return
|
||||
# logging.warning("resetOutputAttr: {}".format(self.name))
|
||||
for output in self.nodeDesc.outputs:
|
||||
if output.isDynamicValue:
|
||||
self.attribute(output.name).value = None
|
||||
|
||||
def loadOutputAttr(self):
|
||||
""" Load output attributes with dynamic values from a values.json file.
|
||||
"""
|
||||
if not self.nodeDesc.hasDynamicOutputAttribute:
|
||||
return
|
||||
valuesFile = self.valuesFile
|
||||
if not os.path.exists(valuesFile):
|
||||
logging.warning("No output attr file: {}".format(valuesFile))
|
||||
return
|
||||
|
||||
# logging.warning("load output attr: {}, value: {}".format(self.name, valuesFile))
|
||||
with open(valuesFile, 'r') as jsonFile:
|
||||
data = json.load(jsonFile)
|
||||
|
||||
# logging.warning(data)
|
||||
for output in self.nodeDesc.outputs:
|
||||
if output.isDynamicValue:
|
||||
self.attribute(output.name).value = data[output.name]
|
||||
|
||||
def saveOutputAttr(self):
|
||||
""" Save output attributes with dynamic values into a values.json file.
|
||||
"""
|
||||
if not self.nodeDesc.hasDynamicOutputAttribute:
|
||||
return
|
||||
data = {}
|
||||
for output in self.nodeDesc.outputs:
|
||||
if output.isDynamicValue:
|
||||
data[output.name] = self.attribute(output.name).value
|
||||
|
||||
valuesFile = self.valuesFile
|
||||
# logging.warning("save output attr: {}, value: {}".format(self.name, valuesFile))
|
||||
valuesFilepathWriting = getWritingFilepath(valuesFile)
|
||||
with open(valuesFilepathWriting, 'w') as jsonFile:
|
||||
json.dump(data, jsonFile, indent=4)
|
||||
renameWritingToFinalPath(valuesFilepathWriting, valuesFile)
|
||||
|
||||
def endSequence(self):
|
||||
pass
|
||||
|
||||
|
@ -1227,6 +1296,7 @@ class BaseNode(BaseObject):
|
|||
comment = Property(str, getComment, notify=internalAttributesChanged)
|
||||
internalFolderChanged = Signal()
|
||||
internalFolder = Property(str, internalFolder.fget, notify=internalFolderChanged)
|
||||
valuesFile = Property(str, valuesFile.fget, notify=internalFolderChanged)
|
||||
depthChanged = Signal()
|
||||
depth = Property(int, depth.fget, notify=depthChanged)
|
||||
minDepth = Property(int, minDepth.fget, notify=depthChanged)
|
||||
|
@ -1281,12 +1351,19 @@ class Node(BaseNode):
|
|||
for attrDesc in self.nodeDesc.internalInputs:
|
||||
self._internalAttributes.add(attributeFactory(attrDesc, kwargs.get(attrDesc.name, None), isOutput=False, node=self))
|
||||
|
||||
# List attributes per uid
|
||||
# Declare events for specific output attributes
|
||||
for attr in self._attributes:
|
||||
if attr.isOutput and attr.desc.semantic == "image":
|
||||
attr.enabledChanged.connect(self.outputAttrEnabledChanged)
|
||||
|
||||
# List attributes per uid
|
||||
for attr in self._attributes:
|
||||
if attr.isInput:
|
||||
for uidIndex in attr.attributeDesc.uid:
|
||||
self.attributesPerUid[uidIndex].add(attr)
|
||||
else:
|
||||
if attr.attributeDesc.uid:
|
||||
logging.error(f"Output Attribute should not contain a UID: '{nodeType}.{attr.name}'")
|
||||
|
||||
# Add internal attributes with a UID to the list
|
||||
for attr in self._internalAttributes:
|
||||
|
@ -1347,7 +1424,7 @@ class Node(BaseNode):
|
|||
def toDict(self):
|
||||
inputs = {k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isInput}
|
||||
internalInputs = {k: v.getExportValue() for k, v in self._internalAttributes.objects.items()}
|
||||
outputs = ({k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isOutput})
|
||||
outputs = ({k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isOutput and not v.desc.isDynamicValue})
|
||||
|
||||
return {
|
||||
'nodeType': self.nodeType,
|
||||
|
@ -1714,7 +1791,7 @@ def nodeFactory(nodeDict, name=None, template=False, uidConflict=False):
|
|||
# raising compatibility issues if their number differs: in that case, it is only useful
|
||||
# if some internal attributes do not exist or are invalid
|
||||
if not template and (sorted([attr.name for attr in nodeDesc.inputs if not isinstance(attr, desc.PushButtonParam)]) != sorted(inputs.keys()) or \
|
||||
sorted([attr.name for attr in nodeDesc.outputs]) != sorted(outputs.keys())):
|
||||
sorted([attr.name for attr in nodeDesc.outputs if not attr.isDynamicValue]) != sorted(outputs.keys())):
|
||||
compatibilityIssue = CompatibilityIssue.DescriptionConflict
|
||||
|
||||
# check whether there are any internal attributes that are invalidating in the node description: if there
|
||||
|
|
|
@ -50,6 +50,7 @@ class TaskThread(Thread):
|
|||
except TypeError:
|
||||
continue
|
||||
|
||||
node.preprocess()
|
||||
for cId, chunk in enumerate(node.chunks):
|
||||
if chunk.isFinishedOrRunning() or not self.isRunning():
|
||||
continue
|
||||
|
@ -78,6 +79,7 @@ class TaskThread(Thread):
|
|||
# Node already removed (for instance a global clear of _nodesToProcess)
|
||||
pass
|
||||
n.clearSubmittedChunks()
|
||||
node.postprocess()
|
||||
|
||||
if stopAndRestart:
|
||||
break
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# coding:utf-8
|
||||
from collections.abc import Iterable
|
||||
import logging
|
||||
import os
|
||||
import json
|
||||
|
@ -625,6 +626,8 @@ class UIGraph(QObject):
|
|||
|
||||
def filterNodes(self, nodes):
|
||||
"""Filter out the nodes that do not exist on the graph."""
|
||||
if not isinstance(nodes, Iterable):
|
||||
nodes = [nodes]
|
||||
return [ n for n in nodes if n in self._graph.nodes.values() ]
|
||||
|
||||
@Slot(Node, QPoint, QObject)
|
||||
|
@ -698,7 +701,7 @@ class UIGraph(QObject):
|
|||
with self.groupedGraphModification("Node duplication", disableUpdates=True):
|
||||
duplicates = self.push(commands.DuplicateNodesCommand(self._graph, nodes))
|
||||
# move nodes below the bounding box formed by the duplicated node(s)
|
||||
bbox = self._layout.boundingBox(duplicates)
|
||||
bbox = self._layout.boundingBox(nodes)
|
||||
|
||||
for n in duplicates:
|
||||
idx = duplicates.index(n)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue