[core] New dynamic output attributes

This commit is contained in:
Fabien Castan 2024-06-10 19:56:18 +02:00
parent 9a09310f07
commit d5e356c0aa
6 changed files with 139 additions and 14 deletions

View file

@ -128,6 +128,8 @@ def validateNodeDesc(nodeDesc):
errors.append(err) errors.append(err)
for param in nodeDesc.outputs: for param in nodeDesc.outputs:
if param.value is None:
continue
err = param.checkValueTypes() err = param.checkValueTypes()
if err: if err:
errors.append(err) errors.append(err)

View file

@ -32,6 +32,7 @@ class Attribute(BaseObject):
self._errorMessage = errorMessage self._errorMessage = errorMessage
self._visible = visible self._visible = visible
self._isExpression = (isinstance(self._value, str) and "{" in self._value) or isinstance(self._value, types.FunctionType) self._isExpression = (isinstance(self._value, str) and "{" in self._value) or isinstance(self._value, types.FunctionType)
self._isDynamicValue = (self._value is None)
self._valueType = None self._valueType = None
name = Property(str, lambda self: self._name, constant=True) name = Property(str, lambda self: self._name, constant=True)
@ -39,9 +40,13 @@ class Attribute(BaseObject):
description = Property(str, lambda self: self._description, constant=True) description = Property(str, lambda self: self._description, constant=True)
value = Property(Variant, lambda self: self._value, constant=True) value = Property(Variant, lambda self: self._value, constant=True)
# isExpression: # isExpression:
# The value of the attribute's descriptor is a static string expression that should be evaluated at runtime. # The default value of the attribute's descriptor is a static string expression that should be evaluated at runtime.
# This property only makes sense for output attributes. # This property only makes sense for output attributes.
isExpression = Property(bool, lambda self: self._isExpression, constant=True) isExpression = Property(bool, lambda self: self._isExpression, constant=True)
# isDynamicValue
# The default value of the attribute's descriptor is None, so it's not an input value,
# but an output value that is computed during the Node's process execution.
isDynamicValue = Property(bool, lambda self: self._isDynamicValue, constant=True)
uid = Property(Variant, lambda self: self._uid, constant=True) uid = Property(Variant, lambda self: self._uid, constant=True)
group = Property(str, lambda self: self._group, constant=True) group = Property(str, lambda self: self._group, constant=True)
advanced = Property(bool, lambda self: self._advanced, constant=True) advanced = Property(bool, lambda self: self._advanced, constant=True)
@ -99,6 +104,8 @@ class ListAttribute(Attribute):
joinChar = Property(str, lambda self: self._joinChar, constant=True) joinChar = Property(str, lambda self: self._joinChar, constant=True)
def validateValue(self, value): def validateValue(self, value):
if value is None:
return value
if JSValue is not None and isinstance(value, JSValue): if JSValue is not None and isinstance(value, JSValue):
# Note: we could use isArray(), property("length").toInt() to retrieve all values # Note: we could use isArray(), property("length").toInt() to retrieve all values
raise ValueError("ListAttribute.validateValue: cannot recognize QJSValue. Please, use JSON.stringify(value) in QML.") raise ValueError("ListAttribute.validateValue: cannot recognize QJSValue. Please, use JSON.stringify(value) in QML.")
@ -138,6 +145,8 @@ class GroupAttribute(Attribute):
groupDesc = Property(Variant, lambda self: self._groupDesc, constant=True) groupDesc = Property(Variant, lambda self: self._groupDesc, constant=True)
def validateValue(self, value): def validateValue(self, value):
if value is None:
return value
""" Ensure value is compatible with the group description and convert value if needed. """ """ Ensure value is compatible with the group description and convert value if needed. """
if JSValue is not None and isinstance(value, JSValue): if JSValue is not None and isinstance(value, JSValue):
# Note: we could use isArray(), property("length").toInt() to retrieve all values # Note: we could use isArray(), property("length").toInt() to retrieve all values
@ -232,6 +241,8 @@ class File(Attribute):
self._valueType = str self._valueType = str
def validateValue(self, value): def validateValue(self, value):
if value is None:
return value
if not isinstance(value, str): if not isinstance(value, str):
raise ValueError('File only supports string input (param:{}, value:{}, type:{})'.format(self.name, value, type(value))) raise ValueError('File only supports string input (param:{}, value:{}, type:{})'.format(self.name, value, type(value)))
return os.path.normpath(value).replace('\\', '/') if value else '' return os.path.normpath(value).replace('\\', '/') if value else ''
@ -252,6 +263,8 @@ class BoolParam(Param):
self._valueType = bool self._valueType = bool
def validateValue(self, value): def validateValue(self, value):
if value is None:
return value
try: try:
if isinstance(value, str): if isinstance(value, str):
# use distutils.util.strtobool to handle (1/0, true/false, on/off, y/n) # use distutils.util.strtobool to handle (1/0, true/false, on/off, y/n)
@ -276,6 +289,8 @@ class IntParam(Param):
self._valueType = int self._valueType = int
def validateValue(self, value): def validateValue(self, value):
if value is None:
return value
# handle unsigned int values that are translated to int by shiboken and may overflow # handle unsigned int values that are translated to int by shiboken and may overflow
try: try:
return int(value) return int(value)
@ -300,6 +315,8 @@ class FloatParam(Param):
self._valueType = float self._valueType = float
def validateValue(self, value): def validateValue(self, value):
if value is None:
return value
try: try:
return float(value) return float(value)
except: except:
@ -320,7 +337,7 @@ class PushButtonParam(Param):
self._valueType = None self._valueType = None
def validateValue(self, value): def validateValue(self, value):
pass return value
def checkValueTypes(self): def checkValueTypes(self):
pass pass
@ -354,6 +371,8 @@ class ChoiceParam(Param):
return self._valueType(value) return self._valueType(value)
def validateValue(self, value): def validateValue(self, value):
if value is None:
return value
if self.exclusive: if self.exclusive:
return self.conformValue(value) return self.conformValue(value)
@ -383,6 +402,8 @@ class StringParam(Param):
self._valueType = str self._valueType = str
def validateValue(self, value): def validateValue(self, value):
if value is None:
return value
if not isinstance(value, str): if not isinstance(value, str):
raise ValueError('StringParam value should be a string (param:{}, value:{}, type:{})'.format(self.name, value, type(value))) raise ValueError('StringParam value should be a string (param:{}, value:{}, type:{})'.format(self.name, value, type(value)))
return value return value
@ -401,6 +422,8 @@ class ColorParam(Param):
self._valueType = str self._valueType = str
def validateValue(self, value): def validateValue(self, value):
if value is None:
return value
if not isinstance(value, str) or len(value.split(" ")) > 1: if not isinstance(value, str) or len(value.split(" ")) > 1:
raise ValueError('ColorParam value should be a string containing either an SVG name or an hexadecimal ' raise ValueError('ColorParam value should be a string containing either an SVG name or an hexadecimal '
'color code (param: {}, value: {}, type: {})'.format(self.name, value, type(value))) 'color code (param: {}, value: {}, type: {})'.format(self.name, value, type(value)))
@ -594,7 +617,8 @@ class Node(object):
category = 'Other' category = 'Other'
def __init__(self): def __init__(self):
pass super(Node, self).__init__()
self.hasDynamicOutputAttribute = any(output.isDynamicValue for output in self.outputs)
def upgradeAttributeValues(self, attrValues, fromVersion): def upgradeAttributeValues(self, attrValues, fromVersion):
return attrValues return attrValues
@ -630,6 +654,9 @@ class InputNode(Node):
""" """
Node that does not need to be processed, it is just a placeholder for inputs. Node that does not need to be processed, it is just a placeholder for inputs.
""" """
def __init__(self):
super(InputNode, self).__init__()
def processChunk(self, chunk): def processChunk(self, chunk):
pass pass
@ -641,6 +668,9 @@ class CommandLineNode(Node):
parallelization = None parallelization = None
commandLineRange = '' commandLineRange = ''
def __init__(self):
super(CommandLineNode, self).__init__()
def buildCommandLine(self, chunk): def buildCommandLine(self, chunk):
cmdPrefix = '' cmdPrefix = ''
@ -708,6 +738,7 @@ class AVCommandLineNode(CommandLineNode):
cmdCore = '' cmdCore = ''
def __init__(self): def __init__(self):
super(AVCommandLineNode, self).__init__()
if AVCommandLineNode.cgroupParsed is False: if AVCommandLineNode.cgroupParsed is False:
@ -730,9 +761,9 @@ class AVCommandLineNode(CommandLineNode):
return commandLineString + AVCommandLineNode.cmdMem + AVCommandLineNode.cmdCore return commandLineString + AVCommandLineNode.cmdMem + AVCommandLineNode.cmdCore
# Test abstract node # Test abstract node
class InitNode: class InitNode(object):
def __init__(self): def __init__(self):
pass super(InitNode, self).__init__()
def initialize(self, node, inputs, recursiveInputs): def initialize(self, node, inputs, recursiveInputs):
""" """

