Merge pull request #2586 from alicevision/fix/attributeValueChanged

Fix attribute value change propagation and callback handling
This commit is contained in:
Candice Bentéjac 2024-10-30 16:09:43 +00:00 committed by GitHub
commit 35914bdb0e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 495 additions and 47 deletions

View file

@ -25,11 +25,14 @@ def attributeFactory(description, value, isOutput, node, root=None, parent=None)
root: (optional) parent Attribute (must be ListAttribute or GroupAttribute) root: (optional) parent Attribute (must be ListAttribute or GroupAttribute)
parent (BaseObject): (optional) the parent BaseObject if any parent (BaseObject): (optional) the parent BaseObject if any
""" """
attr = description.instanceType(node, description, isOutput, root, parent) attr: Attribute = description.instanceType(node, description, isOutput, root, parent)
if value is not None: if value is not None:
attr._set_value(value, emitSignals=False) attr._set_value(value)
else: else:
attr.resetToDefaultValue(emitSignals=False) attr.resetToDefaultValue()
attr.valueChanged.connect(lambda attr=attr: node._onAttributeChanged(attr))
return attr return attr
@ -67,7 +70,6 @@ class Attribute(BaseObject):
self._value = None self._value = None
self.initValue() self.initValue()
self.valueChanged.connect(self.onChanged)
@property @property
def node(self): def node(self):
@ -180,22 +182,7 @@ class Attribute(BaseObject):
return self.getLinkParam().value return self.getLinkParam().value
return self._value return self._value
def onChanged(self): def _set_value(self, value):
""" Called when the attribute value has changed """
if self.node.isCompatibilityNode:
# We have no access to the node's implementation,
# so we cannot call the custom method.
return
if self.isOutput and not self.node.isInputNode:
# Ignore changes on output attributes for non-input nodes
# as they are updated during the node's computation.
# And we do not want notifications during the graph processing.
return
# notify the node that the attribute has changed
# this will call the node descriptor "onAttrNameChanged" method
self.node.onAttributeChanged(self)
def _set_value(self, value, emitSignals=True):
if self._value == value: if self._value == value:
return return
@ -211,9 +198,6 @@ class Attribute(BaseObject):
convertedValue = self.validateValue(value) convertedValue = self.validateValue(value)
self._value = convertedValue self._value = convertedValue
if not emitSignals:
return
# Request graph update when input parameter value is set # Request graph update when input parameter value is set
# and parent node belongs to a graph # and parent node belongs to a graph
# Output attributes value are set internally during the update process, # Output attributes value are set internally during the update process,
@ -251,8 +235,8 @@ class Attribute(BaseObject):
if self.desc._valueType is not None: if self.desc._valueType is not None:
self._value = self.desc._valueType() self._value = self.desc._valueType()
def resetToDefaultValue(self, emitSignals=True): def resetToDefaultValue(self):
self._set_value(copy.copy(self.defaultValue()), emitSignals=emitSignals) self._set_value(copy.copy(self.defaultValue()))
def requestGraphUpdate(self): def requestGraphUpdate(self):
if self.node.graph: if self.node.graph:
@ -538,14 +522,13 @@ class ListAttribute(Attribute):
return self._value.indexOf(item) return self._value.indexOf(item)
def initValue(self): def initValue(self):
self.resetToDefaultValue(emitSignals=False) self.resetToDefaultValue()
def resetToDefaultValue(self, emitSignals=True): def resetToDefaultValue(self):
self._value = ListModel(parent=self) self._value = ListModel(parent=self)
if emitSignals: self.valueChanged.emit()
self.valueChanged.emit()
def _set_value(self, value, emitSignals=True): def _set_value(self, value):
if self.node.graph: if self.node.graph:
self.remove(0, len(self)) self.remove(0, len(self))
# Link to another attribute # Link to another attribute
@ -558,8 +541,6 @@ class ListAttribute(Attribute):
self._value = ListModel(parent=self) self._value = ListModel(parent=self)
newValue = self.desc.validateValue(value) newValue = self.desc.validateValue(value)
self.extend(newValue) self.extend(newValue)
if not emitSignals:
return
self.requestGraphUpdate() self.requestGraphUpdate()
def upgradeValue(self, exportedValues): def upgradeValue(self, exportedValues):
@ -696,7 +677,7 @@ class GroupAttribute(Attribute):
except KeyError: except KeyError:
raise AttributeError(key) raise AttributeError(key)
def _set_value(self, exportedValue, emitSignals=True): def _set_value(self, exportedValue):
value = self.validateValue(exportedValue) value = self.validateValue(exportedValue)
if isinstance(value, dict): if isinstance(value, dict):
# set individual child attribute values # set individual child attribute values
@ -734,7 +715,7 @@ class GroupAttribute(Attribute):
childAttr.valueChanged.connect(self.valueChanged) childAttr.valueChanged.connect(self.valueChanged)
self._value.reset(subAttributes) self._value.reset(subAttributes)
def resetToDefaultValue(self, emitSignals=True): def resetToDefaultValue(self):
for attrDesc in self.desc._groupDesc: for attrDesc in self.desc._groupDesc:
self._value.get(attrDesc.name).resetToDefaultValue() self._value.get(attrDesc.name).resetToDefaultValue()

