Merge pull request #2612 from alicevision/dev/graphIO

Refactor Graph de/serialization
This commit is contained in:
Candice Bentéjac 2025-02-12 11:49:19 +01:00 committed by GitHub
commit 91d2530401
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 1538 additions and 768 deletions

View file

@ -154,10 +154,10 @@ with meshroom.core.graph.GraphModification(graph):
# initialize template pipeline
loweredPipelineTemplates = dict((k.lower(), v) for k, v in meshroom.core.pipelineTemplates.items())
if args.pipeline.lower() in loweredPipelineTemplates:
graph.load(loweredPipelineTemplates[args.pipeline.lower()], setupProjectFile=False, publishOutputs=True if args.output else False)
graph.initFromTemplate(loweredPipelineTemplates[args.pipeline.lower()], publishOutputs=True if args.output else False)
else:
# custom pipeline
graph.load(args.pipeline, setupProjectFile=False, publishOutputs=True if args.output else False)
graph.initFromTemplate(args.pipeline, publishOutputs=True if args.output else False)
def parseInputs(inputs, uniqueInitNode):
"""Utility method for parsing the input and inputRecursive arguments."""

View file

@ -339,9 +339,12 @@ class Attribute(BaseObject):
elif self.isInput and Attribute.isLinkExpression(v):
# value is a link to another attribute
link = v[1:-1]
linkNode, linkAttr = link.split('.')
linkNodeName, linkAttrName = link.split('.')
try:
g.addEdge(g.node(linkNode).attribute(linkAttr), self)
node = g.node(linkNodeName)
if not node:
raise KeyError(f"Node '{linkNodeName}' not found")
g.addEdge(node.attribute(linkAttrName), self)
except KeyError as err:
logging.warning('Connect Attribute from Expression failed.')
logging.warning('Expression: "{exp}"\nError: "{err}".'.format(exp=v, err=err))

View file

