[core] Introducing new graphIO module

Move Graph.IO internal class to its own module, and rename it to `GraphIO`.
This avoid nested classes within the core Graph class, and starts decoupling
the management of graph's IO from the logic of the graph itself.
This commit is contained in:
Yann Lanthony 2025-02-06 16:46:04 +01:00
parent 3064cb9b35
commit a665200c38
4 changed files with 76 additions and 63 deletions

View file

@ -17,6 +17,7 @@ from meshroom.common import BaseObject, DictModel, Slot, Signal, Property
from meshroom.core import Version
from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute
from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit
from meshroom.core.graphIO import GraphIO
from meshroom.core.node import Status, Node, CompatibilityNode
from meshroom.core.nodeFactory import nodeFactory
from meshroom.core.typing import PathLike
@ -183,52 +184,6 @@ class Graph(BaseObject):
"""
_cacheDir = ""
class IO(object):
""" Centralize Graph file keys and IO version. """
__version__ = "2.0"
class Keys(object):
""" File Keys. """
# Doesn't inherit enum to simplify usage (Graph.IO.Keys.XX, without .value)
Header = "header"
NodesVersions = "nodesVersions"
ReleaseVersion = "releaseVersion"
FileVersion = "fileVersion"
Graph = "graph"
class Features(Enum):
""" File Features. """
Graph = "graph"
Header = "header"
NodesVersions = "nodesVersions"
PrecomputedOutputs = "precomputedOutputs"
NodesPositions = "nodesPositions"
@staticmethod
def getFeaturesForVersion(fileVersion):
""" Return the list of supported features based on a file version.
Args:
fileVersion (str, Version): the file version
Returns:
tuple of Graph.IO.Features: the list of supported features
"""
if isinstance(fileVersion, str):
fileVersion = Version(fileVersion)
features = [Graph.IO.Features.Graph]
if fileVersion >= Version("1.0"):
features += [Graph.IO.Features.Header,
Graph.IO.Features.NodesVersions,
Graph.IO.Features.PrecomputedOutputs,
]
if fileVersion >= Version("1.1"):
features += [Graph.IO.Features.NodesPositions]
return tuple(features)
def __init__(self, name, parent=None):
super(Graph, self).__init__(parent)
self.name = name
@ -265,7 +220,7 @@ class Graph(BaseObject):
@property
def fileFeatures(self):
""" Get loaded file supported features based on its version. """
return Graph.IO.getFeaturesForVersion(self.header.get(Graph.IO.Keys.FileVersion, "0.0"))
return GraphIO.getFeaturesForVersion(self.header.get(GraphIO.Keys.FileVersion, "0.0"))
@property
def isLoading(self):
@ -321,8 +276,8 @@ class Graph(BaseObject):
graphData: The serialized Graph.
"""
self.clear()
self.header = graphData.get(Graph.IO.Keys.Header, {})
fileVersion = Version(self.header.get(Graph.IO.Keys.FileVersion, "0.0"))
self.header = graphData.get(GraphIO.Keys.Header, {})
fileVersion = Version(self.header.get(GraphIO.Keys.FileVersion, "0.0"))
graphContent = self._normalizeGraphContent(graphData, fileVersion)
isTemplate = self.header.get("template", False)
@ -355,7 +310,7 @@ class Graph(BaseObject):
logging.warning(e)
def _normalizeGraphContent(self, graphData: dict, fileVersion: Version) -> dict:
graphContent = graphData.get(Graph.IO.Keys.Graph, graphData)
graphContent = graphData.get(GraphIO.Keys.Graph, graphData)
if fileVersion < Version("2.0"):
# For internal folders, all "{uid0}" keys should be replaced with "{uid}"
@ -388,7 +343,7 @@ class Graph(BaseObject):
return node
def _getNodeTypeVersionFromHeader(self, nodeType: str, default: Optional[str] = None) -> Optional[str]:
nodeVersions = self.header.get(Graph.IO.Keys.NodesVersions, {})
nodeVersions = self.header.get(GraphIO.Keys.NodesVersions, {})
return nodeVersions.get(nodeType, default)
def _evaluateUidConflicts(self, data):
@ -1453,8 +1408,8 @@ class Graph(BaseObject):
if not path:
raise ValueError("filepath must be specified for unsaved files.")
self.header[Graph.IO.Keys.ReleaseVersion] = meshroom.__version__
self.header[Graph.IO.Keys.FileVersion] = Graph.IO.__version__
self.header[GraphIO.Keys.ReleaseVersion] = meshroom.__version__
self.header[GraphIO.Keys.FileVersion] = GraphIO.__version__
# Store versions of node types present in the graph (excluding CompatibilityNode instances)
# and remove duplicates
@ -1467,19 +1422,19 @@ class Graph(BaseObject):
# Sort them by name (to avoid random order changing from one save to another)
nodesVersions = dict(sorted(nodesVersions.items()))
# Add it the header
self.header[Graph.IO.Keys.NodesVersions] = nodesVersions
self.header[GraphIO.Keys.NodesVersions] = nodesVersions
self.header["template"] = template
data = {}
if template:
data = {
Graph.IO.Keys.Header: self.header,
Graph.IO.Keys.Graph: self.getNonDefaultInputAttributes()
GraphIO.Keys.Header: self.header,
GraphIO.Keys.Graph: self.getNonDefaultInputAttributes()
}
else:
data = {
Graph.IO.Keys.Header: self.header,
Graph.IO.Keys.Graph: self.toDict()
GraphIO.Keys.Header: self.header,
GraphIO.Keys.Graph: self.toDict()
}
with open(path, 'w') as jsonFile:
@ -1734,7 +1689,7 @@ class Graph(BaseObject):
filepathChanged = Signal()
filepath = Property(str, lambda self: self._filepath, notify=filepathChanged)
isSaving = Property(bool, isSaving.fget, constant=True)
fileReleaseVersion = Property(str, lambda self: self.header.get(Graph.IO.Keys.ReleaseVersion, "0.0"),
fileReleaseVersion = Property(str, lambda self: self.header.get(GraphIO.Keys.ReleaseVersion, "0.0"),
notify=filepathChanged)
fileDateVersion = Property(float, fileDateVersion.fget, fileDateVersion.fset, notify=filepathChanged)
cacheDirChanged = Signal()

56
meshroom/core/graphIO.py Normal file
View file

@ -0,0 +1,56 @@
from enum import Enum
from typing import Union
from meshroom.core import Version
class GraphIO:
"""Centralize Graph file keys and IO version."""
__version__ = "2.0"
class Keys(object):
"""File Keys."""
# Doesn't inherit enum to simplify usage (GraphIO.Keys.XX, without .value)
Header = "header"
NodesVersions = "nodesVersions"
ReleaseVersion = "releaseVersion"
FileVersion = "fileVersion"
Graph = "graph"
class Features(Enum):
"""File Features."""
Graph = "graph"
Header = "header"
NodesVersions = "nodesVersions"
PrecomputedOutputs = "precomputedOutputs"
NodesPositions = "nodesPositions"
@staticmethod
def getFeaturesForVersion(fileVersion: Union[str, Version]) -> tuple["GraphIO.Features",...]:
"""Return the list of supported features based on a file version.
Args:
fileVersion (str, Version): the file version
Returns:
tuple of GraphIO.Features: the list of supported features
"""
if isinstance(fileVersion, str):
fileVersion = Version(fileVersion)
features = [GraphIO.Features.Graph]
if fileVersion >= Version("1.0"):
features += [
GraphIO.Features.Header,
GraphIO.Features.NodesVersions,
GraphIO.Features.PrecomputedOutputs,
]
if fileVersion >= Version("1.1"):
features += [GraphIO.Features.NodesPositions]
return tuple(features)

View file

@ -25,6 +25,7 @@ from meshroom.core import sessionUid
from meshroom.common.qt import QObjectListModel
from meshroom.core.attribute import Attribute, ListAttribute
from meshroom.core.graph import Graph, Edge
from meshroom.core.graphIO import GraphIO
from meshroom.core.taskManager import TaskManager
@ -396,7 +397,7 @@ class UIGraph(QObject):
self.updateChunks()
# perform auto-layout if graph does not provide nodes positions
if Graph.IO.Features.NodesPositions not in self._graph.fileFeatures:
if GraphIO.Features.NodesPositions not in self._graph.fileFeatures:
self._layout.reset()
# clear undo-stack after layout
self._undoStack.clear()

View file

@ -4,6 +4,7 @@
from meshroom.core.graph import Graph
from meshroom.core import pipelineTemplates, Version
from meshroom.core.node import CompatibilityIssue, CompatibilityNode
from meshroom.core.graphIO import GraphIO
import json
import meshroom
@ -24,13 +25,13 @@ def test_templateVersions():
with open(path) as jsonFile:
fileData = json.load(jsonFile)
graphData = fileData.get(Graph.IO.Keys.Graph, fileData)
graphData = fileData.get(GraphIO.Keys.Graph, fileData)
assert isinstance(graphData, dict)
header = fileData.get(Graph.IO.Keys.Header, {})
header = fileData.get(GraphIO.Keys.Header, {})
assert header.get("template", False)
nodesVersions = header.get(Graph.IO.Keys.NodesVersions, {})
nodesVersions = header.get(GraphIO.Keys.NodesVersions, {})
for _, nodeData in graphData.items():
nodeType = nodeData["nodeType"]