View file

@ -14,7 +14,7 @@ import meshroom
import meshroom.core import meshroom.core
from meshroom.common import BaseObject, DictModel, Slot, Signal, Property from meshroom.common import BaseObject, DictModel, Slot, Signal, Property
from meshroom.core import Version from meshroom.core import Version
from meshroom.core.attribute import Attribute, ListAttribute from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute
from meshroom.core.exception import StopGraphVisit, StopBranchVisit from meshroom.core.exception import StopGraphVisit, StopBranchVisit
from meshroom.core.node import nodeFactory, Status, Node, CompatibilityNode from meshroom.core.node import nodeFactory, Status, Node, CompatibilityNode
@ -320,7 +320,10 @@ class Graph(BaseObject):
# that were computed. # that were computed.
if not isTemplate: # UIDs are not stored in templates if not isTemplate: # UIDs are not stored in templates
self._evaluateUidConflicts(graphData) self._evaluateUidConflicts(graphData)
try:
self._applyExpr() self._applyExpr()
except Exception as e:
logging.warning(e)
return True return True
@ -548,12 +551,16 @@ class Graph(BaseObject):
skippedEdges = {} skippedEdges = {}
if not withEdges: if not withEdges:
for n, attr in node.attributes.items(): for n, attr in node.attributes.items():
if attr.isOutput:
# edges are declared in input with an expression linking
# to another param (which could be an output)
continue
# find top-level links # find top-level links
if Attribute.isLinkExpression(attr.value): if Attribute.isLinkExpression(attr.value):
skippedEdges[attr] = attr.value skippedEdges[attr] = attr.value
attr.resetToDefaultValue() attr.resetToDefaultValue()
# find links in ListAttribute children # find links in ListAttribute children
elif isinstance(attr, ListAttribute): elif isinstance(attr, (ListAttribute, GroupAttribute)):
for child in attr.value: for child in attr.value:
if Attribute.isLinkExpression(child.value): if Attribute.isLinkExpression(child.value):
skippedEdges[child] = child.value skippedEdges[child] = child.value
@ -584,6 +591,7 @@ class Graph(BaseObject):
# re-create edges taking into account what has been duplicated # re-create edges taking into account what has been duplicated
for attr, linkExpression in duplicateEdges.items(): for attr, linkExpression in duplicateEdges.items():
logging.warning("attr={} linkExpression={}".format(attr.fullName, linkExpression))
link = linkExpression[1:-1] # remove starting '{' and trailing '}' link = linkExpression[1:-1] # remove starting '{' and trailing '}'
# get source node and attribute name # get source node and attribute name
edgeSrcNodeName, edgeSrcAttrName = link.split(".", 1) edgeSrcNodeName, edgeSrcAttrName = link.split(".", 1)
@ -1625,6 +1633,7 @@ def executeGraph(graph, toNodes=None, forceCompute=False, forceStatus=False):
for n, node in enumerate(nodes): for n, node in enumerate(nodes):
try: try:
node.preprocess()
multiChunks = len(node.chunks) > 1 multiChunks = len(node.chunks) > 1
for c, chunk in enumerate(node.chunks): for c, chunk in enumerate(node.chunks):
if multiChunks: if multiChunks:
@ -1635,6 +1644,7 @@ def executeGraph(graph, toNodes=None, forceCompute=False, forceStatus=False):
print('\n[{node}/{nbNodes}] {nodeName}'.format( print('\n[{node}/{nbNodes}] {nodeName}'.format(
node=n + 1, nbNodes=len(nodes), nodeName=node.nodeType)) node=n + 1, nbNodes=len(nodes), nodeName=node.nodeType))
chunk.process(forceCompute) chunk.process(forceCompute)
node.postprocess()
except Exception as e: except Exception as e:
logging.error("Error on node computation: {}".format(e)) logging.error("Error on node computation: {}".format(e))
graph.clearSubmittedNodes() graph.clearSubmittedNodes()