@ -4,6 +4,7 @@ import json
import logging
import os
import re
from typing import Any, Optional
import weakref
from collections import defaultdict, OrderedDict
from contextlib import contextmanager
@ -16,7 +17,10 @@ from meshroom.common import BaseObject, DictModel, Slot, Signal, Property
from meshroom.core import Version
from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute
from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit
from meshroom.core.node import nodeFactory, Status, Node, CompatibilityNode
from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer, PartialGraphSerializer
from meshroom.core.node import BaseNode, Status, Node, CompatibilityNode
from meshroom.core.nodeFactory import nodeFactory
from meshroom.core.typing import PathLike
# Replace default encoder to support Enums
@ -148,6 +152,21 @@ def changeTopology(func):
return decorator
def blockNodeCallbacks(func):
"""
Graph methods loading serialized graph content must be decorated with 'blockNodeCallbacks',
to avoid attribute changed callbacks defined on node descriptions to be triggered during
this process.
"""
def inner(self, *args, **kwargs):
self._loading = True
try:
return func(self, *args, **kwargs)
finally:
self._loading = False
return inner
class Graph(BaseObject):
"""
_________________ _________________ _________________
@ -165,52 +184,6 @@ class Graph(BaseObject):
"""
_cacheDir = ""
class IO(object):
""" Centralize Graph file keys and IO version. """
__version__ = "2.0"
class Keys(object):
""" File Keys. """
# Doesn't inherit enum to simplify usage (Graph.IO.Keys.XX, without .value)
Header = "header"
NodesVersions = "nodesVersions"
ReleaseVersion = "releaseVersion"
FileVersion = "fileVersion"
Graph = "graph"
class Features(Enum):
""" File Features. """
Graph = "graph"
Header = "header"
NodesVersions = "nodesVersions"
PrecomputedOutputs = "precomputedOutputs"
NodesPositions = "nodesPositions"
@staticmethod
def getFeaturesForVersion(fileVersion):
""" Return the list of supported features based on a file version.
Args:
fileVersion (str, Version): the file version
Returns:
tuple of Graph.IO.Features: the list of supported features
"""
if isinstance(fileVersion, str):
fileVersion = Version(fileVersion)
features = [Graph.IO.Features.Graph]
if fileVersion >= Version("1.0"):
features += [Graph.IO.Features.Header,
Graph.IO.Features.NodesVersions,
Graph.IO.Features.PrecomputedOutputs,
]
if fileVersion >= Version("1.1"):
features += [Graph.IO.Features.NodesPositions]
return tuple(features)
def __init__(self, name, parent=None):
super(Graph, self).__init__(parent)
self.name = name
@ -225,7 +198,6 @@ class Graph(BaseObject):
self._nodes = DictModel(keyAttrName='name', parent=self)
# Edges: use dst attribute as unique key since it can only have one input connection
self._edges = DictModel(keyAttrName='dst', parent=self)
self._importedNodes = DictModel(keyAttrName='name', parent=self)
self._compatibilityNodes = DictModel(keyAttrName='name', parent=self)
self.cacheDir = meshroom.core.defaultCacheFolder
self._filepath = ''
@ -233,20 +205,22 @@ class Graph(BaseObject):
self.header = {}
def clear(self):
self._clearGraphContent()
self.header.clear()
self._compatibilityNodes.clear()
self._unsetFilepath()
def _clearGraphContent(self):
self._edges.clear()
# Tell QML nodes are going to be deleted
for node in self._nodes:
node.alive = False
self._importedNodes.clear()
self._nodes.clear()
self._unsetFilepath()
self._compatibilityNodes.clear()
@property
def fileFeatures(self):
""" Get loaded file supported features based on its version. """
return Graph.IO.getFeaturesForVersion(self.header.get(Graph.IO.Keys.FileVersion, "0.0"))
return GraphIO.getFeaturesForVersion(self.header.get(GraphIO.Keys.FileVersion, "0.0"))
@property
def isLoading(self):
@ -259,37 +233,84 @@ class Graph(BaseObject):
return self._saving
@Slot(str)
def load(self, filepath, setupProjectFile=True, importProject=False, publishOutputs=False):
def load(self, filepath: PathLike):
"""
Load a Meshroom graph ".mg" file.
Load a Meshroom Graph ".mg" file in place.
Args:
filepath: project filepath to load
setupProjectFile: Store the reference to the project file and setup the cache directory.
If false, it only loads the graph of the project file as a template.
importProject: True if the project that is loaded will be imported in the current graph, instead
of opened.
publishOutputs: True if "Publish" nodes from templates should not be ignored.
filepath: The path to the Meshroom Graph file to load.
"""
self._loading = True
try:
return self._load(filepath, setupProjectFile, importProject, publishOutputs)
finally:
self._loading = False
self._deserialize(Graph._loadGraphData(filepath))
self._setFilepath(filepath)
self._fileDateVersion = os.path.getmtime(filepath)
def _load(self, filepath, setupProjectFile, importProject, publishOutputs):
if not importProject:
self.clear()
with open(filepath) as jsonFile:
fileData = json.load(jsonFile)
def initFromTemplate(self, filepath: PathLike, publishOutputs: bool = False):
"""
Deserialize a template Meshroom Graph ".mg" file in place.
self.header = fileData.get(Graph.IO.Keys.Header, {})
When initializing from a template, the internal filepath of the graph instance is not set.
Saving the file on disk will require to specify a filepath.
fileVersion = self.header.get(Graph.IO.Keys.FileVersion, "0.0")
# Retro-compatibility for all project files with the previous UID format
if Version(fileVersion) < Version("2.0"):
Args:
filepath: The path to the Meshroom Graph file to load.
publishOutputs: (optional) Whether to keep 'Publish' nodes.
"""
self._deserialize(Graph._loadGraphData(filepath))
if not publishOutputs:
with GraphModification(self):
for node in [node for node in self.nodes if node.nodeType == "Publish"]:
self.removeNode(node.name)
@staticmethod
def _loadGraphData(filepath: PathLike) -> dict:
"""Deserialize the content of the Meshroom Graph file at `filepath` to a dictionnary."""
with open(filepath) as file:
graphData = json.load(file)
return graphData
@blockNodeCallbacks
def _deserialize(self, graphData: dict):
"""Deserialize `graphData` in the current Graph instance.
Args:
graphData: The serialized Graph.
"""
self.clear()
self.header = graphData.get(GraphIO.Keys.Header, {})
fileVersion = Version(self.header.get(GraphIO.Keys.FileVersion, "0.0"))
graphContent = self._normalizeGraphContent(graphData, fileVersion)
isTemplate = self.header.get(GraphIO.Keys.Template, False)
with GraphModification(self):
# iterate over nodes sorted by suffix index in their names
for nodeName, nodeData in sorted(
graphContent.items(), key=lambda x: self.getNodeIndexFromName(x[0])
):
self._deserializeNode(nodeData, nodeName, self)
# Create graph edges by resolving attributes expressions
self._applyExpr()
# Templates are specific: they contain only the minimal amount of
# serialized data to describe the graph structure.
# They are not meant to be computed: therefore, we can early return here,
# as uid conflict evaluation is only meaningful for nodes with computed data.
if isTemplate:
return
# By this point, the graph has been fully loaded and an updateInternals has been triggered, so all the
# nodes' links have been resolved and their UID computations are all complete.
# It is now possible to check whether the UIDs stored in the graph file for each node correspond to the ones
# that were computed.
self._evaluateUidConflicts(graphContent)
def _normalizeGraphContent(self, graphData: dict, fileVersion: Version) -> dict:
graphContent = graphData.get(GraphIO.Keys.Graph, graphData)
if fileVersion < Version("2.0"):
# For internal folders, all "{uid0}" keys should be replaced with "{uid}"
updatedFileData = json.dumps(fileData).replace("{uid0}", "{uid}")
updatedFileData = json.dumps(graphContent).replace("{uid0}", "{uid}")
# For fileVersion < 2.0, the nodes' UID is stored as:
# "uids": {"0": "hashvalue"}
@ -301,239 +322,124 @@ class Graph(BaseObject):
uid = occ.split("\"")[-2] # UID is second to last element
newUidStr = r'"uid": "{}"'.format(uid)
updatedFileData = updatedFileData.replace(occ, newUidStr)
fileData = json.loads(updatedFileData)
graphContent = json.loads(updatedFileData)
# Older versions of Meshroom files only contained the serialized nodes
graphData = fileData.get(Graph.IO.Keys.Graph, fileData)
return graphContent
if importProject:
self._importedNodes.clear()
graphData = self.updateImportedProject(graphData)
def _deserializeNode(self, nodeData: dict, nodeName: str, fromGraph: "Graph"):
# Retrieve version info from:
# 1. nodeData: node saved from a CompatibilityNode
# 2. nodesVersion in file header: node saved from a Node
# If unvailable, the "version" field will not be set in `nodeData`.
if "version" not in nodeData:
if version := fromGraph._getNodeTypeVersionFromHeader(nodeData["nodeType"]):
nodeData["version"] = version
inTemplate = fromGraph.header.get(GraphIO.Keys.Template, False)
node = nodeFactory(nodeData, nodeName, inTemplate=inTemplate)
self._addNode(node, nodeName)
return node
if not isinstance(graphData, dict):
raise RuntimeError('loadGraph error: Graph is not a dict. File: {}'.format(filepath))
def _getNodeTypeVersionFromHeader(self, nodeType: str, default: Optional[str] = None) -> Optional[str]:
nodeVersions = self.header.get(GraphIO.Keys.NodesVersions, {})
return nodeVersions.get(nodeType, default)
nodesVersions = self.header.get(Graph.IO.Keys.NodesVersions, {})
self._fileDateVersion = os.path.getmtime(filepath)
# Check whether the file was saved as a template in minimal mode
isTemplate = self.header.get("template", False)
with GraphModification(self):
# iterate over nodes sorted by suffix index in their names
for nodeName, nodeData in sorted(graphData.items(), key=lambda x: self.getNodeIndexFromName(x[0])):
if not isinstance(nodeData, dict):
raise RuntimeError('loadGraph error: Node is not a dict. File: {}'.format(filepath))
# retrieve version from
# 1. nodeData: node saved from a CompatibilityNode
# 2. nodesVersion in file header: node saved from a Node
# 3. fallback to no version "0.0": retro-compatibility
if "version" not in nodeData:
nodeData["version"] = nodesVersions.get(nodeData["nodeType"], "0.0")
# if the node is a "Publish" node and comes from a template file, it should be ignored
# unless publishOutputs is True
if isTemplate and not publishOutputs and nodeData["nodeType"] == "Publish":
continue
n = nodeFactory(nodeData, nodeName, template=isTemplate)
# Add node to the graph with raw attributes values
self._addNode(n, nodeName)
if importProject:
self._importedNodes.add(n)
# Create graph edges by resolving attributes expressions
self._applyExpr()
if setupProjectFile:
# Update filepath related members
# Note: needs to be done at the end as it will trigger an updateInternals.
self._setFilepath(filepath)
elif not isTemplate:
# If no filepath is being set but the graph is not a template, trigger an updateInternals either way.
self.updateInternals()
# By this point, the graph has been fully loaded and an updateInternals has been triggered, so all the
# nodes' links have been resolved and their UID computations are all complete.
# It is now possible to check whether the UIDs stored in the graph file for each node correspond to the ones
# that were computed.
if not isTemplate: # UIDs are not stored in templates
self._evaluateUidConflicts(graphData)
try:
self._applyExpr()
except Exception as e:
logging.warning(e)
return True
def _evaluateUidConflicts(self, data):
def _evaluateUidConflicts(self, graphContent: dict):
"""
Compare the UIDs of all the nodes in the graph with the UID that is expected in the graph file. If there
Compare the computed UIDs of all the nodes in the graph with the UIDs serialized in `graphContent`. If there
are mismatches, the nodes with the unexpected UID are replaced with "UidConflict" compatibility nodes.
Already existing nodes are removed and re-added to the graph identically to preserve all the edges,
which may otherwise be invalidated when a node with output edges but a UID conflict is re-generated as a
compatibility node.
Args:
graphContent: The serialized Graph content.
"""
def _serializedNodeUidMatchesComputedUid(nodeData: dict, node: BaseNode) -> bool:
"""Returns whether the serialized UID matches the one computed in the `node` instance."""
if isinstance(node, CompatibilityNode):
return True
serializedUid = nodeData.get("uid", None)
computedUid = node._uid
return serializedUid is None or computedUid is None or serializedUid == computedUid
uidConflictingNodes = [
node
for node in self.nodes
if not _serializedNodeUidMatchesComputedUid(graphContent[node.name], node)
]
if not uidConflictingNodes:
return
logging.warning("UID Compatibility issues found: recreating conflicting nodes as CompatibilityNodes.")
# A uid conflict is contagious: if a node has a uid conflict, all of its downstream nodes may be
# impacted as well, as the uid flows through connections.
# Therefore, we deal with conflicting uid nodes by depth: replacing a node with a CompatibilityNode restores
# the serialized uid, which might solve "false-positives" downstream conflicts as well.
nodesSortedByDepth = sorted(uidConflictingNodes, key=lambda node: node.minDepth)
for node in nodesSortedByDepth:
nodeData = graphContent[node.name]
# Evaluate if the node uid is still conflicting at this point, or if it has been resolved by an
# upstream node replacement.
if _serializedNodeUidMatchesComputedUid(nodeData, node):
continue
expectedUid = node._uid
compatibilityNode = nodeFactory(graphContent[node.name], node.name, expectedUid=expectedUid)
# This operation will trigger a graph update that will recompute the uids of all nodes,
# allowing the iterative resolution of uid conflicts.
self.replaceNode(node.name, compatibilityNode)
def importGraphContentFromFile(self, filepath: PathLike) -> list[Node]:
"""Import the content (nodes and edges) of another Graph file into this Graph instance.
Args:
data (dict): the dictionary containing all the nodes to import and their data
"""
for nodeName, nodeData in sorted(data.items(), key=lambda x: self.getNodeIndexFromName(x[0])):
node = self.node(nodeName)
savedUid = nodeData.get("uid", None)
graphUid = node._uid # Node's UID from the graph itself
if savedUid != graphUid and graphUid is not None:
# Different UIDs, remove the existing node from the graph and replace it with a CompatibilityNode
logging.debug("UID conflict detected for {}".format(nodeName))
self.removeNode(nodeName)
n = nodeFactory(nodeData, nodeName, template=False, uidConflict=True)
self._addNode(n, nodeName)
else:
# f connecting nodes have UID conflicts and are removed/re-added to the graph, some edges may be lost:
# the links will be erroneously updated, and any further resolution will fail.
# Recreating the entire graph as it was ensures that all edges will be correctly preserved.
self.removeNode(nodeName)
n = nodeFactory(nodeData, nodeName, template=False, uidConflict=False)
self._addNode(n, nodeName)
def updateImportedProject(self, data):
"""
Update the names and links of the project to import so that it can fit
correctly in the existing graph.
Parse all the nodes from the project that is going to be imported.
If their name already exists in the graph, replace them with new names,
then parse all the nodes' inputs/outputs to replace the old names with
the new ones in the links.
Args:
data (dict): the dictionary containing all the nodes to import and their data
filepath: The path to the Graph file to import.
Returns:
updatedData (dict): the dictionary containing all the nodes to import with their updated names and data
The list of newly created Nodes.
"""
nameCorrespondences = {} # maps the old node name to its updated one
updatedData = {} # input data with updated node names and links
graph = loadGraph(filepath)
return self.importGraphContent(graph)
def createUniqueNodeName(nodeNames, inputName):
"""
Create a unique name that does not already exist in the current graph or in the list
of nodes that will be imported.
"""
i = 1
while i:
newName = "{name}_{index}".format(name=inputName, index=i)
if newName not in nodeNames and newName not in updatedData.keys():
return newName
i += 1
# First pass to get all the names that already exist in the graph, update them, and keep track of the changes
for nodeName, nodeData in sorted(data.items(), key=lambda x: self.getNodeIndexFromName(x[0])):
if not isinstance(nodeData, dict):
raise RuntimeError('updateImportedProject error: Node is not a dict.')
if nodeName in self._nodes.keys() or nodeName in updatedData.keys():
newName = createUniqueNodeName(self._nodes.keys(), nodeData["nodeType"])
updatedData[newName] = nodeData
nameCorrespondences[nodeName] = newName
else:
updatedData[nodeName] = nodeData
newNames = [nodeName for nodeName in updatedData] # names of all the nodes that will be added
# Second pass to update all the links in the input/output attributes for every node with the new names
for nodeName, nodeData in updatedData.items():
nodeType = nodeData.get("nodeType", None)
nodeDesc = meshroom.core.nodesDesc[nodeType]
inputs = nodeData.get("inputs", {})
outputs = nodeData.get("outputs", {})
if inputs:
inputs = self.updateLinks(inputs, nameCorrespondences)
inputs = self.resetExternalLinks(inputs, nodeDesc.inputs, newNames)
updatedData[nodeName]["inputs"] = inputs
if outputs:
outputs = self.updateLinks(outputs, nameCorrespondences)
outputs = self.resetExternalLinks(outputs, nodeDesc.outputs, newNames)
updatedData[nodeName]["outputs"] = outputs
return updatedData
@staticmethod
def updateLinks(attributes, nameCorrespondences):
@blockNodeCallbacks
def importGraphContent(self, graph: "Graph") -> list[Node]:
"""
Update all the links that refer to nodes that are going to be imported and whose
names have to be updated.
Import the content (node and edges) of another `graph` into this Graph instance.
Nodes are imported with their original names if possible, otherwise a new unique name is generated
from their node type.
Args:
attributes (dict): attributes whose links need to be updated
nameCorrespondences (dict): node names to replace in the links with the name to replace them with
graph: The graph to import.
Returns:
attributes (dict): the attributes with all the updated links
The list of newly created Nodes.
"""
for key, val in attributes.items():
for corr in nameCorrespondences.keys():
if isinstance(val, str) and corr in val:
attributes[key] = val.replace(corr, nameCorrespondences[corr])
elif isinstance(val, list):
for v in val:
if isinstance(v, str):
if corr in v:
val[val.index(v)] = v.replace(corr, nameCorrespondences[corr])
else: # the list does not contain strings, so there cannot be links to update
break
attributes[key] = val
return attributes
def _renameClashingNodes():
if not self.nodes:
return
unavailableNames = set(self.nodes.keys())
for node in graph.nodes:
if node._name in unavailableNames:
node._name = self._createUniqueNodeName(node.nodeType, unavailableNames)
unavailableNames.add(node._name)
@staticmethod
def resetExternalLinks(attributes, nodeDesc, newNames):
"""
Reset all links to nodes that are not part of the nodes which are going to be imported:
if there are links to nodes that are not in the list, then it means that the references
are made to external nodes, and we want to get rid of those.
def _importNodesAndEdges() -> list[Node]:
importedNodes = []
# If we import the content of the graph within itself,
# iterate over a copy of the nodes as the graph is modified during the iteration.
nodes = graph.nodes if graph is not self else list(graph.nodes)
with GraphModification(self):
for srcNode in nodes:
node = self._deserializeNode(srcNode.toDict(), srcNode.name, graph)
importedNodes.append(node)
self._applyExpr()
return importedNodes
Args:
attributes (dict): attributes whose links might need to be reset
nodeDesc (list): list with all the attributes' description (including their default value)
newNames (list): names of the nodes that are going to be imported; no node name should be referenced
in the links except those contained in this list
Returns:
attributes (dict): the attributes with all the links referencing nodes outside those which will be imported
reset to their default values
"""
for key, val in attributes.items():
defaultValue = None
for desc in nodeDesc:
if desc.name == key:
defaultValue = desc.value
break
if isinstance(val, str):
if Attribute.isLinkExpression(val) and not any(name in val for name in newNames):
if defaultValue is not None: # prevents from not entering condition if defaultValue = ''
attributes[key] = defaultValue
elif isinstance(val, list):
removedCnt = len(val) # counter to know whether all the list entries will be deemed invalid
tmpVal = list(val) # deep copy to ensure we iterate over the entire list (even if elements are removed)
for v in tmpVal:
if isinstance(v, str) and Attribute.isLinkExpression(v) and not any(name in v for name in newNames):
val.remove(v)
removedCnt -= 1
if removedCnt == 0 and defaultValue is not None: # if all links were wrong, reset the attribute
attributes[key] = defaultValue
return attributes
_renameClashingNodes()
importedNodes = _importNodesAndEdges()
return importedNodes
@property
def updateEnabled(self):
@ -648,41 +554,6 @@ class Graph(BaseObject):
return duplicates
def pasteNodes(self, data, position):
"""
Paste node(s) in the graph with their connections. The connections can only be between
the pasted nodes and not with the rest of the graph.
Args:
data (dict): the dictionary containing the information about the nodes to paste, with their names and
links already updated to be added to the graph
position (list): the list of positions for each node to paste
Returns:
list: the list of Node objects that were pasted and added to the graph
"""
nodes = []
with GraphModification(self):
positionCnt = 0 # always valid because we know the data is sorted the same way as the position list
for key in sorted(data):
nodeType = data[key].get("nodeType", None)
if not nodeType: # this case should never occur, as the data should have been prefiltered first
pass
attributes = {}
attributes.update(data[key].get("inputs", {}))
attributes.update(data[key].get("outputs", {}))
attributes.update(data[key].get("internalInputs", {}))
node = Node(nodeType, position=position[positionCnt], **attributes)
self._addNode(node, key)
nodes.append(node)
positionCnt += 1
self._applyExpr()
return nodes
def outEdges(self, attribute):
""" Return the list of edges starting from the given attribute """
# type: (Attribute,) -> [Edge]
@ -746,8 +617,6 @@ class Graph(BaseObject):
node.alive = False
self._nodes.remove(node)
if node in self._importedNodes:
self._importedNodes.remove(node)
self.update()
return inEdges, outEdges, outListAttributes
@ -772,18 +641,26 @@ class Graph(BaseObject):
n.updateInternals()
return n
def _createUniqueNodeName(self, inputName):
i = 1
while i:
newName = "{name}_{index}".format(name=inputName, index=i)
if newName not in self._nodes.objects:
def _createUniqueNodeName(self, inputName: str, existingNames: Optional[set[str]] = None):
"""Create a unique node name based on the input name.
Args:
inputName: The desired node name.
existingNames: (optional) If specified, consider this set for uniqueness check, instead of the list of nodes.
"""
existingNodeNames = existingNames or set(self._nodes.objects.keys())
idx = 1
while idx:
newName = f"{inputName}_{idx}"
if newName not in existingNodeNames:
return newName
i += 1
idx += 1
def node(self, nodeName):
return self._nodes.get(nodeName)
def upgradeNode(self, nodeName):
def upgradeNode(self, nodeName) -> Node:
"""
Upgrade the CompatibilityNode identified as 'nodeName'
Args:
@ -803,25 +680,49 @@ class Graph(BaseObject):
if not isinstance(node, CompatibilityNode):
raise ValueError("Upgrade is only available on CompatibilityNode instances.")
upgradedNode = node.upgrade()
with GraphModification(self):
inEdges, outEdges, outListAttributes = self.removeNode(nodeName)
self.addNode(upgradedNode, nodeName)
for dst, src in outEdges.items():
# Re-create the entries in ListAttributes that were completely removed during the call to "removeNode"
# If they are not re-created first, adding their edges will lead to errors
# 0 = attribute name, 1 = attribute index, 2 = attribute value
if dst in outListAttributes.keys():
listAttr = self.attribute(outListAttributes[dst][0])
if isinstance(outListAttributes[dst][2], list):
listAttr[outListAttributes[dst][1]:outListAttributes[dst][1]] = outListAttributes[dst][2]
else:
listAttr.insert(outListAttributes[dst][1], outListAttributes[dst][2])
try:
self.addEdge(self.attribute(src), self.attribute(dst))
except (KeyError, ValueError) as e:
logging.warning("Failed to restore edge {} -> {}: {}".format(src, dst, str(e)))
self.replaceNode(nodeName, upgradedNode)
return upgradedNode
return upgradedNode, inEdges, outEdges, outListAttributes
@changeTopology
def replaceNode(self, nodeName: str, newNode: BaseNode):
"""Replace the node idenfitied by `nodeName` with `newNode`, while restoring compatible edges.
Args:
nodeName: The name of the Node to replace.
newNode: The Node instance to replace it with.
"""
with GraphModification(self):
_, outEdges, outListAttributes = self.removeNode(nodeName)
self.addNode(newNode, nodeName)
self._restoreOutEdges(outEdges, outListAttributes)
def _restoreOutEdges(self, outEdges: dict[str, str], outListAttributes):
"""Restore output edges that were removed during a call to "removeNode".
Args:
outEdges: a dictionary containing the outgoing edges removed by a call to "removeNode".
{dstAttr.getFullNameToNode(), srcAttr.getFullNameToNode()}
outListAttributes: a dictionary containing the values, indices and keys of attributes that were connected
to a ListAttribute prior to the removal of all edges.
{dstAttr.getFullNameToNode(), (dstAttr.root.getFullNameToNode(), dstAttr.index, dstAttr.value)}
"""
def _recreateTargetListAttributeChildren(listAttrName: str, index: int, value: Any):
listAttr = self.attribute(listAttrName)
if not isinstance(listAttr, ListAttribute):
return
if isinstance(value, list):
listAttr[index:index] = value
else:
listAttr.insert(index, value)
for dstName, srcName in outEdges.items():
# Re-create the entries in ListAttributes that were completely removed during the call to "removeNode"
if dstName in outListAttributes:
_recreateTargetListAttributeChildren(*outListAttributes[dstName])
try:
self.addEdge(self.attribute(srcName), self.attribute(dstName))
except (KeyError, ValueError) as e:
logging.warning(f"Failed to restore edge {srcName} -> {dstName}: {str(e)}")
def upgradeAllNodes(self):
""" Upgrade all upgradable CompatibilityNode instances in the graph. """
@ -1352,6 +1253,35 @@ class Graph(BaseObject):
def asString(self):
return str(self.toDict())
def copy(self) -> "Graph":
"""Create a copy of this Graph instance."""
graph = Graph("")
graph._deserialize(self.serialize())
return graph
def serialize(self, asTemplate: bool = False) -> dict:
"""Serialize this Graph instance.
Args:
asTemplate: Whether to use the template serialization.
Returns:
The serialized graph data.
"""
SerializerClass = TemplateGraphSerializer if asTemplate else GraphSerializer
return SerializerClass(self).serialize()
def serializePartial(self, nodes: list[Node]) -> dict:
"""Partially serialize this graph considering only the given list of `nodes`.
Args:
nodes: The list of nodes to serialize.
Returns:
The serialized graph data.
"""
return PartialGraphSerializer(self, nodes=nodes).serialize()
def save(self, filepath=None, setupProjectFile=True, template=False):
"""
Save the current Meshroom graph as a serialized ".mg" file.
@ -1374,34 +1304,7 @@ class Graph(BaseObject):
if not path:
raise ValueError("filepath must be specified for unsaved files.")
self.header[Graph.IO.Keys.ReleaseVersion] = meshroom.__version__
self.header[Graph.IO.Keys.FileVersion] = Graph.IO.__version__
# Store versions of node types present in the graph (excluding CompatibilityNode instances)
# and remove duplicates
usedNodeTypes = set([n.nodeDesc.__class__ for n in self._nodes if isinstance(n, Node)])
# Convert to node types to "name: version"
nodesVersions = {
"{}".format(p.__name__): meshroom.core.nodeVersion(p, "0.0")
for p in usedNodeTypes
}
# Sort them by name (to avoid random order changing from one save to another)
nodesVersions = dict(sorted(nodesVersions.items()))
# Add it the header
self.header[Graph.IO.Keys.NodesVersions] = nodesVersions
self.header["template"] = template
data = {}
if template:
data = {
Graph.IO.Keys.Header: self.header,
Graph.IO.Keys.Graph: self.getNonDefaultInputAttributes()
}
else:
data = {
Graph.IO.Keys.Header: self.header,
Graph.IO.Keys.Graph: self.toDict()
}
data = self.serialize(template)
with open(path, 'w') as jsonFile:
json.dump(data, jsonFile, indent=4)
@ -1412,51 +1315,6 @@ class Graph(BaseObject):
# update the file date version
self._fileDateVersion = os.path.getmtime(path)
def getNonDefaultInputAttributes(self):
"""
Instead of getting all the inputs and internal attribute keys, only get the keys of
the attributes whose value is not the default one.
The output attributes, UIDs, parallelization parameters and internal folder are
not relevant for templates, so they are explicitly removed from the returned dictionary.
Returns:
dict: self.toDict() with the output attributes, UIDs, parallelization parameters, internal folder
and input/internal attributes with default values removed
"""
graph = self.toDict()
for nodeName in graph.keys():
node = self.node(nodeName)
inputKeys = list(graph[nodeName]["inputs"].keys())
internalInputKeys = []
internalInputs = graph[nodeName].get("internalInputs", None)
if internalInputs:
internalInputKeys = list(internalInputs.keys())
for attrName in inputKeys:
attribute = node.attribute(attrName)
# check that attribute is not a link for choice attributes
if attribute.isDefault and not attribute.isLink:
del graph[nodeName]["inputs"][attrName]
for attrName in internalInputKeys:
attribute = node.internalAttribute(attrName)
# check that internal attribute is not a link for choice attributes
if attribute.isDefault and not attribute.isLink:
del graph[nodeName]["internalInputs"][attrName]
# If all the internal attributes are set to their default values, remove the entry
if len(graph[nodeName]["internalInputs"]) == 0:
del graph[nodeName]["internalInputs"]
del graph[nodeName]["outputs"]
del graph[nodeName]["uid"]
del graph[nodeName]["internalFolder"]
del graph[nodeName]["parallelization"]
return graph
def _setFilepath(self, filepath):
"""
Set the internal filepath of this Graph.
@ -1615,11 +1473,6 @@ class Graph(BaseObject):
def edges(self):
return self._edges
@property
def importedNodes(self):
"""" Return the list of nodes that were added to the graph with the latest 'Import Project' action. """
return self._importedNodes
@property
def cacheDir(self):
return self._cacheDir
@ -1660,7 +1513,7 @@ class Graph(BaseObject):
filepathChanged = Signal()
filepath = Property(str, lambda self: self._filepath, notify=filepathChanged)
isSaving = Property(bool, isSaving.fget, constant=True)
fileReleaseVersion = Property(str, lambda self: self.header.get(Graph.IO.Keys.ReleaseVersion, "0.0"),
fileReleaseVersion = Property(str, lambda self: self.header.get(GraphIO.Keys.ReleaseVersion, "0.0"),
notify=filepathChanged)
fileDateVersion = Property(float, fileDateVersion.fget, fileDateVersion.fset, notify=filepathChanged)
cacheDirChanged = Signal()

231
meshroom/core/graphIO.py Normal file
View file

@ -0,0 +1,231 @@
from enum import Enum
from typing import Any, TYPE_CHECKING, Union
import meshroom
from meshroom.core import Version
from meshroom.core.attribute import Attribute, GroupAttribute, ListAttribute
from meshroom.core.node import Node
if TYPE_CHECKING:
from meshroom.core.graph import Graph
class GraphIO:
"""Centralize Graph file keys and IO version."""
__version__ = "2.0"
class Keys(object):
"""File Keys."""
# Doesn't inherit enum to simplify usage (GraphIO.Keys.XX, without .value)
Header = "header"
NodesVersions = "nodesVersions"
ReleaseVersion = "releaseVersion"
FileVersion = "fileVersion"
Graph = "graph"
Template = "template"
class Features(Enum):
"""File Features."""
Graph = "graph"
Header = "header"
NodesVersions = "nodesVersions"
PrecomputedOutputs = "precomputedOutputs"
NodesPositions = "nodesPositions"
@staticmethod
def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features", ...]:
"""Return the list of supported features based on a file version.
Args:
fileVersion (str, Version): the file version
Returns:
tuple of GraphIO.Features: the list of supported features
"""
if isinstance(fileVersion, str):
fileVersion = Version(fileVersion)
features = [GraphIO.Features.Graph]
if fileVersion >= Version("1.0"):
features += [
GraphIO.Features.Header,
GraphIO.Features.NodesVersions,
GraphIO.Features.PrecomputedOutputs,
]
if fileVersion >= Version("1.1"):
features += [GraphIO.Features.NodesPositions]
return tuple(features)
class GraphSerializer:
"""Standard Graph serializer."""
def __init__(self, graph: "Graph") -> None:
self._graph = graph
def serialize(self) -> dict:
"""
Serialize the Graph.
"""
return {
GraphIO.Keys.Header: self.serializeHeader(),
GraphIO.Keys.Graph: self.serializeContent(),
}
@property
def nodes(self) -> list[Node]:
return self._graph.nodes
def serializeHeader(self) -> dict:
"""Build and return the graph serialization header.
The header contains metadata about the graph, such as the:
- version of the software used to create it.
- version of the file format.
- version of the nodes types used in the graph.
- template flag.
"""
header: dict[str, Any] = {}
header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__
header[GraphIO.Keys.FileVersion] = GraphIO.__version__
header[GraphIO.Keys.NodesVersions] = self._getNodeTypesVersions()
return header
def _getNodeTypesVersions(self) -> dict[str, str]:
"""Get registered versions of each node types in `nodes`, excluding CompatibilityNode instances."""
nodeTypes = set([node.nodeDesc.__class__ for node in self.nodes if isinstance(node, Node)])
nodeTypesVersions = {
nodeType.__name__: version
for nodeType in nodeTypes
if (version := meshroom.core.nodeVersion(nodeType)) is not None
}
# Sort them by name (to avoid random order changing from one save to another).
return dict(sorted(nodeTypesVersions.items()))
def serializeContent(self) -> dict:
"""Graph content serialization logic."""
return {node.name: self.serializeNode(node) for node in sorted(self.nodes, key=lambda n: n.name)}
def serializeNode(self, node: Node) -> dict:
"""Node serialization logic."""
return node.toDict()
class TemplateGraphSerializer(GraphSerializer):
"""Serializer for serializing a graph as a template."""
def serializeHeader(self) -> dict:
header = super().serializeHeader()
header[GraphIO.Keys.Template] = True
return header
def serializeNode(self, node: Node) -> dict:
"""Adapt node serialization to template graphs.
Instead of getting all the inputs and internal attribute keys, only get the keys of
the attributes whose value is not the default one.
The output attributes, UIDs, parallelization parameters and internal folder are
not relevant for templates, so they are explicitly removed from the returned dictionary.
"""
# For now, implemented as a post-process to update the default serialization.
nodeData = super().serializeNode(node)
inputKeys = list(nodeData["inputs"].keys())
internalInputKeys = []
internalInputs = nodeData.get("internalInputs", None)
if internalInputs:
internalInputKeys = list(internalInputs.keys())
for attrName in inputKeys:
attribute = node.attribute(attrName)
# check that attribute is not a link for choice attributes
if attribute.isDefault and not attribute.isLink:
del nodeData["inputs"][attrName]
for attrName in internalInputKeys:
attribute = node.internalAttribute(attrName)
# check that internal attribute is not a link for choice attributes
if attribute.isDefault and not attribute.isLink:
del nodeData["internalInputs"][attrName]
# If all the internal attributes are set to their default values, remove the entry
if len(nodeData["internalInputs"]) == 0:
del nodeData["internalInputs"]
del nodeData["outputs"]
del nodeData["uid"]
del nodeData["internalFolder"]
del nodeData["parallelization"]
return nodeData
class PartialGraphSerializer(GraphSerializer):
"""Serializer to serialize a partial graph (a subset of nodes)."""
def __init__(self, graph: "Graph", nodes: list[Node]):
super().__init__(graph)
self._nodes = nodes
@property
def nodes(self) -> list[Node]:
"""Override to consider only the subset of nodes."""
return self._nodes
def serializeNode(self, node: Node) -> dict:
"""Adapt node serialization to partial graph serialization."""
# NOTE: For now, implemented as a post-process to the default serialization.
nodeData = super().serializeNode(node)
# Override input attributes with custom serialization logic, to handle attributes
# connected to nodes that are not in the list of nodes to serialize.
for attributeName in nodeData["inputs"]:
nodeData["inputs"][attributeName] = self._serializeAttribute(node.attribute(attributeName))
# Clear UID for non-compatibility nodes, as the custom attribute serialization
# can be impacting the UID by removing connections to missing nodes.
if not node.isCompatibilityNode:
del nodeData["uid"]
return nodeData
def _serializeAttribute(self, attribute: Attribute) -> Any:
"""
Serialize `attribute` (recursively for list/groups) and deal with attributes being connected
to nodes that are not part of the partial list of nodes to serialize.
"""
linkParam = attribute.getLinkParam()
if linkParam is not None:
# Use standard link serialization if upstream node is part of the serialization.
if linkParam.node in self.nodes:
return attribute.getExportValue()
# Skip link serialization otherwise.
# If part of a list, this entry can be discarded.
if isinstance(attribute.root, ListAttribute):
return None
# Otherwise, return the default value for this attribute.
return attribute.defaultValue()
if isinstance(attribute, ListAttribute):
# Recusively serialize each child of the ListAttribute, skipping those for which the attribute
# serialization logic above returns None.
return [
exportValue
for child in attribute
if (exportValue := self._serializeAttribute(child)) is not None
]
if isinstance(attribute, GroupAttribute):
# Recursively serialize each child of the group attribute.
return {name: self._serializeAttribute(child) for name, child in attribute.value.items()}
return attribute.getExportValue()

View file

@ -1608,7 +1608,8 @@ class CompatibilityNode(BaseNode):
# Make a deepcopy of nodeDict to handle CompatibilityNode duplication
# and be able to change modified inputs (see CompatibilityNode.toDict)
self.nodeDict = copy.deepcopy(nodeDict)
self.version = Version(self.nodeDict.get("version", None))
version = self.nodeDict.get("version")
self.version = Version(version) if version else None
self._inputs = self.nodeDict.get("inputs", {})
self._internalInputs = self.nodeDict.get("internalInputs", {})
@ -1668,7 +1669,17 @@ class CompatibilityNode(BaseNode):
elif isinstance(value, float):
return desc.FloatParam(range=None, **params)
elif isinstance(value, str):
if isOutput or os.path.isabs(value) or Attribute.isLinkExpression(value):
if isOutput or os.path.isabs(value):
return desc.File(**params)
elif Attribute.isLinkExpression(value):
# Do not consider link expression as a valid default desc value.
# When the link expression is applied and transformed to an actual link,
# the systems resets the value using `Attribute.resetToDefaultValue` to indicate
# that this link expression has been handled.
# If the link expression is stored as the default value, it will never be cleared,
# leading to unexpected behavior where the link expression on a CompatibilityNode
# could be evaluated several times and/or incorrectly.
params["value"] = ""
return desc.File(**params)
else:
return desc.StringParam(**params)
@ -1851,113 +1862,3 @@ class CompatibilityNode(BaseNode):
canUpgrade = Property(bool, canUpgrade.fget, constant=True)
issueDetails = Property(str, issueDetails.fget, constant=True)
def nodeFactory(nodeDict, name=None, template=False, uidConflict=False):
"""
Create a node instance by deserializing the given node data.
If the serialized data matches the corresponding node type description, a Node instance is created.
If any compatibility issue occurs, a NodeCompatibility instance is created instead.
Args:
nodeDict (dict): the serialization of the node
name (str): (optional) the node's name
template (bool): (optional) true if the node is part of a template, false otherwise
uidConflict (bool): (optional) true if a UID conflict has been detected externally on that node
Returns:
BaseNode: the created node
"""
nodeType = nodeDict["nodeType"]
# Retro-compatibility: inputs were previously saved as "attributes"
if "inputs" not in nodeDict and "attributes" in nodeDict:
nodeDict["inputs"] = nodeDict["attributes"]
del nodeDict["attributes"]
# Get node inputs/outputs
inputs = nodeDict.get("inputs", {})
internalInputs = nodeDict.get("internalInputs", {})
outputs = nodeDict.get("outputs", {})
version = nodeDict.get("version", None)
internalFolder = nodeDict.get("internalFolder", None)
position = Position(*nodeDict.get("position", []))
uid = nodeDict.get("uid", None)
compatibilityIssue = None
nodeDesc = None
try:
nodeDesc = meshroom.core.nodesDesc[nodeType]
except KeyError:
# Unknown node type
compatibilityIssue = CompatibilityIssue.UnknownNodeType
# Unknown node type should take precedence over UID conflict, as it cannot be resolved
if uidConflict and nodeDesc:
compatibilityIssue = CompatibilityIssue.UidConflict
if nodeDesc and not uidConflict: # if uidConflict, there is no need to look for another compatibility issue
# Compare serialized node version with current node version
currentNodeVersion = meshroom.core.nodeVersion(nodeDesc)
# If both versions are available, check for incompatibility in major version
if version and currentNodeVersion and Version(version).major != Version(currentNodeVersion).major:
compatibilityIssue = CompatibilityIssue.VersionConflict
# In other cases, check attributes compatibility between serialized node and its description
else:
# Check that the node has the exact same set of inputs/outputs as its description, except
# if the node is described in a template file, in which only non-default parameters are saved;
# do not perform that check for internal attributes because there is no point in
# raising compatibility issues if their number differs: in that case, it is only useful
# if some internal attributes do not exist or are invalid
if not template and (sorted([attr.name for attr in nodeDesc.inputs
if not isinstance(attr, desc.PushButtonParam)]) != sorted(inputs.keys()) or
sorted([attr.name for attr in nodeDesc.outputs if not attr.isDynamicValue]) !=
sorted(outputs.keys())):
compatibilityIssue = CompatibilityIssue.DescriptionConflict
# Check whether there are any internal attributes that are invalidating in the node description: if there
# are, then check that these internal attributes are part of nodeDict; if they are not, a compatibility
# issue must be raised to warn the user, as this will automatically change the node's UID
if not template:
invalidatingIntInputs = []
for attr in nodeDesc.internalInputs:
if attr.invalidate:
invalidatingIntInputs.append(attr.name)
for attr in invalidatingIntInputs:
if attr not in internalInputs.keys():
compatibilityIssue = CompatibilityIssue.DescriptionConflict
break
# Verify that all inputs match their descriptions
for attrName, value in inputs.items():
if not CompatibilityNode.attributeDescFromName(nodeDesc.inputs, attrName, value):
compatibilityIssue = CompatibilityIssue.DescriptionConflict
break
# Verify that all internal inputs match their description
for attrName, value in internalInputs.items():
if not CompatibilityNode.attributeDescFromName(nodeDesc.internalInputs, attrName, value):
compatibilityIssue = CompatibilityIssue.DescriptionConflict
break
# Verify that all outputs match their descriptions
for attrName, value in outputs.items():
if not CompatibilityNode.attributeDescFromName(nodeDesc.outputs, attrName, value):
compatibilityIssue = CompatibilityIssue.DescriptionConflict
break
if compatibilityIssue is None:
node = Node(nodeType, position, uid=uid, **inputs, **internalInputs, **outputs)
else:
logging.debug("Compatibility issue detected for node '{}': {}".format(name, compatibilityIssue.name))
node = CompatibilityNode(nodeType, nodeDict, position, compatibilityIssue)
# Retro-compatibility: no internal folder saved
# can't spawn meaningful CompatibilityNode with precomputed outputs
# => automatically try to perform node upgrade
if not internalFolder and nodeDesc:
logging.warning("No serialized output data: performing automatic upgrade on '{}'".format(name))
node = node.upgrade()
# If the node comes from a template file and there is a conflict, it should be upgraded anyway unless it is
# an "unknown node type" conflict (in which case the upgrade would fail)
elif template and compatibilityIssue is not CompatibilityIssue.UnknownNodeType:
node = node.upgrade()
return node

View file

@ -0,0 +1,201 @@
import logging
from typing import Any, Iterable, Optional, Union
import meshroom.core
from meshroom.core import Version, desc
from meshroom.core.node import CompatibilityIssue, CompatibilityNode, Node, Position
def nodeFactory(
nodeData: dict,
name: Optional[str] = None,
inTemplate: bool = False,
expectedUid: Optional[str] = None,
) -> Union[Node, CompatibilityNode]:
"""
Create a node instance by deserializing the given node data.
If the serialized data matches the corresponding node type description, a Node instance is created.
If any compatibility issue occurs, a NodeCompatibility instance is created instead.
Args:
nodeData: The serialized Node data.
name: The node's name.
inTemplate: True if the node is created as part of a graph template.
expectedUid: The expected UID of the node within the context of a Graph.
Returns:
The created Node instance.
"""
return _NodeCreator(nodeData, name, inTemplate, expectedUid).create()
class _NodeCreator:
def __init__(
self,
nodeData: dict,
name: Optional[str] = None,
inTemplate: bool = False,
expectedUid: Optional[str] = None,
):
self.nodeData = nodeData
self.name = name
self.inTemplate = inTemplate
self.expectedUid = expectedUid
self._normalizeNodeData()
self.nodeType = self.nodeData["nodeType"]
self.inputs = self.nodeData.get("inputs", {})
self.internalInputs = self.nodeData.get("internalInputs", {})
self.outputs = self.nodeData.get("outputs", {})
self.version = self.nodeData.get("version", None)
self.internalFolder = self.nodeData.get("internalFolder")
self.position = Position(*self.nodeData.get("position", []))
self.uid = self.nodeData.get("uid", None)
self.nodeDesc = meshroom.core.nodesDesc.get(self.nodeType, None)
def create(self) -> Union[Node, CompatibilityNode]:
compatibilityIssue = self._checkCompatibilityIssues()
if compatibilityIssue:
node = self._createCompatibilityNode(compatibilityIssue)
node = self._tryUpgradeCompatibilityNode(node)
else:
node = self._createNode()
return node
def _normalizeNodeData(self):
"""Consistency fixes for backward compatibility with older serialized data."""
# Inputs were previously saved as "attributes".
if "inputs" not in self.nodeData and "attributes" in self.nodeData:
self.nodeData["inputs"] = self.nodeData["attributes"]
del self.nodeData["attributes"]
def _checkCompatibilityIssues(self) -> Optional[CompatibilityIssue]:
if self.nodeDesc is None:
return CompatibilityIssue.UnknownNodeType
if not self._checkUidCompatibility():
return CompatibilityIssue.UidConflict
if not self._checkVersionCompatibility():
return CompatibilityIssue.VersionConflict
if not self._checkDescriptionCompatibility():
return CompatibilityIssue.DescriptionConflict
return None
def _checkUidCompatibility(self) -> bool:
return self.expectedUid is None or self.expectedUid == self.uid
def _checkVersionCompatibility(self) -> bool:
# Special case: a node with a version set to None indicates
# that it has been created from the current version of the node type.
nodeCreatedFromCurrentVersion = self.version is None
if nodeCreatedFromCurrentVersion:
return True
nodeTypeCurrentVersion = meshroom.core.nodeVersion(self.nodeDesc)
# If the node type has not current version information, assume compatibility.
if nodeTypeCurrentVersion is None:
return True
return Version(self.version).major == Version(nodeTypeCurrentVersion).major
def _checkDescriptionCompatibility(self) -> bool:
# Only perform strict attribute name matching for non-template graphs,
# since only non-default-value input attributes are serialized in templates.
if not self.inTemplate:
if not self._checkAttributesNamesMatchDescription():
return False
return self._checkAttributesAreCompatibleWithDescription()
def _checkAttributesNamesMatchDescription(self) -> bool:
return (
self._checkInputAttributesNames()
and self._checkOutputAttributesNames()
and self._checkInternalAttributesNames()
)
def _checkAttributesAreCompatibleWithDescription(self) -> bool:
return (
self._checkAttributesCompatibility(self.nodeDesc.inputs, self.inputs)
and self._checkAttributesCompatibility(self.nodeDesc.internalInputs, self.internalInputs)
and self._checkAttributesCompatibility(self.nodeDesc.outputs, self.outputs)
)
def _checkInputAttributesNames(self) -> bool:
def serializedInput(attr: desc.Attribute) -> bool:
"""Filter that excludes not-serialized desc input attributes."""
if isinstance(attr, desc.PushButtonParam):
# PushButtonParam are not serialized has they do not hold a value.
return False
return True
refAttributes = filter(serializedInput, self.nodeDesc.inputs)
return self._checkAttributesNamesStrictlyMatch(refAttributes, self.inputs)
def _checkOutputAttributesNames(self) -> bool:
def serializedOutput(attr: desc.Attribute) -> bool:
"""Filter that excludes not-serialized desc output attributes."""
if attr.isDynamicValue:
# Dynamic outputs values are not serialized with the node,
# as their value is written in the computed output data.
return False
return True
refAttributes = filter(serializedOutput, self.nodeDesc.outputs)
return self._checkAttributesNamesStrictlyMatch(refAttributes, self.outputs)
def _checkInternalAttributesNames(self) -> bool:
invalidatingDescAttributes = [attr.name for attr in self.nodeDesc.internalInputs if attr.invalidate]
return all(attr in self.internalInputs.keys() for attr in invalidatingDescAttributes)
def _checkAttributesNamesStrictlyMatch(
self, descAttributes: Iterable[desc.Attribute], attributesDict: dict[str, Any]
) -> bool:
refNames = sorted([attr.name for attr in descAttributes])
attrNames = sorted(attributesDict.keys())
return refNames == attrNames
def _checkAttributesCompatibility(
self, descAttributes: list[desc.Attribute], attributesDict: dict[str, Any]
) -> bool:
return all(
CompatibilityNode.attributeDescFromName(descAttributes, attrName, value) is not None
for attrName, value in attributesDict.items()
)
def _createNode(self) -> Node:
logging.info(f"Creating node '{self.name}'")
return Node(
self.nodeType,
position=self.position,
uid=self.uid,
**self.inputs,
**self.internalInputs,
**self.outputs,
)
def _createCompatibilityNode(self, compatibilityIssue) -> CompatibilityNode:
logging.warning(f"Compatibility issue detected for node '{self.name}': {compatibilityIssue.name}")
return CompatibilityNode(
self.nodeType, self.nodeData, position=self.position, issue=compatibilityIssue
)
def _tryUpgradeCompatibilityNode(self, node: CompatibilityNode) -> Union[Node, CompatibilityNode]:
"""Handle possible upgrades of CompatibilityNodes, when no computed data is associated to the Node."""
if node.issue == CompatibilityIssue.UnknownNodeType:
return node
# Nodes in templates are not meant to hold computation data.
if self.inTemplate:
logging.warning(f"Compatibility issue in template: performing automatic upgrade on '{self.name}'")
return node.upgrade()
# Backward compatibility: "internalFolder" was not serialized.
if not self.internalFolder:
logging.warning(f"No serialized output data: performing automatic upgrade on '{self.name}'")
return node.upgrade()
return node

8
meshroom/core/typing.py Normal file
View file

@ -0,0 +1,8 @@
"""
Common typing aliases used in Meshroom.
"""
from pathlib import Path
from typing import Union
PathLike = Union[Path, str]

View file

@ -6,8 +6,10 @@ from PySide6.QtGui import QUndoCommand, QUndoStack
from PySide6.QtCore import Property, Signal
from meshroom.core.attribute import ListAttribute, Attribute
from meshroom.core.graph import GraphModification
from meshroom.core.node import nodeFactory, Position
from meshroom.core.graph import Graph, GraphModification
from meshroom.core.node import Position, CompatibilityIssue
from meshroom.core.nodeFactory import nodeFactory
from meshroom.core.typing import PathLike
class UndoCommand(QUndoCommand):
@ -168,19 +170,7 @@ class RemoveNodeCommand(GraphCommand):
node = nodeFactory(self.nodeDict, self.nodeName)
self.graph.addNode(node, self.nodeName)
assert (node.getName() == self.nodeName)
# recreate out edges deleted on node removal
for dstAttr, srcAttr in self.outEdges.items():
# if edges were connected to ListAttributes, recreate their corresponding entry in said ListAttribute
# 0 = attribute name, 1 = attribute index, 2 = attribute value
if dstAttr in self.outListAttributes.keys():
listAttr = self.graph.attribute(self.outListAttributes[dstAttr][0])
if isinstance(self.outListAttributes[dstAttr][2], list):
listAttr[self.outListAttributes[dstAttr][1]:self.outListAttributes[dstAttr][1]] = self.outListAttributes[dstAttr][2]
else:
listAttr.insert(self.outListAttributes[dstAttr][1], self.outListAttributes[dstAttr][2])
self.graph.addEdge(self.graph.attribute(srcAttr),
self.graph.attribute(dstAttr))
self.graph._restoreOutEdges(self.outEdges, self.outListAttributes)
class DuplicateNodesCommand(GraphCommand):
@ -209,15 +199,27 @@ class PasteNodesCommand(GraphCommand):
"""
Handle node pasting in a Graph.
"""
def __init__(self, graph, data, position=None, parent=None):
def __init__(self, graph: "Graph", data: dict, position: Position, parent=None):
super(PasteNodesCommand, self).__init__(graph, parent)
self.data = data
self.position = position
self.nodeNames = []
self.nodeNames: list[str] = []
def redoImpl(self):
data = self.graph.updateImportedProject(self.data)
nodes = self.graph.pasteNodes(data, self.position)
graph = Graph("")
try:
graph._deserialize(self.data)
except:
return False
boundingBoxCenter = self._boundingBoxCenter(graph.nodes)
offset = Position(self.position.x - boundingBoxCenter.x, self.position.y - boundingBoxCenter.y)
for node in graph.nodes:
node.position = Position(node.position.x + offset.x, node.position.y + offset.y)
nodes = self.graph.importGraphContent(graph)
self.nodeNames = [node.name for node in nodes]
self.setText("Paste Node{} ({})".format("s" if len(self.nodeNames) > 1 else "", ", ".join(self.nodeNames)))
return nodes
@ -226,12 +228,31 @@ class PasteNodesCommand(GraphCommand):
for name in self.nodeNames:
self.graph.removeNode(name)
def _boundingBox(self, nodes) -> tuple[int, int, int, int]:
if not nodes:
return (0, 0, 0 , 0)
minX = maxX = nodes[0].x
minY = maxY = nodes[0].y
for node in nodes[1:]:
minX = min(minX, node.x)
minY = min(minY, node.y)
maxX = max(maxX, node.x)
maxY = max(maxY, node.y)
return (minX, minY, maxX, maxY)
def _boundingBoxCenter(self, nodes):
minX, minY, maxX, maxY = self._boundingBox(nodes)
return Position((minX + maxX) / 2, (minY + maxY) / 2)
class ImportProjectCommand(GraphCommand):
"""
Handle the import of a project into a Graph.
"""
def __init__(self, graph, filepath=None, position=None, yOffset=0, parent=None):
def __init__(self, graph: Graph, filepath: PathLike, position=None, yOffset=0, parent=None):
super(ImportProjectCommand, self).__init__(graph, parent)
self.filepath = filepath
self.importedNames = []
@ -239,9 +260,8 @@ class ImportProjectCommand(GraphCommand):
self.yOffset = yOffset
def redoImpl(self):
status = self.graph.load(self.filepath, setupProjectFile=False, importProject=True)
importedNodes = self.graph.importedNodes
self.setText("Import Project ({} nodes)".format(importedNodes.count))
importedNodes = self.graph.importGraphContentFromFile(self.filepath)
self.setText(f"Import Project ({len(importedNodes)} nodes)")
lowestY = 0
for node in self.graph.nodes:
@ -419,37 +439,24 @@ class UpgradeNodeCommand(GraphCommand):
super(UpgradeNodeCommand, self).__init__(graph, parent)
self.nodeDict = node.toDict()
self.nodeName = node.getName()
self.outEdges = {}
self.outListAttributes = {}
self.compatibilityIssue = None
self.setText("Upgrade Node {}".format(self.nodeName))
def redoImpl(self):
if not self.graph.node(self.nodeName).canUpgrade:
if not (node := self.graph.node(self.nodeName)).canUpgrade:
return False
upgradedNode, _, self.outEdges, self.outListAttributes = self.graph.upgradeNode(self.nodeName)
return upgradedNode
self.compatibilityIssue = node.issue
return self.graph.upgradeNode(self.nodeName)
def undoImpl(self):
# delete upgraded node
self.graph.removeNode(self.nodeName)
expectedUid = None
if self.compatibilityIssue == CompatibilityIssue.UidConflict:
expectedUid = self.graph.node(self.nodeName)._uid
# recreate compatibility node
with GraphModification(self.graph):
# We come back from an upgrade, so we enforce uidConflict=True as there was a uid conflict before
node = nodeFactory(self.nodeDict, name=self.nodeName, uidConflict=True)
self.graph.addNode(node, self.nodeName)
# recreate out edges
for dstAttr, srcAttr in self.outEdges.items():
# if edges were connected to ListAttributes, recreate their corresponding entry in said ListAttribute
# 0 = attribute name, 1 = attribute index, 2 = attribute value
if dstAttr in self.outListAttributes.keys():
listAttr = self.graph.attribute(self.outListAttributes[dstAttr][0])
if isinstance(self.outListAttributes[dstAttr][2], list):
listAttr[self.outListAttributes[dstAttr][1]:self.outListAttributes[dstAttr][1]] = self.outListAttributes[dstAttr][2]
else:
listAttr.insert(self.outListAttributes[dstAttr][1], self.outListAttributes[dstAttr][2])
self.graph.addEdge(self.graph.attribute(srcAttr),
self.graph.attribute(dstAttr))
node = nodeFactory(self.nodeDict, name=self.nodeName, expectedUid=expectedUid)
self.graph.replaceNode(self.nodeName, node)
class EnableGraphUpdateCommand(GraphCommand):

View file

@ -25,6 +25,7 @@ from meshroom.core import sessionUid
from meshroom.common.qt import QObjectListModel
from meshroom.core.attribute import Attribute, ListAttribute
from meshroom.core.graph import Graph, Edge
from meshroom.core.graphIO import GraphIO
from meshroom.core.taskManager import TaskManager
@ -396,7 +397,7 @@ class UIGraph(QObject):
self.updateChunks()
# perform auto-layout if graph does not provide nodes positions
if Graph.IO.Features.NodesPositions not in self._graph.fileFeatures:
if GraphIO.Features.NodesPositions not in self._graph.fileFeatures:
self._layout.reset()
# clear undo-stack after layout
self._undoStack.clear()
@ -451,17 +452,21 @@ class UIGraph(QObject):
self.stopExecution()
self._chunksMonitor.stop()
@Slot(str, result=bool)
def loadGraph(self, filepath, setupProjectFile=True, publishOutputs=False):
g = Graph('')
status = True
@Slot(str)
def loadGraph(self, filepath):
g = Graph("")
if filepath:
status = g.load(filepath, setupProjectFile, importProject=False, publishOutputs=publishOutputs)
g.load(filepath)
if not os.path.exists(g.cacheDir):
os.mkdir(g.cacheDir)
g.fileDateVersion = os.path.getmtime(filepath)
self.setGraph(g)
return status
@Slot(str, bool, result=bool)
def initFromTemplate(self, filepath, publishOutputs=False):
graph = Graph("")
if filepath:
graph.initFromTemplate(filepath, publishOutputs=publishOutputs)
self.setGraph(graph)
@Slot(QUrl, result="QVariantList")
@Slot(QUrl, QPoint, result="QVariantList")
@ -1045,126 +1050,43 @@ class UIGraph(QObject):
"""
if not self._nodeSelection.hasSelection():
return ""
serializedSelection = {node.name: node.toDict() for node in self.iterSelectedNodes()}
return json.dumps(serializedSelection, indent=4)
graphData = self._graph.serializePartial(self.getSelectedNodes())
return json.dumps(graphData, indent=4)
@Slot(str, QPoint, bool, result=list)
def pasteNodes(self, clipboardContent, position=None, centerPosition=False) -> list[Node]:
@Slot(str, QPoint, result=list)
def pasteNodes(self, serializedData: str, position: Optional[QPoint]=None) -> list[Node]:
"""
Parse the content of the clipboard to see whether it contains
valid node descriptions. If that is the case, the nodes described
in the clipboard are built with the available information.
Otherwise, nothing is done.
Import string-serialized graph content `serializedData` in the current graph, optionally at the given
`position`.
If the `serializedData` does not contain valid serialized graph data, nothing is done.
This function does not need to be preceded by a call to "getSelectedNodesContent".
Any clipboard content that contains at least a node type with a valid JSON
formatting (dictionary form with double quotes around the keys and values)
can be used to generate a node.
This method can be used with the result of "getSelectedNodesContent".
But it also accepts any serialized content that matches the graph data or graph content format.
For example, it is enough to have:
{"nodeName_1": {"nodeType":"CameraInit"}, "nodeName_2": {"nodeType":"FeatureMatching"}}
in the clipboard to create a default CameraInit and a default FeatureMatching nodes.
in `serializedData` to create a default CameraInit and a default FeatureMatching nodes.
Args:
clipboardContent (str): the string contained in the clipboard, that may or may not contain valid
node information
position (QPoint): the position of the mouse in the Graph Editor when the function was called
centerPosition (bool): whether the provided position is not the top-left corner of the pasting
zone, but its center
serializedData: The string-serialized graph data.
position: The position where to paste the nodes. If None, the nodes are pasted at (0, 0).
Returns:
list: the list of Node objects that were pasted and added to the graph
"""
if not clipboardContent:
return
try:
d = json.loads(clipboardContent)
except ValueError as e:
raise ValueError(e)
graphData = json.loads(serializedData)
except json.JSONDecodeError:
logging.warning("Content is not a valid JSON string.")
return []
if not isinstance(d, dict):
raise ValueError("The clipboard does not contain a valid node. Cannot paste it.")
pos = Position(position.x(), position.y()) if position else Position(0, 0)
result = self.push(commands.PasteNodesCommand(self._graph, graphData, pos))
if result is False:
logging.warning("Content is not a valid graph data.")
return []
return result
# If the clipboard contains a header, then a whole file is contained in the clipboard
# Extract the "graph" part and paste it all, ignore the rest
if d.get("header", None):
d = d.get("graph", None)
if not d:
return
if isinstance(position, QPoint):
position = Position(position.x(), position.y())
if self.hoveredNode:
# If a node is hovered, add an offset to prevent complete occlusion
position = Position(position.x + self.layout.gridSpacing, position.y + self.layout.gridSpacing)
# Get the position of the first node in a zone whose top-left corner is the mouse and the bottom-right
# corner the (x, y) coordinates, with x the maximum of all the nodes' position along the x-axis, and y the
# maximum of all the nodes' position along the y-axis. All nodes with a position will be placed relatively
# to the first node within that zone.
firstNodePos = None
minX = 0
maxX = 0
minY = 0
maxY = 0
for key in sorted(d):
nodeType = d[key].get("nodeType", None)
if not nodeType:
raise ValueError("Invalid node description: no provided node type for '{}'".format(key))
pos = d[key].get("position", None)
if pos:
if not firstNodePos:
firstNodePos = pos
minX = pos[0]
maxX = pos[0]
minY = pos[1]
maxY = pos[1]
else:
if minX > pos[0]:
minX = pos[0]
if maxX < pos[0]:
maxX = pos[0]
if minY > pos[1]:
minY = pos[1]
if maxY < pos[1]:
maxY = pos[1]
# Ensure there will not be an error if no node has a specified position
if not firstNodePos:
firstNodePos = [0, 0]
# Position of the first node within the zone
position = Position(position.x + firstNodePos[0] - minX, position.y + firstNodePos[1] - minY)
if centerPosition: # Center the zone around the mouse's position (mouse's position might be artificial)
maxX = maxX + self.layout.nodeWidth # maxX and maxY are the position of the furthest node's top-left corner
maxY = maxY + self.layout.nodeHeight # We want the position of the furthest node's bottom-right corner
position = Position(position.x - ((maxX - minX) / 2), position.y - ((maxY - minY) / 2))
finalPosition = None
prevPosition = None
positions = []
for key in sorted(d):
currentPosition = d[key].get("position", None)
if not finalPosition:
finalPosition = position
else:
if prevPosition and currentPosition:
# If the nodes both have a position, recreate the distance between them with a different
# starting point
x = finalPosition.x + (currentPosition[0] - prevPosition[0])
y = finalPosition.y + (currentPosition[1] - prevPosition[1])
finalPosition = Position(x, y)
else:
# If either the current node or previous one lacks a position, use a custom one
finalPosition = Position(finalPosition.x + self.layout.gridSpacing + self.layout.nodeWidth, finalPosition.y)
prevPosition = currentPosition
positions.append(finalPosition)
return self.push(commands.PasteNodesCommand(self.graph, d, position=positions))
undoStack = Property(QObject, lambda self: self._undoStack, constant=True)
graphChanged = Signal()

View file

@ -185,7 +185,7 @@ Page {
nameFilters: ["Meshroom Graphs (*.mg)"]
onAccepted: {
// Open the template as a regular file
if (_reconstruction.loadUrl(currentFile, true, true)) {
if (_reconstruction.load(currentFile)) {
MeshroomApp.addRecentProjectFile(currentFile.toString())
}
}
@ -400,7 +400,7 @@ Page {
text: "Reload File"
onClicked: {
_reconstruction.loadUrl(_reconstruction.graph.filepath)
_reconstruction.load(_reconstruction.graph.filepath)
fileModifiedDialog.close()
}
}
@ -705,7 +705,7 @@ Page {
MenuItem {
onTriggered: ensureSaved(function() {
openRecentMenu.dismiss()
if (_reconstruction.loadUrl(modelData["path"])) {
if (_reconstruction.load(modelData["path"])) {
MeshroomApp.addRecentProjectFile(modelData["path"])
} else {
MeshroomApp.removeRecentProjectFile(modelData["path"])

View file

@ -82,25 +82,18 @@ Item {
/// Paste content of clipboard to graph editor and create new node if valid
function pasteNodes() {
var finalPosition = undefined
var centerPosition = false
let finalPosition = undefined;
if (mouseArea.containsMouse) {
if (uigraph.hoveredNode !== null) {
var node = nodeDelegate(uigraph.hoveredNode)
finalPosition = Qt.point(node.mousePosition.x + node.x, node.mousePosition.y + node.y)
} else {
finalPosition = mapToItem(draggable, mouseArea.mouseX, mouseArea.mouseY)
}
finalPosition = mapToItem(draggable, mouseArea.mouseX, mouseArea.mouseY);
} else {
finalPosition = getCenterPosition()
centerPosition = true
finalPosition = getCenterPosition();
}
var copiedContent = Clipboard.getText()
var nodes = uigraph.pasteNodes(copiedContent, finalPosition, centerPosition)
const copiedContent = Clipboard.getText();
const nodes = uigraph.pasteNodes(copiedContent, finalPosition);
if (nodes.length > 0) {
uigraph.selectedNode = nodes[0]
uigraph.selectNodes(nodes)
uigraph.selectedNode = nodes[0];
uigraph.selectNodes(nodes);
}
}

View file

@ -389,7 +389,7 @@ Page {
} else {
// Open project
mainStack.push("Application.qml")
if (_reconstruction.loadUrl(modelData["path"])) {
if (_reconstruction.load(modelData["path"])) {
MeshroomApp.addRecentProjectFile(modelData["path"])
} else {
MeshroomApp.removeRecentProjectFile(modelData["path"])

View file

@ -128,7 +128,7 @@ ApplicationWindow {
if (mainStack.currentItem instanceof Homepage) {
mainStack.push("Application.qml")
}
if (_reconstruction.loadUrl(currentFile)) {
if (_reconstruction.load(currentFile)) {
MeshroomApp.addRecentProjectFile(currentFile.toString())
}
}

View file

@ -5,6 +5,7 @@ import os
from collections.abc import Iterable
from multiprocessing.pool import ThreadPool
from threading import Thread
from typing import Callable
from PySide6.QtCore import QObject, Slot, Property, Signal, QUrl, QSizeF, QPoint
from PySide6.QtGui import QMatrix4x4, QMatrix3x3, QQuaternion, QVector3D, QVector2D
@ -534,17 +535,24 @@ class Reconstruction(UIGraph):
# - correct pipeline name but the case does not match (e.g. panoramaHDR instead of panoramaHdr)
# - lowercase pipeline name given through the "New Pipeline" menu
loweredPipelineTemplates = dict((k.lower(), v) for k, v in meshroom.core.pipelineTemplates.items())
if p.lower() in loweredPipelineTemplates:
self.load(loweredPipelineTemplates[p.lower()], setupProjectFile=False)
else:
# use the user-provided default project file
self.load(p, setupProjectFile=False)
filepath = loweredPipelineTemplates.get(p.lower(), p)
return self._loadWithErrorReport(self.initFromTemplate, filepath)
@Slot(str, result=bool)
def load(self, filepath, setupProjectFile=True, publishOutputs=False):
@Slot(QUrl, result=bool)
def load(self, url):
if isinstance(url, QUrl):
# depending how the QUrl has been initialized,
# toLocalFile() may return the local path or an empty string
localFile = url.toLocalFile() or url.toString()
else:
localFile = url
return self._loadWithErrorReport(self.loadGraph, localFile)
def _loadWithErrorReport(self, loadFunction: Callable[[str], None], filepath: str):
logging.info(f"Load project file: '{filepath}'")
try:
status = super(Reconstruction, self).loadGraph(filepath, setupProjectFile, publishOutputs)
loadFunction(filepath)
# warn about pre-release projects being automatically upgraded
if Version(self._graph.fileReleaseVersion).major == "0":
self.warning.emit(Message(
@ -554,8 +562,8 @@ class Reconstruction(UIGraph):
"Open it with the corresponding version of Meshroom to recover your data."
))
self.setActive(True)
return status
except FileNotFoundError as e:
return True
except FileNotFoundError:
self.error.emit(
Message(
"No Such File",
@ -564,8 +572,7 @@ class Reconstruction(UIGraph):
)
)
logging.error("Error while loading '{}': No Such File.".format(filepath))
return False
except Exception as e:
except Exception:
import traceback
trace = traceback.format_exc()
self.error.emit(
@ -577,20 +584,8 @@ class Reconstruction(UIGraph):
)
logging.error("Error while loading '{}'.".format(filepath))
logging.error(trace)
return False
@Slot(QUrl, result=bool)
@Slot(QUrl, bool, bool, result=bool)
def loadUrl(self, url, setupProjectFile=True, publishOutputs=False):
if isinstance(url, (QUrl)):
# depending how the QUrl has been initialized,
# toLocalFile() may return the local path or an empty string
localFile = url.toLocalFile()
if not localFile:
localFile = url.toString()
else:
localFile = url
return self.load(localFile, setupProjectFile, publishOutputs)
return False
def onGraphChanged(self):
""" React to the change of the internal graph. """
@ -860,7 +855,7 @@ class Reconstruction(UIGraph):
)
)
else:
return self.loadUrl(filesByType["meshroomScenes"][0])
return self.load(filesByType["meshroomScenes"][0])

View file

@ -4,6 +4,7 @@ import tempfile
import os
import copy
from typing import Type
import pytest
import meshroom.core
@ -12,6 +13,8 @@ from meshroom.core.exception import GraphCompatibilityError, NodeUpgradeError
from meshroom.core.graph import Graph, loadGraph
from meshroom.core.node import CompatibilityNode, CompatibilityIssue, Node
from .utils import registeredNodeTypes, overrideNodeTypeVersion
SampleGroupV1 = [
desc.IntParam(name="a", label="a", description="", value=0, range=None),
@ -156,6 +159,12 @@ class SampleInputNodeV2(desc.InputNode):
]
def replaceNodeTypeDesc(nodeType: str, nodeDesc: Type[desc.Node]):
"""Change the `nodeDesc` associated to `nodeType`."""
meshroom.core.nodesDesc[nodeType] = nodeDesc
def test_unknown_node_type():
"""
Test compatibility behavior for unknown node type.
@ -218,8 +227,7 @@ def test_description_conflict():
g.save(graphFile)
# reload file as-is, ensure no compatibility issue is detected (no CompatibilityNode instances)
g = loadGraph(graphFile)
assert all(isinstance(n, Node) for n in g.nodes)
loadGraph(graphFile, strictCompatibility=True)
# offset node types register to create description conflicts
# each node type name now reference the next one's implementation
@ -247,7 +255,7 @@ def test_description_conflict():
assert not hasattr(compatNode, "in")
# perform upgrade
upgradedNode = g.upgradeNode(nodeName)[0]
upgradedNode = g.upgradeNode(nodeName)
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV2)
assert list(upgradedNode.attributes.keys()) == ["in", "paramA", "output"]
@ -262,7 +270,7 @@ def test_description_conflict():
assert hasattr(compatNode, "paramA")
# perform upgrade
upgradedNode = g.upgradeNode(nodeName)[0]
upgradedNode = g.upgradeNode(nodeName)
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV3)
assert not hasattr(upgradedNode, "paramA")
@ -275,7 +283,7 @@ def test_description_conflict():
assert not hasattr(compatNode, "paramA")
# perform upgrade
upgradedNode = g.upgradeNode(nodeName)[0]
upgradedNode = g.upgradeNode(nodeName)
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV4)
assert hasattr(upgradedNode, "paramA")
@ -295,7 +303,7 @@ def test_description_conflict():
assert isinstance(elt, next(a for a in SampleGroupV1 if a.name == elt.name).__class__)
# perform upgrade
upgradedNode = g.upgradeNode(nodeName)[0]
upgradedNode = g.upgradeNode(nodeName)
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV5)
assert hasattr(upgradedNode, "paramA")
@ -399,20 +407,250 @@ def test_conformUpgrade():
class TestGraphLoadingWithStrictCompatibility:
def test_failsOnNodeDescriptionCompatibilityIssue(self, graphSavedOnDisk):
registerNodeType(SampleNodeV1)
registerNodeType(SampleNodeV2)
graph: Graph = graphSavedOnDisk
graph.addNewNode(SampleNodeV1.__name__)
graph.save()
# Replace saved node description by V2
meshroom.core.nodesDesc[SampleNodeV1.__name__] = SampleNodeV2
def test_failsOnUnknownNodeType(self, graphSavedOnDisk):
with registeredNodeTypes([SampleNodeV1]):
graph: Graph = graphSavedOnDisk
graph.addNewNode(SampleNodeV1.__name__)
graph.save()
with pytest.raises(GraphCompatibilityError):
loadGraph(graph.filepath, strictCompatibility=True)
unregisterNodeType(SampleNodeV1)
unregisterNodeType(SampleNodeV2)
def test_failsOnNodeDescriptionCompatibilityIssue(self, graphSavedOnDisk):
with registeredNodeTypes([SampleNodeV1, SampleNodeV2]):
graph: Graph = graphSavedOnDisk
graph.addNewNode(SampleNodeV1.__name__)
graph.save()
replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2)
with pytest.raises(GraphCompatibilityError):
loadGraph(graph.filepath, strictCompatibility=True)
class TestGraphTemplateLoading:
def test_failsOnUnknownNodeTypeError(self, graphSavedOnDisk):
with registeredNodeTypes([SampleNodeV1, SampleNodeV2]):
graph: Graph = graphSavedOnDisk
graph.addNewNode(SampleNodeV1.__name__)
graph.save(template=True)
with pytest.raises(GraphCompatibilityError):
loadGraph(graph.filepath, strictCompatibility=True)
def test_loadsIfIncompatibleNodeHasDefaultAttributeValues(self, graphSavedOnDisk):
with registeredNodeTypes([SampleNodeV1, SampleNodeV2]):
graph: Graph = graphSavedOnDisk
graph.addNewNode(SampleNodeV1.__name__)
graph.save(template=True)
replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2)
loadGraph(graph.filepath, strictCompatibility=True)
def test_loadsIfValueSetOnCompatibleAttribute(self, graphSavedOnDisk):
with registeredNodeTypes([SampleNodeV1, SampleNodeV2]):
graph: Graph = graphSavedOnDisk
node = graph.addNewNode(SampleNodeV1.__name__, paramA="foo")
graph.save(template=True)
replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2)
loadedGraph = loadGraph(graph.filepath, strictCompatibility=True)
assert loadedGraph.nodes.get(node.name).paramA.value == "foo"
def test_loadsIfValueSetOnIncompatibleAttribute(self, graphSavedOnDisk):
with registeredNodeTypes([SampleNodeV1, SampleNodeV2]):
graph: Graph = graphSavedOnDisk
graph.addNewNode(SampleNodeV1.__name__, input="foo")
graph.save(template=True)
replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2)
loadGraph(graph.filepath, strictCompatibility=True)
class TestVersionConflict:
def test_loadingConflictingNodeVersionCreatesCompatibilityNodes(self, graphSavedOnDisk):
graph: Graph = graphSavedOnDisk
with registeredNodeTypes([SampleNodeV1]):
with overrideNodeTypeVersion(SampleNodeV1, "1.0"):
node = graph.addNewNode(SampleNodeV1.__name__)
graph.save()
with overrideNodeTypeVersion(SampleNodeV1, "2.0"):
otherGraph = Graph("")
otherGraph.load(graph.filepath)
assert len(otherGraph.compatibilityNodes) == 1
assert otherGraph.node(node.name).issue is CompatibilityIssue.VersionConflict
def test_loadingUnspecifiedNodeVersionAssumesCurrentVersion(self, graphSavedOnDisk):
graph: Graph = graphSavedOnDisk
with registeredNodeTypes([SampleNodeV1]):
graph.addNewNode(SampleNodeV1.__name__)
graph.save()
with overrideNodeTypeVersion(SampleNodeV1, "2.0"):
otherGraph = Graph("")
otherGraph.load(graph.filepath)
assert len(otherGraph.compatibilityNodes) == 0
class UidTestingNodeV1(desc.Node):
inputs = [
desc.File(name="input", label="Input", description="", value="", invalidate=True),
]
outputs = [desc.File(name="output", label="Output", description="", value=desc.Node.internalFolder)]
class UidTestingNodeV2(desc.Node):
"""
Changes from SampleNodeBV1:
* 'param' has been added
"""
inputs = [
desc.File(name="input", label="Input", description="", value="", invalidate=True),
desc.ListAttribute(
name="param",
label="Param",
elementDesc=desc.File(
name="file",
label="File",
description="",
value="",
),
description="",
),
]
outputs = [desc.File(name="output", label="Output", description="", value=desc.Node.internalFolder)]
class UidTestingNodeV3(desc.Node):
"""
Changes from SampleNodeBV2:
* 'input' is not invalidating the UID.
"""
inputs = [
desc.File(name="input", label="Input", description="", value="", invalidate=False),
desc.ListAttribute(
name="param",
label="Param",
elementDesc=desc.File(
name="file",
label="File",
description="",
value="",
),
description="",
),
]
outputs = [desc.File(name="output", label="Output", description="", value=desc.Node.internalFolder)]
class TestUidConflict:
def test_changingInvalidateOnAttributeDescCreatesUidConflict(self, graphSavedOnDisk):
with registeredNodeTypes([UidTestingNodeV2]):
graph: Graph = graphSavedOnDisk
node = graph.addNewNode(UidTestingNodeV2.__name__)
graph.save()
replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3)
with pytest.raises(GraphCompatibilityError):
loadGraph(graph.filepath, strictCompatibility=True)
loadedGraph = loadGraph(graph.filepath)
loadedNode = loadedGraph.node(node.name)
assert isinstance(loadedNode, CompatibilityNode)
assert loadedNode.issue == CompatibilityIssue.UidConflict
def test_uidConflictingNodesPreserveConnectionsOnGraphLoad(self, graphSavedOnDisk):
with registeredNodeTypes([UidTestingNodeV2]):
graph: Graph = graphSavedOnDisk
nodeA = graph.addNewNode(UidTestingNodeV2.__name__)
nodeB = graph.addNewNode(UidTestingNodeV2.__name__)
nodeB.param.append("")
graph.addEdge(nodeA.output, nodeB.param.at(0))
graph.save()
replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3)
loadedGraph = loadGraph(graph.filepath)
assert len(loadedGraph.compatibilityNodes) == 2
loadedNodeA = loadedGraph.node(nodeA.name)
loadedNodeB = loadedGraph.node(nodeB.name)
assert loadedNodeB.param.at(0).linkParam == loadedNodeA.output
def test_upgradingConflictingNodesPreserveConnections(self, graphSavedOnDisk):
with registeredNodeTypes([UidTestingNodeV2]):
graph: Graph = graphSavedOnDisk
nodeA = graph.addNewNode(UidTestingNodeV2.__name__)
nodeB = graph.addNewNode(UidTestingNodeV2.__name__)
# Double-connect nodeA.output to nodeB, on both a single attribute and a list attribute
nodeB.param.append("")
graph.addEdge(nodeA.output, nodeB.param.at(0))
graph.addEdge(nodeA.output, nodeB.input)
graph.save()
replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3)
def checkNodeAConnectionsToNodeB():
loadedNodeA = loadedGraph.node(nodeA.name)
loadedNodeB = loadedGraph.node(nodeB.name)
return (
loadedNodeB.param.at(0).linkParam == loadedNodeA.output
and loadedNodeB.input.linkParam == loadedNodeA.output
)
loadedGraph = loadGraph(graph.filepath)
loadedGraph.upgradeNode(nodeA.name)
assert checkNodeAConnectionsToNodeB()
loadedGraph.upgradeNode(nodeB.name)
assert checkNodeAConnectionsToNodeB()
assert len(loadedGraph.compatibilityNodes) == 0
def test_uidConflictDoesNotPropagateToValidDownstreamNodeThroughConnection(self, graphSavedOnDisk):
with registeredNodeTypes([UidTestingNodeV1, UidTestingNodeV2]):
graph: Graph = graphSavedOnDisk
nodeA = graph.addNewNode(UidTestingNodeV2.__name__)
nodeB = graph.addNewNode(UidTestingNodeV1.__name__)
graph.addEdge(nodeA.output, nodeB.input)
graph.save()
replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3)
loadedGraph = loadGraph(graph.filepath)
assert len(loadedGraph.compatibilityNodes) == 1
def test_uidConflictDoesNotPropagateToValidDownstreamNodeThroughListConnection(self, graphSavedOnDisk):
with registeredNodeTypes([UidTestingNodeV2, UidTestingNodeV3]):
graph: Graph = graphSavedOnDisk
nodeA = graph.addNewNode(UidTestingNodeV2.__name__)
nodeB = graph.addNewNode(UidTestingNodeV3.__name__)
nodeB.param.append("")
graph.addEdge(nodeA.output, nodeB.param.at(0))
graph.save()
replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3)
loadedGraph = loadGraph(graph.filepath)
assert len(loadedGraph.compatibilityNodes) == 1

