[graphIO] Introduce graph serializer classes

Move the serialization logic to dedicated serializer classes.
Implement both `GraphSerializer` and `TemplateGraphSerializer`
to cover for the existing serialization use-cases.
This commit is contained in:
Yann Lanthony 2025-02-06 16:46:04 +01:00
parent a665200c38
commit 01d67eb33d
2 changed files with 128 additions and 76 deletions

View file

@ -17,7 +17,7 @@ from meshroom.common import BaseObject, DictModel, Slot, Signal, Property
from meshroom.core import Version
from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute
from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit
from meshroom.core.graphIO import GraphIO
from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer
from meshroom.core.node import Status, Node, CompatibilityNode
from meshroom.core.nodeFactory import nodeFactory
from meshroom.core.typing import PathLike
@ -1386,6 +1386,18 @@ class Graph(BaseObject):
def asString(self):
return str(self.toDict())
def serialize(self, asTemplate: bool = False) -> dict:
"""Serialize this Graph instance.
Args:
asTemplate: Whether to use the template serialization.
Returns:
The serialized graph data.
"""
SerializerClass = TemplateGraphSerializer if asTemplate else GraphSerializer
return SerializerClass(self).serialize()
def save(self, filepath=None, setupProjectFile=True, template=False):
"""
Save the current Meshroom graph as a serialized ".mg" file.
@ -1408,34 +1420,7 @@ class Graph(BaseObject):
if not path:
raise ValueError("filepath must be specified for unsaved files.")
self.header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__
self.header[GraphIO.Keys.FileVersion] = GraphIO.__version__
# Store versions of node types present in the graph (excluding CompatibilityNode instances)
# and remove duplicates
usedNodeTypes = set([n.nodeDesc.__class__ for n in self._nodes if isinstance(n, Node)])
# Convert to node types to "name: version"
nodesVersions = {
"{}".format(p.__name__): meshroom.core.nodeVersion(p, "0.0")
for p in usedNodeTypes
}
# Sort them by name (to avoid random order changing from one save to another)
nodesVersions = dict(sorted(nodesVersions.items()))
# Add it the header
self.header[GraphIO.Keys.NodesVersions] = nodesVersions
self.header["template"] = template
data = {}
if template:
data = {
GraphIO.Keys.Header: self.header,
GraphIO.Keys.Graph: self.getNonDefaultInputAttributes()
}
else:
data = {
GraphIO.Keys.Header: self.header,
GraphIO.Keys.Graph: self.toDict()
}
data = self.serialize(template)
with open(path, 'w') as jsonFile:
json.dump(data, jsonFile, indent=4)
@ -1446,51 +1431,6 @@ class Graph(BaseObject):
# update the file date version
self._fileDateVersion = os.path.getmtime(path)
def getNonDefaultInputAttributes(self):
"""
Instead of getting all the inputs and internal attribute keys, only get the keys of
the attributes whose value is not the default one.
The output attributes, UIDs, parallelization parameters and internal folder are
not relevant for templates, so they are explicitly removed from the returned dictionary.
Returns:
dict: self.toDict() with the output attributes, UIDs, parallelization parameters, internal folder
and input/internal attributes with default values removed
"""
graph = self.toDict()
for nodeName in graph.keys():
node = self.node(nodeName)
inputKeys = list(graph[nodeName]["inputs"].keys())
internalInputKeys = []
internalInputs = graph[nodeName].get("internalInputs", None)
if internalInputs:
internalInputKeys = list(internalInputs.keys())
for attrName in inputKeys:
attribute = node.attribute(attrName)
# check that attribute is not a link for choice attributes
if attribute.isDefault and not attribute.isLink:
del graph[nodeName]["inputs"][attrName]
for attrName in internalInputKeys:
attribute = node.internalAttribute(attrName)
# check that internal attribute is not a link for choice attributes
if attribute.isDefault and not attribute.isLink:
del graph[nodeName]["internalInputs"][attrName]
# If all the internal attributes are set to their default values, remove the entry
if len(graph[nodeName]["internalInputs"]) == 0:
del graph[nodeName]["internalInputs"]
del graph[nodeName]["outputs"]
del graph[nodeName]["uid"]
del graph[nodeName]["internalFolder"]
del graph[nodeName]["parallelization"]
return graph
def _setFilepath(self, filepath):
"""
Set the internal filepath of this Graph.

