[core] New dynamic output attributes

This commit is contained in:
Fabien Castan 2024-06-10 19:56:18 +02:00
parent 9a09310f07
commit d5e356c0aa
6 changed files with 139 additions and 14 deletions

View file

@ -128,6 +128,8 @@ def validateNodeDesc(nodeDesc):
errors.append(err)
for param in nodeDesc.outputs:
if param.value is None:
continue
err = param.checkValueTypes()
if err:
errors.append(err)

View file

@ -32,6 +32,7 @@ class Attribute(BaseObject):
self._errorMessage = errorMessage
self._visible = visible
self._isExpression = (isinstance(self._value, str) and "{" in self._value) or isinstance(self._value, types.FunctionType)
self._isDynamicValue = (self._value is None)
self._valueType = None
name = Property(str, lambda self: self._name, constant=True)
@ -39,9 +40,13 @@ class Attribute(BaseObject):
description = Property(str, lambda self: self._description, constant=True)
value = Property(Variant, lambda self: self._value, constant=True)
# isExpression:
# The value of the attribute's descriptor is a static string expression that should be evaluated at runtime.
# The default value of the attribute's descriptor is a static string expression that should be evaluated at runtime.
# This property only makes sense for output attributes.
isExpression = Property(bool, lambda self: self._isExpression, constant=True)
# isDynamicValue
# The default value of the attribute's descriptor is None, so it's not an input value,
# but an output value that is computed during the Node's process execution.
isDynamicValue = Property(bool, lambda self: self._isDynamicValue, constant=True)
uid = Property(Variant, lambda self: self._uid, constant=True)
group = Property(str, lambda self: self._group, constant=True)
advanced = Property(bool, lambda self: self._advanced, constant=True)
@ -99,6 +104,8 @@ class ListAttribute(Attribute):
joinChar = Property(str, lambda self: self._joinChar, constant=True)
def validateValue(self, value):
if value is None:
return value
if JSValue is not None and isinstance(value, JSValue):
# Note: we could use isArray(), property("length").toInt() to retrieve all values
raise ValueError("ListAttribute.validateValue: cannot recognize QJSValue. Please, use JSON.stringify(value) in QML.")
@ -138,6 +145,8 @@ class GroupAttribute(Attribute):
groupDesc = Property(Variant, lambda self: self._groupDesc, constant=True)
def validateValue(self, value):
if value is None:
return value
""" Ensure value is compatible with the group description and convert value if needed. """
if JSValue is not None and isinstance(value, JSValue):
# Note: we could use isArray(), property("length").toInt() to retrieve all values
@ -232,6 +241,8 @@ class File(Attribute):
self._valueType = str
def validateValue(self, value):
if value is None:
return value
if not isinstance(value, str):
raise ValueError('File only supports string input (param:{}, value:{}, type:{})'.format(self.name, value, type(value)))
return os.path.normpath(value).replace('\\', '/') if value else ''
@ -252,6 +263,8 @@ class BoolParam(Param):
self._valueType = bool
def validateValue(self, value):
if value is None:
return value
try:
if isinstance(value, str):
# use distutils.util.strtobool to handle (1/0, true/false, on/off, y/n)
@ -276,6 +289,8 @@ class IntParam(Param):
self._valueType = int
def validateValue(self, value):
if value is None:
return value
# handle unsigned int values that are translated to int by shiboken and may overflow
try:
return int(value)
@ -300,6 +315,8 @@ class FloatParam(Param):
self._valueType = float
def validateValue(self, value):
if value is None:
return value
try:
return float(value)
except:
@ -320,7 +337,7 @@ class PushButtonParam(Param):
self._valueType = None
def validateValue(self, value):
pass
return value
def checkValueTypes(self):
pass
@ -354,6 +371,8 @@ class ChoiceParam(Param):
return self._valueType(value)
def validateValue(self, value):
if value is None:
return value
if self.exclusive:
return self.conformValue(value)
@ -383,6 +402,8 @@ class StringParam(Param):
self._valueType = str
def validateValue(self, value):
if value is None:
return value
if not isinstance(value, str):
raise ValueError('StringParam value should be a string (param:{}, value:{}, type:{})'.format(self.name, value, type(value)))
return value
@ -401,6 +422,8 @@ class ColorParam(Param):
self._valueType = str
def validateValue(self, value):
if value is None:
return value
if not isinstance(value, str) or len(value.split(" ")) > 1:
raise ValueError('ColorParam value should be a string containing either an SVG name or an hexadecimal '
'color code (param: {}, value: {}, type: {})'.format(self.name, value, type(value)))
@ -594,7 +617,8 @@ class Node(object):
category = 'Other'
def __init__(self):
pass
super(Node, self).__init__()
self.hasDynamicOutputAttribute = any(output.isDynamicValue for output in self.outputs)
def upgradeAttributeValues(self, attrValues, fromVersion):
return attrValues
@ -630,6 +654,9 @@ class InputNode(Node):
"""
Node that does not need to be processed, it is just a placeholder for inputs.
"""
def __init__(self):
super(InputNode, self).__init__()
def processChunk(self, chunk):
pass
@ -641,6 +668,9 @@ class CommandLineNode(Node):
parallelization = None
commandLineRange = ''
def __init__(self):
super(CommandLineNode, self).__init__()
def buildCommandLine(self, chunk):
cmdPrefix = ''
@ -708,6 +738,7 @@ class AVCommandLineNode(CommandLineNode):
cmdCore = ''
def __init__(self):
super(AVCommandLineNode, self).__init__()
if AVCommandLineNode.cgroupParsed is False:
@ -730,9 +761,9 @@ class AVCommandLineNode(CommandLineNode):
return commandLineString + AVCommandLineNode.cmdMem + AVCommandLineNode.cmdCore
# Test abstract node
class InitNode:
class InitNode(object):
def __init__(self):
pass
super(InitNode, self).__init__()
def initialize(self, node, inputs, recursiveInputs):
"""

View file

@ -14,7 +14,7 @@ import meshroom
import meshroom.core
from meshroom.common import BaseObject, DictModel, Slot, Signal, Property
from meshroom.core import Version
from meshroom.core.attribute import Attribute, ListAttribute
from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute
from meshroom.core.exception import StopGraphVisit, StopBranchVisit
from meshroom.core.node import nodeFactory, Status, Node, CompatibilityNode
@ -320,7 +320,10 @@ class Graph(BaseObject):
# that were computed.
if not isTemplate: # UIDs are not stored in templates
self._evaluateUidConflicts(graphData)
try:
self._applyExpr()
except Exception as e:
logging.warning(e)
return True
@ -548,12 +551,16 @@ class Graph(BaseObject):
skippedEdges = {}
if not withEdges:
for n, attr in node.attributes.items():
if attr.isOutput:
# edges are declared in input with an expression linking
# to another param (which could be an output)
continue
# find top-level links
if Attribute.isLinkExpression(attr.value):
skippedEdges[attr] = attr.value
attr.resetToDefaultValue()
# find links in ListAttribute children
elif isinstance(attr, ListAttribute):
elif isinstance(attr, (ListAttribute, GroupAttribute)):
for child in attr.value:
if Attribute.isLinkExpression(child.value):
skippedEdges[child] = child.value
@ -584,6 +591,7 @@ class Graph(BaseObject):
# re-create edges taking into account what has been duplicated
for attr, linkExpression in duplicateEdges.items():
logging.warning("attr={} linkExpression={}".format(attr.fullName, linkExpression))
link = linkExpression[1:-1] # remove starting '{' and trailing '}'
# get source node and attribute name
edgeSrcNodeName, edgeSrcAttrName = link.split(".", 1)
@ -1625,6 +1633,7 @@ def executeGraph(graph, toNodes=None, forceCompute=False, forceStatus=False):
for n, node in enumerate(nodes):
try:
node.preprocess()
multiChunks = len(node.chunks) > 1
for c, chunk in enumerate(node.chunks):
if multiChunks:
@ -1635,6 +1644,7 @@ def executeGraph(graph, toNodes=None, forceCompute=False, forceStatus=False):
print('\n[{node}/{nbNodes}] {nodeName}'.format(
node=n + 1, nbNodes=len(nodes), nodeName=node.nodeType))
chunk.process(forceCompute)
node.postprocess()
except Exception as e:
logging.error("Error on node computation: {}".format(e))
graph.clearSubmittedNodes()

View file

@ -686,6 +686,10 @@ class BaseNode(BaseObject):
def minDepth(self):
return self.graph.getDepth(self, minimal=True)
@property
def valuesFile(self):
return os.path.join(self.graph.cacheDir, self.internalFolder, 'values')
def getInputNodes(self, recursive, dependenciesOnly):
return self.graph.getInputNodes(self, recursive=recursive, dependenciesOnly=dependenciesOnly)
@ -955,6 +959,7 @@ class BaseNode(BaseObject):
}
self._computeUids()
self._buildCmdVars()
self.updateOutputAttr()
if self.nodeDesc:
self.nodeDesc.postUpdate(self)
# Notify internal folder change if needed
@ -972,8 +977,11 @@ class BaseNode(BaseObject):
"""
Update node status based on status file content/existence.
"""
s = self.globalStatus
for chunk in self._chunks:
chunk.updateStatusFromCache()
# logging.warning("updateStatusFromCache: {}, status: {} => {}".format(self.name, s, self.globalStatus))
self.updateOutputAttr()
def submit(self, forceCompute=False):
for chunk in self._chunks:
@ -988,10 +996,71 @@ class BaseNode(BaseObject):
def processIteration(self, iteration):
self._chunks[iteration].process()
def preprocess(self):
pass
def process(self, forceCompute=False):
for chunk in self._chunks:
chunk.process(forceCompute)
def postprocess(self):
self.saveOutputAttr()
def updateOutputAttr(self):
if not self.nodeDesc:
return
if not self.nodeDesc.hasDynamicOutputAttribute:
return
# logging.warning("updateOutputAttr: {}, status: {}".format(self.name, self.globalStatus))
if self.getGlobalStatus() == Status.SUCCESS:
self.loadOutputAttr()
else:
self.resetOutputAttr()
def resetOutputAttr(self):
if not self.nodeDesc.hasDynamicOutputAttribute:
return
# logging.warning("resetOutputAttr: {}".format(self.name))
for output in self.nodeDesc.outputs:
if output.isDynamicValue:
self.attribute(output.name).value = None
def loadOutputAttr(self):
""" Load output attributes with dynamic values from a values.json file.
"""
if not self.nodeDesc.hasDynamicOutputAttribute:
return
valuesFile = self.valuesFile
if not os.path.exists(valuesFile):
logging.warning("No output attr file: {}".format(valuesFile))
return
# logging.warning("load output attr: {}, value: {}".format(self.name, valuesFile))
with open(valuesFile, 'r') as jsonFile:
data = json.load(jsonFile)
# logging.warning(data)
for output in self.nodeDesc.outputs:
if output.isDynamicValue:
self.attribute(output.name).value = data[output.name]
def saveOutputAttr(self):
""" Save output attributes with dynamic values into a values.json file.
"""
if not self.nodeDesc.hasDynamicOutputAttribute:
return
data = {}
for output in self.nodeDesc.outputs:
if output.isDynamicValue:
data[output.name] = self.attribute(output.name).value
valuesFile = self.valuesFile
# logging.warning("save output attr: {}, value: {}".format(self.name, valuesFile))
valuesFilepathWriting = getWritingFilepath(valuesFile)
with open(valuesFilepathWriting, 'w') as jsonFile:
json.dump(data, jsonFile, indent=4)
renameWritingToFinalPath(valuesFilepathWriting, valuesFile)
def endSequence(self):
pass
@ -1227,6 +1296,7 @@ class BaseNode(BaseObject):
comment = Property(str, getComment, notify=internalAttributesChanged)
internalFolderChanged = Signal()
internalFolder = Property(str, internalFolder.fget, notify=internalFolderChanged)
valuesFile = Property(str, valuesFile.fget, notify=internalFolderChanged)
depthChanged = Signal()
depth = Property(int, depth.fget, notify=depthChanged)
minDepth = Property(int, minDepth.fget, notify=depthChanged)
@ -1281,12 +1351,19 @@ class Node(BaseNode):
for attrDesc in self.nodeDesc.internalInputs:
self._internalAttributes.add(attributeFactory(attrDesc, kwargs.get(attrDesc.name, None), isOutput=False, node=self))
# List attributes per uid
# Declare events for specific output attributes
for attr in self._attributes:
if attr.isOutput and attr.desc.semantic == "image":
attr.enabledChanged.connect(self.outputAttrEnabledChanged)
# List attributes per uid
for attr in self._attributes:
if attr.isInput:
for uidIndex in attr.attributeDesc.uid:
self.attributesPerUid[uidIndex].add(attr)
else:
if attr.attributeDesc.uid:
logging.error(f"Output Attribute should not contain a UID: '{nodeType}.{attr.name}'")
# Add internal attributes with a UID to the list
for attr in self._internalAttributes:
@ -1347,7 +1424,7 @@ class Node(BaseNode):
def toDict(self):
inputs = {k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isInput}
internalInputs = {k: v.getExportValue() for k, v in self._internalAttributes.objects.items()}
outputs = ({k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isOutput})
outputs = ({k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isOutput and not v.desc.isDynamicValue})
return {
'nodeType': self.nodeType,
@ -1714,7 +1791,7 @@ def nodeFactory(nodeDict, name=None, template=False, uidConflict=False):
# raising compatibility issues if their number differs: in that case, it is only useful
# if some internal attributes do not exist or are invalid
if not template and (sorted([attr.name for attr in nodeDesc.inputs if not isinstance(attr, desc.PushButtonParam)]) != sorted(inputs.keys()) or \
sorted([attr.name for attr in nodeDesc.outputs]) != sorted(outputs.keys())):
sorted([attr.name for attr in nodeDesc.outputs if not attr.isDynamicValue]) != sorted(outputs.keys())):
compatibilityIssue = CompatibilityIssue.DescriptionConflict
# check whether there are any internal attributes that are invalidating in the node description: if there

View file

@ -50,6 +50,7 @@ class TaskThread(Thread):
except TypeError:
continue
node.preprocess()
for cId, chunk in enumerate(node.chunks):
if chunk.isFinishedOrRunning() or not self.isRunning():
continue
@ -78,6 +79,7 @@ class TaskThread(Thread):
# Node already removed (for instance a global clear of _nodesToProcess)
pass
n.clearSubmittedChunks()
node.postprocess()
if stopAndRestart:
break

View file

@ -1,5 +1,6 @@
#!/usr/bin/env python
# coding:utf-8
from collections.abc import Iterable
import logging
import os
import json
@ -625,6 +626,8 @@ class UIGraph(QObject):
def filterNodes(self, nodes):
"""Filter out the nodes that do not exist on the graph."""
if not isinstance(nodes, Iterable):
nodes = [nodes]
return [ n for n in nodes if n in self._graph.nodes.values() ]
@Slot(Node, QPoint, QObject)
@ -698,7 +701,7 @@ class UIGraph(QObject):
with self.groupedGraphModification("Node duplication", disableUpdates=True):
duplicates = self.push(commands.DuplicateNodesCommand(self._graph, nodes))
# move nodes below the bounding box formed by the duplicated node(s)
bbox = self._layout.boundingBox(duplicates)
bbox = self._layout.boundingBox(nodes)
for n in duplicates:
idx = duplicates.index(n)