diff --git a/meshroom/core/attribute.py b/meshroom/core/attribute.py index 5274d3cf..34f8a266 100644 --- a/meshroom/core/attribute.py +++ b/meshroom/core/attribute.py @@ -3,6 +3,7 @@ import copy import os import re +from typing import Optional import weakref import types import logging @@ -11,6 +12,7 @@ from collections.abc import Iterable, Sequence from string import Template from meshroom.common import BaseObject, Property, Variant, Signal, ListModel, DictModel, Slot from meshroom.core import desc, hashValue +from meshroom.core.exception import InvalidEdgeError def attributeFactory(description, value, isOutput, node, root=None, parent=None): @@ -72,6 +74,7 @@ class Attribute(BaseObject): # invalidation value for output attributes self._invalidationValue = "" + self._linkExpression: Optional[str] = None self._value = None self.initValue() @@ -191,9 +194,9 @@ class Attribute(BaseObject): if self._value == value: return - if isinstance(value, Attribute) or Attribute.isLinkExpression(value): - # if we set a link to another attribute - self._value = value + if self._handleLinkValue(value): + return + elif isinstance(value, types.FunctionType): # evaluate the function self._value = value(self) @@ -218,6 +221,27 @@ class Attribute(BaseObject): self.valueChanged.emit() self.validValueChanged.emit() + def _handleLinkValue(self, value) -> bool: + """ + Handle assignment of a link if `value` is a serialized link expression or in-memory Attribute reference. + + Returns: Whether the value has been handled as a link, False otherwise. + """ + isAttribute = isinstance(value, Attribute) + isLinkExpression = Attribute.isLinkExpression(value) + + if not isAttribute and not isLinkExpression: + return False + + if isAttribute: + self._linkExpression = value.asLinkExpr() + # If the value is a direct reference to an attribute, it can be directly converted to an edge as + # the source attribute already exists in memory. + self._applyExpr() + elif isLinkExpression: + self._linkExpression = value + return True + @Slot() def _onValueChanged(self): self.node._onAttributeChanged(self) @@ -329,26 +353,30 @@ class Attribute(BaseObject): this function convert the expression into a real edge in the graph and clear the string value. """ - v = self._value - g = self.node.graph - if not g: + if not self.isInput or not self._linkExpression: return - if isinstance(v, Attribute): - g.addEdge(v, self) - self.resetToDefaultValue() - elif self.isInput and Attribute.isLinkExpression(v): - # value is a link to another attribute - link = v[1:-1] - linkNodeName, linkAttrName = link.split('.') - try: - node = g.node(linkNodeName) - if not node: - raise KeyError(f"Node '{linkNodeName}' not found") - g.addEdge(node.attribute(linkAttrName), self) - except KeyError as err: - logging.warning('Connect Attribute from Expression failed.') - logging.warning('Expression: "{exp}"\nError: "{err}".'.format(exp=v, err=err)) - self.resetToDefaultValue() + + if not (graph := self.node.graph): + return + + link = self._linkExpression[1:-1] + linkNodeName, linkAttrName = link.split(".") + try: + node = graph.node(linkNodeName) + if node is None: + raise InvalidEdgeError(self.fullNameToNode, link, "Source node does not exist") + attr = node.attribute(linkAttrName) + if attr is None: + raise InvalidEdgeError(self.fullNameToNode, link, "Source attribute does not exist") + graph.addEdge(attr, self) + except InvalidEdgeError as err: + logging.warning(err) + except Exception as err: + logging.warning("Unexpected error happened during edge creation") + logging.warning(f"Expression '{self._linkExpression}': {err}") + + self._linkExpression = None + self.resetToDefaultValue() def getExportValue(self): if self.isLink: @@ -543,9 +571,8 @@ class ListAttribute(Attribute): def _set_value(self, value): if self.node.graph: self.remove(0, len(self)) - # Link to another attribute - if isinstance(value, ListAttribute) or Attribute.isLinkExpression(value): - self._value = value + if self._handleLinkValue(value): + return # New value else: # During initialization self._value may not be set @@ -556,10 +583,10 @@ class ListAttribute(Attribute): self.requestGraphUpdate() def upgradeValue(self, exportedValues): + if self._handleLinkValue(exportedValues): + return + if not isinstance(exportedValues, list): - if isinstance(exportedValues, ListAttribute) or Attribute.isLinkExpression(exportedValues): - self._set_value(exportedValues) - return raise RuntimeError("ListAttribute.upgradeValue: the given value is of type " + str(type(exportedValues)) + " but a 'list' is expected.") @@ -620,10 +647,8 @@ class ListAttribute(Attribute): return super(ListAttribute, self).uid() def _applyExpr(self): - if not self.node.graph: - return - if isinstance(self._value, ListAttribute) or Attribute.isLinkExpression(self._value): - super(ListAttribute, self)._applyExpr() + if self._linkExpression: + super()._applyExpr() else: for value in self._value: value._applyExpr() diff --git a/meshroom/core/exception.py b/meshroom/core/exception.py index 2365f8e2..53cba7d5 100644 --- a/meshroom/core/exception.py +++ b/meshroom/core/exception.py @@ -12,6 +12,13 @@ class GraphException(MeshroomException): pass +class InvalidEdgeError(GraphException): + """Raised when an edge between two attributes cannot be created.""" + + def __init__(self, srcAttrName: str, dstAttrName: str, msg: str) -> None: + super().__init__(f"Failed to connect {srcAttrName}->{dstAttrName}: {msg}") + + class GraphCompatibilityError(GraphException): """ Raised when node compatibility issues occur when loading a graph. diff --git a/meshroom/core/graph.py b/meshroom/core/graph.py index ce39d0f3..dd6b94d8 100644 --- a/meshroom/core/graph.py +++ b/meshroom/core/graph.py @@ -16,7 +16,7 @@ import meshroom.core from meshroom.common import BaseObject, DictModel, Slot, Signal, Property from meshroom.core import Version from meshroom.core.attribute import Attribute, ListAttribute, GroupAttribute -from meshroom.core.exception import GraphCompatibilityError, StopGraphVisit, StopBranchVisit +from meshroom.core.exception import GraphCompatibilityError, InvalidEdgeError, StopGraphVisit, StopBranchVisit from meshroom.core.graphIO import GraphIO, GraphSerializer, TemplateGraphSerializer, PartialGraphSerializer from meshroom.core.node import BaseNode, Status, Node, CompatibilityNode from meshroom.core.nodeFactory import nodeFactory @@ -485,41 +485,38 @@ class Graph(BaseObject): node._applyExpr() return node - def copyNode(self, srcNode, withEdges=False): + def copyNode(self, srcNode: Node, withEdges: bool=False): """ Get a copy instance of a node outside the graph. Args: - srcNode (Node): the node to copy - withEdges (bool): whether to copy edges + srcNode: the node to copy + withEdges: whether to copy edges Returns: - Node, dict: the created node instance, - a dictionary of linked attributes with their original value (empty if withEdges is True) + The created node instance and the mapping of skipped edge per attribute (always empty if `withEdges` is True). """ + def _removeLinkExpressions(attribute: Attribute, removed: dict[Attribute, str]): + """Recursively remove link expressions from the given root `attribute`.""" + # Link expressions are only stored on input attributes. + if attribute.isOutput: + return + + if attribute._linkExpression: + removed[attribute] = attribute._linkExpression + attribute._linkExpression = None + elif isinstance(attribute, (ListAttribute, GroupAttribute)): + for child in attribute.value: + _removeLinkExpressions(child, removed) + with GraphModification(self): - # create a new node of the same type and with the same attributes values - # keep links as-is so that CompatibilityNodes attributes can be created with correct automatic description - # (File params for link expressions) - node = nodeFactory(srcNode.toDict(), srcNode.nodeType) # use nodeType as name - # skip edges: filter out attributes which are links by resetting default values + node = nodeFactory(srcNode.toDict(), name=srcNode.nodeType) + skippedEdges = {} if not withEdges: - for n, attr in node.attributes.items(): - if attr.isOutput: - # edges are declared in input with an expression linking - # to another param (which could be an output) - continue - # find top-level links - if Attribute.isLinkExpression(attr.value): - skippedEdges[attr] = attr.value - attr.resetToDefaultValue() - # find links in ListAttribute children - elif isinstance(attr, (ListAttribute, GroupAttribute)): - for child in attr.value: - if Attribute.isLinkExpression(child.value): - skippedEdges[child] = child.value - child.resetToDefaultValue() + for _, attr in node.attributes.items(): + _removeLinkExpressions(attr, skippedEdges) + return node, skippedEdges def duplicateNodes(self, srcNodes): @@ -850,13 +847,16 @@ class Graph(BaseObject): return set(self._nodes) - nodesWithInputLink @changeTopology - def addEdge(self, srcAttr, dstAttr): - assert isinstance(srcAttr, Attribute) - assert isinstance(dstAttr, Attribute) - if srcAttr.node.graph != self or dstAttr.node.graph != self: - raise RuntimeError('The attributes of the edge should be part of a common graph.') + def addEdge(self, srcAttr: Attribute, dstAttr: Attribute): + if not (srcAttr.node.graph == dstAttr.node.graph == self): + raise InvalidEdgeError( + srcAttr.fullNameToGraph, dstAttr.fullNameToGraph, "Attributes do not belong to this Graph" + ) if dstAttr in self.edges.keys(): - raise RuntimeError('Destination attribute "{}" is already connected.'.format(dstAttr.getFullNameToNode())) + raise InvalidEdgeError( + srcAttr.fullNameToNode, dstAttr.fullNameToNode, "Destination is already connected" + ) + edge = Edge(srcAttr, dstAttr) self.edges.add(edge) self.markNodesDirty(dstAttr.node)