mirror of
https://github.com/alicevision/Meshroom.git
synced 2025-04-30 10:47:34 +02:00
[graphIO] Introduce graph serializer classes
Move the serialization logic to dedicated serializer classes. Implement both `GraphSerializer` and `TemplateGraphSerializer` to cover for the existing serialization use-cases.
This commit is contained in:
parent
a665200c38
commit
01d67eb33d
2 changed files with 128 additions and 76 deletions
|
@ -17,7 +17,7 @@ from meshroom.common import BaseObject, DictModel, Slot, Signal, Property
|
||||||
from meshroom.core import Version
|
from meshroom.core import Version
|
||||||
from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute
|
from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute
|
||||||
from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit
|
from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit
|
||||||
from meshroom.core.graphIO import GraphIO
|
from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer
|
||||||
from meshroom.core.node import Status, Node, CompatibilityNode
|
from meshroom.core.node import Status, Node, CompatibilityNode
|
||||||
from meshroom.core.nodeFactory import nodeFactory
|
from meshroom.core.nodeFactory import nodeFactory
|
||||||
from meshroom.core.typing import PathLike
|
from meshroom.core.typing import PathLike
|
||||||
|
@ -1386,6 +1386,18 @@ class Graph(BaseObject):
|
||||||
def asString(self):
|
def asString(self):
|
||||||
return str(self.toDict())
|
return str(self.toDict())
|
||||||
|
|
||||||
|
def serialize(self, asTemplate: bool = False) -> dict:
|
||||||
|
"""Serialize this Graph instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
asTemplate: Whether to use the template serialization.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The serialized graph data.
|
||||||
|
"""
|
||||||
|
SerializerClass = TemplateGraphSerializer if asTemplate else GraphSerializer
|
||||||
|
return SerializerClass(self).serialize()
|
||||||
|
|
||||||
def save(self, filepath=None, setupProjectFile=True, template=False):
|
def save(self, filepath=None, setupProjectFile=True, template=False):
|
||||||
"""
|
"""
|
||||||
Save the current Meshroom graph as a serialized ".mg" file.
|
Save the current Meshroom graph as a serialized ".mg" file.
|
||||||
|
@ -1408,34 +1420,7 @@ class Graph(BaseObject):
|
||||||
if not path:
|
if not path:
|
||||||
raise ValueError("filepath must be specified for unsaved files.")
|
raise ValueError("filepath must be specified for unsaved files.")
|
||||||
|
|
||||||
self.header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__
|
data = self.serialize(template)
|
||||||
self.header[GraphIO.Keys.FileVersion] = GraphIO.__version__
|
|
||||||
|
|
||||||
# Store versions of node types present in the graph (excluding CompatibilityNode instances)
|
|
||||||
# and remove duplicates
|
|
||||||
usedNodeTypes = set([n.nodeDesc.__class__ for n in self._nodes if isinstance(n, Node)])
|
|
||||||
# Convert to node types to "name: version"
|
|
||||||
nodesVersions = {
|
|
||||||
"{}".format(p.__name__): meshroom.core.nodeVersion(p, "0.0")
|
|
||||||
for p in usedNodeTypes
|
|
||||||
}
|
|
||||||
# Sort them by name (to avoid random order changing from one save to another)
|
|
||||||
nodesVersions = dict(sorted(nodesVersions.items()))
|
|
||||||
# Add it the header
|
|
||||||
self.header[GraphIO.Keys.NodesVersions] = nodesVersions
|
|
||||||
self.header["template"] = template
|
|
||||||
|
|
||||||
data = {}
|
|
||||||
if template:
|
|
||||||
data = {
|
|
||||||
GraphIO.Keys.Header: self.header,
|
|
||||||
GraphIO.Keys.Graph: self.getNonDefaultInputAttributes()
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
data = {
|
|
||||||
GraphIO.Keys.Header: self.header,
|
|
||||||
GraphIO.Keys.Graph: self.toDict()
|
|
||||||
}
|
|
||||||
|
|
||||||
with open(path, 'w') as jsonFile:
|
with open(path, 'w') as jsonFile:
|
||||||
json.dump(data, jsonFile, indent=4)
|
json.dump(data, jsonFile, indent=4)
|
||||||
|
@ -1446,51 +1431,6 @@ class Graph(BaseObject):
|
||||||
# update the file date version
|
# update the file date version
|
||||||
self._fileDateVersion = os.path.getmtime(path)
|
self._fileDateVersion = os.path.getmtime(path)
|
||||||
|
|
||||||
def getNonDefaultInputAttributes(self):
|
|
||||||
"""
|
|
||||||
Instead of getting all the inputs and internal attribute keys, only get the keys of
|
|
||||||
the attributes whose value is not the default one.
|
|
||||||
The output attributes, UIDs, parallelization parameters and internal folder are
|
|
||||||
not relevant for templates, so they are explicitly removed from the returned dictionary.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
dict: self.toDict() with the output attributes, UIDs, parallelization parameters, internal folder
|
|
||||||
and input/internal attributes with default values removed
|
|
||||||
"""
|
|
||||||
graph = self.toDict()
|
|
||||||
for nodeName in graph.keys():
|
|
||||||
node = self.node(nodeName)
|
|
||||||
|
|
||||||
inputKeys = list(graph[nodeName]["inputs"].keys())
|
|
||||||
|
|
||||||
internalInputKeys = []
|
|
||||||
internalInputs = graph[nodeName].get("internalInputs", None)
|
|
||||||
if internalInputs:
|
|
||||||
internalInputKeys = list(internalInputs.keys())
|
|
||||||
|
|
||||||
for attrName in inputKeys:
|
|
||||||
attribute = node.attribute(attrName)
|
|
||||||
# check that attribute is not a link for choice attributes
|
|
||||||
if attribute.isDefault and not attribute.isLink:
|
|
||||||
del graph[nodeName]["inputs"][attrName]
|
|
||||||
|
|
||||||
for attrName in internalInputKeys:
|
|
||||||
attribute = node.internalAttribute(attrName)
|
|
||||||
# check that internal attribute is not a link for choice attributes
|
|
||||||
if attribute.isDefault and not attribute.isLink:
|
|
||||||
del graph[nodeName]["internalInputs"][attrName]
|
|
||||||
|
|
||||||
# If all the internal attributes are set to their default values, remove the entry
|
|
||||||
if len(graph[nodeName]["internalInputs"]) == 0:
|
|
||||||
del graph[nodeName]["internalInputs"]
|
|
||||||
|
|
||||||
del graph[nodeName]["outputs"]
|
|
||||||
del graph[nodeName]["uid"]
|
|
||||||
del graph[nodeName]["internalFolder"]
|
|
||||||
del graph[nodeName]["parallelization"]
|
|
||||||
|
|
||||||
return graph
|
|
||||||
|
|
||||||
def _setFilepath(self, filepath):
|
def _setFilepath(self, filepath):
|
||||||
"""
|
"""
|
||||||
Set the internal filepath of this Graph.
|
Set the internal filepath of this Graph.
|
||||||
|
|
|
@ -1,7 +1,12 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Union
|
from typing import Any, TYPE_CHECKING, Union
|
||||||
|
|
||||||
|
import meshroom
|
||||||
from meshroom.core import Version
|
from meshroom.core import Version
|
||||||
|
from meshroom.core.node import Node
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from meshroom.core.graph import Graph
|
||||||
|
|
||||||
|
|
||||||
class GraphIO:
|
class GraphIO:
|
||||||
|
@ -54,3 +59,110 @@ class GraphIO:
|
||||||
|
|
||||||
return tuple(features)
|
return tuple(features)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphSerializer:
|
||||||
|
"""Standard Graph serializer."""
|
||||||
|
|
||||||
|
def __init__(self, graph: "Graph") -> None:
|
||||||
|
self._graph = graph
|
||||||
|
|
||||||
|
def serialize(self) -> dict:
|
||||||
|
"""
|
||||||
|
Serialize the Graph.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
GraphIO.Keys.Header: self.serializeHeader(),
|
||||||
|
GraphIO.Keys.Graph: self.serializeContent(),
|
||||||
|
}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def nodes(self) -> list[Node]:
|
||||||
|
return self._graph.nodes
|
||||||
|
|
||||||
|
def serializeHeader(self) -> dict:
|
||||||
|
"""Build and return the graph serialization header.
|
||||||
|
|
||||||
|
The header contains metadata about the graph, such as the:
|
||||||
|
- version of the software used to create it.
|
||||||
|
- version of the file format.
|
||||||
|
- version of the nodes types used in the graph.
|
||||||
|
- template flag.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
nodes: (optional) The list of nodes to consider for node types versions - use all nodes if not specified.
|
||||||
|
template: Whether the graph is going to be serialized as a template.
|
||||||
|
"""
|
||||||
|
header: dict[str, Any] = {}
|
||||||
|
header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__
|
||||||
|
header[GraphIO.Keys.FileVersion] = GraphIO.__version__
|
||||||
|
header[GraphIO.Keys.NodesVersions] = self._getNodeTypesVersions()
|
||||||
|
return header
|
||||||
|
|
||||||
|
def _getNodeTypesVersions(self) -> dict[str, str]:
|
||||||
|
"""Get registered versions of each node types in `nodes`, excluding CompatibilityNode instances."""
|
||||||
|
nodeTypes = set([node.nodeDesc.__class__ for node in self.nodes if isinstance(node, Node)])
|
||||||
|
nodeTypesVersions = {
|
||||||
|
nodeType.__name__: meshroom.core.nodeVersion(nodeType, "0.0") for nodeType in nodeTypes
|
||||||
|
}
|
||||||
|
# Sort them by name (to avoid random order changing from one save to another).
|
||||||
|
return dict(sorted(nodeTypesVersions.items()))
|
||||||
|
|
||||||
|
def serializeContent(self) -> dict:
|
||||||
|
"""Graph content serialization logic."""
|
||||||
|
return {node.name: self.serializeNode(node) for node in sorted(self.nodes, key=lambda n: n.name)}
|
||||||
|
|
||||||
|
def serializeNode(self, node: Node) -> dict:
|
||||||
|
"""Node serialization logic."""
|
||||||
|
return node.toDict()
|
||||||
|
|
||||||
|
|
||||||
|
class TemplateGraphSerializer(GraphSerializer):
|
||||||
|
"""Serializer for serializing a graph as a template."""
|
||||||
|
|
||||||
|
def serializeHeader(self) -> dict:
|
||||||
|
header = super().serializeHeader()
|
||||||
|
header["template"] = True
|
||||||
|
return header
|
||||||
|
|
||||||
|
def serializeNode(self, node: Node) -> dict:
|
||||||
|
"""Adapt node serialization to template graphs.
|
||||||
|
|
||||||
|
Instead of getting all the inputs and internal attribute keys, only get the keys of
|
||||||
|
the attributes whose value is not the default one.
|
||||||
|
The output attributes, UIDs, parallelization parameters and internal folder are
|
||||||
|
not relevant for templates, so they are explicitly removed from the returned dictionary.
|
||||||
|
"""
|
||||||
|
# For now, implemented as a post-process to update the default serialization.
|
||||||
|
nodeData = super().serializeNode(node)
|
||||||
|
|
||||||
|
inputKeys = list(nodeData["inputs"].keys())
|
||||||
|
|
||||||
|
internalInputKeys = []
|
||||||
|
internalInputs = nodeData.get("internalInputs", None)
|
||||||
|
if internalInputs:
|
||||||
|
internalInputKeys = list(internalInputs.keys())
|
||||||
|
|
||||||
|
for attrName in inputKeys:
|
||||||
|
attribute = node.attribute(attrName)
|
||||||
|
# check that attribute is not a link for choice attributes
|
||||||
|
if attribute.isDefault and not attribute.isLink:
|
||||||
|
del nodeData["inputs"][attrName]
|
||||||
|
|
||||||
|
for attrName in internalInputKeys:
|
||||||
|
attribute = node.internalAttribute(attrName)
|
||||||
|
# check that internal attribute is not a link for choice attributes
|
||||||
|
if attribute.isDefault and not attribute.isLink:
|
||||||
|
del nodeData["internalInputs"][attrName]
|
||||||
|
|
||||||
|
# If all the internal attributes are set to their default values, remove the entry
|
||||||
|
if len(nodeData["internalInputs"]) == 0:
|
||||||
|
del nodeData["internalInputs"]
|
||||||
|
|
||||||
|
del nodeData["outputs"]
|
||||||
|
del nodeData["uid"]
|
||||||
|
del nodeData["internalFolder"]
|
||||||
|
del nodeData["parallelization"]
|
||||||
|
|
||||||
|
return nodeData
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Reference in a new issue