View file

@ -14,6 +14,7 @@ import types
import uuid import uuid
from collections import namedtuple from collections import namedtuple
from enum import Enum from enum import Enum
from typing import Callable, Optional
import meshroom import meshroom
from meshroom.common import Signal, Variant, Property, BaseObject, Slot, ListModel, DictModel from meshroom.common import Signal, Variant, Property, BaseObject, Slot, ListModel, DictModel
@ -929,25 +930,50 @@ class BaseNode(BaseObject):
def _updateChunks(self): def _updateChunks(self):
pass pass
def onAttributeChanged(self, attr): def _getAttributeChangedCallback(self, attr: Attribute) -> Optional[Callable]:
""" When an attribute changed, a specific function can be defined in the descriptor and be called. """Get the node descriptor-defined value changed callback associated to `attr` if any."""
# Callbacks cannot be defined on nested attributes.
if attr.root is not None:
return None
attrCapitalizedName = attr.name[:1].upper() + attr.name[1:]
callbackName = f"on{attrCapitalizedName}Changed"
callback = getattr(self.nodeDesc, callbackName, None)
return callback if callback and callable(callback) else None
def _onAttributeChanged(self, attr: Attribute):
"""
When an attribute value has changed, a specific function can be defined in the descriptor and be called.
Args: Args:
attr (Attribute): attribute that has changed attr: The Attribute that has changed.
""" """
# Call the specific function if it exists in the node implementation
paramName = attr.name[:1].upper() + attr.name[1:] if self.isCompatibilityNode:
methodName = f'on{paramName}Changed' # Compatibility nodes are not meant to be updated.
if hasattr(self.nodeDesc, methodName): return
m = getattr(self.nodeDesc, methodName)
if callable(m): if attr.isOutput and not self.isInputNode:
m(self) # Ignore changes on output attributes for non-input nodes
# as they are updated during the node's computation.
# And we do not want notifications during the graph processing.
return
if attr.value is None:
# Discard dynamic values depending on the graph processing.
return
callback = self._getAttributeChangedCallback(attr)
if callback:
callback(self)
if self.graph: if self.graph:
# If we are in a graph, propagate the notification to the connected output attributes # If we are in a graph, propagate the notification to the connected output attributes
outEdges = self.graph.outEdges(attr) for edge in self.graph.outEdges(attr):
for edge in outEdges: edge.dst.node._onAttributeChanged(edge.dst)
edge.dst.onChanged()
def onAttributeClicked(self, attr): def onAttributeClicked(self, attr):
""" When an attribute is clicked, a specific function can be defined in the descriptor and be called. """ When an attribute is clicked, a specific function can be defined in the descriptor and be called.

32
tests/conftest.py Normal file
View file