364
tests/test_graphIO.py Normal file
View file

@ -0,0 +1,364 @@
import json
from textwrap import dedent
from meshroom.core import desc
from meshroom.core.graph import Graph
from meshroom.core.node import CompatibilityIssue
from .utils import registeredNodeTypes, overrideNodeTypeVersion
class SimpleNode(desc.Node):
inputs = [
desc.File(name="input", label="Input", description="", value=""),
]
outputs = [
desc.File(name="output", label="Output", description="", value=""),
]
class NodeWithListAttributes(desc.Node):
inputs = [
desc.ListAttribute(
name="listInput",
label="List Input",
description="",
elementDesc=desc.File(name="file", label="File", description="", value=""),
exposed=True,
),
desc.GroupAttribute(
name="group",
label="Group",
description="",
groupDesc=[
desc.ListAttribute(
name="listInput",
label="List Input",
description="",
elementDesc=desc.File(name="file", label="File", description="", value=""),
exposed=True,
),
],
),
]
def compareGraphsContent(graphA: Graph, graphB: Graph) -> bool:
"""Returns whether the content (node and deges) of two graphs are considered identical.
Similar nodes: nodes with the same name, type and compatibility status.
Similar edges: edges with the same source and destination attribute names.
"""
def _buildNodesSet(graph: Graph):
return set([(node.name, node.nodeType, node.isCompatibilityNode) for node in graph.nodes])
def _buildEdgesSet(graph: Graph):
return set([(edge.src.fullName, edge.dst.fullName) for edge in graph.edges])
nodesSetA, edgesSetA = _buildNodesSet(graphA), _buildEdgesSet(graphA)
nodesSetB, edgesSetB = _buildNodesSet(graphB), _buildEdgesSet(graphB)
return nodesSetA == nodesSetB and edgesSetA == edgesSetB
class TestImportGraphContent:
def test_importEmptyGraph(self):
graph = Graph("")
otherGraph = Graph("")
nodes = otherGraph.importGraphContent(graph)
assert len(nodes) == 0
assert len(graph.nodes) == 0
def test_importGraphWithSingleNode(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
graph.addNewNode(SimpleNode.__name__)
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
assert compareGraphsContent(graph, otherGraph)
def test_importGraphWithSeveralNodes(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
graph.addNewNode(SimpleNode.__name__)
graph.addNewNode(SimpleNode.__name__)
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
assert compareGraphsContent(graph, otherGraph)
def test_importingGraphWithNodesAndEdges(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
assert compareGraphsContent(graph, otherGraph)
def test_edgeRemappingOnImportingGraphSeveralTimes(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
otherGraph.importGraphContent(graph)
def test_edgeRemappingOnImportingGraphWithUnkownNodeTypesSeveralTimes(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
otherGraph.importGraphContent(graph)
assert len(otherGraph.nodes) == 4
assert len(otherGraph.compatibilityNodes) == 4
assert len(otherGraph.edges) == 2
def test_importGraphWithUnknownNodeTypesCreatesCompatibilityNodes(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
graph.addNewNode(SimpleNode.__name__)
otherGraph = Graph("")
importedNode = otherGraph.importGraphContent(graph)
assert len(importedNode) == 1
assert importedNode[0].isCompatibilityNode
def test_importGraphContentInPlace(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
graph.importGraphContent(graph)
assert len(graph.nodes) == 4
def test_importGraphContentFromFile(self, graphSavedOnDisk):
graph: Graph = graphSavedOnDisk
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
graph.save()
otherGraph = Graph("")
nodes = otherGraph.importGraphContentFromFile(graph.filepath)
assert len(nodes) == 2
assert compareGraphsContent(graph, otherGraph)
def test_importGraphContentFromFileWithCompatibilityNodes(self, graphSavedOnDisk):
graph: Graph = graphSavedOnDisk
with registeredNodeTypes([SimpleNode]):
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA_1.output, nodeA_2.input)
graph.save()
otherGraph = Graph("")
nodes = otherGraph.importGraphContentFromFile(graph.filepath)
assert len(nodes) == 2
assert len(otherGraph.compatibilityNodes) == 2
assert not compareGraphsContent(graph, otherGraph)
def test_importingDifferentNodeVersionCreatesCompatibilityNodes(self, graphSavedOnDisk):
graph: Graph = graphSavedOnDisk
with registeredNodeTypes([SimpleNode]):
with overrideNodeTypeVersion(SimpleNode, "1.0"):
node = graph.addNewNode(SimpleNode.__name__)
graph.save()
with overrideNodeTypeVersion(SimpleNode, "2.0"):
otherGraph = Graph("")
nodes = otherGraph.importGraphContentFromFile(graph.filepath)
assert len(nodes) == 1
assert len(otherGraph.compatibilityNodes) == 1
assert otherGraph.node(node.name).issue is CompatibilityIssue.VersionConflict
class TestGraphPartialSerialization:
def test_emptyGraph(self):
graph = Graph("")
serializedGraph = graph.serializePartial([])
otherGraph = Graph("")
otherGraph._deserialize(serializedGraph)
assert compareGraphsContent(graph, otherGraph)
def test_serializeAllNodesIsSimilarToStandardSerialization(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA = graph.addNewNode(SimpleNode.__name__)
nodeB = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA.output, nodeB.input)
partialSerializedGraph = graph.serializePartial([nodeA, nodeB])
standardSerializedGraph = graph.serialize()
graphA = Graph("")
graphA._deserialize(partialSerializedGraph)
graphB = Graph("")
graphB._deserialize(standardSerializedGraph)
assert compareGraphsContent(graph, graphA)
assert compareGraphsContent(graphA, graphB)
def test_listAttributeToListAttributeConnectionIsSerialized(self):
graph = Graph("")
with registeredNodeTypes([NodeWithListAttributes]):
nodeA = graph.addNewNode(NodeWithListAttributes.__name__)
nodeB = graph.addNewNode(NodeWithListAttributes.__name__)
graph.addEdge(nodeA.listInput, nodeB.listInput)
otherGraph = Graph("")
otherGraph._deserialize(graph.serializePartial([nodeA, nodeB]))
assert otherGraph.node(nodeB.name).listInput.linkParam == otherGraph.node(nodeA.name).listInput
def test_singleNodeWithInputConnectionFromNonSerializedNodeRemovesEdge(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA = graph.addNewNode(SimpleNode.__name__)
nodeB = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA.output, nodeB.input)
serializedGraph = graph.serializePartial([nodeB])
otherGraph = Graph("")
otherGraph._deserialize(serializedGraph)
assert len(otherGraph.compatibilityNodes) == 0
assert len(otherGraph.nodes) == 1
assert len(otherGraph.edges) == 0
def test_serializeSingleNodeWithInputConnectionToListAttributeRemovesListEntry(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode, NodeWithListAttributes]):
nodeA = graph.addNewNode(SimpleNode.__name__)
nodeB = graph.addNewNode(NodeWithListAttributes.__name__)
nodeB.listInput.append("")
graph.addEdge(nodeA.output, nodeB.listInput.at(0))
otherGraph = Graph("")
otherGraph._deserialize(graph.serializePartial([nodeB]))
assert len(otherGraph.node(nodeB.name).listInput) == 0
def test_serializeSingleNodeWithInputConnectionToNestedListAttributeRemovesListEntry(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode, NodeWithListAttributes]):
nodeA = graph.addNewNode(SimpleNode.__name__)
nodeB = graph.addNewNode(NodeWithListAttributes.__name__)
nodeB.group.listInput.append("")
graph.addEdge(nodeA.output, nodeB.group.listInput.at(0))
otherGraph = Graph("")
otherGraph._deserialize(graph.serializePartial([nodeB]))
assert len(otherGraph.node(nodeB.name).group.listInput) == 0
class TestGraphCopy:
def test_graphCopyIsIdenticalToOriginalGraph(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA = graph.addNewNode(SimpleNode.__name__)
nodeB = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA.output, nodeB.input)
graphCopy = graph.copy()
assert compareGraphsContent(graph, graphCopy)
def test_graphCopyWithUnknownNodeTypesDiffersFromOriginalGraph(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
nodeA = graph.addNewNode(SimpleNode.__name__)
nodeB = graph.addNewNode(SimpleNode.__name__)
graph.addEdge(nodeA.output, nodeB.input)
graphCopy = graph.copy()
assert not compareGraphsContent(graph, graphCopy)
class TestImportGraphContentFromMinimalGraphData:
def test_nodeWithoutVersionInfoIsUpgraded(self):
graph = Graph("")
with (
registeredNodeTypes([SimpleNode]),
overrideNodeTypeVersion(SimpleNode, "2.0"),
):
sampleGraphContent = dedent("""
{
"SimpleNode_1": { "nodeType": "SimpleNode" }
}
""")
graph._deserialize(json.loads(sampleGraphContent))
assert len(graph.nodes) == 1
assert len(graph.compatibilityNodes) == 0
def test_connectionsToMissingNodesAreDiscarded(self):
graph = Graph("")
with registeredNodeTypes([SimpleNode]):
sampleGraphContent = dedent("""
{
"SimpleNode_1": {
"nodeType": "SimpleNode", "inputs": { "input": "{NotSerializedNode.output}" }
}
}
""")
graph._deserialize(json.loads(sampleGraphContent))

View file

@ -431,3 +431,28 @@ class TestAttributeCallbackBehaviorWithUpstreamDynamicOutputs:
assert nodeB.affectedInput.value == 0
class TestAttributeCallbackBehaviorOnGraphImport:
@classmethod
def setup_class(cls):
registerNodeType(NodeWithAttributeChangedCallback)
@classmethod
def teardown_class(cls):
unregisterNodeType(NodeWithAttributeChangedCallback)
def test_importingGraphDoesNotTriggerAttributeChangedCallbacks(self):
graph = Graph("")
nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
graph.addEdge(nodeA.affectedInput, nodeB.input)
nodeA.input.value = 5
nodeB.affectedInput.value = 2
otherGraph = Graph("")
otherGraph.importGraphContent(graph)
assert otherGraph.node(nodeB.name).affectedInput.value == 2

View file

@ -4,6 +4,7 @@
from meshroom.core.graph import Graph
from meshroom.core import pipelineTemplates, Version
from meshroom.core.node import CompatibilityIssue, CompatibilityNode
from meshroom.core.graphIO import GraphIO
import json
import meshroom
@ -24,13 +25,13 @@ def test_templateVersions():
with open(path) as jsonFile:
fileData = json.load(jsonFile)
graphData = fileData.get(Graph.IO.Keys.Graph, fileData)
graphData = fileData.get(GraphIO.Keys.Graph, fileData)
assert isinstance(graphData, dict)
header = fileData.get(Graph.IO.Keys.Header, {})
header = fileData.get(GraphIO.Keys.Header, {})
assert header.get("template", False)
nodesVersions = header.get(Graph.IO.Keys.NodesVersions, {})
nodesVersions = header.get(GraphIO.Keys.NodesVersions, {})
for _, nodeData in graphData.items():
nodeType = nodeData["nodeType"]

28
tests/utils.py Normal file
View file

@ -0,0 +1,28 @@
from contextlib import contextmanager
from unittest.mock import patch
from typing import Type
import meshroom
from meshroom.core import registerNodeType, unregisterNodeType
from meshroom.core import desc
@contextmanager
def registeredNodeTypes(nodeTypes: list[Type[desc.Node]]):
for nodeType in nodeTypes:
registerNodeType(nodeType)
yield
for nodeType in nodeTypes:
unregisterNodeType(nodeType)
@contextmanager
def overrideNodeTypeVersion(nodeType: Type[desc.Node], version: str):
"""Helper context manager to override the version of a given node type."""
unpatchedFunc = meshroom.core.nodeVersion
with patch.object(
meshroom.core,
"nodeVersion",
side_effect=lambda type: version if type is nodeType else unpatchedFunc(type),
):
yield