mirror of
https://github.com/alicevision/Meshroom.git
synced 2025-04-28 09:47:20 +02:00
[core] Refactor nodeFactory
function
Rewrite `nodeFactory` to reduce cognitive complexity, while preserving the current behavior.
This commit is contained in:
parent
75db9dc16c
commit
c883c53397
3 changed files with 175 additions and 93 deletions
|
@ -339,7 +339,7 @@ class Graph(BaseObject):
|
|||
if isTemplate and not publishOutputs and nodeData["nodeType"] == "Publish":
|
||||
continue
|
||||
|
||||
n = nodeFactory(nodeData, nodeName, template=isTemplate)
|
||||
n = nodeFactory(nodeData, nodeName, inTemplate=isTemplate)
|
||||
|
||||
# Add node to the graph with raw attributes values
|
||||
self._addNode(n, nodeName)
|
||||
|
@ -392,14 +392,14 @@ class Graph(BaseObject):
|
|||
# Different UIDs, remove the existing node from the graph and replace it with a CompatibilityNode
|
||||
logging.debug("UID conflict detected for {}".format(nodeName))
|
||||
self.removeNode(nodeName)
|
||||
n = nodeFactory(nodeData, nodeName, template=False, uidConflict=True)
|
||||
n = nodeFactory(nodeData, nodeName, expectedUid=graphUid)
|
||||
self._addNode(n, nodeName)
|
||||
else:
|
||||
# f connecting nodes have UID conflicts and are removed/re-added to the graph, some edges may be lost:
|
||||
# the links will be erroneously updated, and any further resolution will fail.
|
||||
# Recreating the entire graph as it was ensures that all edges will be correctly preserved.
|
||||
self.removeNode(nodeName)
|
||||
n = nodeFactory(nodeData, nodeName, template=False, uidConflict=False)
|
||||
n = nodeFactory(nodeData, nodeName)
|
||||
self._addNode(n, nodeName)
|
||||
|
||||
def updateImportedProject(self, data):
|
||||
|
|
|
@ -1,116 +1,197 @@
|
|||
import logging
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
import meshroom.core
|
||||
from meshroom.core import Version, desc
|
||||
from meshroom.core.node import CompatibilityIssue, CompatibilityNode, Node, Position
|
||||
|
||||
|
||||
def nodeFactory(nodeDict, name=None, template=False, uidConflict=False):
|
||||
def nodeFactory(
|
||||
nodeData: dict,
|
||||
name: Optional[str] = None,
|
||||
inTemplate: bool = False,
|
||||
expectedUid: Optional[str] = None,
|
||||
) -> Union[Node, CompatibilityNode]:
|
||||
"""
|
||||
Create a node instance by deserializing the given node data.
|
||||
If the serialized data matches the corresponding node type description, a Node instance is created.
|
||||
If any compatibility issue occurs, a NodeCompatibility instance is created instead.
|
||||
|
||||
Args:
|
||||
nodeDict (dict): the serialization of the node
|
||||
name (str): (optional) the node's name
|
||||
template (bool): (optional) true if the node is part of a template, false otherwise
|
||||
uidConflict (bool): (optional) true if a UID conflict has been detected externally on that node
|
||||
nodeDict: The serialized Node data.
|
||||
name: (optional) The node's name.
|
||||
inTemplate: (optional) True if the node is created as part of a graph template.
|
||||
expectedUid: (optional) The expected UID of the node within the context of a Graph.
|
||||
|
||||
Returns:
|
||||
BaseNode: the created node
|
||||
The created Node instance.
|
||||
"""
|
||||
nodeType = nodeDict["nodeType"]
|
||||
return _NodeCreator(nodeData, name, inTemplate, expectedUid).create()
|
||||
|
||||
# Retro-compatibility: inputs were previously saved as "attributes"
|
||||
if "inputs" not in nodeDict and "attributes" in nodeDict:
|
||||
nodeDict["inputs"] = nodeDict["attributes"]
|
||||
del nodeDict["attributes"]
|
||||
|
||||
# Get node inputs/outputs
|
||||
inputs = nodeDict.get("inputs", {})
|
||||
internalInputs = nodeDict.get("internalInputs", {})
|
||||
outputs = nodeDict.get("outputs", {})
|
||||
version = nodeDict.get("version", None)
|
||||
internalFolder = nodeDict.get("internalFolder", None)
|
||||
position = Position(*nodeDict.get("position", []))
|
||||
uid = nodeDict.get("uid", None)
|
||||
class _NodeCreator:
|
||||
|
||||
compatibilityIssue = None
|
||||
def __init__(
|
||||
self,
|
||||
nodeData: dict,
|
||||
name: Optional[str] = None,
|
||||
inTemplate: bool = False,
|
||||
expectedUid: Optional[str] = None,
|
||||
):
|
||||
self.nodeData = nodeData
|
||||
self.name = name
|
||||
self.inTemplate = inTemplate
|
||||
self.expectedUid = expectedUid
|
||||
|
||||
nodeDesc = None
|
||||
try:
|
||||
nodeDesc = meshroom.core.nodesDesc[nodeType]
|
||||
except KeyError:
|
||||
# Unknown node type
|
||||
compatibilityIssue = CompatibilityIssue.UnknownNodeType
|
||||
self._normalizeNodeData()
|
||||
|
||||
# Unknown node type should take precedence over UID conflict, as it cannot be resolved
|
||||
if uidConflict and nodeDesc:
|
||||
compatibilityIssue = CompatibilityIssue.UidConflict
|
||||
self.nodeType = self.nodeData["nodeType"]
|
||||
self.inputs = self.nodeData.get("inputs", {})
|
||||
self.internalInputs = self.nodeData.get("internalInputs", {})
|
||||
self.outputs = self.nodeData.get("outputs", {})
|
||||
self.version = self.nodeData.get("version", None)
|
||||
self.internalFolder = self.nodeData.get("internalFolder")
|
||||
self.position = Position(*self.nodeData.get("position", []))
|
||||
self.uid = self.nodeData.get("uid", None)
|
||||
self.nodeDesc = meshroom.core.nodesDesc.get(self.nodeType, None)
|
||||
|
||||
if nodeDesc and not uidConflict: # if uidConflict, there is no need to look for another compatibility issue
|
||||
# Compare serialized node version with current node version
|
||||
currentNodeVersion = meshroom.core.nodeVersion(nodeDesc)
|
||||
# If both versions are available, check for incompatibility in major version
|
||||
if version and currentNodeVersion and Version(version).major != Version(currentNodeVersion).major:
|
||||
compatibilityIssue = CompatibilityIssue.VersionConflict
|
||||
# In other cases, check attributes compatibility between serialized node and its description
|
||||
def create(self) -> Union[Node, CompatibilityNode]:
|
||||
compatibilityIssue = self._checkCompatibilityIssues()
|
||||
if compatibilityIssue:
|
||||
node = self._createCompatibilityNode(compatibilityIssue)
|
||||
node = self._tryUpgradeCompatibilityNode(node)
|
||||
else:
|
||||
# Check that the node has the exact same set of inputs/outputs as its description, except
|
||||
# if the node is described in a template file, in which only non-default parameters are saved;
|
||||
# do not perform that check for internal attributes because there is no point in
|
||||
# raising compatibility issues if their number differs: in that case, it is only useful
|
||||
# if some internal attributes do not exist or are invalid
|
||||
if not template and (sorted([attr.name for attr in nodeDesc.inputs
|
||||
if not isinstance(attr, desc.PushButtonParam)]) != sorted(inputs.keys()) or
|
||||
sorted([attr.name for attr in nodeDesc.outputs if not attr.isDynamicValue]) !=
|
||||
sorted(outputs.keys())):
|
||||
compatibilityIssue = CompatibilityIssue.DescriptionConflict
|
||||
node = self._createNode()
|
||||
return node
|
||||
|
||||
# Check whether there are any internal attributes that are invalidating in the node description: if there
|
||||
# are, then check that these internal attributes are part of nodeDict; if they are not, a compatibility
|
||||
# issue must be raised to warn the user, as this will automatically change the node's UID
|
||||
if not template:
|
||||
invalidatingIntInputs = []
|
||||
for attr in nodeDesc.internalInputs:
|
||||
if attr.invalidate:
|
||||
invalidatingIntInputs.append(attr.name)
|
||||
for attr in invalidatingIntInputs:
|
||||
if attr not in internalInputs.keys():
|
||||
compatibilityIssue = CompatibilityIssue.DescriptionConflict
|
||||
break
|
||||
def _normalizeNodeData(self):
|
||||
"""Consistency fixes for backward compatibility with older serialized data."""
|
||||
# Inputs were previously saved as "attributes".
|
||||
if "inputs" not in self.nodeData and "attributes" in self.nodeData:
|
||||
self.nodeData["inputs"] = self.nodeData["attributes"]
|
||||
del self.nodeData["attributes"]
|
||||
|
||||
# Verify that all inputs match their descriptions
|
||||
for attrName, value in inputs.items():
|
||||
if not CompatibilityNode.attributeDescFromName(nodeDesc.inputs, attrName, value):
|
||||
compatibilityIssue = CompatibilityIssue.DescriptionConflict
|
||||
break
|
||||
# Verify that all internal inputs match their description
|
||||
for attrName, value in internalInputs.items():
|
||||
if not CompatibilityNode.attributeDescFromName(nodeDesc.internalInputs, attrName, value):
|
||||
compatibilityIssue = CompatibilityIssue.DescriptionConflict
|
||||
break
|
||||
# Verify that all outputs match their descriptions
|
||||
for attrName, value in outputs.items():
|
||||
if not CompatibilityNode.attributeDescFromName(nodeDesc.outputs, attrName, value):
|
||||
compatibilityIssue = CompatibilityIssue.DescriptionConflict
|
||||
break
|
||||
def _checkCompatibilityIssues(self) -> Optional[CompatibilityIssue]:
|
||||
if self.nodeDesc is None:
|
||||
return CompatibilityIssue.UnknownNodeType
|
||||
|
||||
if compatibilityIssue is None:
|
||||
node = Node(nodeType, position, uid=uid, **inputs, **internalInputs, **outputs)
|
||||
else:
|
||||
logging.debug("Compatibility issue detected for node '{}': {}".format(name, compatibilityIssue.name))
|
||||
node = CompatibilityNode(nodeType, nodeDict, position, compatibilityIssue)
|
||||
# Retro-compatibility: no internal folder saved
|
||||
# can't spawn meaningful CompatibilityNode with precomputed outputs
|
||||
# => automatically try to perform node upgrade
|
||||
if not internalFolder and nodeDesc:
|
||||
logging.warning("No serialized output data: performing automatic upgrade on '{}'".format(name))
|
||||
node = node.upgrade()
|
||||
# If the node comes from a template file and there is a conflict, it should be upgraded anyway unless it is
|
||||
# an "unknown node type" conflict (in which case the upgrade would fail)
|
||||
elif template and compatibilityIssue is not CompatibilityIssue.UnknownNodeType:
|
||||
node = node.upgrade()
|
||||
if not self._checkUidCompatibility():
|
||||
return CompatibilityIssue.UidConflict
|
||||
|
||||
return node
|
||||
if not self._checkVersionCompatibility():
|
||||
return CompatibilityIssue.VersionConflict
|
||||
|
||||
if not self._checkDescriptionCompatibility():
|
||||
return CompatibilityIssue.DescriptionConflict
|
||||
|
||||
return None
|
||||
|
||||
def _checkUidCompatibility(self) -> bool:
|
||||
return self.expectedUid is None or self.expectedUid == self.uid
|
||||
|
||||
def _checkVersionCompatibility(self) -> bool:
|
||||
# Special case: a node with a version set to None indicates
|
||||
# that it has been created from the current version of the node type.
|
||||
nodeCreatedFromCurrentVersion = self.version is None
|
||||
if nodeCreatedFromCurrentVersion:
|
||||
return True
|
||||
nodeTypeCurrentVersion = meshroom.core.nodeVersion(self.nodeDesc, "0.0")
|
||||
return Version(self.version).major == Version(nodeTypeCurrentVersion).major
|
||||
|
||||
def _checkDescriptionCompatibility(self) -> bool:
|
||||
# Only perform strict attribute name matching for non-template graphs,
|
||||
# since only non-default-value input attributes are serialized in templates.
|
||||
if not self.inTemplate:
|
||||
if not self._checkAttributesNamesMatchDescription():
|
||||
return False
|
||||
|
||||
return self._checkAttributesAreCompatibleWithDescription()
|
||||
|
||||
def _checkAttributesNamesMatchDescription(self) -> bool:
|
||||
return (
|
||||
self._checkInputAttributesNames()
|
||||
and self._checkOutputAttributesNames()
|
||||
and self._checkInternalAttributesNames()
|
||||
)
|
||||
|
||||
def _checkAttributesAreCompatibleWithDescription(self) -> bool:
|
||||
return (
|
||||
self._checkAttributesCompatibility(self.nodeDesc.inputs, self.inputs)
|
||||
and self._checkAttributesCompatibility(self.nodeDesc.internalInputs, self.internalInputs)
|
||||
and self._checkAttributesCompatibility(self.nodeDesc.outputs, self.outputs)
|
||||
)
|
||||
|
||||
def _checkInputAttributesNames(self) -> bool:
|
||||
def serializedInput(attr: desc.Attribute) -> bool:
|
||||
"""Filter that excludes not-serialized desc input attributes."""
|
||||
if isinstance(attr, desc.PushButtonParam):
|
||||
# PushButtonParam are not serialized has they do not hold a value.
|
||||
return False
|
||||
return True
|
||||
|
||||
refAttributes = filter(serializedInput, self.nodeDesc.inputs)
|
||||
return self._checkAttributesNamesStrictlyMatch(refAttributes, self.inputs)
|
||||
|
||||
def _checkOutputAttributesNames(self) -> bool:
|
||||
def serializedOutput(attr: desc.Attribute) -> bool:
|
||||
"""Filter that excludes not-serialized desc output attributes."""
|
||||
if attr.isDynamicValue:
|
||||
# Dynamic outputs values are not serialized with the node,
|
||||
# as their value is written in the computed output data.
|
||||
return False
|
||||
return True
|
||||
|
||||
refAttributes = filter(serializedOutput, self.nodeDesc.outputs)
|
||||
return self._checkAttributesNamesStrictlyMatch(refAttributes, self.outputs)
|
||||
|
||||
def _checkInternalAttributesNames(self) -> bool:
|
||||
invalidatingDescAttributes = [attr.name for attr in self.nodeDesc.internalInputs if attr.invalidate]
|
||||
return all(attr in self.internalInputs.keys() for attr in invalidatingDescAttributes)
|
||||
|
||||
def _checkAttributesNamesStrictlyMatch(
|
||||
self, descAttributes: Iterable[desc.Attribute], attributesDict: dict[str, Any]
|
||||
) -> bool:
|
||||
refNames = sorted([attr.name for attr in descAttributes])
|
||||
attrNames = sorted(attributesDict.keys())
|
||||
return refNames == attrNames
|
||||
|
||||
def _checkAttributesCompatibility(
|
||||
self, descAttributes: list[desc.Attribute], attributesDict: dict[str, Any]
|
||||
) -> bool:
|
||||
return all(
|
||||
CompatibilityNode.attributeDescFromName(descAttributes, attrName, value) is not None
|
||||
for attrName, value in attributesDict.items()
|
||||
)
|
||||
|
||||
def _createNode(self) -> Node:
|
||||
logging.info(f"Creating node '{self.name}'")
|
||||
return Node(
|
||||
self.nodeType,
|
||||
position=self.position,
|
||||
uid=self.uid,
|
||||
**self.inputs,
|
||||
**self.internalInputs,
|
||||
**self.outputs,
|
||||
)
|
||||
|
||||
def _createCompatibilityNode(self, compatibilityIssue) -> CompatibilityNode:
|
||||
logging.warning(f"Compatibility issue detected for node '{self.name}': {compatibilityIssue.name}")
|
||||
return CompatibilityNode(
|
||||
self.nodeType, self.nodeData, position=self.position, issue=compatibilityIssue
|
||||
)
|
||||
|
||||
def _tryUpgradeCompatibilityNode(self, node: CompatibilityNode) -> Union[Node, CompatibilityNode]:
|
||||
"""Handle possible upgrades of CompatibilityNodes, when no computed data is associated to the Node."""
|
||||
if node.issue == CompatibilityIssue.UnknownNodeType:
|
||||
return node
|
||||
|
||||
# Nodes in templates are not meant to hold computation data.
|
||||
if self.inTemplate:
|
||||
logging.warning(f"Compatibility issue in template: performing automatic upgrade on '{self.name}'")
|
||||
return node.upgrade()
|
||||
|
||||
# Backward compatibility: "internalFolder" was not serialized.
|
||||
if not self.internalFolder:
|
||||
logging.warning(f"No serialized output data: performing automatic upgrade on '{self.name}'")
|
||||
|
||||
return node
|
||||
|
|
|
@ -432,11 +432,12 @@ class UpgradeNodeCommand(GraphCommand):
|
|||
|
||||
def undoImpl(self):
|
||||
# delete upgraded node
|
||||
expectedUid = self.graph.node(self.nodeName)._uid
|
||||
self.graph.removeNode(self.nodeName)
|
||||
# recreate compatibility node
|
||||
with GraphModification(self.graph):
|
||||
# We come back from an upgrade, so we enforce uidConflict=True as there was a uid conflict before
|
||||
node = nodeFactory(self.nodeDict, name=self.nodeName, uidConflict=True)
|
||||
node = nodeFactory(self.nodeDict, name=self.nodeName, expectedUid=expectedUid)
|
||||
self.graph.addNode(node, self.nodeName)
|
||||
# recreate out edges
|
||||
for dstAttr, srcAttr in self.outEdges.items():
|
||||
|
|
Loading…
Add table
Reference in a new issue