View file

@ -686,6 +686,10 @@ class BaseNode(BaseObject):
def minDepth(self): def minDepth(self):
return self.graph.getDepth(self, minimal=True) return self.graph.getDepth(self, minimal=True)
@property
def valuesFile(self):
return os.path.join(self.graph.cacheDir, self.internalFolder, 'values')
def getInputNodes(self, recursive, dependenciesOnly): def getInputNodes(self, recursive, dependenciesOnly):
return self.graph.getInputNodes(self, recursive=recursive, dependenciesOnly=dependenciesOnly) return self.graph.getInputNodes(self, recursive=recursive, dependenciesOnly=dependenciesOnly)
@ -955,6 +959,7 @@ class BaseNode(BaseObject):
} }
self._computeUids() self._computeUids()
self._buildCmdVars() self._buildCmdVars()
self.updateOutputAttr()
if self.nodeDesc: if self.nodeDesc:
self.nodeDesc.postUpdate(self) self.nodeDesc.postUpdate(self)
# Notify internal folder change if needed # Notify internal folder change if needed
@ -972,8 +977,11 @@ class BaseNode(BaseObject):
""" """
Update node status based on status file content/existence. Update node status based on status file content/existence.
""" """
s = self.globalStatus
for chunk in self._chunks: for chunk in self._chunks:
chunk.updateStatusFromCache() chunk.updateStatusFromCache()
# logging.warning("updateStatusFromCache: {}, status: {} => {}".format(self.name, s, self.globalStatus))
self.updateOutputAttr()
def submit(self, forceCompute=False): def submit(self, forceCompute=False):
for chunk in self._chunks: for chunk in self._chunks:
@ -988,10 +996,71 @@ class BaseNode(BaseObject):
def processIteration(self, iteration): def processIteration(self, iteration):
self._chunks[iteration].process() self._chunks[iteration].process()
def preprocess(self):
pass
def process(self, forceCompute=False): def process(self, forceCompute=False):
for chunk in self._chunks: for chunk in self._chunks:
chunk.process(forceCompute) chunk.process(forceCompute)
def postprocess(self):
self.saveOutputAttr()
def updateOutputAttr(self):
if not self.nodeDesc:
return
if not self.nodeDesc.hasDynamicOutputAttribute:
return
# logging.warning("updateOutputAttr: {}, status: {}".format(self.name, self.globalStatus))
if self.getGlobalStatus() == Status.SUCCESS:
self.loadOutputAttr()
else:
self.resetOutputAttr()
def resetOutputAttr(self):
if not self.nodeDesc.hasDynamicOutputAttribute:
return
# logging.warning("resetOutputAttr: {}".format(self.name))
for output in self.nodeDesc.outputs:
if output.isDynamicValue:
self.attribute(output.name).value = None
def loadOutputAttr(self):
""" Load output attributes with dynamic values from a values.json file.
"""
if not self.nodeDesc.hasDynamicOutputAttribute:
return
valuesFile = self.valuesFile
if not os.path.exists(valuesFile):
logging.warning("No output attr file: {}".format(valuesFile))
return
# logging.warning("load output attr: {}, value: {}".format(self.name, valuesFile))
with open(valuesFile, 'r') as jsonFile:
data = json.load(jsonFile)
# logging.warning(data)
for output in self.nodeDesc.outputs:
if output.isDynamicValue:
self.attribute(output.name).value = data[output.name]
def saveOutputAttr(self):
""" Save output attributes with dynamic values into a values.json file.
"""
if not self.nodeDesc.hasDynamicOutputAttribute:
return
data = {}
for output in self.nodeDesc.outputs:
if output.isDynamicValue:
data[output.name] = self.attribute(output.name).value
valuesFile = self.valuesFile
# logging.warning("save output attr: {}, value: {}".format(self.name, valuesFile))
valuesFilepathWriting = getWritingFilepath(valuesFile)
with open(valuesFilepathWriting, 'w') as jsonFile:
json.dump(data, jsonFile, indent=4)
renameWritingToFinalPath(valuesFilepathWriting, valuesFile)
def endSequence(self): def endSequence(self):
pass pass
@ -1227,6 +1296,7 @@ class BaseNode(BaseObject):
comment = Property(str, getComment, notify=internalAttributesChanged) comment = Property(str, getComment, notify=internalAttributesChanged)
internalFolderChanged = Signal() internalFolderChanged = Signal()
internalFolder = Property(str, internalFolder.fget, notify=internalFolderChanged) internalFolder = Property(str, internalFolder.fget, notify=internalFolderChanged)
valuesFile = Property(str, valuesFile.fget, notify=internalFolderChanged)
depthChanged = Signal() depthChanged = Signal()
depth = Property(int, depth.fget, notify=depthChanged) depth = Property(int, depth.fget, notify=depthChanged)
minDepth = Property(int, minDepth.fget, notify=depthChanged) minDepth = Property(int, minDepth.fget, notify=depthChanged)
@ -1281,12 +1351,19 @@ class Node(BaseNode):
for attrDesc in self.nodeDesc.internalInputs: for attrDesc in self.nodeDesc.internalInputs:
self._internalAttributes.add(attributeFactory(attrDesc, kwargs.get(attrDesc.name, None), isOutput=False, node=self)) self._internalAttributes.add(attributeFactory(attrDesc, kwargs.get(attrDesc.name, None), isOutput=False, node=self))
# List attributes per uid # Declare events for specific output attributes
for attr in self._attributes: for attr in self._attributes:
if attr.isOutput and attr.desc.semantic == "image": if attr.isOutput and attr.desc.semantic == "image":
attr.enabledChanged.connect(self.outputAttrEnabledChanged) attr.enabledChanged.connect(self.outputAttrEnabledChanged)
# List attributes per uid
for attr in self._attributes:
if attr.isInput:
for uidIndex in attr.attributeDesc.uid: for uidIndex in attr.attributeDesc.uid:
self.attributesPerUid[uidIndex].add(attr) self.attributesPerUid[uidIndex].add(attr)
else:
if attr.attributeDesc.uid:
logging.error(f"Output Attribute should not contain a UID: '{nodeType}.{attr.name}'")
# Add internal attributes with a UID to the list # Add internal attributes with a UID to the list
for attr in self._internalAttributes: for attr in self._internalAttributes:
@ -1347,7 +1424,7 @@ class Node(BaseNode):
def toDict(self): def toDict(self):
inputs = {k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isInput} inputs = {k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isInput}
internalInputs = {k: v.getExportValue() for k, v in self._internalAttributes.objects.items()} internalInputs = {k: v.getExportValue() for k, v in self._internalAttributes.objects.items()}
outputs = ({k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isOutput}) outputs = ({k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isOutput and not v.desc.isDynamicValue})
return { return {
'nodeType': self.nodeType, 'nodeType': self.nodeType,
@ -1714,7 +1791,7 @@ def nodeFactory(nodeDict, name=None, template=False, uidConflict=False):
# raising compatibility issues if their number differs: in that case, it is only useful # raising compatibility issues if their number differs: in that case, it is only useful
# if some internal attributes do not exist or are invalid # if some internal attributes do not exist or are invalid
if not template and (sorted([attr.name for attr in nodeDesc.inputs if not isinstance(attr, desc.PushButtonParam)]) != sorted(inputs.keys()) or \ if not template and (sorted([attr.name for attr in nodeDesc.inputs if not isinstance(attr, desc.PushButtonParam)]) != sorted(inputs.keys()) or \
sorted([attr.name for attr in nodeDesc.outputs]) != sorted(outputs.keys())): sorted([attr.name for attr in nodeDesc.outputs if not attr.isDynamicValue]) != sorted(outputs.keys())):
compatibilityIssue = CompatibilityIssue.DescriptionConflict compatibilityIssue = CompatibilityIssue.DescriptionConflict
# check whether there are any internal attributes that are invalidating in the node description: if there # check whether there are any internal attributes that are invalidating in the node description: if there

View file

@ -50,6 +50,7 @@ class TaskThread(Thread):
except TypeError: except TypeError:
continue continue
node.preprocess()
for cId, chunk in enumerate(node.chunks): for cId, chunk in enumerate(node.chunks):
if chunk.isFinishedOrRunning() or not self.isRunning(): if chunk.isFinishedOrRunning() or not self.isRunning():
continue continue
@ -78,6 +79,7 @@ class TaskThread(Thread):
# Node already removed (for instance a global clear of _nodesToProcess) # Node already removed (for instance a global clear of _nodesToProcess)
pass pass
n.clearSubmittedChunks() n.clearSubmittedChunks()
node.postprocess()
if stopAndRestart: if stopAndRestart:
break break

View file

@ -1,5 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
# coding:utf-8 # coding:utf-8
from collections.abc import Iterable
import logging import logging
import os import os
import json import json
@ -625,6 +626,8 @@ class UIGraph(QObject):
def filterNodes(self, nodes): def filterNodes(self, nodes):
"""Filter out the nodes that do not exist on the graph.""" """Filter out the nodes that do not exist on the graph."""
if not isinstance(nodes, Iterable):
nodes = [nodes]
return [ n for n in nodes if n in self._graph.nodes.values() ] return [ n for n in nodes if n in self._graph.nodes.values() ]
@Slot(Node, QPoint, QObject) @Slot(Node, QPoint, QObject)
@ -698,7 +701,7 @@ class UIGraph(QObject):
with self.groupedGraphModification("Node duplication", disableUpdates=True): with self.groupedGraphModification("Node duplication", disableUpdates=True):
duplicates = self.push(commands.DuplicateNodesCommand(self._graph, nodes)) duplicates = self.push(commands.DuplicateNodesCommand(self._graph, nodes))
# move nodes below the bounding box formed by the duplicated node(s) # move nodes below the bounding box formed by the duplicated node(s)
bbox = self._layout.boundingBox(duplicates) bbox = self._layout.boundingBox(nodes)
for n in duplicates: for n in duplicates:
idx = duplicates.index(n) idx = duplicates.index(n)