@ -0,0 +1,32 @@
from pathlib import Path
import tempfile
import pytest
from meshroom.core.graph import Graph
@pytest.fixture
def graphWithIsolatedCache():
"""
Yield a Graph instance using a unique temporary cache directory.
Can be used for testing graph computation in isolation, without having to save the graph to disk.
"""
with tempfile.TemporaryDirectory() as cacheDir:
graph = Graph("")
graph.cacheDir = cacheDir
yield graph
@pytest.fixture
def graphSavedOnDisk():
"""
Yield a Graph instance saved in a unique temporary folder.
Can be used for testing graph IO and computation in isolation.
"""
with tempfile.TemporaryDirectory() as cacheDir:
graph = Graph("")
graph.save(Path(cacheDir) / "test_graph.mg")
yield graph

View file

@ -0,0 +1,409 @@
# coding:utf-8
from meshroom.core.graph import Graph, loadGraph, executeGraph
from meshroom.core import desc, registerNodeType, unregisterNodeType
from meshroom.core.node import Node
class NodeWithAttributeChangedCallback(desc.Node):
"""
A Node containing an input Attribute with an 'on{Attribute}Changed' method,
called whenever the value of this attribute is changed explicitly.
"""
inputs = [
desc.IntParam(
name="input",
label="Input",
description="Attribute with a value changed callback (onInputChanged)",
value=0,
range=None,
),
desc.IntParam(
name="affectedInput",
label="Affected Input",
description="Updated to input.value * 2 whenever 'input' is explicitly modified",
value=0,
range=None,
),
]
def onInputChanged(self, instance: Node):
instance.affectedInput.value = instance.input.value * 2
def processChunk(self, chunk):
pass # No-op.
class TestNodeWithAttributeChangedCallback:
@classmethod
def setup_class(cls):
registerNodeType(NodeWithAttributeChangedCallback)
@classmethod
def teardown_class(cls):
unregisterNodeType(NodeWithAttributeChangedCallback)
def test_assignValueTriggersCallback(self):
node = Node(NodeWithAttributeChangedCallback.__name__)
assert node.affectedInput.value == 0
node.input.value = 10
assert node.affectedInput.value == 20
def test_specifyDefaultValueDoesNotTriggerCallback(self):
node = Node(NodeWithAttributeChangedCallback.__name__, input=10)
assert node.affectedInput.value == 0
def test_assignDefaultValueDoesNotTriggerCallback(self):
node = Node(NodeWithAttributeChangedCallback.__name__, input=10)
node.input.value = 10
assert node.affectedInput.value == 0
def test_assignNonDefaultValueTriggersCallback(self):
node = Node(NodeWithAttributeChangedCallback.__name__, input=10)
node.input.value = 2
assert node.affectedInput.value == 4
class TestAttributeCallbackTriggerInGraph:
@classmethod
def setup_class(cls):
registerNodeType(NodeWithAttributeChangedCallback)
@classmethod
def teardown_class(cls):
unregisterNodeType(NodeWithAttributeChangedCallback)
def test_connectionTriggersCallback(self):
graph = Graph("")
nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
assert nodeA.affectedInput.value == nodeB.affectedInput.value == 0
nodeA.input.value = 1
graph.addEdge(nodeA.input, nodeB.input)
assert nodeA.affectedInput.value == nodeB.affectedInput.value == 2
def test_connectedValueChangeTriggersCallback(self):
graph = Graph("")
nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
assert nodeA.affectedInput.value == nodeB.affectedInput.value == 0
graph.addEdge(nodeA.input, nodeB.input)
nodeA.input.value = 1
assert nodeA.affectedInput.value == 2
assert nodeB.affectedInput.value == 2
def test_defaultValueOnlyTriggersCallbackDownstream(self):
graph = Graph("")
nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__, input=1)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
assert nodeA.affectedInput.value == 0
assert nodeB.affectedInput.value == 0
graph.addEdge(nodeA.input, nodeB.input)
assert nodeA.affectedInput.value == 0
assert nodeB.affectedInput.value == 2
def test_valueChangeIsPropagatedAlongNodeChain(self):
graph = Graph("")
nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeC = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeD = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
graph.addEdges(
(nodeA.affectedInput, nodeB.input),
(nodeB.affectedInput, nodeC.input),
(nodeC.affectedInput, nodeD.input),
)
nodeA.input.value = 5
assert nodeA.affectedInput.value == nodeB.input.value == 10
assert nodeB.affectedInput.value == nodeC.input.value == 20
assert nodeC.affectedInput.value == nodeD.input.value == 40
assert nodeD.affectedInput.value == 80
def test_disconnectionTriggersCallback(self):
graph = Graph("")
nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
graph.addEdge(nodeA.input, nodeB.input)
nodeA.input.value = 5
assert nodeB.affectedInput.value == 10
graph.removeEdge(nodeB.input)
assert nodeB.input.value == 0
assert nodeB.affectedInput.value == 0
def test_loadingGraphDoesNotTriggerCallback(self, graphSavedOnDisk):
graph: Graph = graphSavedOnDisk
node = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
node.input.value = 5
node.affectedInput.value = 2
graph.save()
loadedGraph = loadGraph(graph.filepath)
loadedNode = loadedGraph.node(node.name)
assert loadedNode
assert loadedNode.affectedInput.value == 2
def test_loadingGraphDoesNotTriggerCallbackForConnectedAttributes(
self, graphSavedOnDisk
):
graph: Graph = graphSavedOnDisk
nodeA = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
graph.addEdge(nodeA.input, nodeB.input)
nodeA.input.value = 5
nodeB.affectedInput.value = 2
graph.save()
loadedGraph = loadGraph(graph.filepath)
loadedNodeB = loadedGraph.node(nodeB.name)
assert loadedNodeB
assert loadedNodeB.affectedInput.value == 2
class NodeWithCompoundAttributes(desc.Node):
"""
A Node containing a variation of compound attributes (List/Groups),
called whenever the value of this attribute is changed explicitly.
"""
inputs = [
desc.ListAttribute(
name="listInput",
label="List Input",
description="ListAttribute of IntParams.",
elementDesc=desc.IntParam(
name="int", label="Int", description="", value=0, range=None
),
),
desc.GroupAttribute(
name="groupInput",
label="Group Input",
description="GroupAttribute with a single 'IntParam' element.",
groupDesc=[
desc.IntParam(
name="int", label="Int", description="", value=0, range=None
)
],
),
desc.ListAttribute(
name="listOfGroupsInput",
label="List of Groups input",
description="ListAttribute of GroupAttribute with a single 'IntParam' element.",
elementDesc=desc.GroupAttribute(
name="subGroup",
label="SubGroup",
description="",
groupDesc=[
desc.IntParam(
name="int", label="Int", description="", value=0, range=None
)
],
)
),
desc.GroupAttribute(
name="groupWithListInput",
label="Group with List",
description="GroupAttribute with a single 'ListAttribute of IntParam' element.",
groupDesc=[
desc.ListAttribute(
name="subList",
label="SubList",
description="",
elementDesc=desc.IntParam(
name="int", label="Int", description="", value=0, range=None
)
)
]
)
]
class TestAttributeCallbackBehaviorWithUpstreamCompoundAttributes:
@classmethod
def setup_class(cls):
registerNodeType(NodeWithAttributeChangedCallback)
registerNodeType(NodeWithCompoundAttributes)
@classmethod
def teardown_class(cls):
unregisterNodeType(NodeWithAttributeChangedCallback)
unregisterNodeType(NodeWithCompoundAttributes)
def test_connectionToListElement(self):
graph = Graph("")
nodeA = graph.addNewNode(NodeWithCompoundAttributes.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeA.listInput.append(0)
attr = nodeA.listInput.at(0)
graph.addEdge(attr, nodeB.input)
attr.value = 10
assert nodeB.input.value == 10
assert nodeB.affectedInput.value == 20
def test_connectionToGroupElement(self):
graph = Graph("")
nodeA = graph.addNewNode(NodeWithCompoundAttributes.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
graph.addEdge(nodeA.groupInput.int, nodeB.input)
nodeA.groupInput.int.value = 10
assert nodeB.input.value == 10
assert nodeB.affectedInput.value == 20
def test_connectionToGroupElementInList(self):
graph = Graph("")
nodeA = graph.addNewNode(NodeWithCompoundAttributes.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeA.listOfGroupsInput.append({})
attr = nodeA.listOfGroupsInput.at(0)
graph.addEdge(attr.int, nodeB.input)
attr.int.value = 10
assert nodeB.input.value == 10
assert nodeB.affectedInput.value == 20
def test_connectionToListElementInGroup(self):
graph = Graph("")
nodeA = graph.addNewNode(NodeWithCompoundAttributes.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeA.groupWithListInput.subList.append(0)
attr = nodeA.groupWithListInput.subList.at(0)
graph.addEdge(attr, nodeB.input)
attr.value = 10
assert nodeB.input.value == 10
assert nodeB.affectedInput.value == 20
class NodeWithDynamicOutputValue(desc.Node):
"""
A Node containing an output attribute which value is computed dynamically during graph execution.
"""
inputs = [
desc.IntParam(
name="input",
label="Input",
description="Input used in the computation of 'output'",
value=0,
),
]
outputs = [
desc.IntParam(
name="output",
label="Output",
description="Dynamically computed output (input * 2)",
# Setting value to None makes the attribute dynamic.
value=None,
),
]
def processChunk(self, chunk):
chunk.node.output.value = chunk.node.input.value * 2
class TestAttributeCallbackBehaviorWithUpstreamDynamicOutputs:
@classmethod
def setup_class(cls):
registerNodeType(NodeWithAttributeChangedCallback)
registerNodeType(NodeWithDynamicOutputValue)
@classmethod
def teardown_class(cls):
unregisterNodeType(NodeWithAttributeChangedCallback)
unregisterNodeType(NodeWithDynamicOutputValue)
def test_connectingUncomputedDynamicOutputDoesNotTriggerDownstreamAttributeChangedCallback(
self,
):
graph = Graph("")
nodeA = graph.addNewNode(NodeWithDynamicOutputValue.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeA.input.value = 10
graph.addEdge(nodeA.output, nodeB.input)
assert nodeB.affectedInput.value == 0
def test_connectingComputedDynamicOutputTriggersDownstreamAttributeChangedCallback(
self, graphWithIsolatedCache
):
graph: Graph = graphWithIsolatedCache
nodeA = graph.addNewNode(NodeWithDynamicOutputValue.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeA.input.value = 10
executeGraph(graph)
graph.addEdge(nodeA.output, nodeB.input)
assert nodeA.output.value == nodeB.input.value == 20
assert nodeB.affectedInput.value == 40
def test_dynamicOutputValueComputeDoesNotTriggerDownstreamAttributeChangedCallback(
self, graphWithIsolatedCache
):
graph: Graph = graphWithIsolatedCache
nodeA = graph.addNewNode(NodeWithDynamicOutputValue.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
graph.addEdge(nodeA.output, nodeB.input)
nodeA.input.value = 10
executeGraph(graph)
assert nodeB.input.value == 20
assert nodeB.affectedInput.value == 0
def test_clearingDynamicOutputValueDoesNotTriggerDownstreamAttributeChangedCallback(
self, graphWithIsolatedCache
):
graph: Graph = graphWithIsolatedCache
nodeA = graph.addNewNode(NodeWithDynamicOutputValue.__name__)
nodeB = graph.addNewNode(NodeWithAttributeChangedCallback.__name__)
nodeA.input.value = 10
executeGraph(graph)
graph.addEdge(nodeA.output, nodeB.input)
expectedPreClearValue = nodeA.input.value * 2 * 2
assert nodeB.affectedInput.value == expectedPreClearValue
nodeA.clearData()
assert nodeA.output.value == nodeB.input.value is None
assert nodeB.affectedInput.value == expectedPreClearValue