mirror of
https://github.com/alicevision/Meshroom.git
synced 2025-04-28 17:57:16 +02:00
[graphIO] Introduce graph serializer classes
Move the serialization logic to dedicated serializer classes. Implement both `GraphSerializer` and `TemplateGraphSerializer` to cover for the existing serialization use-cases.
This commit is contained in:
parent
a665200c38
commit
01d67eb33d
2 changed files with 128 additions and 76 deletions
|
@ -17,7 +17,7 @@ from meshroom.common import BaseObject, DictModel, Slot, Signal, Property
|
|||
from meshroom.core import Version
|
||||
from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute
|
||||
from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit
|
||||
from meshroom.core.graphIO import GraphIO
|
||||
from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer
|
||||
from meshroom.core.node import Status, Node, CompatibilityNode
|
||||
from meshroom.core.nodeFactory import nodeFactory
|
||||
from meshroom.core.typing import PathLike
|
||||
|
@ -1386,6 +1386,18 @@ class Graph(BaseObject):
|
|||
def asString(self):
|
||||
return str(self.toDict())
|
||||
|
||||
def serialize(self, asTemplate: bool = False) -> dict:
|
||||
"""Serialize this Graph instance.
|
||||
|
||||
Args:
|
||||
asTemplate: Whether to use the template serialization.
|
||||
|
||||
Returns:
|
||||
The serialized graph data.
|
||||
"""
|
||||
SerializerClass = TemplateGraphSerializer if asTemplate else GraphSerializer
|
||||
return SerializerClass(self).serialize()
|
||||
|
||||
def save(self, filepath=None, setupProjectFile=True, template=False):
|
||||
"""
|
||||
Save the current Meshroom graph as a serialized ".mg" file.
|
||||
|
@ -1408,34 +1420,7 @@ class Graph(BaseObject):
|
|||
if not path:
|
||||
raise ValueError("filepath must be specified for unsaved files.")
|
||||
|
||||
self.header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__
|
||||
self.header[GraphIO.Keys.FileVersion] = GraphIO.__version__
|
||||
|
||||
# Store versions of node types present in the graph (excluding CompatibilityNode instances)
|
||||
# and remove duplicates
|
||||
usedNodeTypes = set([n.nodeDesc.__class__ for n in self._nodes if isinstance(n, Node)])
|
||||
# Convert to node types to "name: version"
|
||||
nodesVersions = {
|
||||
"{}".format(p.__name__): meshroom.core.nodeVersion(p, "0.0")
|
||||
for p in usedNodeTypes
|
||||
}
|
||||
# Sort them by name (to avoid random order changing from one save to another)
|
||||
nodesVersions = dict(sorted(nodesVersions.items()))
|
||||
# Add it the header
|
||||
self.header[GraphIO.Keys.NodesVersions] = nodesVersions
|
||||
self.header["template"] = template
|
||||
|
||||
data = {}
|
||||
if template:
|
||||
data = {
|
||||
GraphIO.Keys.Header: self.header,
|
||||
GraphIO.Keys.Graph: self.getNonDefaultInputAttributes()
|
||||
}
|
||||
else:
|
||||
data = {
|
||||
GraphIO.Keys.Header: self.header,
|
||||
GraphIO.Keys.Graph: self.toDict()
|
||||
}
|
||||
data = self.serialize(template)
|
||||
|
||||
with open(path, 'w') as jsonFile:
|
||||
json.dump(data, jsonFile, indent=4)
|
||||
|
@ -1446,51 +1431,6 @@ class Graph(BaseObject):
|
|||
# update the file date version
|
||||
self._fileDateVersion = os.path.getmtime(path)
|
||||
|
||||
def getNonDefaultInputAttributes(self):
|
||||
"""
|
||||
Instead of getting all the inputs and internal attribute keys, only get the keys of
|
||||
the attributes whose value is not the default one.
|
||||
The output attributes, UIDs, parallelization parameters and internal folder are
|
||||
not relevant for templates, so they are explicitly removed from the returned dictionary.
|
||||
|
||||
Returns:
|
||||
dict: self.toDict() with the output attributes, UIDs, parallelization parameters, internal folder
|
||||
and input/internal attributes with default values removed
|
||||
"""
|
||||
graph = self.toDict()
|
||||
for nodeName in graph.keys():
|
||||
node = self.node(nodeName)
|
||||
|
||||
inputKeys = list(graph[nodeName]["inputs"].keys())
|
||||
|
||||
internalInputKeys = []
|
||||
internalInputs = graph[nodeName].get("internalInputs", None)
|
||||
if internalInputs:
|
||||
internalInputKeys = list(internalInputs.keys())
|
||||
|
||||
for attrName in inputKeys:
|
||||
attribute = node.attribute(attrName)
|
||||
# check that attribute is not a link for choice attributes
|
||||
if attribute.isDefault and not attribute.isLink:
|
||||
del graph[nodeName]["inputs"][attrName]
|
||||
|
||||
for attrName in internalInputKeys:
|
||||
attribute = node.internalAttribute(attrName)
|
||||
# check that internal attribute is not a link for choice attributes
|
||||
if attribute.isDefault and not attribute.isLink:
|
||||
del graph[nodeName]["internalInputs"][attrName]
|
||||
|
||||
# If all the internal attributes are set to their default values, remove the entry
|
||||
if len(graph[nodeName]["internalInputs"]) == 0:
|
||||
del graph[nodeName]["internalInputs"]
|
||||
|
||||
del graph[nodeName]["outputs"]
|
||||
del graph[nodeName]["uid"]
|
||||
del graph[nodeName]["internalFolder"]
|
||||
del graph[nodeName]["parallelization"]
|
||||
|
||||
return graph
|
||||
|
||||
def _setFilepath(self, filepath):
|
||||
"""
|
||||
Set the internal filepath of this Graph.
|
||||
|
|
|
@ -1,7 +1,12 @@
|
|||
from enum import Enum
|
||||
from typing import Union
|
||||
from typing import Any, TYPE_CHECKING, Union
|
||||
|
||||
import meshroom
|
||||
from meshroom.core import Version
|
||||
from meshroom.core.node import Node
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from meshroom.core.graph import Graph
|
||||
|
||||
|
||||
class GraphIO:
|
||||
|
@ -29,7 +34,7 @@ class GraphIO:
|
|||
NodesPositions = "nodesPositions"
|
||||
|
||||
@staticmethod
|
||||
def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features",...]:
|
||||
def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features", ...]:
|
||||
"""Return the list of supported features based on a file version.
|
||||
|
||||
Args:
|
||||
|
@ -54,3 +59,110 @@ class GraphIO:
|
|||
|
||||
return tuple(features)
|
||||
|
||||
|
||||
class GraphSerializer:
|
||||
"""Standard Graph serializer."""
|
||||
|
||||
def __init__(self, graph: "Graph") -> None:
|
||||
self._graph = graph
|
||||
|
||||
def serialize(self) -> dict:
|
||||
"""
|
||||
Serialize the Graph.
|
||||
"""
|
||||
return {
|
||||
GraphIO.Keys.Header: self.serializeHeader(),
|
||||
GraphIO.Keys.Graph: self.serializeContent(),
|
||||
}
|
||||
|
||||
@property
|
||||
def nodes(self) -> list[Node]:
|
||||
return self._graph.nodes
|
||||
|
||||
def serializeHeader(self) -> dict:
|
||||
"""Build and return the graph serialization header.
|
||||
|
||||
The header contains metadata about the graph, such as the:
|
||||
- version of the software used to create it.
|
||||
- version of the file format.
|
||||
- version of the nodes types used in the graph.
|
||||
- template flag.
|
||||
|
||||
Args:
|
||||
nodes: (optional) The list of nodes to consider for node types versions - use all nodes if not specified.
|
||||
template: Whether the graph is going to be serialized as a template.
|
||||
"""
|
||||
header: dict[str, Any] = {}
|
||||
header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__
|
||||
header[GraphIO.Keys.FileVersion] = GraphIO.__version__
|
||||
header[GraphIO.Keys.NodesVersions] = self._getNodeTypesVersions()
|
||||
return header
|
||||
|
||||
def _getNodeTypesVersions(self) -> dict[str, str]:
|
||||
"""Get registered versions of each node types in `nodes`, excluding CompatibilityNode instances."""
|
||||
nodeTypes = set([node.nodeDesc.__class__ for node in self.nodes if isinstance(node, Node)])
|
||||
nodeTypesVersions = {
|
||||
nodeType.__name__: meshroom.core.nodeVersion(nodeType, "0.0") for nodeType in nodeTypes
|
||||
}
|
||||
# Sort them by name (to avoid random order changing from one save to another).
|
||||
return dict(sorted(nodeTypesVersions.items()))
|
||||
|
||||
def serializeContent(self) -> dict:
|
||||
"""Graph content serialization logic."""
|
||||
return {node.name: self.serializeNode(node) for node in sorted(self.nodes, key=lambda n: n.name)}
|
||||
|
||||
def serializeNode(self, node: Node) -> dict:
|
||||
"""Node serialization logic."""
|
||||
return node.toDict()
|
||||
|
||||
|
||||
class TemplateGraphSerializer(GraphSerializer):
|
||||
"""Serializer for serializing a graph as a template."""
|
||||
|
||||
def serializeHeader(self) -> dict:
|
||||
header = super().serializeHeader()
|
||||
header["template"] = True
|
||||
return header
|
||||
|
||||
def serializeNode(self, node: Node) -> dict:
|
||||
"""Adapt node serialization to template graphs.
|
||||
|
||||
Instead of getting all the inputs and internal attribute keys, only get the keys of
|
||||
the attributes whose value is not the default one.
|
||||
The output attributes, UIDs, parallelization parameters and internal folder are
|
||||
not relevant for templates, so they are explicitly removed from the returned dictionary.
|
||||
"""
|
||||
# For now, implemented as a post-process to update the default serialization.
|
||||
nodeData = super().serializeNode(node)
|
||||
|
||||
inputKeys = list(nodeData["inputs"].keys())
|
||||
|
||||
internalInputKeys = []
|
||||
internalInputs = nodeData.get("internalInputs", None)
|
||||
if internalInputs:
|
||||
internalInputKeys = list(internalInputs.keys())
|
||||
|
||||
for attrName in inputKeys:
|
||||
attribute = node.attribute(attrName)
|
||||
# check that attribute is not a link for choice attributes
|
||||
if attribute.isDefault and not attribute.isLink:
|
||||
del nodeData["inputs"][attrName]
|
||||
|
||||
for attrName in internalInputKeys:
|
||||
attribute = node.internalAttribute(attrName)
|
||||
# check that internal attribute is not a link for choice attributes
|
||||
if attribute.isDefault and not attribute.isLink:
|
||||
del nodeData["internalInputs"][attrName]
|
||||
|
||||
# If all the internal attributes are set to their default values, remove the entry
|
||||
if len(nodeData["internalInputs"]) == 0:
|
||||
del nodeData["internalInputs"]
|
||||
|
||||
del nodeData["outputs"]
|
||||
del nodeData["uid"]
|
||||
del nodeData["internalFolder"]
|
||||
del nodeData["parallelization"]
|
||||
|
||||
return nodeData
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Reference in a new issue