View file

@ -1,7 +1,12 @@
from enum import Enum
from typing import Union
from typing import Any, TYPE_CHECKING, Union
import meshroom
from meshroom.core import Version
from meshroom.core.node import Node
if TYPE_CHECKING:
from meshroom.core.graph import Graph
class GraphIO:
@ -29,7 +34,7 @@ class GraphIO:
NodesPositions = "nodesPositions"
@staticmethod
def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features",...]:
def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features", ...]:
"""Return the list of supported features based on a file version.
Args:
@ -54,3 +59,110 @@ class GraphIO:
return tuple(features)
class GraphSerializer:
"""Standard Graph serializer."""
def __init__(self, graph: "Graph") -> None:
self._graph = graph
def serialize(self) -> dict:
"""
Serialize the Graph.
"""
return {
GraphIO.Keys.Header: self.serializeHeader(),
GraphIO.Keys.Graph: self.serializeContent(),
}
@property
def nodes(self) -> list[Node]:
return self._graph.nodes
def serializeHeader(self) -> dict:
"""Build and return the graph serialization header.
The header contains metadata about the graph, such as the:
- version of the software used to create it.
- version of the file format.
- version of the nodes types used in the graph.
- template flag.
Args:
nodes: (optional) The list of nodes to consider for node types versions - use all nodes if not specified.
template: Whether the graph is going to be serialized as a template.
"""
header: dict[str, Any] = {}
header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__
header[GraphIO.Keys.FileVersion] = GraphIO.__version__
header[GraphIO.Keys.NodesVersions] = self._getNodeTypesVersions()
return header
def _getNodeTypesVersions(self) -> dict[str, str]:
"""Get registered versions of each node types in `nodes`, excluding CompatibilityNode instances."""
nodeTypes = set([node.nodeDesc.__class__ for node in self.nodes if isinstance(node, Node)])
nodeTypesVersions = {
nodeType.__name__: meshroom.core.nodeVersion(nodeType, "0.0") for nodeType in nodeTypes
}
# Sort them by name (to avoid random order changing from one save to another).
return dict(sorted(nodeTypesVersions.items()))
def serializeContent(self) -> dict:
"""Graph content serialization logic."""
return {node.name: self.serializeNode(node) for node in sorted(self.nodes, key=lambda n: n.name)}
def serializeNode(self, node: Node) -> dict:
"""Node serialization logic."""
return node.toDict()
class TemplateGraphSerializer(GraphSerializer):
"""Serializer for serializing a graph as a template."""
def serializeHeader(self) -> dict:
header = super().serializeHeader()
header["template"] = True
return header
def serializeNode(self, node: Node) -> dict:
"""Adapt node serialization to template graphs.
Instead of getting all the inputs and internal attribute keys, only get the keys of
the attributes whose value is not the default one.
The output attributes, UIDs, parallelization parameters and internal folder are
not relevant for templates, so they are explicitly removed from the returned dictionary.
"""
# For now, implemented as a post-process to update the default serialization.
nodeData = super().serializeNode(node)
inputKeys = list(nodeData["inputs"].keys())
internalInputKeys = []
internalInputs = nodeData.get("internalInputs", None)
if internalInputs:
internalInputKeys = list(internalInputs.keys())
for attrName in inputKeys:
attribute = node.attribute(attrName)
# check that attribute is not a link for choice attributes
if attribute.isDefault and not attribute.isLink:
del nodeData["inputs"][attrName]
for attrName in internalInputKeys:
attribute = node.internalAttribute(attrName)
# check that internal attribute is not a link for choice attributes
if attribute.isDefault and not attribute.isLink:
del nodeData["internalInputs"][attrName]
# If all the internal attributes are set to their default values, remove the entry
if len(nodeData["internalInputs"]) == 0:
del nodeData["internalInputs"]
del nodeData["outputs"]
del nodeData["uid"]
del nodeData["internalFolder"]
del nodeData["parallelization"]
return nodeData