[core] Add typing on methods

This commit is contained in:
Candice Bentéjac 2025-05-06 09:55:32 +01:00
parent b56420318d
commit 6f69588f0b
2 changed files with 115 additions and 94 deletions

View file

@ -128,7 +128,7 @@ def loadClasses(folder, packageName, classType):
return classes
def validateNodeDesc(nodeDesc):
def validateNodeDesc(nodeDesc: desc.Node) -> list:
"""
Check that the node has a valid description before being loaded. For the description
to be valid, the default value of every parameter needs to correspond to the type
@ -140,10 +140,10 @@ def validateNodeDesc(nodeDesc):
"group", is invalid, then it will be added to the list as "group:x".
Args:
nodeDesc (desc.Node): description of the node
nodeDesc: description of the node
Returns:
errors (list): the list of invalid parameters if there are any, empty list otherwise
errors: the list of invalid parameters if there are any, empty list otherwise
"""
errors = []
@ -280,7 +280,7 @@ class Version:
return self.components[2]
def moduleVersion(moduleName, default=None):
def moduleVersion(moduleName: str, default=None):
""" Return the version of a module indicated with '__version__' keyword.
Args:
@ -293,7 +293,7 @@ def moduleVersion(moduleName, default=None):
return getattr(sys.modules[moduleName], "__version__", default)
def nodeVersion(nodeDesc, default=None):
def nodeVersion(nodeDesc: desc.Node, default=None):
""" Return node type version for the given node description class.
Args:
@ -306,7 +306,7 @@ def nodeVersion(nodeDesc, default=None):
return moduleVersion(nodeDesc.__module__, default)
def registerNodeType(nodeType):
def registerNodeType(nodeType: desc.Node):
""" Register a Node Type based on a Node Description class.
After registration, nodes of this type can be instantiated in a Graph.
@ -316,7 +316,7 @@ def registerNodeType(nodeType):
nodesDesc[nodeType.__name__] = nodeType
def unregisterNodeType(nodeType):
def unregisterNodeType(nodeType: desc.Node):
""" Remove 'nodeType' from the list of register node types. """
assert nodeType.__name__ in nodesDesc
del nodesDesc[nodeType.__name__]
@ -367,7 +367,7 @@ def loadPluginsFolder(folder):
loadPluginFolder(subFolder)
def registerSubmitter(s):
def registerSubmitter(s: BaseSubmitter):
if s.name in submitters:
logging.error(f"Submitter {s.name} is already registered.")
submitters[s.name] = s

View file

@ -17,15 +17,16 @@ if TYPE_CHECKING:
from meshroom.core.graph import Edge
def attributeFactory(description, value, isOutput, node, root=None, parent=None):
def attributeFactory(description: str, value, isOutput: bool, node, root=None, parent=None):
"""
Create an Attribute based on description type.
Args:
description: the Attribute description
value: value of the Attribute. Will be set if not None.
isOutput: whether is Attribute is an output attribute.
node (Node): node owning the Attribute. Note that the created Attribute is not added to Node's attributes
value: value of the Attribute. Will be set if not None.
isOutput: whether the Attribute is an output attribute.
node (Node): node owning the Attribute. Note that the created Attribute is not added to \
Node's attributes
root: (optional) parent Attribute (must be ListAttribute or GroupAttribute)
parent (BaseObject): (optional) the parent BaseObject if any
"""
@ -53,27 +54,27 @@ class Attribute(BaseObject):
VALID_IMAGE_SEMANTICS = ["image", "imageList", "sequence"]
VALID_3D_EXTENSIONS = [".obj", ".stl", ".fbx", ".gltf", ".abc", ".ply"]
def __init__(self, node, attributeDesc, isOutput, root=None, parent=None):
def __init__(self, node, attributeDesc: desc.Attribute, isOutput: bool, root=None, parent=None):
"""
Attribute constructor
Args:
node (Node): the Node hosting this Attribute
attributeDesc (desc.Attribute): the description of this Attribute
isOutput (bool): whether this Attribute is an output of the Node
attributeDesc: the description of this Attribute
isOutput: whether this Attribute is an output of the Node
root (Attribute): (optional) the root Attribute (List or Group) containing this one
parent (BaseObject): (optional) the parent BaseObject
"""
super().__init__(parent)
self._name = attributeDesc.name
self._name: str = attributeDesc.name
self._root = None if root is None else weakref.ref(root)
self._node = weakref.ref(node)
self.attributeDesc = attributeDesc
self._isOutput = isOutput
self._label = attributeDesc.label
self._enabled = True
self._validValue = True
self._description = attributeDesc.description
self.attributeDesc: desc.Attribute = attributeDesc
self._isOutput: bool = isOutput
self._label: str = attributeDesc.label
self._enabled: bool = True
self._validValue: bool = True
self._description: str = attributeDesc.description
self._invalidate = False if self._isOutput else attributeDesc.invalidate
# invalidation value for output attributes
@ -91,11 +92,11 @@ class Attribute(BaseObject):
def root(self):
return self._root() if self._root else None
def getName(self):
def getName(self) -> str:
""" Attribute name """
return self._name
def getFullName(self):
def getFullName(self) -> str:
""" Name inside the Graph: groupName.name """
if isinstance(self.root, ListAttribute):
return f'{self.root.getFullName()}[{self.root.index(self)}]'
@ -103,53 +104,55 @@ class Attribute(BaseObject):
return f'{self.root.getFullName()}.{self.getName()}'
return self.getName()
def getFullNameToNode(self):
def getFullNameToNode(self) -> str:
""" Name inside the Graph: nodeName.groupName.name """
return f'{self.node.name}.{self.getFullName()}'
def getFullNameToGraph(self):
def getFullNameToGraph(self) -> str:
""" Name inside the Graph: graphName.nodeName.groupName.name """
graphName = self.node.graph.name if self.node.graph else "UNDEFINED"
return f'{graphName}.{self.getFullNameToNode()}'
def asLinkExpr(self):
def asLinkExpr(self) -> str:
""" Return link expression for this Attribute """
return "{" + self.getFullNameToNode() + "}"
def getType(self):
def getType(self) -> str:
return self.attributeDesc.type
def _isReadOnly(self):
def _isReadOnly(self) -> bool:
return not self._isOutput and self.node.isCompatibilityNode
def getBaseType(self):
def getBaseType(self) -> str:
return self.getType()
def getLabel(self):
def getLabel(self) -> str:
return self._label
@Slot(str, result=bool)
def matchText(self, text):
def matchText(self, text: str) -> bool:
return self.fullLabel.lower().find(text.lower()) > -1
def getFullLabel(self):
""" Full Label includes the name of all parent groups, e.g. 'groupLabel subGroupLabel Label' """
def getFullLabel(self) -> str:
"""
Full Label includes the name of all parent groups, e.g. 'groupLabel subGroupLabel Label'.
"""
if isinstance(self.root, ListAttribute):
return self.root.getFullLabel()
elif isinstance(self.root, GroupAttribute):
return f'{self.root.getFullLabel()} {self.getLabel()}'
return self.getLabel()
def getFullLabelToNode(self):
def getFullLabelToNode(self) -> str:
""" Label inside the Graph: nodeLabel groupLabel Label """
return f'{self.node.label} {self.getFullLabel()}'
def getFullLabelToGraph(self):
def getFullLabelToGraph(self) -> str:
""" Label inside the Graph: graphName nodeLabel groupLabel Label """
graphName = self.node.graph.name if self.node.graph else "UNDEFINED"
return f'{graphName} {self.getFullLabelToNode()}'
def getEnabled(self):
def getEnabled(self) -> bool:
if isinstance(self.desc.enabled, types.FunctionType):
try:
return self.desc.enabled(self.node)
@ -205,8 +208,8 @@ class Attribute(BaseObject):
# evaluate the function
self._value = value(self)
else:
# if we set a new value, we use the attribute descriptor validator to check the validity of the value
# and apply some conversion if needed
# if we set a new value, we use the attribute descriptor validator to check the
# validity of the value and apply some conversion if needed
convertedValue = self.validateValue(value)
self._value = convertedValue
@ -266,26 +269,27 @@ class Attribute(BaseObject):
self.node.updateInternalAttributes()
@property
def isOutput(self):
def isOutput(self) -> bool:
return self._isOutput
@property
def isInput(self):
def isInput(self) -> bool:
return not self._isOutput
def uid(self):
def uid(self) -> str:
"""
Compute the UID for the attribute.
"""
if self.isOutput:
if self.desc.isDynamicValue:
# If the attribute is a dynamic output, the UID is derived from the node UID.
# To guarantee that each output attribute receives a unique ID, we add the attribute name to it.
# To guarantee that each output attribute receives a unique ID, we add the attribute
# name to it.
return hashValue((self.name, self.node._uid))
else:
# Only dependent on the hash of its value without the cache folder.
# "/" at the end of the link is stripped to prevent having different UIDs depending on
# whether the invalidation value finishes with it or not
# "/" at the end of the link is stripped to prevent having different UIDs depending
# on whether the invalidation value finishes with it or not
strippedInvalidationValue = self._invalidationValue.rstrip("/")
return hashValue(strippedInvalidationValue)
if self.isLink:
@ -298,13 +302,15 @@ class Attribute(BaseObject):
return hashValue(self._value)
@property
def isLink(self):
def isLink(self) -> bool:
""" Whether the input attribute is a link to another attribute. """
# note: directly use self.node.graph._edges to avoid using the property that may become invalid at some point
return self.node.graph and self.isInput and self.node.graph._edges and self in self.node.graph._edges.keys()
# note: directly use self.node.graph._edges to avoid using the property that may become
# invalid at some point
return self.node.graph and self.isInput and self.node.graph._edges and \
self in self.node.graph._edges.keys()
@staticmethod
def isLinkExpression(value):
def isLinkExpression(value) -> bool:
"""
Return whether the given argument is a link expression.
A link expression is a string matching the {nodeName.attrName} pattern.
@ -322,12 +328,13 @@ class Attribute(BaseObject):
return linkParam
@property
def hasOutputConnections(self):
""" Whether the attribute has output connections, i.e is the source of at least one edge. """
def hasOutputConnections(self) -> bool:
"""
Whether the attribute has output connections, i.e is the source of at least one edge.
"""
# safety check to avoid evaluation errors
if not self.node.graph or not self.node.graph.edges:
return False
return next((edge for edge in self.node.graph.edges.values() if edge.src == self), None) is not None
def getInputConnections(self) -> list["Edge"]:
@ -391,28 +398,29 @@ class Attribute(BaseObject):
return self.value
def getEvalValue(self):
'''
"""
Return the value. If it is a string, expressions will be evaluated.
'''
"""
if isinstance(self.value, str):
substituted = Template(self.value).safe_substitute(os.environ)
try:
varResolved = substituted.format(**self.node._cmdVars)
return varResolved
except (KeyError, IndexError):
# Catch KeyErrors and IndexErros to be able to open files created prior to the support of
# relative variables (when self.node._cmdVars was not used to evaluate expressions in the attribute)
# Catch KeyErrors and IndexErros to be able to open files created prior to the
# support of relative variables (when self.node._cmdVars was not used to evaluate
# expressions in the attribute)
return substituted
return self.value
def getValueStr(self, withQuotes=True):
'''
def getValueStr(self, withQuotes=True) -> str:
"""
Return the value formatted as a string with quotes to deal with spaces.
If it is a string, expressions will be evaluated.
If it is an empty string, it will returns 2 quotes.
If it is an empty list, it will returns a really empty string.
If it is a list with one empty string element, it will returns 2 quotes.
'''
"""
# ChoiceParam with multiple values should be combined
if isinstance(self.attributeDesc, desc.ChoiceParam) and not self.attributeDesc.exclusive:
# Ensure value is a list as expected
@ -421,8 +429,10 @@ class Attribute(BaseObject):
if withQuotes and v:
return f'"{v}"'
return v
# String, File, single value Choice are based on strings and should includes quotes to deal with spaces
if withQuotes and isinstance(self.attributeDesc, (desc.StringParam, desc.File, desc.ChoiceParam)):
# String, File, single value Choice are based on strings and should includes quotes
# to deal with spaces
if withQuotes and \
isinstance(self.attributeDesc, (desc.StringParam, desc.File, desc.ChoiceParam)):
return f'"{self.getEvalValue()}"'
return str(self.getEvalValue())
@ -436,10 +446,11 @@ class Attribute(BaseObject):
logging.warning("Failed to evaluate default value (node lambda) for attribute '{}': {}".
format(self.name, e))
return None
# Need to force a copy, for the case where the value is a list (avoid reference to the desc value)
# Need to force a copy, for the case where the value is a list
# (avoid reference to the desc value)
return copy.copy(self.desc.value)
def _isDefault(self):
def _isDefault(self) -> bool:
return self.value == self.defaultValue()
def getPrimitiveValue(self, exportDefault=True):
@ -510,7 +521,8 @@ class Attribute(BaseObject):
isDefault = Property(bool, _isDefault, notify=valueChanged)
linkParam = Property(BaseObject, getLinkParam, notify=isLinkChanged)
rootLinkParam = Property(BaseObject, lambda self: self.getLinkParam(recursive=True), notify=isLinkChanged)
rootLinkParam = Property(BaseObject, lambda self: self.getLinkParam(recursive=True),
notify=isLinkChanged)
node = Property(BaseObject, node.fget, constant=True)
enabledChanged = Signal()
enabled = Property(bool, getEnabled, setEnabled, notify=enabledChanged)
@ -522,7 +534,7 @@ class Attribute(BaseObject):
def raiseIfLink(func):
""" If Attribute instance is a link, raise a RuntimeError."""
""" If Attribute instance is a link, raise a RuntimeError. """
def wrapper(attr, *args, **kwargs):
if attr.isLink:
raise RuntimeError("Can't modify connected Attribute")
@ -531,7 +543,8 @@ def raiseIfLink(func):
class PushButtonParam(Attribute):
def __init__(self, node, attributeDesc, isOutput, root=None, parent=None):
def __init__(self, node, attributeDesc: desc.PushButtonParam, isOutput: bool,
root=None, parent=None):
super().__init__(node, attributeDesc, isOutput, root, parent)
@Slot()
@ -541,7 +554,8 @@ class PushButtonParam(Attribute):
class ChoiceParam(Attribute):
def __init__(self, node, attributeDesc: desc.ChoiceParam, isOutput, root=None, parent=None):
def __init__(self, node, attributeDesc: desc.ChoiceParam, isOutput: bool,
root=None, parent=None):
super().__init__(node, attributeDesc, isOutput, root, parent)
self._values = None
@ -568,7 +582,7 @@ class ChoiceParam(Attribute):
raise ValueError("Non exclusive ChoiceParam value should be iterable (param:{}, value:{}, type:{})".
format(self.name, value, type(value)))
return [self.conformValue(v) for v in value]
def _set_value(self, value):
# Handle alternative serialization for ChoiceParam with overriden values.
serializedValueWithValuesOverrides = isinstance(value, dict)
@ -585,7 +599,8 @@ class ChoiceParam(Attribute):
self.valuesChanged.emit()
def getExportValue(self):
useStandardSerialization = self.isLink or not self.desc._saveValuesOverride or self._values is None
useStandardSerialization = self.isLink or not self.desc._saveValuesOverride or \
self._values is None
if useStandardSerialization:
return super().getExportValue()
@ -602,7 +617,8 @@ class ChoiceParam(Attribute):
class ListAttribute(Attribute):
def __init__(self, node, attributeDesc, isOutput, root=None, parent=None):
def __init__(self, node, attributeDesc: desc.ListAttribute, isOutput: bool,
root=None, parent=None):
super().__init__(node, attributeDesc, isOutput, root, parent)
def __len__(self):
@ -617,7 +633,7 @@ class ListAttribute(Attribute):
return self.attributeDesc.elementDesc.__class__.__name__
def at(self, idx):
""" Returns child attribute at index 'idx' """
""" Returns child attribute at index 'idx'. """
# Implement 'at' rather than '__getitem__'
# since the later is called spuriously when object is used in QML
return self._value.at(idx)
@ -649,7 +665,8 @@ class ListAttribute(Attribute):
def upgradeValue(self, exportedValues):
if not isinstance(exportedValues, list):
if isinstance(exportedValues, ListAttribute) or Attribute.isLinkExpression(exportedValues):
if isinstance(exportedValues, ListAttribute) or \
Attribute.isLinkExpression(exportedValues):
self._set_value(exportedValues)
return
raise RuntimeError("ListAttribute.upgradeValue: the given value is of type " +
@ -657,7 +674,8 @@ class ListAttribute(Attribute):
attrs = []
for v in exportedValues:
a = attributeFactory(self.attributeDesc.elementDesc, None, self.isOutput, self.node, self)
a = attributeFactory(self.attributeDesc.elementDesc, None, self.isOutput,
self.node, self)
a.upgradeValue(v)
attrs.append(a)
index = len(self._value)
@ -675,7 +693,8 @@ class ListAttribute(Attribute):
if self._value is None:
self._value = ListModel(parent=self)
values = value if isinstance(value, list) else [value]
attrs = [attributeFactory(self.attributeDesc.elementDesc, v, self.isOutput, self.node, self) for v in values]
attrs = [attributeFactory(self.attributeDesc.elementDesc, v, self.isOutput, self.node, self)
for v in values]
self._value.insert(index, attrs)
self.valueChanged.emit()
self._applyExpr()
@ -725,27 +744,28 @@ class ListAttribute(Attribute):
return self.getLinkParam().asLinkExpr()
return [attr.getExportValue() for attr in self._value]
def defaultValue(self):
def defaultValue(self) -> list:
return []
def _isDefault(self):
def _isDefault(self) -> bool:
return len(self._value) == 0
def getPrimitiveValue(self, exportDefault=True):
if exportDefault:
return [attr.getPrimitiveValue(exportDefault=exportDefault) for attr in self._value]
else:
return [attr.getPrimitiveValue(exportDefault=exportDefault) for attr in self._value if not attr.isDefault]
return [attr.getPrimitiveValue(exportDefault=exportDefault) for attr in self._value
if not attr.isDefault]
def getValueStr(self, withQuotes=True):
def getValueStr(self, withQuotes=True) -> str:
assert isinstance(self.value, ListModel)
if self.attributeDesc.joinChar == ' ':
return self.attributeDesc.joinChar.join([v.getValueStr(withQuotes=withQuotes) for v in self.value])
else:
v = self.attributeDesc.joinChar.join([v.getValueStr(withQuotes=False) for v in self.value])
if withQuotes and v:
return f'"{v}"'
return v
return self.attributeDesc.joinChar.join([v.getValueStr(withQuotes=withQuotes)
for v in self.value])
v = self.attributeDesc.joinChar.join([v.getValueStr(withQuotes=False)
for v in self.value])
if withQuotes and v:
return f'"{v}"'
return v
def updateInternals(self):
super().updateInternals()
@ -753,9 +773,10 @@ class ListAttribute(Attribute):
attr.updateInternals()
@property
def isLinkNested(self):
def isLinkNested(self) -> bool:
""" Whether the attribute or any of its elements is a link to another attribute. """
# note: directly use self.node.graph._edges to avoid using the property that may become invalid at some point
# note: directly use self.node.graph._edges to avoid using the property that may become
# invalid at some point
return self.isLink \
or self.node.graph and self.isInput and self.node.graph._edges \
and any(v in self.node.graph._edges.keys() for v in self._value)
@ -799,7 +820,8 @@ class ListAttribute(Attribute):
class GroupAttribute(Attribute):
def __init__(self, node, attributeDesc, isOutput, root=None, parent=None):
def __init__(self, node, attributeDesc: desc.GroupAttribute, isOutput: bool,
root=None, parent=None):
super().__init__(node, attributeDesc, isOutput, root, parent)
def __getattr__(self, key):
@ -854,12 +876,12 @@ class GroupAttribute(Attribute):
self._value.get(attrDesc.name).resetToDefaultValue()
@Slot(str, result=Attribute)
def childAttribute(self, key):
def childAttribute(self, key: str) -> Attribute:
"""
Get child attribute by name or None if none was found.
Args:
key (str): the name of the child attribute
key: the name of the child attribute
Returns:
Attribute: the child attribute or None
@ -892,9 +914,8 @@ class GroupAttribute(Attribute):
def getPrimitiveValue(self, exportDefault=True):
if exportDefault:
return {name: attr.getPrimitiveValue(exportDefault=exportDefault) for name, attr in self._value.items()}
else:
return {name: attr.getPrimitiveValue(exportDefault=exportDefault) for name, attr in self._value.items()
if not attr.isDefault}
return {name: attr.getPrimitiveValue(exportDefault=exportDefault) for name, attr in self._value.items()
if not attr.isDefault}
def getValueStr(self, withQuotes=True):
# add brackets if requested
@ -925,7 +946,7 @@ class GroupAttribute(Attribute):
attr.updateInternals()
@Slot(str, result=bool)
def matchText(self, text):
def matchText(self, text: str) -> bool:
return super().matchText(text) or any(c.matchText(text) for c in self._value)
# Override value property