mirror of
https://github.com/alicevision/Meshroom.git
synced 2025-05-28 16:36:32 +02:00
[core] introduce CompatibilityNode for improved scene compatibilities
Improve node serialization/deserialization to be able to recreate the exact same node in the graph when loading a meshroom project, even if the corresponding node's description has changed or does not exist anymore. This allows to recover already computed data on disk, without being impacted by changed uids. CompatibilityNode also provides an on-demand upgrade system to turn into a Node that meets the current node description (if possible). * new abstract class BaseNode, base class for Node and CompatibiliyNode * Node: serialize everything needed to spawn a CompatibilityNode with precomputed outputs: inputs, uids, parallelization settings, unresolved internal folders and outputs * node_factory: handles node deserialization and compatibility issues to create either a Node or a CompatibilityNode * add compatibility unit tests
This commit is contained in:
parent
b6cbb0cc63
commit
33eb7f3a7f
6 changed files with 458 additions and 101 deletions
|
@ -16,7 +16,15 @@ class UnknownNodeTypeError(GraphException):
|
|||
"""
|
||||
Raised when asked to create a unknown node type.
|
||||
"""
|
||||
def __init__(self, nodeType):
|
||||
def __init__(self, nodeType, msg=None):
|
||||
msg = "Unknown Node Type: " + nodeType
|
||||
super(UnknownNodeTypeError, self).__init__(msg)
|
||||
self.nodeType = nodeType
|
||||
|
||||
|
||||
class NodeUpgradeError(GraphException):
|
||||
def __init__(self, nodeName, details=None):
|
||||
msg = "Failed to upgrade node {}".format(nodeName)
|
||||
if details:
|
||||
msg += ": {}".format(details)
|
||||
super(NodeUpgradeError, self).__init__(msg)
|
||||
|
|
|
@ -14,7 +14,8 @@ import meshroom
|
|||
import meshroom.core
|
||||
from meshroom.common import BaseObject, DictModel, Slot, Signal, Property
|
||||
from meshroom.core.attribute import Attribute
|
||||
from meshroom.core.node import node_factory, Status
|
||||
from meshroom.core.exception import UnknownNodeTypeError
|
||||
from meshroom.core.node import node_factory, Status, Node, CompatibilityNode
|
||||
|
||||
# Replace default encoder to support Enums
|
||||
|
||||
|
@ -207,10 +208,13 @@ class Graph(BaseObject):
|
|||
if not isinstance(nodeData, dict):
|
||||
raise RuntimeError('loadGraph error: Node is not a dict. File: {}'.format(filepath))
|
||||
|
||||
n = node_factory(nodeData['nodeType'],
|
||||
# allow simple retro-compatibility, though cache might get invalidated
|
||||
skipInvalidAttributes=True,
|
||||
**nodeData['attributes'])
|
||||
# retrieve version from
|
||||
# 1. nodeData: node saved from a CompatibilityNode
|
||||
# 2. nodesVersion in file header: node saved from a Node
|
||||
# 3. fallback to no version "0.0": retro-compatibility
|
||||
if "version" not in nodeData:
|
||||
nodeData["version"] = nodesVersions.get(nodeData["nodeType"], "0.0")
|
||||
n = node_factory(nodeData, nodeName)
|
||||
|
||||
# Add node to the graph with raw attributes values
|
||||
self._addNode(n, nodeName)
|
||||
|
@ -314,7 +318,7 @@ class Graph(BaseObject):
|
|||
if name and name in self._nodes.keys():
|
||||
name = self._createUniqueNodeName(name)
|
||||
|
||||
n = self.addNode(node_factory(nodeType, False, **kwargs), uniqueName=name)
|
||||
n = self.addNode(Node(nodeType, **kwargs), uniqueName=name)
|
||||
n.updateInternals()
|
||||
return n
|
||||
|
||||
|
@ -329,6 +333,29 @@ class Graph(BaseObject):
|
|||
def node(self, nodeName):
|
||||
return self._nodes.get(nodeName)
|
||||
|
||||
def upgradeNode(self, nodeName):
|
||||
"""
|
||||
Upgrade the CompatibilityNode identified as 'nodeName'
|
||||
Args:
|
||||
nodeName (str): the name of the CompatibilityNode to upgrade
|
||||
|
||||
Returns:
|
||||
the list of deleted input/output edges
|
||||
"""
|
||||
node = self.node(nodeName)
|
||||
if not isinstance(node, CompatibilityNode):
|
||||
raise ValueError("Upgrade is only available on CompatibilityNode instances.")
|
||||
upgradedNode = node.upgrade()
|
||||
inEdges, outEdges = self.removeNode(nodeName)
|
||||
self.addNode(upgradedNode, nodeName)
|
||||
for dst, src in outEdges.items():
|
||||
try:
|
||||
self.addEdge(self.attribute(src), self.attribute(dst))
|
||||
except (KeyError, ValueError) as e:
|
||||
logging.warning("Failed to restore edge {} -> {}: {}".format(src, dst, str(e)))
|
||||
|
||||
return inEdges, outEdges
|
||||
|
||||
@Slot(str, result=Attribute)
|
||||
def attribute(self, fullName):
|
||||
# type: (str) -> Attribute
|
||||
|
@ -670,7 +697,8 @@ class Graph(BaseObject):
|
|||
self.header[Graph.IO.ReleaseVersion] = meshroom.__version__
|
||||
self.header[Graph.IO.FileVersion] = Graph.IO.__version__
|
||||
|
||||
usedNodeTypes = set([n.nodeDesc.__class__ for n in self._nodes])
|
||||
# store versions of node types present in the graph (excluding CompatibilityNode instances)
|
||||
usedNodeTypes = set([n.nodeDesc.__class__ for n in self._nodes if isinstance(n, Node)])
|
||||
|
||||
self.header[Graph.IO.NodesVersions] = {
|
||||
"{}".format(p.__name__): meshroom.core.nodeVersion(p, "0.0")
|
||||
|
|
|
@ -9,15 +9,15 @@ import re
|
|||
import shutil
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABCMeta
|
||||
from collections import defaultdict
|
||||
|
||||
from enum import Enum
|
||||
|
||||
import meshroom
|
||||
from meshroom.common import Signal, Variant, Property, BaseObject, Slot, ListModel, DictModel
|
||||
from meshroom.core import desc, stats, hashValue
|
||||
from meshroom.core import desc, stats, hashValue, pyCompatibility
|
||||
from meshroom.core.attribute import attribute_factory, ListAttribute, GroupAttribute, Attribute
|
||||
from meshroom.core.exception import UnknownNodeTypeError
|
||||
from meshroom.core.exception import NodeUpgradeError, UnknownNodeTypeError
|
||||
|
||||
|
||||
class Status(Enum):
|
||||
|
@ -282,15 +282,17 @@ class NodeChunk(BaseObject):
|
|||
statisticsFile = Property(str, statisticsFile.fget, notify=nodeFolderChanged)
|
||||
|
||||
|
||||
class Node(BaseObject):
|
||||
class BaseNode(BaseObject):
|
||||
"""
|
||||
Base Abstract class for Graph nodes.
|
||||
"""
|
||||
__metaclass__ = ABCMeta
|
||||
|
||||
# Regexp handling complex attribute names with recursive understanding of Lists and Groups
|
||||
# i.e: a.b, a[0], a[0].b.c[1]
|
||||
attributeRE = re.compile(r'\.?(?P<name>\w+)(?:\[(?P<index>\d+)\])?')
|
||||
|
||||
def __init__(self, nodeDesc, parent=None, **kwargs):
|
||||
def __init__(self, nodeType, parent=None, **kwargs):
|
||||
"""
|
||||
Create a new Node instance based on the given node description.
|
||||
Any other keyword argument will be used to initialize this node's attributes.
|
||||
|
@ -300,10 +302,16 @@ class Node(BaseObject):
|
|||
parent (BaseObject): this Node's parent
|
||||
**kwargs: attributes values
|
||||
"""
|
||||
super(Node, self).__init__(parent)
|
||||
self.nodeDesc = nodeDesc
|
||||
self.packageName = self.nodeDesc.packageName
|
||||
self.packageVersion = self.nodeDesc.packageVersion
|
||||
super(BaseNode, self).__init__(parent)
|
||||
self._nodeType = nodeType
|
||||
self.nodeDesc = None
|
||||
|
||||
# instantiate node description if nodeType is valid
|
||||
if nodeType in meshroom.core.nodesDesc:
|
||||
self.nodeDesc = meshroom.core.nodesDesc[nodeType]()
|
||||
|
||||
self.packageName = self.packageVersion = ""
|
||||
self._internalFolder = ""
|
||||
|
||||
self._name = None # type: str
|
||||
self.graph = None # type: Graph
|
||||
|
@ -314,26 +322,21 @@ class Node(BaseObject):
|
|||
self._size = 0
|
||||
self._attributes = DictModel(keyAttrName='name', parent=self)
|
||||
self.attributesPerUid = defaultdict(set)
|
||||
self._initFromDesc()
|
||||
for k, v in kwargs.items():
|
||||
self.attribute(k).value = v
|
||||
self._updateChunks()
|
||||
|
||||
def __getattr__(self, k):
|
||||
try:
|
||||
# Throws exception if not in prototype chain
|
||||
# return object.__getattribute__(self, k) # doesn't work in python2
|
||||
return object.__getattr__(self, k)
|
||||
except AttributeError:
|
||||
except AttributeError as e:
|
||||
try:
|
||||
return self.attribute(k)
|
||||
except KeyError:
|
||||
raise AttributeError(k)
|
||||
raise e
|
||||
|
||||
def getName(self):
|
||||
return self._name
|
||||
|
||||
|
||||
@property
|
||||
def packageFullName(self):
|
||||
return '-'.join([self.packageName, self.packageVersion])
|
||||
|
@ -364,29 +367,13 @@ class Node(BaseObject):
|
|||
def getAttributes(self):
|
||||
return self._attributes
|
||||
|
||||
def _initFromDesc(self):
|
||||
# Init from class and instance members
|
||||
|
||||
for attrDesc in self.nodeDesc.inputs:
|
||||
assert isinstance(attrDesc, meshroom.core.desc.Attribute)
|
||||
self._attributes.add(attribute_factory(attrDesc, None, False, self))
|
||||
|
||||
for attrDesc in self.nodeDesc.outputs:
|
||||
assert isinstance(attrDesc, meshroom.core.desc.Attribute)
|
||||
self._attributes.add(attribute_factory(attrDesc, None, True, self))
|
||||
|
||||
# List attributes per uid
|
||||
for attr in self._attributes:
|
||||
for uidIndex in attr.attributeDesc.uid:
|
||||
self.attributesPerUid[uidIndex].add(attr)
|
||||
|
||||
def _applyExpr(self):
|
||||
for attr in self._attributes:
|
||||
attr._applyExpr()
|
||||
|
||||
@property
|
||||
def nodeType(self):
|
||||
return self.nodeDesc.__class__.__name__
|
||||
return self._nodeType
|
||||
|
||||
@property
|
||||
def depth(self):
|
||||
|
@ -397,13 +384,7 @@ class Node(BaseObject):
|
|||
return self.graph.getDepth(self, minimal=True)
|
||||
|
||||
def toDict(self):
|
||||
attributes = {k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isInput}
|
||||
return {
|
||||
'nodeType': self.nodeType,
|
||||
'packageName': self.packageName,
|
||||
'packageVersion': self.packageVersion,
|
||||
'attributes': {k: v for k, v in attributes.items() if v is not None}, # filter empty values
|
||||
}
|
||||
pass
|
||||
|
||||
def _computeUids(self):
|
||||
""" Compute node uids by combining associated attributes' uids. """
|
||||
|
@ -471,7 +452,7 @@ class Node(BaseObject):
|
|||
""" Delete this Node internal folder.
|
||||
Status will be reset to Status.NONE
|
||||
"""
|
||||
if os.path.exists(self.internalFolder):
|
||||
if self.internalFolder and os.path.exists(self.internalFolder):
|
||||
shutil.rmtree(self.internalFolder)
|
||||
self.updateStatusFromCache()
|
||||
|
||||
|
@ -508,25 +489,7 @@ class Node(BaseObject):
|
|||
chunk.updateStatisticsFromCache()
|
||||
|
||||
def _updateChunks(self):
|
||||
""" Update Node's computation task splitting into NodeChunks based on its description """
|
||||
self.setSize(self.nodeDesc.size.computeSize(self))
|
||||
if self.isParallelized:
|
||||
try:
|
||||
ranges = self.nodeDesc.parallelization.getRanges(self)
|
||||
if len(ranges) != len(self._chunks):
|
||||
self._chunks.setObjectList([NodeChunk(self, range) for range in ranges])
|
||||
else:
|
||||
for chunk, range in zip(self._chunks, ranges):
|
||||
chunk.range = range
|
||||
except RuntimeError:
|
||||
# TODO: set node internal status to error
|
||||
logging.warning("Invalid Parallelization on node {}".format(self._name))
|
||||
self._chunks.clear()
|
||||
else:
|
||||
if len(self._chunks) != 1:
|
||||
self._chunks.setObjectList([NodeChunk(self, desc.Range())])
|
||||
else:
|
||||
self._chunks[0].range = desc.Range()
|
||||
pass
|
||||
|
||||
def updateInternals(self, cacheDir=None):
|
||||
""" Update Node's internal parameters and output attributes.
|
||||
|
@ -558,7 +521,7 @@ class Node(BaseObject):
|
|||
|
||||
@property
|
||||
def internalFolder(self):
|
||||
return self.nodeDesc.internalFolder.format(**self._cmdVars)
|
||||
return self._internalFolder.format(**self._cmdVars)
|
||||
|
||||
def updateStatusFromCache(self):
|
||||
"""
|
||||
|
@ -623,42 +586,286 @@ class Node(BaseObject):
|
|||
size = Property(int, getSize, notify=sizeChanged)
|
||||
|
||||
|
||||
def node_factory(nodeType, skipInvalidAttributes=False, **attributes):
|
||||
class Node(BaseNode):
|
||||
"""
|
||||
Create a new Node of type NodeType and initialize its attributes with given kwargs.
|
||||
|
||||
Args:
|
||||
nodeType (str): name of the node description class
|
||||
skipInvalidAttributes (bool): whether to skip attributes not defined in
|
||||
or incompatible with nodeType's description.
|
||||
attributes (): serialized nodes attributes
|
||||
|
||||
Raises:
|
||||
UnknownNodeTypeError if nodeType is unknown
|
||||
A standard Graph node based on a node type.
|
||||
"""
|
||||
try:
|
||||
nodeDesc = meshroom.core.nodesDesc[nodeType]()
|
||||
except KeyError:
|
||||
# unknown node type
|
||||
def __init__(self, nodeType, parent=None, **kwargs):
|
||||
super(Node, self).__init__(nodeType, parent, **kwargs)
|
||||
|
||||
if not self.nodeDesc:
|
||||
raise UnknownNodeTypeError(nodeType)
|
||||
|
||||
if skipInvalidAttributes:
|
||||
# compare given attributes with the ones from node desc
|
||||
descAttrNames = set([attr.name for attr in nodeDesc.inputs])
|
||||
attrNames = set([name for name in attributes.keys()])
|
||||
invalidAttributes = list(attrNames.difference(descAttrNames))
|
||||
commonAttributes = list(attrNames.intersection(descAttrNames))
|
||||
# compare value types for common attributes
|
||||
for attr in [attr for attr in nodeDesc.inputs if attr.name in commonAttributes]:
|
||||
self.packageName = self.nodeDesc.packageName
|
||||
self.packageVersion = self.nodeDesc.packageVersion
|
||||
self._internalFolder = self.nodeDesc.internalFolder
|
||||
|
||||
for attrDesc in self.nodeDesc.inputs:
|
||||
self._attributes.add(attribute_factory(attrDesc, None, False, self))
|
||||
|
||||
for attrDesc in self.nodeDesc.outputs:
|
||||
self._attributes.add(attribute_factory(attrDesc, None, True, self))
|
||||
|
||||
# List attributes per uid
|
||||
for attr in self._attributes:
|
||||
for uidIndex in attr.attributeDesc.uid:
|
||||
self.attributesPerUid[uidIndex].add(attr)
|
||||
|
||||
# initialize attribute values
|
||||
for k, v in kwargs.items():
|
||||
attr = self.attribute(k)
|
||||
if attr.isInput:
|
||||
self.attribute(k).value = v
|
||||
|
||||
def toDict(self):
|
||||
inputs = {k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isInput}
|
||||
outputs = ({k: v.getExportValue() for k, v in self._attributes.objects.items() if v.isOutput})
|
||||
|
||||
return {
|
||||
'nodeType': self.nodeType,
|
||||
'parallelization': {
|
||||
'blockSize': self.nodeDesc.parallelization.blockSize if self.isParallelized else 0,
|
||||
'size': self.size,
|
||||
'split': self.nbParallelizationBlocks
|
||||
},
|
||||
'uids': self._uids,
|
||||
'internalFolder': self._internalFolder,
|
||||
'attributes': {k: v for k, v in inputs.items() if v is not None}, # filter empty values
|
||||
'outputs': outputs,
|
||||
}
|
||||
|
||||
def _updateChunks(self):
|
||||
""" Update Node's computation task splitting into NodeChunks based on its description """
|
||||
self.setSize(self.nodeDesc.size.computeSize(self))
|
||||
if self.isParallelized:
|
||||
try:
|
||||
attr.validateValue(attributes[attr.name])
|
||||
except:
|
||||
invalidAttributes.append(attr.name)
|
||||
ranges = self.nodeDesc.parallelization.getRanges(self)
|
||||
if len(ranges) != len(self._chunks):
|
||||
self._chunks.setObjectList([NodeChunk(self, range) for range in ranges])
|
||||
else:
|
||||
for chunk, range in zip(self._chunks, ranges):
|
||||
chunk.range = range
|
||||
except RuntimeError:
|
||||
# TODO: set node internal status to error
|
||||
logging.warning("Invalid Parallelization on node {}".format(self._name))
|
||||
self._chunks.clear()
|
||||
else:
|
||||
if len(self._chunks) != 1:
|
||||
self._chunks.setObjectList([NodeChunk(self, desc.Range())])
|
||||
else:
|
||||
self._chunks[0].range = desc.Range()
|
||||
|
||||
if invalidAttributes and skipInvalidAttributes:
|
||||
# filter out invalid attributes
|
||||
logging.info("Skipping invalid attributes initialization for {}: {}".format(nodeType, invalidAttributes))
|
||||
for attr in invalidAttributes:
|
||||
del attributes[attr]
|
||||
|
||||
return Node(nodeDesc, **attributes)
|
||||
class CompatibilityIssue(Enum):
|
||||
"""
|
||||
Enum describing compatibility issues when deserializing a Node.
|
||||
"""
|
||||
UnknownIssue = 0 # unknown issue fallback
|
||||
UnknownNodeType = 1 # the node type has no corresponding description class
|
||||
VersionConflict = 2 # mismatch between node's description version and serialized node data
|
||||
DescriptionConflict = 3 # mismatch between node's description attributes and serialized node data
|
||||
UidConflict = 4 # mismatch between computed uids and uids stored in serialized node data
|
||||
|
||||
|
||||
class CompatibilityNode(BaseNode):
|
||||
"""
|
||||
Fallback BaseNode subclass to instantiate Nodes having compatibility issues with current type description.
|
||||
CompatibilityNode creates an 'empty-shell' exposing the deserialized node as-is,
|
||||
with all its inputs and precomputed outputs.
|
||||
"""
|
||||
def __init__(self, nodeType, nodeDict, issue=CompatibilityIssue.UnknownIssue, parent=None):
|
||||
super(CompatibilityNode, self).__init__(nodeType, parent)
|
||||
|
||||
self.issue = issue
|
||||
self.nodeDict = nodeDict
|
||||
|
||||
self.inputs = nodeDict.get("inputs", {})
|
||||
self.outputs = nodeDict.get("outputs", {})
|
||||
self._internalFolder = self.nodeDict.get("internalFolder", "")
|
||||
self._uids = self.nodeDict.get("uids", {})
|
||||
|
||||
# restore parallelization settings
|
||||
self.parallelization = self.nodeDict.get("parallelization", {})
|
||||
self.splitCount = self.parallelization.get("split", 1)
|
||||
self.setSize(self.parallelization.get("size", 1))
|
||||
|
||||
# inputs matching current type description
|
||||
self._commonInputs = []
|
||||
# create input attributes
|
||||
for attrName, value in self.inputs.items():
|
||||
matchDesc = self._addAttribute(attrName, value, False)
|
||||
# store attributes that could be used during node upgrade
|
||||
if matchDesc:
|
||||
self._commonInputs.append(attrName)
|
||||
# create outputs attributes
|
||||
for attrName, value in self.outputs.items():
|
||||
self._addAttribute(attrName, value, True)
|
||||
|
||||
# create NodeChunks matching serialized parallelization settings
|
||||
self._chunks.setObjectList([
|
||||
NodeChunk(self, desc.Range(i, blockSize=self.parallelization.get("blockSize", 0)))
|
||||
for i in range(self.splitCount)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def attributeDescFromValue(attrName, value, isOutput):
|
||||
"""
|
||||
Generate an attribute description (desc.Attribute) that best matches 'value'.
|
||||
|
||||
Args:
|
||||
attrName (str): the name of the attribute
|
||||
value: the value of the attribute
|
||||
isOutput (bool): whether the attribute is an output
|
||||
|
||||
Returns:
|
||||
desc.Attribute: the generated attribute description
|
||||
"""
|
||||
params = {
|
||||
"name": attrName, "label": attrName,
|
||||
"description": "Incompatible parameter",
|
||||
"value": value, "uid": (),
|
||||
"group": "incompatible"
|
||||
}
|
||||
if isinstance(value, bool):
|
||||
return desc.BoolParam(**params)
|
||||
if isinstance(value, int):
|
||||
return desc.IntParam(range=None, **params)
|
||||
elif isinstance(value, float):
|
||||
return desc.FloatParam(range=None, **params)
|
||||
elif isinstance(value, pyCompatibility.basestring):
|
||||
if isOutput or os.path.isabs(value) or Attribute.isLinkExpression(value):
|
||||
return desc.File(**params)
|
||||
else:
|
||||
return desc.StringParam(**params)
|
||||
# handle any other type of parameters (List/Group) as Strings
|
||||
return desc.StringParam(**params)
|
||||
|
||||
@staticmethod
|
||||
def attributeDescFromName(refAttributes, name, value):
|
||||
"""
|
||||
Try to find a matching attribute description in refAttributes for given attribute 'name' and 'value'.
|
||||
|
||||
Args:
|
||||
refAttributes ([Attribute]): reference Attributes to look for a description
|
||||
name (str): attribute's name
|
||||
value: attribute's value
|
||||
|
||||
Returns:
|
||||
desc.Attribute: an attribute description from refAttributes if a match is found, None otherwise.
|
||||
"""
|
||||
# from original node description based on attribute's name
|
||||
attrDesc = next((d for d in refAttributes if d.name == name), None)
|
||||
if attrDesc:
|
||||
# ensure value is valid for this description
|
||||
try:
|
||||
attrDesc.validateValue(value)
|
||||
except ValueError:
|
||||
attrDesc = None
|
||||
return attrDesc
|
||||
|
||||
def _addAttribute(self, name, val, isOutput):
|
||||
"""
|
||||
Add a new attribute on this node.
|
||||
|
||||
Args:
|
||||
name (str): the name of the attribute
|
||||
val: the attribute's value
|
||||
isOutput: whether the attribute is an output
|
||||
|
||||
Returns:
|
||||
bool: whether the attribute exists in the node description
|
||||
"""
|
||||
attrDesc = None
|
||||
if self.nodeDesc:
|
||||
refAttrs = self.nodeDesc.outputs if isOutput else self.nodeDesc.inputs
|
||||
attrDesc = CompatibilityNode.attributeDescFromName(refAttrs, name, val)
|
||||
matchDesc = attrDesc is not None
|
||||
if not matchDesc:
|
||||
attrDesc = CompatibilityNode.attributeDescFromValue(name, val, isOutput)
|
||||
attribute = attribute_factory(attrDesc, val, isOutput, self)
|
||||
self._attributes.add(attribute)
|
||||
return matchDesc
|
||||
|
||||
def toDict(self):
|
||||
"""
|
||||
Return the original serialized node that generated a compatibility issue.
|
||||
"""
|
||||
return self.nodeDict
|
||||
|
||||
@property
|
||||
def canUpgrade(self):
|
||||
""" Return whether the node can be upgraded.
|
||||
This is the case when the underlying node type has a corresponding description. """
|
||||
return self.nodeDesc is not None
|
||||
|
||||
def upgrade(self):
|
||||
"""
|
||||
Return a new Node instance based on original node type with common inputs initialized.
|
||||
"""
|
||||
if not self.canUpgrade:
|
||||
raise NodeUpgradeError(self.name, "no matching node type")
|
||||
# TODO: use upgrade method of node description if available
|
||||
return Node(self.nodeType, **{key: value for key, value in self.inputs.items() if key in self._commonInputs})
|
||||
|
||||
|
||||
def node_factory(nodeDict, name=None):
|
||||
"""
|
||||
Create a node instance by deserializing the given node data.
|
||||
If the serialized data matches the corresponding node type description, a Node instance is created.
|
||||
If any compatibility issue occurs, a NodeCompatibility instance is created instead.
|
||||
|
||||
Args:
|
||||
nodeDict (dict): the serialization of the node
|
||||
name (str): (optional) the node's name
|
||||
|
||||
Returns:
|
||||
BaseNode: the created node
|
||||
"""
|
||||
nodeType = nodeDict["nodeType"]
|
||||
# get node inputs/outputs
|
||||
if "inputs" not in nodeDict:
|
||||
# retro-compatibility: inputs were previously saved as "attributes"
|
||||
nodeDict["inputs"] = nodeDict.get("attributes", {})
|
||||
|
||||
inputs = nodeDict.get("inputs", {})
|
||||
outputs = nodeDict.get("outputs", {})
|
||||
version = nodeDict.get("version", None)
|
||||
internalFolder = nodeDict.get("internalFolder", None)
|
||||
|
||||
compatibilityIssue = None
|
||||
|
||||
nodeDesc = None
|
||||
try:
|
||||
nodeDesc = meshroom.core.nodesDesc[nodeType]
|
||||
except KeyError:
|
||||
# unknown node type
|
||||
compatibilityIssue = CompatibilityIssue.UnknownNodeType
|
||||
|
||||
if nodeDesc:
|
||||
# compare serialized node version with current node version
|
||||
currentNodeVersion = meshroom.core.nodeVersion(nodeDesc)
|
||||
# if both versions are available, check for incompatibility in major version
|
||||
if version and currentNodeVersion and version.split('.')[0] != currentNodeVersion.split('.')[0]:
|
||||
compatibilityIssue = CompatibilityIssue.VersionConflict
|
||||
# in other cases, check attributes compatibility between serialized node and its description
|
||||
else:
|
||||
descAttrNames = set([attr.name for attr in nodeDesc.inputs + nodeDesc.outputs])
|
||||
attrNames = set([name for name in list(inputs.keys()) + list(outputs.keys())])
|
||||
if attrNames != descAttrNames:
|
||||
compatibilityIssue = CompatibilityIssue.DescriptionConflict
|
||||
|
||||
# no compatibility issues: instantiate a Node
|
||||
if compatibilityIssue is None:
|
||||
n = Node(nodeType, **inputs)
|
||||
# otherwise, instantiate a CompatibilityNode
|
||||
else:
|
||||
logging.warning("Compatibility issue detected for node '{}': {}".format(name, compatibilityIssue.name))
|
||||
n = CompatibilityNode(nodeType, nodeDict, compatibilityIssue)
|
||||
# retro-compatibility: no internal folder saved
|
||||
# can't spawn meaningful CompatibilityNode with precomputed outputs
|
||||
# => automatically try to perform node upgrade
|
||||
if not internalFolder and nodeDesc:
|
||||
logging.warning("No serialized output data: performing automatic upgrade on '{}'".format(name))
|
||||
n = n.upgrade()
|
||||
|
||||
return n
|
||||
|
|
|
@ -7,6 +7,7 @@ from PySide2.QtCore import Property, Signal
|
|||
|
||||
from meshroom.core.attribute import ListAttribute, Attribute
|
||||
from meshroom.core.graph import GraphModification
|
||||
from meshroom.core.node import node_factory
|
||||
|
||||
|
||||
class UndoCommand(QUndoCommand):
|
||||
|
@ -125,8 +126,8 @@ class RemoveNodeCommand(GraphCommand):
|
|||
|
||||
def undoImpl(self):
|
||||
with GraphModification(self.graph):
|
||||
node = self.graph.addNewNode(nodeType=self.nodeDict["nodeType"],
|
||||
name=self.nodeName, **self.nodeDict["attributes"])
|
||||
node = node_factory(self.nodeDict, self.nodeName)
|
||||
self.graph.addNode(node, self.nodeName)
|
||||
assert (node.getName() == self.nodeName)
|
||||
# recreate out edges deleted on node removal
|
||||
for dstAttr, srcAttr in self.outEdges.items():
|
||||
|
|
|
@ -356,8 +356,7 @@ class Reconstruction(UIGraph):
|
|||
# If cameraInit is None (i.e: SfM augmentation):
|
||||
# * create an uninitialized node
|
||||
# * wait for the result before actually creating new nodes in the graph (see onIntrinsicsAvailable)
|
||||
attributes = cameraInit.toDict()["attributes"] if cameraInit else {}
|
||||
cameraInitCopy = node_factory("CameraInit", **attributes)
|
||||
cameraInitCopy = node_factory(cameraInit.toDict())
|
||||
|
||||
try:
|
||||
self.setBuildingIntrinsics(True)
|
||||
|
|
114
tests/test_compatibility.py
Normal file
114
tests/test_compatibility.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
#!/usr/bin/env python
|
||||
# coding:utf-8
|
||||
import tempfile
|
||||
|
||||
import os
|
||||
import pytest
|
||||
|
||||
import meshroom.core
|
||||
from meshroom.core import desc, registerNodeType, unregisterNodeType
|
||||
from meshroom.core.exception import NodeUpgradeError
|
||||
from meshroom.core.graph import Graph, loadGraph
|
||||
from meshroom.core.node import CompatibilityNode, CompatibilityIssue, Node
|
||||
|
||||
|
||||
class SampleNodeV1(desc.Node):
|
||||
""" Version 1 Sample Node """
|
||||
inputs = [
|
||||
desc.File(name='input', label='Input', description='', value='', uid=[0],),
|
||||
desc.StringParam(name='paramA', label='ParamA', description='', value='', uid=[]) # No impact on UID
|
||||
]
|
||||
outputs = [
|
||||
desc.File(name='output', label='Output', description='', value=desc.Node.internalFolder, uid=[])
|
||||
]
|
||||
|
||||
|
||||
class SampleNodeV2(desc.Node):
|
||||
""" Changes from V1:
|
||||
* 'input' has been renamed to 'in'
|
||||
"""
|
||||
inputs = [
|
||||
desc.File(name='in', label='Input', description='', value='', uid=[0],),
|
||||
desc.StringParam(name='paramA', label='ParamA', description='', value='', uid=[]) # No impact on UID
|
||||
]
|
||||
outputs = [
|
||||
desc.File(name='output', label='Output', description='', value=desc.Node.internalFolder, uid=[])
|
||||
]
|
||||
|
||||
|
||||
def test_unknown_node_type():
|
||||
"""
|
||||
Test compatibility behavior for unknown node type.
|
||||
"""
|
||||
registerNodeType(SampleNodeV1)
|
||||
g = Graph('')
|
||||
n = g.addNewNode("SampleNodeV1", input="/dev/null", paramA="foo")
|
||||
graphFile = os.path.join(tempfile.mkdtemp(), "test_unknown_node_type.mg")
|
||||
g.save(graphFile)
|
||||
internalFolder = n.internalFolder
|
||||
nodeName = n.name
|
||||
unregisterNodeType(SampleNodeV1)
|
||||
|
||||
# reload file
|
||||
g = loadGraph(graphFile)
|
||||
os.remove(graphFile)
|
||||
|
||||
assert len(g.nodes) == 1
|
||||
n = g.node(nodeName)
|
||||
# SampleNodeV1 is now an unknown type
|
||||
# check node instance type and compatibility issue type
|
||||
assert isinstance(n, CompatibilityNode)
|
||||
assert n.issue == CompatibilityIssue.UnknownNodeType
|
||||
# check if attributes are properly restored
|
||||
assert len(n.attributes) == 3
|
||||
assert n.input.isInput
|
||||
assert n.output.isOutput
|
||||
# check if internal folder
|
||||
assert n.internalFolder == internalFolder
|
||||
|
||||
# upgrade can't be perform on unknown node types
|
||||
assert not n.canUpgrade
|
||||
with pytest.raises(NodeUpgradeError):
|
||||
g.upgradeNode(nodeName)
|
||||
|
||||
|
||||
def test_description_conflict():
|
||||
"""
|
||||
Test compatibility behavior for conflicting node descriptions.
|
||||
"""
|
||||
registerNodeType(SampleNodeV1)
|
||||
|
||||
g = Graph('')
|
||||
n = g.addNewNode("SampleNodeV1")
|
||||
graphFile = os.path.join(tempfile.mkdtemp(), "test_description_conflict.mg")
|
||||
g.save(graphFile)
|
||||
internalFolder = n.internalFolder
|
||||
nodeName = n.name
|
||||
|
||||
# replace SampleNodeV1 by SampleNodeV2
|
||||
# 'SampleNodeV1' is still registered but implementation has changed
|
||||
meshroom.core.nodesDesc[SampleNodeV1.__name__] = SampleNodeV2
|
||||
|
||||
# reload file
|
||||
g = loadGraph(graphFile)
|
||||
os.remove(graphFile)
|
||||
|
||||
assert len(g.nodes) == 1
|
||||
n = g.node(nodeName)
|
||||
# Node description clashes between what has been saved
|
||||
assert isinstance(n, CompatibilityNode)
|
||||
assert n.issue == CompatibilityIssue.DescriptionConflict
|
||||
assert len(n.attributes) == 3
|
||||
assert hasattr(n, "input")
|
||||
assert not hasattr(n, "in")
|
||||
assert n.internalFolder == internalFolder
|
||||
|
||||
# perform upgrade
|
||||
g.upgradeNode(nodeName)
|
||||
n = g.node(nodeName)
|
||||
|
||||
assert isinstance(n, Node)
|
||||
assert not hasattr(n, "input")
|
||||
assert hasattr(n, "in")
|
||||
# check uid has changed (not the same set of attributes)
|
||||
assert n.internalFolder != internalFolder
|
Loading…
Add table
Add a link
Reference in a new issue