mirror of
https://github.com/alicevision/Meshroom.git
synced 2025-08-03 00:38:41 +02:00
Merge pull request #2612 from alicevision/dev/graphIO
Refactor Graph de/serialization
This commit is contained in:
commit
91d2530401
19 changed files with 1538 additions and 768 deletions
|
@ -154,10 +154,10 @@ with meshroom.core.graph.GraphModification(graph):
|
|||
# initialize template pipeline
|
||||
loweredPipelineTemplates = dict((k.lower(), v) for k, v in meshroom.core.pipelineTemplates.items())
|
||||
if args.pipeline.lower() in loweredPipelineTemplates:
|
||||
graph.load(loweredPipelineTemplates[args.pipeline.lower()], setupProjectFile=False, publishOutputs=True if args.output else False)
|
||||
graph.initFromTemplate(loweredPipelineTemplates[args.pipeline.lower()], publishOutputs=True if args.output else False)
|
||||
else:
|
||||
# custom pipeline
|
||||
graph.load(args.pipeline, setupProjectFile=False, publishOutputs=True if args.output else False)
|
||||
graph.initFromTemplate(args.pipeline, publishOutputs=True if args.output else False)
|
||||
|
||||
def parseInputs(inputs, uniqueInitNode):
|
||||
"""Utility method for parsing the input and inputRecursive arguments."""
|
||||
|
|
|
@ -339,9 +339,12 @@ class Attribute(BaseObject):
|
|||
elif self.isInput and Attribute.isLinkExpression(v):
|
||||
# value is a link to another attribute
|
||||
link = v[1:-1]
|
||||
linkNode, linkAttr = link.split('.')
|
||||
linkNodeName, linkAttrName = link.split('.')
|
||||
try:
|
||||
g.addEdge(g.node(linkNode).attribute(linkAttr), self)
|
||||
node = g.node(linkNodeName)
|
||||
if not node:
|
||||
raise KeyError(f"Node '{linkNodeName}' not found")
|
||||
g.addEdge(node.attribute(linkAttrName), self)
|
||||
except KeyError as err:
|
||||
logging.warning('Connect Attribute from Expression failed.')
|
||||
logging.warning('Expression: "{exp}"\nError: "{err}".'.format(exp=v, err=err))
|
||||
|
|
|
@ -4,6 +4,7 @@ import json
|
|||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Optional
|
||||
import weakref
|
||||
from collections import defaultdict, OrderedDict
|
||||
from contextlib import contextmanager
|
||||
|
@ -16,7 +17,10 @@ from meshroom.common import BaseObject, DictModel, Slot, Signal, Property
|
|||
from meshroom.core import Version
|
||||
from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute
|
||||
from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit
|
||||
from meshroom.core.node import nodeFactory, Status, Node, CompatibilityNode
|
||||
from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer, PartialGraphSerializer
|
||||
from meshroom.core.node import BaseNode, Status, Node, CompatibilityNode
|
||||
from meshroom.core.nodeFactory import nodeFactory
|
||||
from meshroom.core.typing import PathLike
|
||||
|
||||
# Replace default encoder to support Enums
|
||||
|
||||
|
@ -148,6 +152,21 @@ def changeTopology(func):
|
|||
return decorator
|
||||
|
||||
|
||||
def blockNodeCallbacks(func):
|
||||
"""
|
||||
Graph methods loading serialized graph content must be decorated with 'blockNodeCallbacks',
|
||||
to avoid attribute changed callbacks defined on node descriptions to be triggered during
|
||||
this process.
|
||||
"""
|
||||
def inner(self, *args, **kwargs):
|
||||
self._loading = True
|
||||
try:
|
||||
return func(self, *args, **kwargs)
|
||||
finally:
|
||||
self._loading = False
|
||||
return inner
|
||||
|
||||
|
||||
class Graph(BaseObject):
|
||||
"""
|
||||
_________________ _________________ _________________
|
||||
|
@ -165,52 +184,6 @@ class Graph(BaseObject):
|
|||
"""
|
||||
_cacheDir = ""
|
||||
|
||||
class IO(object):
|
||||
""" Centralize Graph file keys and IO version. """
|
||||
__version__ = "2.0"
|
||||
|
||||
class Keys(object):
|
||||
""" File Keys. """
|
||||
# Doesn't inherit enum to simplify usage (Graph.IO.Keys.XX, without .value)
|
||||
Header = "header"
|
||||
NodesVersions = "nodesVersions"
|
||||
ReleaseVersion = "releaseVersion"
|
||||
FileVersion = "fileVersion"
|
||||
Graph = "graph"
|
||||
|
||||
class Features(Enum):
|
||||
""" File Features. """
|
||||
Graph = "graph"
|
||||
Header = "header"
|
||||
NodesVersions = "nodesVersions"
|
||||
PrecomputedOutputs = "precomputedOutputs"
|
||||
NodesPositions = "nodesPositions"
|
||||
|
||||
@staticmethod
|
||||
def getFeaturesForVersion(fileVersion):
|
||||
""" Return the list of supported features based on a file version.
|
||||
|
||||
Args:
|
||||
fileVersion (str, Version): the file version
|
||||
|
||||
Returns:
|
||||
tuple of Graph.IO.Features: the list of supported features
|
||||
"""
|
||||
if isinstance(fileVersion, str):
|
||||
fileVersion = Version(fileVersion)
|
||||
|
||||
features = [Graph.IO.Features.Graph]
|
||||
if fileVersion >= Version("1.0"):
|
||||
features += [Graph.IO.Features.Header,
|
||||
Graph.IO.Features.NodesVersions,
|
||||
Graph.IO.Features.PrecomputedOutputs,
|
||||
]
|
||||
|
||||
if fileVersion >= Version("1.1"):
|
||||
features += [Graph.IO.Features.NodesPositions]
|
||||
|
||||
return tuple(features)
|
||||
|
||||
def __init__(self, name, parent=None):
|
||||
super(Graph, self).__init__(parent)
|
||||
self.name = name
|
||||
|
@ -225,7 +198,6 @@ class Graph(BaseObject):
|
|||
self._nodes = DictModel(keyAttrName='name', parent=self)
|
||||
# Edges: use dst attribute as unique key since it can only have one input connection
|
||||
self._edges = DictModel(keyAttrName='dst', parent=self)
|
||||
self._importedNodes = DictModel(keyAttrName='name', parent=self)
|
||||
self._compatibilityNodes = DictModel(keyAttrName='name', parent=self)
|
||||
self.cacheDir = meshroom.core.defaultCacheFolder
|
||||
self._filepath = ''
|
||||
|
@ -233,20 +205,22 @@ class Graph(BaseObject):
|
|||
self.header = {}
|
||||
|
||||
def clear(self):
|
||||
self._clearGraphContent()
|
||||
self.header.clear()
|
||||
self._compatibilityNodes.clear()
|
||||
self._unsetFilepath()
|
||||
|
||||
def _clearGraphContent(self):
|
||||
self._edges.clear()
|
||||
# Tell QML nodes are going to be deleted
|
||||
for node in self._nodes:
|
||||
node.alive = False
|
||||
self._importedNodes.clear()
|
||||
self._nodes.clear()
|
||||
self._unsetFilepath()
|
||||
self._compatibilityNodes.clear()
|
||||
|
||||
@property
|
||||
def fileFeatures(self):
|
||||
""" Get loaded file supported features based on its version. """
|
||||
return Graph.IO.getFeaturesForVersion(self.header.get(Graph.IO.Keys.FileVersion, "0.0"))
|
||||
return GraphIO.getFeaturesForVersion(self.header.get(GraphIO.Keys.FileVersion, "0.0"))
|
||||
|
||||
@property
|
||||
def isLoading(self):
|
||||
|
@ -259,37 +233,84 @@ class Graph(BaseObject):
|
|||
return self._saving
|
||||
|
||||
@Slot(str)
|
||||
def load(self, filepath, setupProjectFile=True, importProject=False, publishOutputs=False):
|
||||
def load(self, filepath: PathLike):
|
||||
"""
|
||||
Load a Meshroom graph ".mg" file.
|
||||
Load a Meshroom Graph ".mg" file in place.
|
||||
|
||||
Args:
|
||||
filepath: project filepath to load
|
||||
setupProjectFile: Store the reference to the project file and setup the cache directory.
|
||||
If false, it only loads the graph of the project file as a template.
|
||||
importProject: True if the project that is loaded will be imported in the current graph, instead
|
||||
of opened.
|
||||
publishOutputs: True if "Publish" nodes from templates should not be ignored.
|
||||
filepath: The path to the Meshroom Graph file to load.
|
||||
"""
|
||||
self._loading = True
|
||||
try:
|
||||
return self._load(filepath, setupProjectFile, importProject, publishOutputs)
|
||||
finally:
|
||||
self._loading = False
|
||||
self._deserialize(Graph._loadGraphData(filepath))
|
||||
self._setFilepath(filepath)
|
||||
self._fileDateVersion = os.path.getmtime(filepath)
|
||||
|
||||
def _load(self, filepath, setupProjectFile, importProject, publishOutputs):
|
||||
if not importProject:
|
||||
self.clear()
|
||||
with open(filepath) as jsonFile:
|
||||
fileData = json.load(jsonFile)
|
||||
def initFromTemplate(self, filepath: PathLike, publishOutputs: bool = False):
|
||||
"""
|
||||
Deserialize a template Meshroom Graph ".mg" file in place.
|
||||
|
||||
self.header = fileData.get(Graph.IO.Keys.Header, {})
|
||||
When initializing from a template, the internal filepath of the graph instance is not set.
|
||||
Saving the file on disk will require to specify a filepath.
|
||||
|
||||
fileVersion = self.header.get(Graph.IO.Keys.FileVersion, "0.0")
|
||||
# Retro-compatibility for all project files with the previous UID format
|
||||
if Version(fileVersion) < Version("2.0"):
|
||||
Args:
|
||||
filepath: The path to the Meshroom Graph file to load.
|
||||
publishOutputs: (optional) Whether to keep 'Publish' nodes.
|
||||
"""
|
||||
self._deserialize(Graph._loadGraphData(filepath))
|
||||
|
||||
if not publishOutputs:
|
||||
with GraphModification(self):
|
||||
for node in [node for node in self.nodes if node.nodeType == "Publish"]:
|
||||
self.removeNode(node.name)
|
||||
|
||||
@staticmethod
|
||||
def _loadGraphData(filepath: PathLike) -> dict:
|
||||
"""Deserialize the content of the Meshroom Graph file at `filepath` to a dictionnary."""
|
||||
with open(filepath) as file:
|
||||
graphData = json.load(file)
|
||||
return graphData
|
||||
|
||||
@blockNodeCallbacks
|
||||
def _deserialize(self, graphData: dict):
|
||||
"""Deserialize `graphData` in the current Graph instance.
|
||||
|
||||
Args:
|
||||
graphData: The serialized Graph.
|
||||
"""
|
||||
self.clear()
|
||||
self.header = graphData.get(GraphIO.Keys.Header, {})
|
||||
fileVersion = Version(self.header.get(GraphIO.Keys.FileVersion, "0.0"))
|
||||
graphContent = self._normalizeGraphContent(graphData, fileVersion)
|
||||
isTemplate = self.header.get(GraphIO.Keys.Template, False)
|
||||
|
||||
with GraphModification(self):
|
||||
# iterate over nodes sorted by suffix index in their names
|
||||
for nodeName, nodeData in sorted(
|
||||
graphContent.items(), key=lambda x: self.getNodeIndexFromName(x[0])
|
||||
):
|
||||
self._deserializeNode(nodeData, nodeName, self)
|
||||
|
||||
# Create graph edges by resolving attributes expressions
|
||||
self._applyExpr()
|
||||
|
||||
# Templates are specific: they contain only the minimal amount of
|
||||
# serialized data to describe the graph structure.
|
||||
# They are not meant to be computed: therefore, we can early return here,
|
||||
# as uid conflict evaluation is only meaningful for nodes with computed data.
|
||||
if isTemplate:
|
||||
return
|
||||
|
||||
# By this point, the graph has been fully loaded and an updateInternals has been triggered, so all the
|
||||
# nodes' links have been resolved and their UID computations are all complete.
|
||||
# It is now possible to check whether the UIDs stored in the graph file for each node correspond to the ones
|
||||
# that were computed.
|
||||
self._evaluateUidConflicts(graphContent)
|
||||
|
||||
def _normalizeGraphContent(self, graphData: dict, fileVersion: Version) -> dict:
|
||||
graphContent = graphData.get(GraphIO.Keys.Graph, graphData)
|
||||
|
||||
if fileVersion < Version("2.0"):
|
||||
# For internal folders, all "{uid0}" keys should be replaced with "{uid}"
|
||||
updatedFileData = json.dumps(fileData).replace("{uid0}", "{uid}")
|
||||
updatedFileData = json.dumps(graphContent).replace("{uid0}", "{uid}")
|
||||
|
||||
# For fileVersion < 2.0, the nodes' UID is stored as:
|
||||
# "uids": {"0": "hashvalue"}
|
||||
|
@ -301,239 +322,124 @@ class Graph(BaseObject):
|
|||
uid = occ.split("\"")[-2] # UID is second to last element
|
||||
newUidStr = r'"uid": "{}"'.format(uid)
|
||||
updatedFileData = updatedFileData.replace(occ, newUidStr)
|
||||
fileData = json.loads(updatedFileData)
|
||||
graphContent = json.loads(updatedFileData)
|
||||
|
||||
# Older versions of Meshroom files only contained the serialized nodes
|
||||
graphData = fileData.get(Graph.IO.Keys.Graph, fileData)
|
||||
return graphContent
|
||||
|
||||
if importProject:
|
||||
self._importedNodes.clear()
|
||||
graphData = self.updateImportedProject(graphData)
|
||||
def _deserializeNode(self, nodeData: dict, nodeName: str, fromGraph: "Graph"):
|
||||
# Retrieve version info from:
|
||||
# 1. nodeData: node saved from a CompatibilityNode
|
||||
# 2. nodesVersion in file header: node saved from a Node
|
||||
# If unvailable, the "version" field will not be set in `nodeData`.
|
||||
if "version" not in nodeData:
|
||||
if version := fromGraph._getNodeTypeVersionFromHeader(nodeData["nodeType"]):
|
||||
nodeData["version"] = version
|
||||
inTemplate = fromGraph.header.get(GraphIO.Keys.Template, False)
|
||||
node = nodeFactory(nodeData, nodeName, inTemplate=inTemplate)
|
||||
self._addNode(node, nodeName)
|
||||
return node
|
||||
|
||||
if not isinstance(graphData, dict):
|
||||
raise RuntimeError('loadGraph error: Graph is not a dict. File: {}'.format(filepath))
|
||||
def _getNodeTypeVersionFromHeader(self, nodeType: str, default: Optional[str] = None) -> Optional[str]:
|
||||
nodeVersions = self.header.get(GraphIO.Keys.NodesVersions, {})
|
||||
return nodeVersions.get(nodeType, default)
|
||||
|
||||
nodesVersions = self.header.get(Graph.IO.Keys.NodesVersions, {})
|
||||
|
||||
self._fileDateVersion = os.path.getmtime(filepath)
|
||||
|
||||
# Check whether the file was saved as a template in minimal mode
|
||||
isTemplate = self.header.get("template", False)
|
||||
|
||||
with GraphModification(self):
|
||||
# iterate over nodes sorted by suffix index in their names
|
||||
for nodeName, nodeData in sorted(graphData.items(), key=lambda x: self.getNodeIndexFromName(x[0])):
|
||||
if not isinstance(nodeData, dict):
|
||||
raise RuntimeError('loadGraph error: Node is not a dict. File: {}'.format(filepath))
|
||||
|
||||
# retrieve version from
|
||||
# 1. nodeData: node saved from a CompatibilityNode
|
||||
# 2. nodesVersion in file header: node saved from a Node
|
||||
# 3. fallback to no version "0.0": retro-compatibility
|
||||
if "version" not in nodeData:
|
||||
nodeData["version"] = nodesVersions.get(nodeData["nodeType"], "0.0")
|
||||
|
||||
# if the node is a "Publish" node and comes from a template file, it should be ignored
|
||||
# unless publishOutputs is True
|
||||
if isTemplate and not publishOutputs and nodeData["nodeType"] == "Publish":
|
||||
continue
|
||||
|
||||
n = nodeFactory(nodeData, nodeName, template=isTemplate)
|
||||
|
||||
# Add node to the graph with raw attributes values
|
||||
self._addNode(n, nodeName)
|
||||
|
||||
if importProject:
|
||||
self._importedNodes.add(n)
|
||||
|
||||
# Create graph edges by resolving attributes expressions
|
||||
self._applyExpr()
|
||||
|
||||
if setupProjectFile:
|
||||
# Update filepath related members
|
||||
# Note: needs to be done at the end as it will trigger an updateInternals.
|
||||
self._setFilepath(filepath)
|
||||
elif not isTemplate:
|
||||
# If no filepath is being set but the graph is not a template, trigger an updateInternals either way.
|
||||
self.updateInternals()
|
||||
|
||||
# By this point, the graph has been fully loaded and an updateInternals has been triggered, so all the
|
||||
# nodes' links have been resolved and their UID computations are all complete.
|
||||
# It is now possible to check whether the UIDs stored in the graph file for each node correspond to the ones
|
||||
# that were computed.
|
||||
if not isTemplate: # UIDs are not stored in templates
|
||||
self._evaluateUidConflicts(graphData)
|
||||
try:
|
||||
self._applyExpr()
|
||||
except Exception as e:
|
||||
logging.warning(e)
|
||||
|
||||
return True
|
||||
|
||||
def _evaluateUidConflicts(self, data):
|
||||
def _evaluateUidConflicts(self, graphContent: dict):
|
||||
"""
|
||||
Compare the UIDs of all the nodes in the graph with the UID that is expected in the graph file. If there
|
||||
Compare the computed UIDs of all the nodes in the graph with the UIDs serialized in `graphContent`. If there
|
||||
are mismatches, the nodes with the unexpected UID are replaced with "UidConflict" compatibility nodes.
|
||||
Already existing nodes are removed and re-added to the graph identically to preserve all the edges,
|
||||
which may otherwise be invalidated when a node with output edges but a UID conflict is re-generated as a
|
||||
compatibility node.
|
||||
|
||||
Args:
|
||||
graphContent: The serialized Graph content.
|
||||
"""
|
||||
|
||||
def _serializedNodeUidMatchesComputedUid(nodeData: dict, node: BaseNode) -> bool:
|
||||
"""Returns whether the serialized UID matches the one computed in the `node` instance."""
|
||||
if isinstance(node, CompatibilityNode):
|
||||
return True
|
||||
serializedUid = nodeData.get("uid", None)
|
||||
computedUid = node._uid
|
||||
return serializedUid is None or computedUid is None or serializedUid == computedUid
|
||||
|
||||
uidConflictingNodes = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if not _serializedNodeUidMatchesComputedUid(graphContent[node.name], node)
|
||||
]
|
||||
|
||||
if not uidConflictingNodes:
|
||||
return
|
||||
|
||||
logging.warning("UID Compatibility issues found: recreating conflicting nodes as CompatibilityNodes.")
|
||||
|
||||
# A uid conflict is contagious: if a node has a uid conflict, all of its downstream nodes may be
|
||||
# impacted as well, as the uid flows through connections.
|
||||
# Therefore, we deal with conflicting uid nodes by depth: replacing a node with a CompatibilityNode restores
|
||||
# the serialized uid, which might solve "false-positives" downstream conflicts as well.
|
||||
nodesSortedByDepth = sorted(uidConflictingNodes, key=lambda node: node.minDepth)
|
||||
for node in nodesSortedByDepth:
|
||||
nodeData = graphContent[node.name]
|
||||
# Evaluate if the node uid is still conflicting at this point, or if it has been resolved by an
|
||||
# upstream node replacement.
|
||||
if _serializedNodeUidMatchesComputedUid(nodeData, node):
|
||||
continue
|
||||
expectedUid = node._uid
|
||||
compatibilityNode = nodeFactory(graphContent[node.name], node.name, expectedUid=expectedUid)
|
||||
# This operation will trigger a graph update that will recompute the uids of all nodes,
|
||||
# allowing the iterative resolution of uid conflicts.
|
||||
self.replaceNode(node.name, compatibilityNode)
|
||||
|
||||
|
||||
def importGraphContentFromFile(self, filepath: PathLike) -> list[Node]:
|
||||
"""Import the content (nodes and edges) of another Graph file into this Graph instance.
|
||||
|
||||
Args:
|
||||
data (dict): the dictionary containing all the nodes to import and their data
|
||||
"""
|
||||
for nodeName, nodeData in sorted(data.items(), key=lambda x: self.getNodeIndexFromName(x[0])):
|
||||
node = self.node(nodeName)
|
||||
|
||||
savedUid = nodeData.get("uid", None)
|
||||
graphUid = node._uid # Node's UID from the graph itself
|
||||
|
||||
if savedUid != graphUid and graphUid is not None:
|
||||
# Different UIDs, remove the existing node from the graph and replace it with a CompatibilityNode
|
||||
logging.debug("UID conflict detected for {}".format(nodeName))
|
||||
self.removeNode(nodeName)
|
||||
n = nodeFactory(nodeData, nodeName, template=False, uidConflict=True)
|
||||
self._addNode(n, nodeName)
|
||||
else:
|
||||
# f connecting nodes have UID conflicts and are removed/re-added to the graph, some edges may be lost:
|
||||
# the links will be erroneously updated, and any further resolution will fail.
|
||||
# Recreating the entire graph as it was ensures that all edges will be correctly preserved.
|
||||
self.removeNode(nodeName)
|
||||
n = nodeFactory(nodeData, nodeName, template=False, uidConflict=False)
|
||||
self._addNode(n, nodeName)
|
||||
|
||||
def updateImportedProject(self, data):
|
||||
"""
|
||||
Update the names and links of the project to import so that it can fit
|
||||
correctly in the existing graph.
|
||||
|
||||
Parse all the nodes from the project that is going to be imported.
|
||||
If their name already exists in the graph, replace them with new names,
|
||||
then parse all the nodes' inputs/outputs to replace the old names with
|
||||
the new ones in the links.
|
||||
|
||||
Args:
|
||||
data (dict): the dictionary containing all the nodes to import and their data
|
||||
filepath: The path to the Graph file to import.
|
||||
|
||||
Returns:
|
||||
updatedData (dict): the dictionary containing all the nodes to import with their updated names and data
|
||||
The list of newly created Nodes.
|
||||
"""
|
||||
nameCorrespondences = {} # maps the old node name to its updated one
|
||||
updatedData = {} # input data with updated node names and links
|
||||
graph = loadGraph(filepath)
|
||||
return self.importGraphContent(graph)
|
||||
|
||||
def createUniqueNodeName(nodeNames, inputName):
|
||||
"""
|
||||
Create a unique name that does not already exist in the current graph or in the list
|
||||
of nodes that will be imported.
|
||||
"""
|
||||
i = 1
|
||||
while i:
|
||||
newName = "{name}_{index}".format(name=inputName, index=i)
|
||||
if newName not in nodeNames and newName not in updatedData.keys():
|
||||
return newName
|
||||
i += 1
|
||||
|
||||
# First pass to get all the names that already exist in the graph, update them, and keep track of the changes
|
||||
for nodeName, nodeData in sorted(data.items(), key=lambda x: self.getNodeIndexFromName(x[0])):
|
||||
if not isinstance(nodeData, dict):
|
||||
raise RuntimeError('updateImportedProject error: Node is not a dict.')
|
||||
|
||||
if nodeName in self._nodes.keys() or nodeName in updatedData.keys():
|
||||
newName = createUniqueNodeName(self._nodes.keys(), nodeData["nodeType"])
|
||||
updatedData[newName] = nodeData
|
||||
nameCorrespondences[nodeName] = newName
|
||||
|
||||
else:
|
||||
updatedData[nodeName] = nodeData
|
||||
|
||||
newNames = [nodeName for nodeName in updatedData] # names of all the nodes that will be added
|
||||
|
||||
# Second pass to update all the links in the input/output attributes for every node with the new names
|
||||
for nodeName, nodeData in updatedData.items():
|
||||
nodeType = nodeData.get("nodeType", None)
|
||||
nodeDesc = meshroom.core.nodesDesc[nodeType]
|
||||
|
||||
inputs = nodeData.get("inputs", {})
|
||||
outputs = nodeData.get("outputs", {})
|
||||
|
||||
if inputs:
|
||||
inputs = self.updateLinks(inputs, nameCorrespondences)
|
||||
inputs = self.resetExternalLinks(inputs, nodeDesc.inputs, newNames)
|
||||
updatedData[nodeName]["inputs"] = inputs
|
||||
if outputs:
|
||||
outputs = self.updateLinks(outputs, nameCorrespondences)
|
||||
outputs = self.resetExternalLinks(outputs, nodeDesc.outputs, newNames)
|
||||
updatedData[nodeName]["outputs"] = outputs
|
||||
|
||||
return updatedData
|
||||
|
||||
@staticmethod
|
||||
def updateLinks(attributes, nameCorrespondences):
|
||||
@blockNodeCallbacks
|
||||
def importGraphContent(self, graph: "Graph") -> list[Node]:
|
||||
"""
|
||||
Update all the links that refer to nodes that are going to be imported and whose
|
||||
names have to be updated.
|
||||
Import the content (node and edges) of another `graph` into this Graph instance.
|
||||
|
||||
Nodes are imported with their original names if possible, otherwise a new unique name is generated
|
||||
from their node type.
|
||||
|
||||
Args:
|
||||
attributes (dict): attributes whose links need to be updated
|
||||
nameCorrespondences (dict): node names to replace in the links with the name to replace them with
|
||||
graph: The graph to import.
|
||||
|
||||
Returns:
|
||||
attributes (dict): the attributes with all the updated links
|
||||
The list of newly created Nodes.
|
||||
"""
|
||||
for key, val in attributes.items():
|
||||
for corr in nameCorrespondences.keys():
|
||||
if isinstance(val, str) and corr in val:
|
||||
attributes[key] = val.replace(corr, nameCorrespondences[corr])
|
||||
elif isinstance(val, list):
|
||||
for v in val:
|
||||
if isinstance(v, str):
|
||||
if corr in v:
|
||||
val[val.index(v)] = v.replace(corr, nameCorrespondences[corr])
|
||||
else: # the list does not contain strings, so there cannot be links to update
|
||||
break
|
||||
attributes[key] = val
|
||||
|
||||
return attributes
|
||||
def _renameClashingNodes():
|
||||
if not self.nodes:
|
||||
return
|
||||
unavailableNames = set(self.nodes.keys())
|
||||
for node in graph.nodes:
|
||||
if node._name in unavailableNames:
|
||||
node._name = self._createUniqueNodeName(node.nodeType, unavailableNames)
|
||||
unavailableNames.add(node._name)
|
||||
|
||||
@staticmethod
|
||||
def resetExternalLinks(attributes, nodeDesc, newNames):
|
||||
"""
|
||||
Reset all links to nodes that are not part of the nodes which are going to be imported:
|
||||
if there are links to nodes that are not in the list, then it means that the references
|
||||
are made to external nodes, and we want to get rid of those.
|
||||
def _importNodesAndEdges() -> list[Node]:
|
||||
importedNodes = []
|
||||
# If we import the content of the graph within itself,
|
||||
# iterate over a copy of the nodes as the graph is modified during the iteration.
|
||||
nodes = graph.nodes if graph is not self else list(graph.nodes)
|
||||
with GraphModification(self):
|
||||
for srcNode in nodes:
|
||||
node = self._deserializeNode(srcNode.toDict(), srcNode.name, graph)
|
||||
importedNodes.append(node)
|
||||
self._applyExpr()
|
||||
return importedNodes
|
||||
|
||||
Args:
|
||||
attributes (dict): attributes whose links might need to be reset
|
||||
nodeDesc (list): list with all the attributes' description (including their default value)
|
||||
newNames (list): names of the nodes that are going to be imported; no node name should be referenced
|
||||
in the links except those contained in this list
|
||||
|
||||
Returns:
|
||||
attributes (dict): the attributes with all the links referencing nodes outside those which will be imported
|
||||
reset to their default values
|
||||
"""
|
||||
for key, val in attributes.items():
|
||||
defaultValue = None
|
||||
for desc in nodeDesc:
|
||||
if desc.name == key:
|
||||
defaultValue = desc.value
|
||||
break
|
||||
|
||||
if isinstance(val, str):
|
||||
if Attribute.isLinkExpression(val) and not any(name in val for name in newNames):
|
||||
if defaultValue is not None: # prevents from not entering condition if defaultValue = ''
|
||||
attributes[key] = defaultValue
|
||||
|
||||
elif isinstance(val, list):
|
||||
removedCnt = len(val) # counter to know whether all the list entries will be deemed invalid
|
||||
tmpVal = list(val) # deep copy to ensure we iterate over the entire list (even if elements are removed)
|
||||
for v in tmpVal:
|
||||
if isinstance(v, str) and Attribute.isLinkExpression(v) and not any(name in v for name in newNames):
|
||||
val.remove(v)
|
||||
removedCnt -= 1
|
||||
if removedCnt == 0 and defaultValue is not None: # if all links were wrong, reset the attribute
|
||||
attributes[key] = defaultValue
|
||||
|
||||
return attributes
|
||||
_renameClashingNodes()
|
||||
importedNodes = _importNodesAndEdges()
|
||||
return importedNodes
|
||||
|
||||
@property
|
||||
def updateEnabled(self):
|
||||
|
@ -648,41 +554,6 @@ class Graph(BaseObject):
|
|||
|
||||
return duplicates
|
||||
|
||||
def pasteNodes(self, data, position):
|
||||
"""
|
||||
Paste node(s) in the graph with their connections. The connections can only be between
|
||||
the pasted nodes and not with the rest of the graph.
|
||||
|
||||
Args:
|
||||
data (dict): the dictionary containing the information about the nodes to paste, with their names and
|
||||
links already updated to be added to the graph
|
||||
position (list): the list of positions for each node to paste
|
||||
|
||||
Returns:
|
||||
list: the list of Node objects that were pasted and added to the graph
|
||||
"""
|
||||
nodes = []
|
||||
with GraphModification(self):
|
||||
positionCnt = 0 # always valid because we know the data is sorted the same way as the position list
|
||||
for key in sorted(data):
|
||||
nodeType = data[key].get("nodeType", None)
|
||||
if not nodeType: # this case should never occur, as the data should have been prefiltered first
|
||||
pass
|
||||
|
||||
attributes = {}
|
||||
attributes.update(data[key].get("inputs", {}))
|
||||
attributes.update(data[key].get("outputs", {}))
|
||||
attributes.update(data[key].get("internalInputs", {}))
|
||||
|
||||
node = Node(nodeType, position=position[positionCnt], **attributes)
|
||||
self._addNode(node, key)
|
||||
|
||||
nodes.append(node)
|
||||
positionCnt += 1
|
||||
|
||||
self._applyExpr()
|
||||
return nodes
|
||||
|
||||
def outEdges(self, attribute):
|
||||
""" Return the list of edges starting from the given attribute """
|
||||
# type: (Attribute,) -> [Edge]
|
||||
|
@ -746,8 +617,6 @@ class Graph(BaseObject):
|
|||
|
||||
node.alive = False
|
||||
self._nodes.remove(node)
|
||||
if node in self._importedNodes:
|
||||
self._importedNodes.remove(node)
|
||||
self.update()
|
||||
|
||||
return inEdges, outEdges, outListAttributes
|
||||
|
@ -772,18 +641,26 @@ class Graph(BaseObject):
|
|||
n.updateInternals()
|
||||
return n
|
||||
|
||||
def _createUniqueNodeName(self, inputName):
|
||||
i = 1
|
||||
while i:
|
||||
newName = "{name}_{index}".format(name=inputName, index=i)
|
||||
if newName not in self._nodes.objects:
|
||||
def _createUniqueNodeName(self, inputName: str, existingNames: Optional[set[str]] = None):
|
||||
"""Create a unique node name based on the input name.
|
||||
|
||||
Args:
|
||||
inputName: The desired node name.
|
||||
existingNames: (optional) If specified, consider this set for uniqueness check, instead of the list of nodes.
|
||||
"""
|
||||
existingNodeNames = existingNames or set(self._nodes.objects.keys())
|
||||
|
||||
idx = 1
|
||||
while idx:
|
||||
newName = f"{inputName}_{idx}"
|
||||
if newName not in existingNodeNames:
|
||||
return newName
|
||||
i += 1
|
||||
idx += 1
|
||||
|
||||
def node(self, nodeName):
|
||||
return self._nodes.get(nodeName)
|
||||
|
||||
def upgradeNode(self, nodeName):
|
||||
def upgradeNode(self, nodeName) -> Node:
|
||||
"""
|
||||
Upgrade the CompatibilityNode identified as 'nodeName'
|
||||
Args:
|
||||
|
@ -803,25 +680,49 @@ class Graph(BaseObject):
|
|||
if not isinstance(node, CompatibilityNode):
|
||||
raise ValueError("Upgrade is only available on CompatibilityNode instances.")
|
||||
upgradedNode = node.upgrade()
|
||||
with GraphModification(self):
|
||||
inEdges, outEdges, outListAttributes = self.removeNode(nodeName)
|
||||
self.addNode(upgradedNode, nodeName)
|
||||
for dst, src in outEdges.items():
|
||||
# Re-create the entries in ListAttributes that were completely removed during the call to "removeNode"
|
||||
# If they are not re-created first, adding their edges will lead to errors
|
||||
# 0 = attribute name, 1 = attribute index, 2 = attribute value
|
||||
if dst in outListAttributes.keys():
|
||||
listAttr = self.attribute(outListAttributes[dst][0])
|
||||
if isinstance(outListAttributes[dst][2], list):
|
||||
listAttr[outListAttributes[dst][1]:outListAttributes[dst][1]] = outListAttributes[dst][2]
|
||||
else:
|
||||
listAttr.insert(outListAttributes[dst][1], outListAttributes[dst][2])
|
||||
try:
|
||||
self.addEdge(self.attribute(src), self.attribute(dst))
|
||||
except (KeyError, ValueError) as e:
|
||||
logging.warning("Failed to restore edge {} -> {}: {}".format(src, dst, str(e)))
|
||||
self.replaceNode(nodeName, upgradedNode)
|
||||
return upgradedNode
|
||||
|
||||
return upgradedNode, inEdges, outEdges, outListAttributes
|
||||
@changeTopology
|
||||
def replaceNode(self, nodeName: str, newNode: BaseNode):
|
||||
"""Replace the node idenfitied by `nodeName` with `newNode`, while restoring compatible edges.
|
||||
|
||||
Args:
|
||||
nodeName: The name of the Node to replace.
|
||||
newNode: The Node instance to replace it with.
|
||||
"""
|
||||
with GraphModification(self):
|
||||
_, outEdges, outListAttributes = self.removeNode(nodeName)
|
||||
self.addNode(newNode, nodeName)
|
||||
self._restoreOutEdges(outEdges, outListAttributes)
|
||||
|
||||
def _restoreOutEdges(self, outEdges: dict[str, str], outListAttributes):
|
||||
"""Restore output edges that were removed during a call to "removeNode".
|
||||
|
||||
Args:
|
||||
outEdges: a dictionary containing the outgoing edges removed by a call to "removeNode".
|
||||
{dstAttr.getFullNameToNode(), srcAttr.getFullNameToNode()}
|
||||
outListAttributes: a dictionary containing the values, indices and keys of attributes that were connected
|
||||
to a ListAttribute prior to the removal of all edges.
|
||||
{dstAttr.getFullNameToNode(), (dstAttr.root.getFullNameToNode(), dstAttr.index, dstAttr.value)}
|
||||
"""
|
||||
def _recreateTargetListAttributeChildren(listAttrName: str, index: int, value: Any):
|
||||
listAttr = self.attribute(listAttrName)
|
||||
if not isinstance(listAttr, ListAttribute):
|
||||
return
|
||||
if isinstance(value, list):
|
||||
listAttr[index:index] = value
|
||||
else:
|
||||
listAttr.insert(index, value)
|
||||
|
||||
for dstName, srcName in outEdges.items():
|
||||
# Re-create the entries in ListAttributes that were completely removed during the call to "removeNode"
|
||||
if dstName in outListAttributes:
|
||||
_recreateTargetListAttributeChildren(*outListAttributes[dstName])
|
||||
try:
|
||||
self.addEdge(self.attribute(srcName), self.attribute(dstName))
|
||||
except (KeyError, ValueError) as e:
|
||||
logging.warning(f"Failed to restore edge {srcName} -> {dstName}: {str(e)}")
|
||||
|
||||
def upgradeAllNodes(self):
|
||||
""" Upgrade all upgradable CompatibilityNode instances in the graph. """
|
||||
|
@ -1352,6 +1253,35 @@ class Graph(BaseObject):
|
|||
def asString(self):
|
||||
return str(self.toDict())
|
||||
|
||||
def copy(self) -> "Graph":
|
||||
"""Create a copy of this Graph instance."""
|
||||
graph = Graph("")
|
||||
graph._deserialize(self.serialize())
|
||||
return graph
|
||||
|
||||
def serialize(self, asTemplate: bool = False) -> dict:
|
||||
"""Serialize this Graph instance.
|
||||
|
||||
Args:
|
||||
asTemplate: Whether to use the template serialization.
|
||||
|
||||
Returns:
|
||||
The serialized graph data.
|
||||
"""
|
||||
SerializerClass = TemplateGraphSerializer if asTemplate else GraphSerializer
|
||||
return SerializerClass(self).serialize()
|
||||
|
||||
def serializePartial(self, nodes: list[Node]) -> dict:
|
||||
"""Partially serialize this graph considering only the given list of `nodes`.
|
||||
|
||||
Args:
|
||||
nodes: The list of nodes to serialize.
|
||||
|
||||
Returns:
|
||||
The serialized graph data.
|
||||
"""
|
||||
return PartialGraphSerializer(self, nodes=nodes).serialize()
|
||||
|
||||
def save(self, filepath=None, setupProjectFile=True, template=False):
|
||||
"""
|
||||
Save the current Meshroom graph as a serialized ".mg" file.
|
||||
|
@ -1374,34 +1304,7 @@ class Graph(BaseObject):
|
|||
if not path:
|
||||
raise ValueError("filepath must be specified for unsaved files.")
|
||||
|
||||
self.header[Graph.IO.Keys.ReleaseVersion] = meshroom.__version__
|
||||
self.header[Graph.IO.Keys.FileVersion] = Graph.IO.__version__
|
||||
|
||||
# Store versions of node types present in the graph (excluding CompatibilityNode instances)
|
||||
# and remove duplicates
|
||||
usedNodeTypes = set([n.nodeDesc.__class__ for n in self._nodes if isinstance(n, Node)])
|
||||
# Convert to node types to "name: version"
|
||||
nodesVersions = {
|
||||
"{}".format(p.__name__): meshroom.core.nodeVersion(p, "0.0")
|
||||
for p in usedNodeTypes
|
||||
}
|
||||
# Sort them by name (to avoid random order changing from one save to another)
|
||||
nodesVersions = dict(sorted(nodesVersions.items()))
|
||||
# Add it the header
|
||||
self.header[Graph.IO.Keys.NodesVersions] = nodesVersions
|
||||
self.header["template"] = template
|
||||
|
||||
data = {}
|
||||
if template:
|
||||
data = {
|
||||
Graph.IO.Keys.Header: self.header,
|
||||
Graph.IO.Keys.Graph: self.getNonDefaultInputAttributes()
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
Graph.IO.Keys.Header: self.header,
|
||||
Graph.IO.Keys.Graph: self.toDict()
|
||||
}
|
||||
data = self.serialize(template)
|
||||
|
||||
with open(path, 'w') as jsonFile:
|
||||
json.dump(data, jsonFile, indent=4)
|
||||
|
@ -1412,51 +1315,6 @@ class Graph(BaseObject):
|
|||
# update the file date version
|
||||
self._fileDateVersion = os.path.getmtime(path)
|
||||
|
||||
def getNonDefaultInputAttributes(self):
|
||||
"""
|
||||
Instead of getting all the inputs and internal attribute keys, only get the keys of
|
||||
the attributes whose value is not the default one.
|
||||
The output attributes, UIDs, parallelization parameters and internal folder are
|
||||
not relevant for templates, so they are explicitly removed from the returned dictionary.
|
||||
|
||||
Returns:
|
||||
dict: self.toDict() with the output attributes, UIDs, parallelization parameters, internal folder
|
||||
and input/internal attributes with default values removed
|
||||
"""
|
||||
graph = self.toDict()
|
||||
for nodeName in graph.keys():
|
||||
node = self.node(nodeName)
|
||||
|
||||
inputKeys = list(graph[nodeName]["inputs"].keys())
|
||||
|
||||
internalInputKeys = []
|
||||
internalInputs = graph[nodeName].get("internalInputs", None)
|
||||
if internalInputs:
|
||||
internalInputKeys = list(internalInputs.keys())
|
||||
|
||||
for attrName in inputKeys:
|
||||
attribute = node.attribute(attrName)
|
||||
# check that attribute is not a link for choice attributes
|
||||
if attribute.isDefault and not attribute.isLink:
|
||||
del graph[nodeName]["inputs"][attrName]
|
||||
|
||||
for attrName in internalInputKeys:
|
||||
attribute = node.internalAttribute(attrName)
|
||||
# check that internal attribute is not a link for choice attributes
|
||||
if attribute.isDefault and not attribute.isLink:
|
||||
del graph[nodeName]["internalInputs"][attrName]
|
||||
|
||||
# If all the internal attributes are set to their default values, remove the entry
|
||||
if len(graph[nodeName]["internalInputs"]) == 0:
|
||||
del graph[nodeName]["internalInputs"]
|
||||
|
||||
del graph[nodeName]["outputs"]
|
||||
del graph[nodeName]["uid"]
|
||||
del graph[nodeName]["internalFolder"]
|
||||
del graph[nodeName]["parallelization"]
|
||||
|
||||
return graph
|
||||
|
||||
def _setFilepath(self, filepath):
|
||||
"""
|
||||
Set the internal filepath of this Graph.
|
||||
|
@ -1615,11 +1473,6 @@ class Graph(BaseObject):
|
|||
def edges(self):
|
||||
return self._edges
|
||||
|
||||
@property
|
||||
def importedNodes(self):
|
||||
"""" Return the list of nodes that were added to the graph with the latest 'Import Project' action. """
|
||||
return self._importedNodes
|
||||
|
||||
@property
|
||||
def cacheDir(self):
|
||||
return self._cacheDir
|
||||
|
@ -1660,7 +1513,7 @@ class Graph(BaseObject):
|
|||
filepathChanged = Signal()
|
||||
filepath = Property(str, lambda self: self._filepath, notify=filepathChanged)
|
||||
isSaving = Property(bool, isSaving.fget, constant=True)
|
||||
fileReleaseVersion = Property(str, lambda self: self.header.get(Graph.IO.Keys.ReleaseVersion, "0.0"),
|
||||
fileReleaseVersion = Property(str, lambda self: self.header.get(GraphIO.Keys.ReleaseVersion, "0.0"),
|
||||
notify=filepathChanged)
|
||||
fileDateVersion = Property(float, fileDateVersion.fget, fileDateVersion.fset, notify=filepathChanged)
|
||||
cacheDirChanged = Signal()
|
||||
|
|
231
meshroom/core/graphIO.py
Normal file
231
meshroom/core/graphIO.py
Normal file
|
@ -0,0 +1,231 @@
|
|||
from enum import Enum
|
||||
from typing import Any, TYPE_CHECKING, Union
|
||||
|
||||
import meshroom
|
||||
from meshroom.core import Version
|
||||
from meshroom.core.attribute import Attribute, GroupAttribute, ListAttribute
|
||||
from meshroom.core.node import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from meshroom.core.graph import Graph
|
||||
|
||||
|
||||
class GraphIO:
|
||||
"""Centralize Graph file keys and IO version."""
|
||||
|
||||
__version__ = "2.0"
|
||||
|
||||
class Keys(object):
|
||||
"""File Keys."""
|
||||
|
||||
# Doesn't inherit enum to simplify usage (GraphIO.Keys.XX, without .value)
|
||||
Header = "header"
|
||||
NodesVersions = "nodesVersions"
|
||||
ReleaseVersion = "releaseVersion"
|
||||
FileVersion = "fileVersion"
|
||||
Graph = "graph"
|
||||
Template = "template"
|
||||
|
||||
class Features(Enum):
|
||||
"""File Features."""
|
||||
|
||||
Graph = "graph"
|
||||
Header = "header"
|
||||
NodesVersions = "nodesVersions"
|
||||
PrecomputedOutputs = "precomputedOutputs"
|
||||
NodesPositions = "nodesPositions"
|
||||
|
||||
@staticmethod
|
||||
def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features", ...]:
|
||||
"""Return the list of supported features based on a file version.
|
||||
|
||||
Args:
|
||||
fileVersion (str, Version): the file version
|
||||
|
||||
Returns:
|
||||
tuple of GraphIO.Features: the list of supported features
|
||||
"""
|
||||
if isinstance(fileVersion, str):
|
||||
fileVersion = Version(fileVersion)
|
||||
|
||||
features = [GraphIO.Features.Graph]
|
||||
if fileVersion >= Version("1.0"):
|
||||
features += [
|
||||
GraphIO.Features.Header,
|
||||
GraphIO.Features.NodesVersions,
|
||||
GraphIO.Features.PrecomputedOutputs,
|
||||
]
|
||||
|
||||
if fileVersion >= Version("1.1"):
|
||||
features += [GraphIO.Features.NodesPositions]
|
||||
|
||||
return tuple(features)
|
||||
|
||||
|
||||
class GraphSerializer:
|
||||
"""Standard Graph serializer."""
|
||||
|
||||
def __init__(self, graph: "Graph") -> None:
|
||||
self._graph = graph
|
||||
|
||||
def serialize(self) -> dict:
|
||||
"""
|
||||
Serialize the Graph.
|
||||
"""
|
||||
return {
|
||||
GraphIO.Keys.Header: self.serializeHeader(),
|
||||
GraphIO.Keys.Graph: self.serializeContent(),
|
||||
}
|
||||
|
||||
@property
|
||||
def nodes(self) -> list[Node]:
|
||||
return self._graph.nodes
|
||||
|
||||
def serializeHeader(self) -> dict:
|
||||
"""Build and return the graph serialization header.
|
||||
|
||||
The header contains metadata about the graph, such as the:
|
||||
- version of the software used to create it.
|
||||
- version of the file format.
|
||||
- version of the nodes types used in the graph.
|
||||
- template flag.
|
||||
"""
|
||||
header: dict[str, Any] = {}
|
||||
header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__
|
||||
header[GraphIO.Keys.FileVersion] = GraphIO.__version__
|
||||
header[GraphIO.Keys.NodesVersions] = self._getNodeTypesVersions()
|
||||
return header
|
||||
|
||||
def _getNodeTypesVersions(self) -> dict[str, str]:
|
||||
"""Get registered versions of each node types in `nodes`, excluding CompatibilityNode instances."""
|
||||
nodeTypes = set([node.nodeDesc.__class__ for node in self.nodes if isinstance(node, Node)])
|
||||
nodeTypesVersions = {
|
||||
nodeType.__name__: version
|
||||
for nodeType in nodeTypes
|
||||
if (version := meshroom.core.nodeVersion(nodeType)) is not None
|
||||
}
|
||||
# Sort them by name (to avoid random order changing from one save to another).
|
||||
return dict(sorted(nodeTypesVersions.items()))
|
||||
|
||||
def serializeContent(self) -> dict:
|
||||
"""Graph content serialization logic."""
|
||||
return {node.name: self.serializeNode(node) for node in sorted(self.nodes, key=lambda n: n.name)}
|
||||
|
||||
def serializeNode(self, node: Node) -> dict:
|
||||
"""Node serialization logic."""
|
||||
return node.toDict()
|
||||
|
||||
|
||||
class TemplateGraphSerializer(GraphSerializer):
|
||||
"""Serializer for serializing a graph as a template."""
|
||||
|
||||
def serializeHeader(self) -> dict:
|
||||
header = super().serializeHeader()
|
||||
header[GraphIO.Keys.Template] = True
|
||||
return header
|
||||
|
||||
def serializeNode(self, node: Node) -> dict:
|
||||
"""Adapt node serialization to template graphs.
|
||||
|
||||
Instead of getting all the inputs and internal attribute keys, only get the keys of
|
||||
the attributes whose value is not the default one.
|
||||
The output attributes, UIDs, parallelization parameters and internal folder are
|
||||
not relevant for templates, so they are explicitly removed from the returned dictionary.
|
||||
"""
|
||||
# For now, implemented as a post-process to update the default serialization.
|
||||
nodeData = super().serializeNode(node)
|
||||
|
||||
inputKeys = list(nodeData["inputs"].keys())
|
||||
|
||||
internalInputKeys = []
|
||||
internalInputs = nodeData.get("internalInputs", None)
|
||||
if internalInputs:
|
||||
internalInputKeys = list(internalInputs.keys())
|
||||
|
||||
for attrName in inputKeys:
|
||||
attribute = node.attribute(attrName)
|
||||
# check that attribute is not a link for choice attributes
|
||||
if attribute.isDefault and not attribute.isLink:
|
||||
del nodeData["inputs"][attrName]
|
||||
|
||||
for attrName in internalInputKeys:
|
||||
attribute = node.internalAttribute(attrName)
|
||||
# check that internal attribute is not a link for choice attributes
|
||||
if attribute.isDefault and not attribute.isLink:
|
||||
del nodeData["internalInputs"][attrName]
|
||||
|
||||
# If all the internal attributes are set to their default values, remove the entry
|
||||
if len(nodeData["internalInputs"]) == 0:
|
||||
del nodeData["internalInputs"]
|
||||
|
||||
del nodeData["outputs"]
|
||||
del nodeData["uid"]
|
||||
del nodeData["internalFolder"]
|
||||
del nodeData["parallelization"]
|
||||
|
||||
return nodeData
|
||||
|
||||
|
||||
class PartialGraphSerializer(GraphSerializer):
|
||||
"""Serializer to serialize a partial graph (a subset of nodes)."""
|
||||
|
||||
def __init__(self, graph: "Graph", nodes: list[Node]):
|
||||
super().__init__(graph)
|
||||
self._nodes = nodes
|
||||
|
||||
@property
|
||||
def nodes(self) -> list[Node]:
|
||||
"""Override to consider only the subset of nodes."""
|
||||
return self._nodes
|
||||
|
||||
def serializeNode(self, node: Node) -> dict:
|
||||
"""Adapt node serialization to partial graph serialization."""
|
||||
# NOTE: For now, implemented as a post-process to the default serialization.
|
||||
nodeData = super().serializeNode(node)
|
||||
|
||||
# Override input attributes with custom serialization logic, to handle attributes
|
||||
# connected to nodes that are not in the list of nodes to serialize.
|
||||
for attributeName in nodeData["inputs"]:
|
||||
nodeData["inputs"][attributeName] = self._serializeAttribute(node.attribute(attributeName))
|
||||
|
||||
# Clear UID for non-compatibility nodes, as the custom attribute serialization
|
||||
# can be impacting the UID by removing connections to missing nodes.
|
||||
if not node.isCompatibilityNode:
|
||||
del nodeData["uid"]
|
||||
|
||||
return nodeData
|
||||
|
||||
def _serializeAttribute(self, attribute: Attribute) -> Any:
|
||||
"""
|
||||
Serialize `attribute` (recursively for list/groups) and deal with attributes being connected
|
||||
to nodes that are not part of the partial list of nodes to serialize.
|
||||
"""
|
||||
linkParam = attribute.getLinkParam()
|
||||
|
||||
if linkParam is not None:
|
||||
# Use standard link serialization if upstream node is part of the serialization.
|
||||
if linkParam.node in self.nodes:
|
||||
return attribute.getExportValue()
|
||||
# Skip link serialization otherwise.
|
||||
# If part of a list, this entry can be discarded.
|
||||
if isinstance(attribute.root, ListAttribute):
|
||||
return None
|
||||
# Otherwise, return the default value for this attribute.
|
||||
return attribute.defaultValue()
|
||||
|
||||
if isinstance(attribute, ListAttribute):
|
||||
# Recusively serialize each child of the ListAttribute, skipping those for which the attribute
|
||||
# serialization logic above returns None.
|
||||
return [
|
||||
exportValue
|
||||
for child in attribute
|
||||
if (exportValue := self._serializeAttribute(child)) is not None
|
||||
]
|
||||
|
||||
if isinstance(attribute, GroupAttribute):
|
||||
# Recursively serialize each child of the group attribute.
|
||||
return {name: self._serializeAttribute(child) for name, child in attribute.value.items()}
|
||||
|
||||
return attribute.getExportValue()
|
||||
|
||||
|
|
@ -1608,7 +1608,8 @@ class CompatibilityNode(BaseNode):
|
|||
# Make a deepcopy of nodeDict to handle CompatibilityNode duplication
|
||||
# and be able to change modified inputs (see CompatibilityNode.toDict)
|
||||
self.nodeDict = copy.deepcopy(nodeDict)
|
||||
self.version = Version(self.nodeDict.get("version", None))
|
||||
version = self.nodeDict.get("version")
|
||||
self.version = Version(version) if version else None
|
||||
|
||||
self._inputs = self.nodeDict.get("inputs", {})
|
||||
self._internalInputs = self.nodeDict.get("internalInputs", {})
|
||||
|
@ -1668,7 +1669,17 @@ class CompatibilityNode(BaseNode):
|
|||
elif isinstance(value, float):
|
||||
return desc.FloatParam(range=None, **params)
|
||||
elif isinstance(value, str):
|
||||
if isOutput or os.path.isabs(value) or Attribute.isLinkExpression(value):
|
||||
if isOutput or os.path.isabs(value):
|
||||
return desc.File(**params)
|
||||
elif Attribute.isLinkExpression(value):
|
||||
# Do not consider link expression as a valid default desc value.
|
||||
# When the link expression is applied and transformed to an actual link,
|
||||
# the systems resets the value using `Attribute.resetToDefaultValue` to indicate
|
||||
# that this link expression has been handled.
|
||||
# If the link expression is stored as the default value, it will never be cleared,
|
||||
# leading to unexpected behavior where the link expression on a CompatibilityNode
|
||||
# could be evaluated several times and/or incorrectly.
|
||||
params["value"] = ""
|
||||
return desc.File(**params)
|
||||
else:
|
||||
return desc.StringParam(**params)
|
||||
|
@ -1851,113 +1862,3 @@ class CompatibilityNode(BaseNode):
|
|||
canUpgrade = Property(bool, canUpgrade.fget, constant=True)
|
||||
issueDetails = Property(str, issueDetails.fget, constant=True)
|
||||
|
||||
|
||||
def nodeFactory(nodeDict, name=None, template=False, uidConflict=False):
|
||||
"""
|
||||
Create a node instance by deserializing the given node data.
|
||||
If the serialized data matches the corresponding node type description, a Node instance is created.
|
||||
If any compatibility issue occurs, a NodeCompatibility instance is created instead.
|
||||
|
||||
Args:
|
||||
nodeDict (dict): the serialization of the node
|
||||
name (str): (optional) the node's name
|
||||
template (bool): (optional) true if the node is part of a template, false otherwise
|
||||
uidConflict (bool): (optional) true if a UID conflict has been detected externally on that node
|
||||
|
||||
Returns:
|
||||
BaseNode: the created node
|
||||
"""
|
||||
nodeType = nodeDict["nodeType"]
|
||||
|
||||
# Retro-compatibility: inputs were previously saved as "attributes"
|
||||
if "inputs" not in nodeDict and "attributes" in nodeDict:
|
||||
nodeDict["inputs"] = nodeDict["attributes"]
|
||||
del nodeDict["attributes"]
|
||||
|
||||
# Get node inputs/outputs
|
||||
inputs = nodeDict.get("inputs", {})
|
||||
internalInputs = nodeDict.get("internalInputs", {})
|
||||
outputs = nodeDict.get("outputs", {})
|
||||
version = nodeDict.get("version", None)
|
||||
internalFolder = nodeDict.get("internalFolder", None)
|
||||
position = Position(*nodeDict.get("position", []))
|
||||
uid = nodeDict.get("uid", None)
|
||||
|
||||
compatibilityIssue = None
|
||||
|
||||
nodeDesc = None
|
||||
try:
|
||||
nodeDesc = meshroom.core.nodesDesc[nodeType]
|
||||
except KeyError:
|
||||
# Unknown node type
|
||||
compatibilityIssue = CompatibilityIssue.UnknownNodeType
|
||||
|
||||
# Unknown node type should take precedence over UID conflict, as it cannot be resolved
|
||||
if uidConflict and nodeDesc:
|
||||
compatibilityIssue = CompatibilityIssue.UidConflict
|
||||
|
||||
if nodeDesc and not uidConflict: # if uidConflict, there is no need to look for another compatibility issue
|
||||
# Compare serialized node version with current node version
|
||||
currentNodeVersion = meshroom.core.nodeVersion(nodeDesc)
|
||||
# If both versions are available, check for incompatibility in major version
|
||||
if version and currentNodeVersion and Version(version).major != Version(currentNodeVersion).major:
|
||||
compatibilityIssue = CompatibilityIssue.VersionConflict
|
||||
# In other cases, check attributes compatibility between serialized node and its description
|
||||
else:
|
||||
# Check that the node has the exact same set of inputs/outputs as its description, except
|
||||
# if the node is described in a template file, in which only non-default parameters are saved;
|
||||
# do not perform that check for internal attributes because there is no point in
|
||||
# raising compatibility issues if their number differs: in that case, it is only useful
|
||||
# if some internal attributes do not exist or are invalid
|
||||
if not template and (sorted([attr.name for attr in nodeDesc.inputs
|
||||
if not isinstance(attr, desc.PushButtonParam)]) != sorted(inputs.keys()) or
|
||||
sorted([attr.name for attr in nodeDesc.outputs if not attr.isDynamicValue]) !=
|
||||
sorted(outputs.keys())):
|
||||
compatibilityIssue = CompatibilityIssue.DescriptionConflict
|
||||
|
||||
# Check whether there are any internal attributes that are invalidating in the node description: if there
|
||||
# are, then check that these internal attributes are part of nodeDict; if they are not, a compatibility
|
||||
# issue must be raised to warn the user, as this will automatically change the node's UID
|
||||
if not template:
|
||||
invalidatingIntInputs = []
|
||||
for attr in nodeDesc.internalInputs:
|
||||
if attr.invalidate:
|
||||
invalidatingIntInputs.append(attr.name)
|
||||
for attr in invalidatingIntInputs:
|
||||
if attr not in internalInputs.keys():
|
||||
compatibilityIssue = CompatibilityIssue.DescriptionConflict
|
||||
break
|
||||
|
||||
# Verify that all inputs match their descriptions
|
||||
for attrName, value in inputs.items():
|
||||
if not CompatibilityNode.attributeDescFromName(nodeDesc.inputs, attrName, value):
|
||||
compatibilityIssue = CompatibilityIssue.DescriptionConflict
|
||||
break
|
||||
# Verify that all internal inputs match their description
|
||||
for attrName, value in internalInputs.items():
|
||||
if not CompatibilityNode.attributeDescFromName(nodeDesc.internalInputs, attrName, value):
|
||||
compatibilityIssue = CompatibilityIssue.DescriptionConflict
|
||||
break
|
||||
# Verify that all outputs match their descriptions
|
||||
for attrName, value in outputs.items():
|
||||
if not CompatibilityNode.attributeDescFromName(nodeDesc.outputs, attrName, value):
|
||||
compatibilityIssue = CompatibilityIssue.DescriptionConflict
|
||||
break
|
||||
|
||||
if compatibilityIssue is None:
|
||||
node = Node(nodeType, position, uid=uid, **inputs, **internalInputs, **outputs)
|
||||
else:
|
||||
logging.debug("Compatibility issue detected for node '{}': {}".format(name, compatibilityIssue.name))
|
||||
node = CompatibilityNode(nodeType, nodeDict, position, compatibilityIssue)
|
||||
# Retro-compatibility: no internal folder saved
|
||||
# can't spawn meaningful CompatibilityNode with precomputed outputs
|
||||
# => automatically try to perform node upgrade
|
||||
if not internalFolder and nodeDesc:
|
||||
logging.warning("No serialized output data: performing automatic upgrade on '{}'".format(name))
|
||||
node = node.upgrade()
|
||||
# If the node comes from a template file and there is a conflict, it should be upgraded anyway unless it is
|
||||
# an "unknown node type" conflict (in which case the upgrade would fail)
|
||||
elif template and compatibilityIssue is not CompatibilityIssue.UnknownNodeType:
|
||||
node = node.upgrade()
|
||||
|
||||
return node
|
||||
|
|
201
meshroom/core/nodeFactory.py
Normal file
201
meshroom/core/nodeFactory.py
Normal file
|
@ -0,0 +1,201 @@
|
|||
import logging
|
||||
from typing import Any, Iterable, Optional, Union
|
||||
|
||||
import meshroom.core
|
||||
from meshroom.core import Version, desc
|
||||
from meshroom.core.node import CompatibilityIssue, CompatibilityNode, Node, Position
|
||||
|
||||
|
||||
def nodeFactory(
|
||||
nodeData: dict,
|
||||
name: Optional[str] = None,
|
||||
inTemplate: bool = False,
|
||||
expectedUid: Optional[str] = None,
|
||||
) -> Union[Node, CompatibilityNode]:
|
||||
"""
|
||||
Create a node instance by deserializing the given node data.
|
||||
If the serialized data matches the corresponding node type description, a Node instance is created.
|
||||
If any compatibility issue occurs, a NodeCompatibility instance is created instead.
|
||||
|
||||
Args:
|
||||
nodeData: The serialized Node data.
|
||||
name: The node's name.
|
||||
inTemplate: True if the node is created as part of a graph template.
|
||||
expectedUid: The expected UID of the node within the context of a Graph.
|
||||
|
||||
Returns:
|
||||
The created Node instance.
|
||||
"""
|
||||
return _NodeCreator(nodeData, name, inTemplate, expectedUid).create()
|
||||
|
||||
|
||||
class _NodeCreator:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nodeData: dict,
|
||||
name: Optional[str] = None,
|
||||
inTemplate: bool = False,
|
||||
expectedUid: Optional[str] = None,
|
||||
):
|
||||
self.nodeData = nodeData
|
||||
self.name = name
|
||||
self.inTemplate = inTemplate
|
||||
self.expectedUid = expectedUid
|
||||
|
||||
self._normalizeNodeData()
|
||||
|
||||
self.nodeType = self.nodeData["nodeType"]
|
||||
self.inputs = self.nodeData.get("inputs", {})
|
||||
self.internalInputs = self.nodeData.get("internalInputs", {})
|
||||
self.outputs = self.nodeData.get("outputs", {})
|
||||
self.version = self.nodeData.get("version", None)
|
||||
self.internalFolder = self.nodeData.get("internalFolder")
|
||||
self.position = Position(*self.nodeData.get("position", []))
|
||||
self.uid = self.nodeData.get("uid", None)
|
||||
self.nodeDesc = meshroom.core.nodesDesc.get(self.nodeType, None)
|
||||
|
||||
def create(self) -> Union[Node, CompatibilityNode]:
|
||||
compatibilityIssue = self._checkCompatibilityIssues()
|
||||
if compatibilityIssue:
|
||||
node = self._createCompatibilityNode(compatibilityIssue)
|
||||
node = self._tryUpgradeCompatibilityNode(node)
|
||||
else:
|
||||
node = self._createNode()
|
||||
return node
|
||||
|
||||
def _normalizeNodeData(self):
|
||||
"""Consistency fixes for backward compatibility with older serialized data."""
|
||||
# Inputs were previously saved as "attributes".
|
||||
if "inputs" not in self.nodeData and "attributes" in self.nodeData:
|
||||
self.nodeData["inputs"] = self.nodeData["attributes"]
|
||||
del self.nodeData["attributes"]
|
||||
|
||||
def _checkCompatibilityIssues(self) -> Optional[CompatibilityIssue]:
|
||||
if self.nodeDesc is None:
|
||||
return CompatibilityIssue.UnknownNodeType
|
||||
|
||||
if not self._checkUidCompatibility():
|
||||
return CompatibilityIssue.UidConflict
|
||||
|
||||
if not self._checkVersionCompatibility():
|
||||
return CompatibilityIssue.VersionConflict
|
||||
|
||||
if not self._checkDescriptionCompatibility():
|
||||
return CompatibilityIssue.DescriptionConflict
|
||||
|
||||
return None
|
||||
|
||||
def _checkUidCompatibility(self) -> bool:
|
||||
return self.expectedUid is None or self.expectedUid == self.uid
|
||||
|
||||
def _checkVersionCompatibility(self) -> bool:
|
||||
# Special case: a node with a version set to None indicates
|
||||
# that it has been created from the current version of the node type.
|
||||
nodeCreatedFromCurrentVersion = self.version is None
|
||||
if nodeCreatedFromCurrentVersion:
|
||||
return True
|
||||
nodeTypeCurrentVersion = meshroom.core.nodeVersion(self.nodeDesc)
|
||||
# If the node type has not current version information, assume compatibility.
|
||||
if nodeTypeCurrentVersion is None:
|
||||
return True
|
||||
return Version(self.version).major == Version(nodeTypeCurrentVersion).major
|
||||
|
||||
def _checkDescriptionCompatibility(self) -> bool:
|
||||
# Only perform strict attribute name matching for non-template graphs,
|
||||
# since only non-default-value input attributes are serialized in templates.
|
||||
if not self.inTemplate:
|
||||
if not self._checkAttributesNamesMatchDescription():
|
||||
return False
|
||||
|
||||
return self._checkAttributesAreCompatibleWithDescription()
|
||||
|
||||
def _checkAttributesNamesMatchDescription(self) -> bool:
|
||||
return (
|
||||
self._checkInputAttributesNames()
|
||||
and self._checkOutputAttributesNames()
|
||||
and self._checkInternalAttributesNames()
|
||||
)
|
||||
|
||||
def _checkAttributesAreCompatibleWithDescription(self) -> bool:
|
||||
return (
|
||||
self._checkAttributesCompatibility(self.nodeDesc.inputs, self.inputs)
|
||||
and self._checkAttributesCompatibility(self.nodeDesc.internalInputs, self.internalInputs)
|
||||
and self._checkAttributesCompatibility(self.nodeDesc.outputs, self.outputs)
|
||||
)
|
||||
|
||||
def _checkInputAttributesNames(self) -> bool:
|
||||
def serializedInput(attr: desc.Attribute) -> bool:
|
||||
"""Filter that excludes not-serialized desc input attributes."""
|
||||
if isinstance(attr, desc.PushButtonParam):
|
||||
# PushButtonParam are not serialized has they do not hold a value.
|
||||
return False
|
||||
return True
|
||||
|
||||
refAttributes = filter(serializedInput, self.nodeDesc.inputs)
|
||||
return self._checkAttributesNamesStrictlyMatch(refAttributes, self.inputs)
|
||||
|
||||
def _checkOutputAttributesNames(self) -> bool:
|
||||
def serializedOutput(attr: desc.Attribute) -> bool:
|
||||
"""Filter that excludes not-serialized desc output attributes."""
|
||||
if attr.isDynamicValue:
|
||||
# Dynamic outputs values are not serialized with the node,
|
||||
# as their value is written in the computed output data.
|
||||
return False
|
||||
return True
|
||||
|
||||
refAttributes = filter(serializedOutput, self.nodeDesc.outputs)
|
||||
return self._checkAttributesNamesStrictlyMatch(refAttributes, self.outputs)
|
||||
|
||||
def _checkInternalAttributesNames(self) -> bool:
|
||||
invalidatingDescAttributes = [attr.name for attr in self.nodeDesc.internalInputs if attr.invalidate]
|
||||
return all(attr in self.internalInputs.keys() for attr in invalidatingDescAttributes)
|
||||
|
||||
def _checkAttributesNamesStrictlyMatch(
|
||||
self, descAttributes: Iterable[desc.Attribute], attributesDict: dict[str, Any]
|
||||
) -> bool:
|
||||
refNames = sorted([attr.name for attr in descAttributes])
|
||||
attrNames = sorted(attributesDict.keys())
|
||||
return refNames == attrNames
|
||||
|
||||
def _checkAttributesCompatibility(
|
||||
self, descAttributes: list[desc.Attribute], attributesDict: dict[str, Any]
|
||||
) -> bool:
|
||||
return all(
|
||||
CompatibilityNode.attributeDescFromName(descAttributes, attrName, value) is not None
|
||||
for attrName, value in attributesDict.items()
|
||||
)
|
||||
|
||||
def _createNode(self) -> Node:
|
||||
logging.info(f"Creating node '{self.name}'")
|
||||
return Node(
|
||||
self.nodeType,
|
||||
position=self.position,
|
||||
uid=self.uid,
|
||||
**self.inputs,
|
||||
**self.internalInputs,
|
||||
**self.outputs,
|
||||
)
|
||||
|
||||
def _createCompatibilityNode(self, compatibilityIssue) -> CompatibilityNode:
|
||||
logging.warning(f"Compatibility issue detected for node '{self.name}': {compatibilityIssue.name}")
|
||||
return CompatibilityNode(
|
||||
self.nodeType, self.nodeData, position=self.position, issue=compatibilityIssue
|
||||
)
|
||||
|
||||
def _tryUpgradeCompatibilityNode(self, node: CompatibilityNode) -> Union[Node, CompatibilityNode]:
|
||||
"""Handle possible upgrades of CompatibilityNodes, when no computed data is associated to the Node."""
|
||||
if node.issue == CompatibilityIssue.UnknownNodeType:
|
||||
return node
|
||||
|
||||
# Nodes in templates are not meant to hold computation data.
|
||||
if self.inTemplate:
|
||||
logging.warning(f"Compatibility issue in template: performing automatic upgrade on '{self.name}'")
|
||||
return node.upgrade()
|
||||
|
||||
# Backward compatibility: "internalFolder" was not serialized.
|
||||
if not self.internalFolder:
|
||||
logging.warning(f"No serialized output data: performing automatic upgrade on '{self.name}'")
|
||||
return node.upgrade()
|
||||
|
||||
return node
|
8
meshroom/core/typing.py
Normal file
8
meshroom/core/typing.py
Normal file
|
@ -0,0 +1,8 @@
|
|||
"""
|
||||
Common typing aliases used in Meshroom.
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
PathLike = Union[Path, str]
|
|
@ -6,8 +6,10 @@ from PySide6.QtGui import QUndoCommand, QUndoStack
|
|||
from PySide6.QtCore import Property, Signal
|
||||
|
||||
from meshroom.core.attribute import ListAttribute, Attribute
|
||||
from meshroom.core.graph import GraphModification
|
||||
from meshroom.core.node import nodeFactory, Position
|
||||
from meshroom.core.graph import Graph, GraphModification
|
||||
from meshroom.core.node import Position, CompatibilityIssue
|
||||
from meshroom.core.nodeFactory import nodeFactory
|
||||
from meshroom.core.typing import PathLike
|
||||
|
||||
|
||||
class UndoCommand(QUndoCommand):
|
||||
|
@ -168,19 +170,7 @@ class RemoveNodeCommand(GraphCommand):
|
|||
node = nodeFactory(self.nodeDict, self.nodeName)
|
||||
self.graph.addNode(node, self.nodeName)
|
||||
assert (node.getName() == self.nodeName)
|
||||
# recreate out edges deleted on node removal
|
||||
for dstAttr, srcAttr in self.outEdges.items():
|
||||
# if edges were connected to ListAttributes, recreate their corresponding entry in said ListAttribute
|
||||
# 0 = attribute name, 1 = attribute index, 2 = attribute value
|
||||
if dstAttr in self.outListAttributes.keys():
|
||||
listAttr = self.graph.attribute(self.outListAttributes[dstAttr][0])
|
||||
if isinstance(self.outListAttributes[dstAttr][2], list):
|
||||
listAttr[self.outListAttributes[dstAttr][1]:self.outListAttributes[dstAttr][1]] = self.outListAttributes[dstAttr][2]
|
||||
else:
|
||||
listAttr.insert(self.outListAttributes[dstAttr][1], self.outListAttributes[dstAttr][2])
|
||||
|
||||
self.graph.addEdge(self.graph.attribute(srcAttr),
|
||||
self.graph.attribute(dstAttr))
|
||||
self.graph._restoreOutEdges(self.outEdges, self.outListAttributes)
|
||||
|
||||
|
||||
class DuplicateNodesCommand(GraphCommand):
|
||||
|
@ -209,15 +199,27 @@ class PasteNodesCommand(GraphCommand):
|
|||
"""
|
||||
Handle node pasting in a Graph.
|
||||
"""
|
||||
def __init__(self, graph, data, position=None, parent=None):
|
||||
def __init__(self, graph: "Graph", data: dict, position: Position, parent=None):
|
||||
super(PasteNodesCommand, self).__init__(graph, parent)
|
||||
self.data = data
|
||||
self.position = position
|
||||
self.nodeNames = []
|
||||
self.nodeNames: list[str] = []
|
||||
|
||||
def redoImpl(self):
|
||||
data = self.graph.updateImportedProject(self.data)
|
||||
nodes = self.graph.pasteNodes(data, self.position)
|
||||
graph = Graph("")
|
||||
try:
|
||||
graph._deserialize(self.data)
|
||||
except:
|
||||
return False
|
||||
|
||||
boundingBoxCenter = self._boundingBoxCenter(graph.nodes)
|
||||
offset = Position(self.position.x - boundingBoxCenter.x, self.position.y - boundingBoxCenter.y)
|
||||
|
||||
for node in graph.nodes:
|
||||
node.position = Position(node.position.x + offset.x, node.position.y + offset.y)
|
||||
|
||||
nodes = self.graph.importGraphContent(graph)
|
||||
|
||||
self.nodeNames = [node.name for node in nodes]
|
||||
self.setText("Paste Node{} ({})".format("s" if len(self.nodeNames) > 1 else "", ", ".join(self.nodeNames)))
|
||||
return nodes
|
||||
|
@ -226,12 +228,31 @@ class PasteNodesCommand(GraphCommand):
|
|||
for name in self.nodeNames:
|
||||
self.graph.removeNode(name)
|
||||
|
||||
def _boundingBox(self, nodes) -> tuple[int, int, int, int]:
|
||||
if not nodes:
|
||||
return (0, 0, 0 , 0)
|
||||
|
||||
minX = maxX = nodes[0].x
|
||||
minY = maxY = nodes[0].y
|
||||
|
||||
for node in nodes[1:]:
|
||||
minX = min(minX, node.x)
|
||||
minY = min(minY, node.y)
|
||||
maxX = max(maxX, node.x)
|
||||
maxY = max(maxY, node.y)
|
||||
|
||||
return (minX, minY, maxX, maxY)
|
||||
|
||||
def _boundingBoxCenter(self, nodes):
|
||||
minX, minY, maxX, maxY = self._boundingBox(nodes)
|
||||
return Position((minX + maxX) / 2, (minY + maxY) / 2)
|
||||
|
||||
class ImportProjectCommand(GraphCommand):
|
||||
"""
|
||||
Handle the import of a project into a Graph.
|
||||
"""
|
||||
def __init__(self, graph, filepath=None, position=None, yOffset=0, parent=None):
|
||||
|
||||
def __init__(self, graph: Graph, filepath: PathLike, position=None, yOffset=0, parent=None):
|
||||
super(ImportProjectCommand, self).__init__(graph, parent)
|
||||
self.filepath = filepath
|
||||
self.importedNames = []
|
||||
|
@ -239,9 +260,8 @@ class ImportProjectCommand(GraphCommand):
|
|||
self.yOffset = yOffset
|
||||
|
||||
def redoImpl(self):
|
||||
status = self.graph.load(self.filepath, setupProjectFile=False, importProject=True)
|
||||
importedNodes = self.graph.importedNodes
|
||||
self.setText("Import Project ({} nodes)".format(importedNodes.count))
|
||||
importedNodes = self.graph.importGraphContentFromFile(self.filepath)
|
||||
self.setText(f"Import Project ({len(importedNodes)} nodes)")
|
||||
|
||||
lowestY = 0
|
||||
for node in self.graph.nodes:
|
||||
|
@ -419,37 +439,24 @@ class UpgradeNodeCommand(GraphCommand):
|
|||
super(UpgradeNodeCommand, self).__init__(graph, parent)
|
||||
self.nodeDict = node.toDict()
|
||||
self.nodeName = node.getName()
|
||||
self.outEdges = {}
|
||||
self.outListAttributes = {}
|
||||
self.compatibilityIssue = None
|
||||
self.setText("Upgrade Node {}".format(self.nodeName))
|
||||
|
||||
def redoImpl(self):
|
||||
if not self.graph.node(self.nodeName).canUpgrade:
|
||||
if not (node := self.graph.node(self.nodeName)).canUpgrade:
|
||||
return False
|
||||
upgradedNode, _, self.outEdges, self.outListAttributes = self.graph.upgradeNode(self.nodeName)
|
||||
return upgradedNode
|
||||
self.compatibilityIssue = node.issue
|
||||
return self.graph.upgradeNode(self.nodeName)
|
||||
|
||||
def undoImpl(self):
|
||||
# delete upgraded node
|
||||
self.graph.removeNode(self.nodeName)
|
||||
expectedUid = None
|
||||
if self.compatibilityIssue == CompatibilityIssue.UidConflict:
|
||||
expectedUid = self.graph.node(self.nodeName)._uid
|
||||
|
||||
# recreate compatibility node
|
||||
with GraphModification(self.graph):
|
||||
# We come back from an upgrade, so we enforce uidConflict=True as there was a uid conflict before
|
||||
node = nodeFactory(self.nodeDict, name=self.nodeName, uidConflict=True)
|
||||
self.graph.addNode(node, self.nodeName)
|
||||
# recreate out edges
|
||||
for dstAttr, srcAttr in self.outEdges.items():
|
||||
# if edges were connected to ListAttributes, recreate their corresponding entry in said ListAttribute
|
||||
# 0 = attribute name, 1 = attribute index, 2 = attribute value
|
||||
if dstAttr in self.outListAttributes.keys():
|
||||
listAttr = self.graph.attribute(self.outListAttributes[dstAttr][0])
|
||||
if isinstance(self.outListAttributes[dstAttr][2], list):
|
||||
listAttr[self.outListAttributes[dstAttr][1]:self.outListAttributes[dstAttr][1]] = self.outListAttributes[dstAttr][2]
|
||||
else:
|
||||
listAttr.insert(self.outListAttributes[dstAttr][1], self.outListAttributes[dstAttr][2])
|
||||
|
||||
self.graph.addEdge(self.graph.attribute(srcAttr),
|
||||
self.graph.attribute(dstAttr))
|
||||
node = nodeFactory(self.nodeDict, name=self.nodeName, expectedUid=expectedUid)
|
||||
self.graph.replaceNode(self.nodeName, node)
|
||||
|
||||
|
||||
class EnableGraphUpdateCommand(GraphCommand):
|
||||
|
|
|
@ -25,6 +25,7 @@ from meshroom.core import sessionUid
|
|||
from meshroom.common.qt import QObjectListModel
|
||||
from meshroom.core.attribute import Attribute, ListAttribute
|
||||
from meshroom.core.graph import Graph, Edge
|
||||
from meshroom.core.graphIO import GraphIO
|
||||
|
||||
from meshroom.core.taskManager import TaskManager
|
||||
|
||||
|
@ -396,7 +397,7 @@ class UIGraph(QObject):
|
|||
self.updateChunks()
|
||||
|
||||
# perform auto-layout if graph does not provide nodes positions
|
||||
if Graph.IO.Features.NodesPositions not in self._graph.fileFeatures:
|
||||
if GraphIO.Features.NodesPositions not in self._graph.fileFeatures:
|
||||
self._layout.reset()
|
||||
# clear undo-stack after layout
|
||||
self._undoStack.clear()
|
||||
|
@ -451,17 +452,21 @@ class UIGraph(QObject):
|
|||
self.stopExecution()
|
||||
self._chunksMonitor.stop()
|
||||
|
||||
@Slot(str, result=bool)
|
||||
def loadGraph(self, filepath, setupProjectFile=True, publishOutputs=False):
|
||||
g = Graph('')
|
||||
status = True
|
||||
@Slot(str)
|
||||
def loadGraph(self, filepath):
|
||||
g = Graph("")
|
||||
if filepath:
|
||||
status = g.load(filepath, setupProjectFile, importProject=False, publishOutputs=publishOutputs)
|
||||
g.load(filepath)
|
||||
if not os.path.exists(g.cacheDir):
|
||||
os.mkdir(g.cacheDir)
|
||||
g.fileDateVersion = os.path.getmtime(filepath)
|
||||
self.setGraph(g)
|
||||
return status
|
||||
|
||||
@Slot(str, bool, result=bool)
|
||||
def initFromTemplate(self, filepath, publishOutputs=False):
|
||||
graph = Graph("")
|
||||
if filepath:
|
||||
graph.initFromTemplate(filepath, publishOutputs=publishOutputs)
|
||||
self.setGraph(graph)
|
||||
|
||||
@Slot(QUrl, result="QVariantList")
|
||||
@Slot(QUrl, QPoint, result="QVariantList")
|
||||
|
@ -1045,126 +1050,43 @@ class UIGraph(QObject):
|
|||
"""
|
||||
if not self._nodeSelection.hasSelection():
|
||||
return ""
|
||||
serializedSelection = {node.name: node.toDict() for node in self.iterSelectedNodes()}
|
||||
return json.dumps(serializedSelection, indent=4)
|
||||
graphData = self._graph.serializePartial(self.getSelectedNodes())
|
||||
return json.dumps(graphData, indent=4)
|
||||
|
||||
@Slot(str, QPoint, bool, result=list)
|
||||
def pasteNodes(self, clipboardContent, position=None, centerPosition=False) -> list[Node]:
|
||||
@Slot(str, QPoint, result=list)
|
||||
def pasteNodes(self, serializedData: str, position: Optional[QPoint]=None) -> list[Node]:
|
||||
"""
|
||||
Parse the content of the clipboard to see whether it contains
|
||||
valid node descriptions. If that is the case, the nodes described
|
||||
in the clipboard are built with the available information.
|
||||
Otherwise, nothing is done.
|
||||
Import string-serialized graph content `serializedData` in the current graph, optionally at the given
|
||||
`position`.
|
||||
If the `serializedData` does not contain valid serialized graph data, nothing is done.
|
||||
|
||||
This function does not need to be preceded by a call to "getSelectedNodesContent".
|
||||
Any clipboard content that contains at least a node type with a valid JSON
|
||||
formatting (dictionary form with double quotes around the keys and values)
|
||||
can be used to generate a node.
|
||||
This method can be used with the result of "getSelectedNodesContent".
|
||||
But it also accepts any serialized content that matches the graph data or graph content format.
|
||||
|
||||
For example, it is enough to have:
|
||||
{"nodeName_1": {"nodeType":"CameraInit"}, "nodeName_2": {"nodeType":"FeatureMatching"}}
|
||||
in the clipboard to create a default CameraInit and a default FeatureMatching nodes.
|
||||
in `serializedData` to create a default CameraInit and a default FeatureMatching nodes.
|
||||
|
||||
Args:
|
||||
clipboardContent (str): the string contained in the clipboard, that may or may not contain valid
|
||||
node information
|
||||
position (QPoint): the position of the mouse in the Graph Editor when the function was called
|
||||
centerPosition (bool): whether the provided position is not the top-left corner of the pasting
|
||||
zone, but its center
|
||||
serializedData: The string-serialized graph data.
|
||||
position: The position where to paste the nodes. If None, the nodes are pasted at (0, 0).
|
||||
|
||||
Returns:
|
||||
list: the list of Node objects that were pasted and added to the graph
|
||||
"""
|
||||
if not clipboardContent:
|
||||
return
|
||||
|
||||
try:
|
||||
d = json.loads(clipboardContent)
|
||||
except ValueError as e:
|
||||
raise ValueError(e)
|
||||
graphData = json.loads(serializedData)
|
||||
except json.JSONDecodeError:
|
||||
logging.warning("Content is not a valid JSON string.")
|
||||
return []
|
||||
|
||||
if not isinstance(d, dict):
|
||||
raise ValueError("The clipboard does not contain a valid node. Cannot paste it.")
|
||||
pos = Position(position.x(), position.y()) if position else Position(0, 0)
|
||||
result = self.push(commands.PasteNodesCommand(self._graph, graphData, pos))
|
||||
if result is False:
|
||||
logging.warning("Content is not a valid graph data.")
|
||||
return []
|
||||
return result
|
||||
|
||||
# If the clipboard contains a header, then a whole file is contained in the clipboard
|
||||
# Extract the "graph" part and paste it all, ignore the rest
|
||||
if d.get("header", None):
|
||||
d = d.get("graph", None)
|
||||
if not d:
|
||||
return
|
||||
|
||||
if isinstance(position, QPoint):
|
||||
position = Position(position.x(), position.y())
|
||||
if self.hoveredNode:
|
||||
# If a node is hovered, add an offset to prevent complete occlusion
|
||||
position = Position(position.x + self.layout.gridSpacing, position.y + self.layout.gridSpacing)
|
||||
|
||||
# Get the position of the first node in a zone whose top-left corner is the mouse and the bottom-right
|
||||
# corner the (x, y) coordinates, with x the maximum of all the nodes' position along the x-axis, and y the
|
||||
# maximum of all the nodes' position along the y-axis. All nodes with a position will be placed relatively
|
||||
# to the first node within that zone.
|
||||
firstNodePos = None
|
||||
minX = 0
|
||||
maxX = 0
|
||||
minY = 0
|
||||
maxY = 0
|
||||
for key in sorted(d):
|
||||
nodeType = d[key].get("nodeType", None)
|
||||
if not nodeType:
|
||||
raise ValueError("Invalid node description: no provided node type for '{}'".format(key))
|
||||
|
||||
pos = d[key].get("position", None)
|
||||
if pos:
|
||||
if not firstNodePos:
|
||||
firstNodePos = pos
|
||||
minX = pos[0]
|
||||
maxX = pos[0]
|
||||
minY = pos[1]
|
||||
maxY = pos[1]
|
||||
else:
|
||||
if minX > pos[0]:
|
||||
minX = pos[0]
|
||||
if maxX < pos[0]:
|
||||
maxX = pos[0]
|
||||
if minY > pos[1]:
|
||||
minY = pos[1]
|
||||
if maxY < pos[1]:
|
||||
maxY = pos[1]
|
||||
|
||||
# Ensure there will not be an error if no node has a specified position
|
||||
if not firstNodePos:
|
||||
firstNodePos = [0, 0]
|
||||
|
||||
# Position of the first node within the zone
|
||||
position = Position(position.x + firstNodePos[0] - minX, position.y + firstNodePos[1] - minY)
|
||||
|
||||
if centerPosition: # Center the zone around the mouse's position (mouse's position might be artificial)
|
||||
maxX = maxX + self.layout.nodeWidth # maxX and maxY are the position of the furthest node's top-left corner
|
||||
maxY = maxY + self.layout.nodeHeight # We want the position of the furthest node's bottom-right corner
|
||||
position = Position(position.x - ((maxX - minX) / 2), position.y - ((maxY - minY) / 2))
|
||||
|
||||
finalPosition = None
|
||||
prevPosition = None
|
||||
positions = []
|
||||
|
||||
for key in sorted(d):
|
||||
currentPosition = d[key].get("position", None)
|
||||
if not finalPosition:
|
||||
finalPosition = position
|
||||
else:
|
||||
if prevPosition and currentPosition:
|
||||
# If the nodes both have a position, recreate the distance between them with a different
|
||||
# starting point
|
||||
x = finalPosition.x + (currentPosition[0] - prevPosition[0])
|
||||
y = finalPosition.y + (currentPosition[1] - prevPosition[1])
|
||||
finalPosition = Position(x, y)
|
||||
else:
|
||||
# If either the current node or previous one lacks a position, use a custom one
|
||||
finalPosition = Position(finalPosition.x + self.layout.gridSpacing + self.layout.nodeWidth, finalPosition.y)
|
||||
prevPosition = currentPosition
|
||||
positions.append(finalPosition)
|
||||
|
||||
return self.push(commands.PasteNodesCommand(self.graph, d, position=positions))
|
||||
|
||||
undoStack = Property(QObject, lambda self: self._undoStack, constant=True)
|
||||
graphChanged = Signal()
|
||||
|
|
|
@ -185,7 +185,7 @@ Page {
|
|||
nameFilters: ["Meshroom Graphs (*.mg)"]
|
||||
onAccepted: {
|
||||
// Open the template as a regular file
|
||||
if (_reconstruction.loadUrl(currentFile, true, true)) {
|
||||
if (_reconstruction.load(currentFile)) {
|
||||
MeshroomApp.addRecentProjectFile(currentFile.toString())
|
||||
}
|
||||
}
|
||||
|
@ -400,7 +400,7 @@ Page {
|
|||
text: "Reload File"
|
||||
|
||||
onClicked: {
|
||||
_reconstruction.loadUrl(_reconstruction.graph.filepath)
|
||||
_reconstruction.load(_reconstruction.graph.filepath)
|
||||
fileModifiedDialog.close()
|
||||
}
|
||||
}
|
||||
|
@ -705,7 +705,7 @@ Page {
|
|||
MenuItem {
|
||||
onTriggered: ensureSaved(function() {
|
||||
openRecentMenu.dismiss()
|
||||
if (_reconstruction.loadUrl(modelData["path"])) {
|
||||
if (_reconstruction.load(modelData["path"])) {
|
||||
MeshroomApp.addRecentProjectFile(modelData["path"])
|
||||
} else {
|
||||
MeshroomApp.removeRecentProjectFile(modelData["path"])
|
||||
|
|
|
@ -82,25 +82,18 @@ Item {
|
|||
|
||||
/// Paste content of clipboard to graph editor and create new node if valid
|
||||
function pasteNodes() {
|
||||
var finalPosition = undefined
|
||||
var centerPosition = false
|
||||
let finalPosition = undefined;
|
||||
if (mouseArea.containsMouse) {
|
||||
if (uigraph.hoveredNode !== null) {
|
||||
var node = nodeDelegate(uigraph.hoveredNode)
|
||||
finalPosition = Qt.point(node.mousePosition.x + node.x, node.mousePosition.y + node.y)
|
||||
} else {
|
||||
finalPosition = mapToItem(draggable, mouseArea.mouseX, mouseArea.mouseY)
|
||||
}
|
||||
finalPosition = mapToItem(draggable, mouseArea.mouseX, mouseArea.mouseY);
|
||||
} else {
|
||||
finalPosition = getCenterPosition()
|
||||
centerPosition = true
|
||||
finalPosition = getCenterPosition();
|
||||
}
|
||||
|
||||
var copiedContent = Clipboard.getText()
|
||||
var nodes = uigraph.pasteNodes(copiedContent, finalPosition, centerPosition)
|
||||
const copiedContent = Clipboard.getText();
|
||||
const nodes = uigraph.pasteNodes(copiedContent, finalPosition);
|
||||
if (nodes.length > 0) {
|
||||
uigraph.selectedNode = nodes[0]
|
||||
uigraph.selectNodes(nodes)
|
||||
uigraph.selectedNode = nodes[0];
|
||||
uigraph.selectNodes(nodes);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -389,7 +389,7 @@ Page {
|
|||
} else {
|
||||
// Open project
|
||||
mainStack.push("Application.qml")
|
||||
if (_reconstruction.loadUrl(modelData["path"])) {
|
||||
if (_reconstruction.load(modelData["path"])) {
|
||||
MeshroomApp.addRecentProjectFile(modelData["path"])
|
||||
} else {
|
||||
MeshroomApp.removeRecentProjectFile(modelData["path"])
|
||||
|
|
|
@ -128,7 +128,7 @@ ApplicationWindow {
|
|||
if (mainStack.currentItem instanceof Homepage) {
|
||||
mainStack.push("Application.qml")
|
||||
}
|
||||
if (_reconstruction.loadUrl(currentFile)) {
|
||||
if (_reconstruction.load(currentFile)) {
|
||||
MeshroomApp.addRecentProjectFile(currentFile.toString())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,6 +5,7 @@ import os
|
|||
from collections.abc import Iterable
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from threading import Thread
|
||||
from typing import Callable
|
||||
|
||||
from PySide6.QtCore import QObject, Slot, Property, Signal, QUrl, QSizeF, QPoint
|
||||
from PySide6.QtGui import QMatrix4x4, QMatrix3x3, QQuaternion, QVector3D, QVector2D
|
||||
|
@ -534,17 +535,24 @@ class Reconstruction(UIGraph):
|
|||
# - correct pipeline name but the case does not match (e.g. panoramaHDR instead of panoramaHdr)
|
||||
# - lowercase pipeline name given through the "New Pipeline" menu
|
||||
loweredPipelineTemplates = dict((k.lower(), v) for k, v in meshroom.core.pipelineTemplates.items())
|
||||
if p.lower() in loweredPipelineTemplates:
|
||||
self.load(loweredPipelineTemplates[p.lower()], setupProjectFile=False)
|
||||
else:
|
||||
# use the user-provided default project file
|
||||
self.load(p, setupProjectFile=False)
|
||||
filepath = loweredPipelineTemplates.get(p.lower(), p)
|
||||
return self._loadWithErrorReport(self.initFromTemplate, filepath)
|
||||
|
||||
@Slot(str, result=bool)
|
||||
def load(self, filepath, setupProjectFile=True, publishOutputs=False):
|
||||
@Slot(QUrl, result=bool)
|
||||
def load(self, url):
|
||||
if isinstance(url, QUrl):
|
||||
# depending how the QUrl has been initialized,
|
||||
# toLocalFile() may return the local path or an empty string
|
||||
localFile = url.toLocalFile() or url.toString()
|
||||
else:
|
||||
localFile = url
|
||||
return self._loadWithErrorReport(self.loadGraph, localFile)
|
||||
|
||||
def _loadWithErrorReport(self, loadFunction: Callable[[str], None], filepath: str):
|
||||
logging.info(f"Load project file: '{filepath}'")
|
||||
try:
|
||||
status = super(Reconstruction, self).loadGraph(filepath, setupProjectFile, publishOutputs)
|
||||
loadFunction(filepath)
|
||||
# warn about pre-release projects being automatically upgraded
|
||||
if Version(self._graph.fileReleaseVersion).major == "0":
|
||||
self.warning.emit(Message(
|
||||
|
@ -554,8 +562,8 @@ class Reconstruction(UIGraph):
|
|||
"Open it with the corresponding version of Meshroom to recover your data."
|
||||
))
|
||||
self.setActive(True)
|
||||
return status
|
||||
except FileNotFoundError as e:
|
||||
return True
|
||||
except FileNotFoundError:
|
||||
self.error.emit(
|
||||
Message(
|
||||
"No Such File",
|
||||
|
@ -564,8 +572,7 @@ class Reconstruction(UIGraph):
|
|||
)
|
||||
)
|
||||
logging.error("Error while loading '{}': No Such File.".format(filepath))
|
||||
return False
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
import traceback
|
||||
trace = traceback.format_exc()
|
||||
self.error.emit(
|
||||
|
@ -577,20 +584,8 @@ class Reconstruction(UIGraph):
|
|||
)
|
||||
logging.error("Error while loading '{}'.".format(filepath))
|
||||
logging.error(trace)
|
||||
return False
|
||||
|
||||
@Slot(QUrl, result=bool)
|
||||
@Slot(QUrl, bool, bool, result=bool)
|
||||
def loadUrl(self, url, setupProjectFile=True, publishOutputs=False):
|
||||
if isinstance(url, (QUrl)):
|
||||
# depending how the QUrl has been initialized,
|
||||
# toLocalFile() may return the local path or an empty string
|
||||
localFile = url.toLocalFile()
|
||||
if not localFile:
|
||||
localFile = url.toString()
|
||||
else:
|
||||
localFile = url
|
||||
return self.load(localFile, setupProjectFile, publishOutputs)
|
||||
return False
|
||||
|
||||
def onGraphChanged(self):
|
||||
""" React to the change of the internal graph. """
|
||||
|
@ -860,7 +855,7 @@ class Reconstruction(UIGraph):
|
|||
)
|
||||
)
|
||||
else:
|
||||
return self.loadUrl(filesByType["meshroomScenes"][0])
|
||||
return self.load(filesByType["meshroomScenes"][0])
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@ import tempfile
|
|||
import os
|
||||
|
||||
import copy
|
||||
from typing import Type
|
||||
import pytest
|
||||
|
||||
import meshroom.core
|
||||
|
@ -12,6 +13,8 @@ from meshroom.core.exception import GraphCompatibilityError, NodeUpgradeError
|
|||
from meshroom.core.graph import Graph, loadGraph
|
||||
from meshroom.core.node import CompatibilityNode, CompatibilityIssue, Node
|
||||
|
||||
from .utils import registeredNodeTypes, overrideNodeTypeVersion
|
||||
|
||||
|
||||
SampleGroupV1 = [
|
||||
desc.IntParam(name="a", label="a", description="", value=0, range=None),
|
||||
|
@ -156,6 +159,12 @@ class SampleInputNodeV2(desc.InputNode):
|
|||
]
|
||||
|
||||
|
||||
|
||||
def replaceNodeTypeDesc(nodeType: str, nodeDesc: Type[desc.Node]):
|
||||
"""Change the `nodeDesc` associated to `nodeType`."""
|
||||
meshroom.core.nodesDesc[nodeType] = nodeDesc
|
||||
|
||||
|
||||
def test_unknown_node_type():
|
||||
"""
|
||||
Test compatibility behavior for unknown node type.
|
||||
|
@ -218,8 +227,7 @@ def test_description_conflict():
|
|||
g.save(graphFile)
|
||||
|
||||
# reload file as-is, ensure no compatibility issue is detected (no CompatibilityNode instances)
|
||||
g = loadGraph(graphFile)
|
||||
assert all(isinstance(n, Node) for n in g.nodes)
|
||||
loadGraph(graphFile, strictCompatibility=True)
|
||||
|
||||
# offset node types register to create description conflicts
|
||||
# each node type name now reference the next one's implementation
|
||||
|
@ -247,7 +255,7 @@ def test_description_conflict():
|
|||
assert not hasattr(compatNode, "in")
|
||||
|
||||
# perform upgrade
|
||||
upgradedNode = g.upgradeNode(nodeName)[0]
|
||||
upgradedNode = g.upgradeNode(nodeName)
|
||||
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV2)
|
||||
|
||||
assert list(upgradedNode.attributes.keys()) == ["in", "paramA", "output"]
|
||||
|
@ -262,7 +270,7 @@ def test_description_conflict():
|
|||
assert hasattr(compatNode, "paramA")
|
||||
|
||||
# perform upgrade
|
||||
upgradedNode = g.upgradeNode(nodeName)[0]
|
||||
upgradedNode = g.upgradeNode(nodeName)
|
||||
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV3)
|
||||
|
||||
assert not hasattr(upgradedNode, "paramA")
|
||||
|
@ -275,7 +283,7 @@ def test_description_conflict():
|
|||
assert not hasattr(compatNode, "paramA")
|
||||
|
||||
# perform upgrade
|
||||
upgradedNode = g.upgradeNode(nodeName)[0]
|
||||
upgradedNode = g.upgradeNode(nodeName)
|
||||
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV4)
|
||||
|
||||
assert hasattr(upgradedNode, "paramA")
|
||||
|
@ -295,7 +303,7 @@ def test_description_conflict():
|
|||
assert isinstance(elt, next(a for a in SampleGroupV1 if a.name == elt.name).__class__)
|
||||
|
||||
# perform upgrade
|
||||
upgradedNode = g.upgradeNode(nodeName)[0]
|
||||
upgradedNode = g.upgradeNode(nodeName)
|
||||
assert isinstance(upgradedNode, Node) and isinstance(upgradedNode.nodeDesc, SampleNodeV5)
|
||||
|
||||
assert hasattr(upgradedNode, "paramA")
|
||||
|
@ -399,20 +407,250 @@ def test_conformUpgrade():
|
|||
|
||||
class TestGraphLoadingWithStrictCompatibility:
|
||||
|
||||
def test_failsOnNodeDescriptionCompatibilityIssue(self, graphSavedOnDisk):
|
||||
registerNodeType(SampleNodeV1)
|
||||
registerNodeType(SampleNodeV2)
|
||||
|
||||
graph: Graph = graphSavedOnDisk
|
||||
graph.addNewNode(SampleNodeV1.__name__)
|
||||
graph.save()
|
||||
|
||||
# Replace saved node description by V2
|
||||
meshroom.core.nodesDesc[SampleNodeV1.__name__] = SampleNodeV2
|
||||
def test_failsOnUnknownNodeType(self, graphSavedOnDisk):
|
||||
with registeredNodeTypes([SampleNodeV1]):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
graph.addNewNode(SampleNodeV1.__name__)
|
||||
graph.save()
|
||||
|
||||
with pytest.raises(GraphCompatibilityError):
|
||||
loadGraph(graph.filepath, strictCompatibility=True)
|
||||
|
||||
unregisterNodeType(SampleNodeV1)
|
||||
unregisterNodeType(SampleNodeV2)
|
||||
|
||||
def test_failsOnNodeDescriptionCompatibilityIssue(self, graphSavedOnDisk):
|
||||
|
||||
with registeredNodeTypes([SampleNodeV1, SampleNodeV2]):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
graph.addNewNode(SampleNodeV1.__name__)
|
||||
graph.save()
|
||||
|
||||
replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2)
|
||||
|
||||
with pytest.raises(GraphCompatibilityError):
|
||||
loadGraph(graph.filepath, strictCompatibility=True)
|
||||
|
||||
|
||||
class TestGraphTemplateLoading:
|
||||
|
||||
def test_failsOnUnknownNodeTypeError(self, graphSavedOnDisk):
|
||||
|
||||
with registeredNodeTypes([SampleNodeV1, SampleNodeV2]):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
graph.addNewNode(SampleNodeV1.__name__)
|
||||
graph.save(template=True)
|
||||
|
||||
with pytest.raises(GraphCompatibilityError):
|
||||
loadGraph(graph.filepath, strictCompatibility=True)
|
||||
|
||||
def test_loadsIfIncompatibleNodeHasDefaultAttributeValues(self, graphSavedOnDisk):
|
||||
with registeredNodeTypes([SampleNodeV1, SampleNodeV2]):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
graph.addNewNode(SampleNodeV1.__name__)
|
||||
graph.save(template=True)
|
||||
|
||||
replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2)
|
||||
|
||||
loadGraph(graph.filepath, strictCompatibility=True)
|
||||
|
||||
def test_loadsIfValueSetOnCompatibleAttribute(self, graphSavedOnDisk):
|
||||
with registeredNodeTypes([SampleNodeV1, SampleNodeV2]):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
node = graph.addNewNode(SampleNodeV1.__name__, paramA="foo")
|
||||
graph.save(template=True)
|
||||
|
||||
replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2)
|
||||
|
||||
loadedGraph = loadGraph(graph.filepath, strictCompatibility=True)
|
||||
assert loadedGraph.nodes.get(node.name).paramA.value == "foo"
|
||||
|
||||
def test_loadsIfValueSetOnIncompatibleAttribute(self, graphSavedOnDisk):
|
||||
with registeredNodeTypes([SampleNodeV1, SampleNodeV2]):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
graph.addNewNode(SampleNodeV1.__name__, input="foo")
|
||||
graph.save(template=True)
|
||||
|
||||
replaceNodeTypeDesc(SampleNodeV1.__name__, SampleNodeV2)
|
||||
|
||||
loadGraph(graph.filepath, strictCompatibility=True)
|
||||
|
||||
class TestVersionConflict:
|
||||
|
||||
def test_loadingConflictingNodeVersionCreatesCompatibilityNodes(self, graphSavedOnDisk):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
|
||||
with registeredNodeTypes([SampleNodeV1]):
|
||||
with overrideNodeTypeVersion(SampleNodeV1, "1.0"):
|
||||
node = graph.addNewNode(SampleNodeV1.__name__)
|
||||
graph.save()
|
||||
|
||||
with overrideNodeTypeVersion(SampleNodeV1, "2.0"):
|
||||
otherGraph = Graph("")
|
||||
otherGraph.load(graph.filepath)
|
||||
|
||||
assert len(otherGraph.compatibilityNodes) == 1
|
||||
assert otherGraph.node(node.name).issue is CompatibilityIssue.VersionConflict
|
||||
|
||||
def test_loadingUnspecifiedNodeVersionAssumesCurrentVersion(self, graphSavedOnDisk):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
|
||||
with registeredNodeTypes([SampleNodeV1]):
|
||||
graph.addNewNode(SampleNodeV1.__name__)
|
||||
graph.save()
|
||||
|
||||
with overrideNodeTypeVersion(SampleNodeV1, "2.0"):
|
||||
otherGraph = Graph("")
|
||||
otherGraph.load(graph.filepath)
|
||||
|
||||
assert len(otherGraph.compatibilityNodes) == 0
|
||||
|
||||
|
||||
class UidTestingNodeV1(desc.Node):
|
||||
inputs = [
|
||||
desc.File(name="input", label="Input", description="", value="", invalidate=True),
|
||||
]
|
||||
outputs = [desc.File(name="output", label="Output", description="", value=desc.Node.internalFolder)]
|
||||
|
||||
|
||||
class UidTestingNodeV2(desc.Node):
|
||||
"""
|
||||
Changes from SampleNodeBV1:
|
||||
* 'param' has been added
|
||||
"""
|
||||
|
||||
inputs = [
|
||||
desc.File(name="input", label="Input", description="", value="", invalidate=True),
|
||||
desc.ListAttribute(
|
||||
name="param",
|
||||
label="Param",
|
||||
elementDesc=desc.File(
|
||||
name="file",
|
||||
label="File",
|
||||
description="",
|
||||
value="",
|
||||
),
|
||||
description="",
|
||||
),
|
||||
]
|
||||
outputs = [desc.File(name="output", label="Output", description="", value=desc.Node.internalFolder)]
|
||||
|
||||
|
||||
class UidTestingNodeV3(desc.Node):
|
||||
"""
|
||||
Changes from SampleNodeBV2:
|
||||
* 'input' is not invalidating the UID.
|
||||
"""
|
||||
|
||||
inputs = [
|
||||
desc.File(name="input", label="Input", description="", value="", invalidate=False),
|
||||
desc.ListAttribute(
|
||||
name="param",
|
||||
label="Param",
|
||||
elementDesc=desc.File(
|
||||
name="file",
|
||||
label="File",
|
||||
description="",
|
||||
value="",
|
||||
),
|
||||
description="",
|
||||
),
|
||||
]
|
||||
outputs = [desc.File(name="output", label="Output", description="", value=desc.Node.internalFolder)]
|
||||
|
||||
|
||||
class TestUidConflict:
|
||||
def test_changingInvalidateOnAttributeDescCreatesUidConflict(self, graphSavedOnDisk):
|
||||
with registeredNodeTypes([UidTestingNodeV2]):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
node = graph.addNewNode(UidTestingNodeV2.__name__)
|
||||
|
||||
graph.save()
|
||||
replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3)
|
||||
|
||||
with pytest.raises(GraphCompatibilityError):
|
||||
loadGraph(graph.filepath, strictCompatibility=True)
|
||||
|
||||
loadedGraph = loadGraph(graph.filepath)
|
||||
loadedNode = loadedGraph.node(node.name)
|
||||
assert isinstance(loadedNode, CompatibilityNode)
|
||||
assert loadedNode.issue == CompatibilityIssue.UidConflict
|
||||
|
||||
def test_uidConflictingNodesPreserveConnectionsOnGraphLoad(self, graphSavedOnDisk):
|
||||
with registeredNodeTypes([UidTestingNodeV2]):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
nodeA = graph.addNewNode(UidTestingNodeV2.__name__)
|
||||
nodeB = graph.addNewNode(UidTestingNodeV2.__name__)
|
||||
|
||||
nodeB.param.append("")
|
||||
graph.addEdge(nodeA.output, nodeB.param.at(0))
|
||||
|
||||
graph.save()
|
||||
replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3)
|
||||
|
||||
loadedGraph = loadGraph(graph.filepath)
|
||||
assert len(loadedGraph.compatibilityNodes) == 2
|
||||
|
||||
loadedNodeA = loadedGraph.node(nodeA.name)
|
||||
loadedNodeB = loadedGraph.node(nodeB.name)
|
||||
|
||||
assert loadedNodeB.param.at(0).linkParam == loadedNodeA.output
|
||||
|
||||
def test_upgradingConflictingNodesPreserveConnections(self, graphSavedOnDisk):
|
||||
with registeredNodeTypes([UidTestingNodeV2]):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
nodeA = graph.addNewNode(UidTestingNodeV2.__name__)
|
||||
nodeB = graph.addNewNode(UidTestingNodeV2.__name__)
|
||||
|
||||
# Double-connect nodeA.output to nodeB, on both a single attribute and a list attribute
|
||||
nodeB.param.append("")
|
||||
graph.addEdge(nodeA.output, nodeB.param.at(0))
|
||||
graph.addEdge(nodeA.output, nodeB.input)
|
||||
|
||||
graph.save()
|
||||
replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3)
|
||||
|
||||
def checkNodeAConnectionsToNodeB():
|
||||
loadedNodeA = loadedGraph.node(nodeA.name)
|
||||
loadedNodeB = loadedGraph.node(nodeB.name)
|
||||
return (
|
||||
loadedNodeB.param.at(0).linkParam == loadedNodeA.output
|
||||
and loadedNodeB.input.linkParam == loadedNodeA.output
|
||||
)
|
||||
|
||||
loadedGraph = loadGraph(graph.filepath)
|
||||
loadedGraph.upgradeNode(nodeA.name)
|
||||
|
||||
assert checkNodeAConnectionsToNodeB()
|
||||
loadedGraph.upgradeNode(nodeB.name)
|
||||
|
||||
assert checkNodeAConnectionsToNodeB()
|
||||
assert len(loadedGraph.compatibilityNodes) == 0
|
||||
|
||||
|
||||
def test_uidConflictDoesNotPropagateToValidDownstreamNodeThroughConnection(self, graphSavedOnDisk):
|
||||
with registeredNodeTypes([UidTestingNodeV1, UidTestingNodeV2]):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
nodeA = graph.addNewNode(UidTestingNodeV2.__name__)
|
||||
nodeB = graph.addNewNode(UidTestingNodeV1.__name__)
|
||||
|
||||
graph.addEdge(nodeA.output, nodeB.input)
|
||||
|
||||
graph.save()
|
||||
replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3)
|
||||
|
||||
loadedGraph = loadGraph(graph.filepath)
|
||||
assert len(loadedGraph.compatibilityNodes) == 1
|
||||
|
||||
def test_uidConflictDoesNotPropagateToValidDownstreamNodeThroughListConnection(self, graphSavedOnDisk):
|
||||
with registeredNodeTypes([UidTestingNodeV2, UidTestingNodeV3]):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
nodeA = graph.addNewNode(UidTestingNodeV2.__name__)
|
||||
nodeB = graph.addNewNode(UidTestingNodeV3.__name__)
|
||||
|
||||
nodeB.param.append("")
|
||||
graph.addEdge(nodeA.output, nodeB.param.at(0))
|
||||
|
||||
graph.save()
|
||||
replaceNodeTypeDesc(UidTestingNodeV2.__name__, UidTestingNodeV3)
|
||||
|
||||
loadedGraph = loadGraph(graph.filepath)
|
||||
assert len(loadedGraph.compatibilityNodes) == 1
|
||||
|
|
364
tests/test_graphIO.py
Normal file
364
tests/test_graphIO.py
Normal file
|
@ -0,0 +1,364 @@
|
|||
import json
|
||||
from textwrap import dedent
|
||||
|
||||
from meshroom.core import desc
|
||||
from meshroom.core.graph import Graph
|
||||
from meshroom.core.node import CompatibilityIssue
|
||||
|
||||
from .utils import registeredNodeTypes, overrideNodeTypeVersion
|
||||
|
||||
|
||||
class SimpleNode(desc.Node):
|
||||
inputs = [
|
||||
desc.File(name="input", label="Input", description="", value=""),
|
||||
]
|
||||
outputs = [
|
||||
desc.File(name="output", label="Output", description="", value=""),
|
||||
]
|
||||
|
||||
|
||||
class NodeWithListAttributes(desc.Node):
|
||||
inputs = [
|
||||
desc.ListAttribute(
|
||||
name="listInput",
|
||||
label="List Input",
|
||||
description="",
|
||||
elementDesc=desc.File(name="file", label="File", description="", value=""),
|
||||
exposed=True,
|
||||
),
|
||||
desc.GroupAttribute(
|
||||
name="group",
|
||||
label="Group",
|
||||
description="",
|
||||
groupDesc=[
|
||||
desc.ListAttribute(
|
||||
name="listInput",
|
||||
label="List Input",
|
||||
description="",
|
||||
elementDesc=desc.File(name="file", label="File", description="", value=""),
|
||||
exposed=True,
|
||||
),
|
||||
],
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def compareGraphsContent(graphA: Graph, graphB: Graph) -> bool:
|
||||
"""Returns whether the content (node and deges) of two graphs are considered identical.
|
||||
|
||||
Similar nodes: nodes with the same name, type and compatibility status.
|
||||
Similar edges: edges with the same source and destination attribute names.
|
||||
"""
|
||||
|
||||
def _buildNodesSet(graph: Graph):
|
||||
return set([(node.name, node.nodeType, node.isCompatibilityNode) for node in graph.nodes])
|
||||
|
||||
def _buildEdgesSet(graph: Graph):
|
||||
return set([(edge.src.fullName, edge.dst.fullName) for edge in graph.edges])
|
||||
|
||||
nodesSetA, edgesSetA = _buildNodesSet(graphA), _buildEdgesSet(graphA)
|
||||
nodesSetB, edgesSetB = _buildNodesSet(graphB), _buildEdgesSet(graphB)
|
||||
|
||||
return nodesSetA == nodesSetB and edgesSetA == edgesSetB
|
||||
|
||||
|
||||
class TestImportGraphContent:
|
||||
def test_importEmptyGraph(self):
|
||||
graph = Graph("")
|
||||
|
||||
otherGraph = Graph("")
|
||||
nodes = otherGraph.importGraphContent(graph)
|
||||
|
||||
assert len(nodes) == 0
|
||||
assert len(graph.nodes) == 0
|
||||
|
||||
def test_importGraphWithSingleNode(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph.importGraphContent(graph)
|
||||
|
||||
assert compareGraphsContent(graph, otherGraph)
|
||||
|
||||
def test_importGraphWithSeveralNodes(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
graph.addNewNode(SimpleNode.__name__)
|
||||
graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph.importGraphContent(graph)
|
||||
|
||||
assert compareGraphsContent(graph, otherGraph)
|
||||
|
||||
def test_importingGraphWithNodesAndEdges(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA_1.output, nodeA_2.input)
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph.importGraphContent(graph)
|
||||
assert compareGraphsContent(graph, otherGraph)
|
||||
|
||||
def test_edgeRemappingOnImportingGraphSeveralTimes(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA_1.output, nodeA_2.input)
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph.importGraphContent(graph)
|
||||
otherGraph.importGraphContent(graph)
|
||||
|
||||
def test_edgeRemappingOnImportingGraphWithUnkownNodeTypesSeveralTimes(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA_1.output, nodeA_2.input)
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph.importGraphContent(graph)
|
||||
otherGraph.importGraphContent(graph)
|
||||
|
||||
assert len(otherGraph.nodes) == 4
|
||||
assert len(otherGraph.compatibilityNodes) == 4
|
||||
assert len(otherGraph.edges) == 2
|
||||
|
||||
def test_importGraphWithUnknownNodeTypesCreatesCompatibilityNodes(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
otherGraph = Graph("")
|
||||
importedNode = otherGraph.importGraphContent(graph)
|
||||
|
||||
assert len(importedNode) == 1
|
||||
assert importedNode[0].isCompatibilityNode
|
||||
|
||||
def test_importGraphContentInPlace(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA_1.output, nodeA_2.input)
|
||||
|
||||
graph.importGraphContent(graph)
|
||||
|
||||
assert len(graph.nodes) == 4
|
||||
|
||||
def test_importGraphContentFromFile(self, graphSavedOnDisk):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA_1.output, nodeA_2.input)
|
||||
graph.save()
|
||||
|
||||
otherGraph = Graph("")
|
||||
nodes = otherGraph.importGraphContentFromFile(graph.filepath)
|
||||
|
||||
assert len(nodes) == 2
|
||||
|
||||
assert compareGraphsContent(graph, otherGraph)
|
||||
|
||||
def test_importGraphContentFromFileWithCompatibilityNodes(self, graphSavedOnDisk):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA_1 = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeA_2 = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA_1.output, nodeA_2.input)
|
||||
graph.save()
|
||||
|
||||
otherGraph = Graph("")
|
||||
nodes = otherGraph.importGraphContentFromFile(graph.filepath)
|
||||
|
||||
assert len(nodes) == 2
|
||||
assert len(otherGraph.compatibilityNodes) == 2
|
||||
assert not compareGraphsContent(graph, otherGraph)
|
||||
|
||||
def test_importingDifferentNodeVersionCreatesCompatibilityNodes(self, graphSavedOnDisk):
|
||||
graph: Graph = graphSavedOnDisk
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
with overrideNodeTypeVersion(SimpleNode, "1.0"):
|
||||
node = graph.addNewNode(SimpleNode.__name__)
|
||||
graph.save()
|
||||
|
||||
with overrideNodeTypeVersion(SimpleNode, "2.0"):
|
||||
otherGraph = Graph("")
|
||||
nodes = otherGraph.importGraphContentFromFile(graph.filepath)
|
||||
|
||||
assert len(nodes) == 1
|
||||
assert len(otherGraph.compatibilityNodes) == 1
|
||||
assert otherGraph.node(node.name).issue is CompatibilityIssue.VersionConflict
|
||||
|
||||
class TestGraphPartialSerialization:
|
||||
def test_emptyGraph(self):
|
||||
graph = Graph("")
|
||||
serializedGraph = graph.serializePartial([])
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph._deserialize(serializedGraph)
|
||||
assert compareGraphsContent(graph, otherGraph)
|
||||
|
||||
def test_serializeAllNodesIsSimilarToStandardSerialization(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeB = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA.output, nodeB.input)
|
||||
|
||||
partialSerializedGraph = graph.serializePartial([nodeA, nodeB])
|
||||
standardSerializedGraph = graph.serialize()
|
||||
|
||||
graphA = Graph("")
|
||||
graphA._deserialize(partialSerializedGraph)
|
||||
|
||||
graphB = Graph("")
|
||||
graphB._deserialize(standardSerializedGraph)
|
||||
|
||||
assert compareGraphsContent(graph, graphA)
|
||||
assert compareGraphsContent(graphA, graphB)
|
||||
|
||||
def test_listAttributeToListAttributeConnectionIsSerialized(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([NodeWithListAttributes]):
|
||||
nodeA = graph.addNewNode(NodeWithListAttributes.__name__)
|
||||
nodeB = graph.addNewNode(NodeWithListAttributes.__name__)
|
||||
|
||||
graph.addEdge(nodeA.listInput, nodeB.listInput)
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph._deserialize(graph.serializePartial([nodeA, nodeB]))
|
||||
|
||||
assert otherGraph.node(nodeB.name).listInput.linkParam == otherGraph.node(nodeA.name).listInput
|
||||
|
||||
def test_singleNodeWithInputConnectionFromNonSerializedNodeRemovesEdge(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeB = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA.output, nodeB.input)
|
||||
|
||||
serializedGraph = graph.serializePartial([nodeB])
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph._deserialize(serializedGraph)
|
||||
|
||||
assert len(otherGraph.compatibilityNodes) == 0
|
||||
assert len(otherGraph.nodes) == 1
|
||||
assert len(otherGraph.edges) == 0
|
||||
|
||||
def test_serializeSingleNodeWithInputConnectionToListAttributeRemovesListEntry(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode, NodeWithListAttributes]):
|
||||
nodeA = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeB = graph.addNewNode(NodeWithListAttributes.__name__)
|
||||
|
||||
nodeB.listInput.append("")
|
||||
graph.addEdge(nodeA.output, nodeB.listInput.at(0))
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph._deserialize(graph.serializePartial([nodeB]))
|
||||
|
||||
assert len(otherGraph.node(nodeB.name).listInput) == 0
|
||||
|
||||
def test_serializeSingleNodeWithInputConnectionToNestedListAttributeRemovesListEntry(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode, NodeWithListAttributes]):
|
||||
nodeA = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeB = graph.addNewNode(NodeWithListAttributes.__name__)
|
||||
|
||||
nodeB.group.listInput.append("")
|
||||
graph.addEdge(nodeA.output, nodeB.group.listInput.at(0))
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph._deserialize(graph.serializePartial([nodeB]))
|
||||
|
||||
assert len(otherGraph.node(nodeB.name).group.listInput) == 0
|
||||
|
||||
|
||||
class TestGraphCopy:
|
||||
def test_graphCopyIsIdenticalToOriginalGraph(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeB = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA.output, nodeB.input)
|
||||
|
||||
graphCopy = graph.copy()
|
||||
assert compareGraphsContent(graph, graphCopy)
|
||||
|
||||
def test_graphCopyWithUnknownNodeTypesDiffersFromOriginalGraph(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
nodeA = graph.addNewNode(SimpleNode.__name__)
|
||||
nodeB = graph.addNewNode(SimpleNode.__name__)
|
||||
|
||||
graph.addEdge(nodeA.output, nodeB.input)
|
||||
|
||||
graphCopy = graph.copy()
|
||||
assert not compareGraphsContent(graph, graphCopy)
|
||||
|
||||
|
||||
class TestImportGraphContentFromMinimalGraphData:
|
||||
def test_nodeWithoutVersionInfoIsUpgraded(self):
|
||||
graph = Graph("")
|
||||
|
||||
with (
|
||||
registeredNodeTypes([SimpleNode]),
|
||||
overrideNodeTypeVersion(SimpleNode, "2.0"),
|
||||
):
|
||||
sampleGraphContent = dedent("""
|
||||
{
|
||||
"SimpleNode_1": { "nodeType": "SimpleNode" }
|
||||
}
|
||||
""")
|
||||
graph._deserialize(json.loads(sampleGraphContent))
|
||||
|
||||
assert len(graph.nodes) == 1
|
||||
assert len(graph.compatibilityNodes) == 0
|
||||
|
||||
def test_connectionsToMissingNodesAreDiscarded(self):
|
||||
graph = Graph("")
|
||||
|
||||
with registeredNodeTypes([SimpleNode]):
|
||||
sampleGraphContent = dedent("""
|
||||
{
|
||||
"SimpleNode_1": {
|
||||
"nodeType": "SimpleNode", "inputs": { "input": "{NotSerializedNode.output}" }
|
||||
}
|
||||
}
|
||||
""")
|
||||
graph._deserialize(json.loads(sampleGraphContent))
|
|
@ -431,3 +431,28 @@ class TestAttributeCallbackBehaviorWithUpstreamDynamicOutputs:
|
|||
assert nodeB.affectedInput.value == 0
|
||||
|
||||
|
||||
class TestAttributeCallbackBehaviorOnGraphImport:
|
||||
@classmethod
|
||||
def setup_class(cls):
|
||||
registerNodeType(NodeWithAttributeChangedCallback)
|
||||
|
||||
@classmethod
|
||||
def teardown_class(cls):
|
||||
unregisterNodeType(NodeWithAttributeChangedCallback)
|
||||
|
||||
def test_importingGraphDoesNotTriggerAttributeChangedCallbacks(self):
|
||||
graph = Graph("")
|
||||
|
||||
nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
|
||||
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
|
||||
|
||||
graph.addEdge(nodeA.affectedInput, nodeB.input)
|
||||
|
||||
nodeA.input.value = 5
|
||||
nodeB.affectedInput.value = 2
|
||||
|
||||
otherGraph = Graph("")
|
||||
otherGraph.importGraphContent(graph)
|
||||
|
||||
assert otherGraph.node(nodeB.name).affectedInput.value == 2
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
from meshroom.core.graph import Graph
|
||||
from meshroom.core import pipelineTemplates, Version
|
||||
from meshroom.core.node import CompatibilityIssue, CompatibilityNode
|
||||
from meshroom.core.graphIO import GraphIO
|
||||
|
||||
import json
|
||||
import meshroom
|
||||
|
@ -24,13 +25,13 @@ def test_templateVersions():
|
|||
with open(path) as jsonFile:
|
||||
fileData = json.load(jsonFile)
|
||||
|
||||
graphData = fileData.get(Graph.IO.Keys.Graph, fileData)
|
||||
graphData = fileData.get(GraphIO.Keys.Graph, fileData)
|
||||
|
||||
assert isinstance(graphData, dict)
|
||||
|
||||
header = fileData.get(Graph.IO.Keys.Header, {})
|
||||
header = fileData.get(GraphIO.Keys.Header, {})
|
||||
assert header.get("template", False)
|
||||
nodesVersions = header.get(Graph.IO.Keys.NodesVersions, {})
|
||||
nodesVersions = header.get(GraphIO.Keys.NodesVersions, {})
|
||||
|
||||
for _, nodeData in graphData.items():
|
||||
nodeType = nodeData["nodeType"]
|
||||
|
|
28
tests/utils.py
Normal file
28
tests/utils.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
from contextlib import contextmanager
|
||||
from unittest.mock import patch
|
||||
from typing import Type
|
||||
|
||||
import meshroom
|
||||
from meshroom.core import registerNodeType, unregisterNodeType
|
||||
from meshroom.core import desc
|
||||
|
||||
@contextmanager
|
||||
def registeredNodeTypes(nodeTypes: list[Type[desc.Node]]):
|
||||
for nodeType in nodeTypes:
|
||||
registerNodeType(nodeType)
|
||||
|
||||
yield
|
||||
|
||||
for nodeType in nodeTypes:
|
||||
unregisterNodeType(nodeType)
|
||||
|
||||
@contextmanager
|
||||
def overrideNodeTypeVersion(nodeType: Type[desc.Node], version: str):
|
||||
"""Helper context manager to override the version of a given node type."""
|
||||
unpatchedFunc = meshroom.core.nodeVersion
|
||||
with patch.object(
|
||||
meshroom.core,
|
||||
"nodeVersion",
|
||||
side_effect=lambda type: version if type is nodeType else unpatchedFunc(type),
|
||||
):
|
||||
yield
|
Loading…
Add table
Add a link
Reference in a new issue