From 064d76de3de441ed74208f8b15f1779891a41b0b Mon Sep 17 00:00:00 2001 From: Guillaume Buisson Date: Fri, 15 Nov 2019 16:36:32 +0100 Subject: [PATCH] [attribute] refactor attribute/desc to customize command line arguments formatting --- meshroom/core/attribute.py | 19 +-------- meshroom/core/desc.py | 79 ++++++++++++++++++++++++++++---------- meshroom/core/node.py | 30 +++++++-------- 3 files changed, 74 insertions(+), 54 deletions(-) diff --git a/meshroom/core/attribute.py b/meshroom/core/attribute.py index 1bc105a2..496201c0 100644 --- a/meshroom/core/attribute.py +++ b/meshroom/core/attribute.py @@ -191,13 +191,8 @@ class Attribute(BaseObject): return self.desc.value return self._value - def getValueStr(self): - if isinstance(self.attributeDesc, desc.ChoiceParam) and not self.attributeDesc.exclusive: - assert(isinstance(self.value, collections.Sequence) and not isinstance(self.value, pyCompatibility.basestring)) - return self.attributeDesc.joinChar.join(self.value) - if isinstance(self.attributeDesc, (desc.StringParam, desc.File)): - return '"{}"'.format(self.value) - return str(self.value) + def format(self): + return self.desc.format(self.value) def defaultValue(self): return self.desc.value @@ -332,11 +327,6 @@ class ListAttribute(Attribute): else: return [attr.getPrimitiveValue(exportDefault=exportDefault) for attr in self._value if not attr.isDefault] - def getValueStr(self): - if isinstance(self.value, ListModel): - return self.attributeDesc.joinChar.join([v.getValueStr() for v in self.value]) - return super(ListAttribute, self).getValueStr() - # Override value property setter value = Property(Variant, Attribute._get_value, _set_value, notify=Attribute.valueChanged) isDefault = Property(bool, _isDefault, notify=Attribute.valueChanged) @@ -413,11 +403,6 @@ class GroupAttribute(Attribute): else: return {name: attr.getPrimitiveValue(exportDefault=exportDefault) for name, attr in self._value.items() if not attr.isDefault} - def getValueStr(self): - # sort values based on child attributes group description order - sortedSubValues = [self._value.get(attr.name).getValueStr() for attr in self.attributeDesc.groupDesc] - return self.attributeDesc.joinChar.join(sortedSubValues) - # Override value property value = Property(Variant, Attribute._get_value, _set_value, notify=Attribute.valueChanged) isDefault = Property(bool, _isDefault, notify=Attribute.valueChanged) diff --git a/meshroom/core/desc.py b/meshroom/core/desc.py index a9a96a10..9f44130a 100755 --- a/meshroom/core/desc.py +++ b/meshroom/core/desc.py @@ -1,4 +1,4 @@ -from meshroom.common import BaseObject, Property, Variant, VariantList +from meshroom.common import BaseObject, Property, Variant, VariantList, ListModel from meshroom.core import pyCompatibility from enum import Enum # available by default in python3. For python2: "pip install enum34" import collections @@ -11,7 +11,7 @@ class Attribute(BaseObject): """ """ - def __init__(self, name, label, description, value, advanced, uid, group): + def __init__(self, name, label, description, value, advanced, uid, group, formatter): super(Attribute, self).__init__() self._name = name self._label = label @@ -20,6 +20,7 @@ class Attribute(BaseObject): self._uid = uid self._group = group self._advanced = advanced + self._formatter = formatter or self._defaultFormatter name = Property(str, lambda self: self._name, constant=True) label = Property(str, lambda self: self._label, constant=True) @@ -45,17 +46,33 @@ class Attribute(BaseObject): except ValueError: return False return True + + def format(self, value): + """ Returns a list of (group, name, value) parameters """ + return self._formatter(self, value) + + @staticmethod + def _defaultFormatter(desc, value): + result_value = value + if isinstance(desc, ChoiceParam) and not desc.exclusive: + assert(isinstance(value, collections.Sequence) and not isinstance(value, pyCompatibility.basestring)) + result_value = desc.joinChar.join(value) + elif isinstance(desc, (StringParam, File)): + result_value = '"{}"'.format(value) + else: + result_value = str(value) + return ((desc.group, desc.name, result_value),) class ListAttribute(Attribute): """ A list of Attributes """ - def __init__(self, elementDesc, name, label, description, group='allParams', advanced=False, joinChar=' '): + def __init__(self, elementDesc, name, label, description, group='allParams', advanced=False, joinChar=' ', formatter=None): """ :param elementDesc: the Attribute description of elements to store in that list """ self._elementDesc = elementDesc self._joinChar = joinChar - super(ListAttribute, self).__init__(name=name, label=label, description=description, value=[], uid=(), group=group, advanced=advanced) + super(ListAttribute, self).__init__(name=name, label=label, description=description, value=[], uid=(), group=group, advanced=advanced, formatter=formatter) elementDesc = Property(Attribute, lambda self: self._elementDesc, constant=True) uid = Property(Variant, lambda self: self.elementDesc.uid, constant=True) @@ -75,16 +92,23 @@ class ListAttribute(Attribute): return self._elementDesc.matchDescription(value[0]) return True + @staticmethod + def _defaultFormatter(desc, value): + result_value = value + if isinstance(value, ListModel): + result_value = desc.joinChar.join([subv for v in value for _, _, subv in v.format()]) + return Attribute._defaultFormatter(desc, result_value) + class GroupAttribute(Attribute): """ A macro Attribute composed of several Attributes """ - def __init__(self, groupDesc, name, label, description, group='allParams', advanced=False, joinChar=' '): + def __init__(self, groupDesc, name, label, description, group='allParams', advanced=False, joinChar=' ', formatter=None): """ :param groupDesc: the description of the Attributes composing this group """ self._groupDesc = groupDesc self._joinChar = joinChar - super(GroupAttribute, self).__init__(name=name, label=label, description=description, value={}, uid=(), group=group, advanced=advanced) + super(GroupAttribute, self).__init__(name=name, label=label, description=description, value={}, uid=(), group=group, advanced=advanced, formatter=formatter) groupDesc = Property(Variant, lambda self: self._groupDesc, constant=True) @@ -120,6 +144,21 @@ class GroupAttribute(Attribute): allUids.extend(desc.uid) return allUids + @staticmethod + def _defaultFormatter(desc, value): + # sort values based on child attributes group description order + sortedSubValues = [subv for attr in desc.groupDesc for _, _, subv in value.get(attr.name).format()] + result_value = desc.joinChar.join(sortedSubValues) + return Attribute._defaultFormatter(desc, result_value) + + @staticmethod + def prefixFormatter(desc, value): + return [(group, desc.joinChar.join((desc.name, name)), v) for attr in desc.groupDesc for group, name, v in value.get(attr.name).format()] + + @staticmethod + def passthroughFormatter(desc, value): + return [item for attr in desc.groupDesc for item in value.get(attr.name).format()] + uid = Property(Variant, retrieveChildrenUids, constant=True) joinChar = Property(str, lambda self: self._joinChar, constant=True) @@ -127,15 +166,15 @@ class GroupAttribute(Attribute): class Param(Attribute): """ """ - def __init__(self, name, label, description, value, uid, group, advanced): - super(Param, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced) + def __init__(self, name, label, description, value, uid, group, advanced, formatter=None): + super(Param, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced, formatter=formatter) class File(Attribute): """ """ - def __init__(self, name, label, description, value, uid, group='allParams', advanced=False): - super(File, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced) + def __init__(self, name, label, description, value, uid, group='allParams', advanced=False, formatter=None): + super(File, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced, formatter=formatter) def validateValue(self, value): if not isinstance(value, pyCompatibility.basestring): @@ -146,8 +185,8 @@ class File(Attribute): class BoolParam(Param): """ """ - def __init__(self, name, label, description, value, uid, group='allParams', advanced=False): - super(BoolParam, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced) + def __init__(self, name, label, description, value, uid, group='allParams', advanced=False, formatter=None): + super(BoolParam, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced, formatter=formatter) def validateValue(self, value): try: @@ -159,9 +198,9 @@ class BoolParam(Param): class IntParam(Param): """ """ - def __init__(self, name, label, description, value, range, uid, group='allParams', advanced=False): + def __init__(self, name, label, description, value, range, uid, group='allParams', advanced=False, formatter=None): self._range = range - super(IntParam, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced) + super(IntParam, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced, formatter=formatter) def validateValue(self, value): # handle unsigned int values that are translated to int by shiboken and may overflow @@ -178,9 +217,9 @@ class IntParam(Param): class FloatParam(Param): """ """ - def __init__(self, name, label, description, value, range, uid, group='allParams', advanced=False): + def __init__(self, name, label, description, value, range, uid, group='allParams', advanced=False, formatter=None): self._range = range - super(FloatParam, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced) + super(FloatParam, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced, formatter=formatter) def validateValue(self, value): try: @@ -194,13 +233,13 @@ class FloatParam(Param): class ChoiceParam(Param): """ """ - def __init__(self, name, label, description, value, values, exclusive, uid, group='allParams', joinChar=' ', advanced=False): + def __init__(self, name, label, description, value, values, exclusive, uid, group='allParams', joinChar=' ', advanced=False, formatter=None): assert values self._values = values self._exclusive = exclusive self._joinChar = joinChar self._valueType = type(self._values[0]) # cast to value type - super(ChoiceParam, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced) + super(ChoiceParam, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced, formatter=formatter) def conformValue(self, val): """ Conform 'val' to the correct type and check for its validity """ @@ -225,8 +264,8 @@ class ChoiceParam(Param): class StringParam(Param): """ """ - def __init__(self, name, label, description, value, uid, group='allParams', advanced=False): - super(StringParam, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced) + def __init__(self, name, label, description, value, uid, group='allParams', advanced=False, formatter=None): + super(StringParam, self).__init__(name=name, label=label, description=description, value=value, uid=uid, group=group, advanced=advanced, formatter=formatter) def validateValue(self, value): if not isinstance(value, pyCompatibility.basestring): diff --git a/meshroom/core/node.py b/meshroom/core/node.py index f2e98a2d..2a7b0f91 100644 --- a/meshroom/core/node.py +++ b/meshroom/core/node.py @@ -453,18 +453,19 @@ class BaseNode(BaseObject): for uidIndex, value in self._uids.items(): self._cmdVars['uid{}'.format(uidIndex)] = value + def populate(cmdVars, group, name, value): + cmdVars[name] = '--{name} {value}'.format(name=name, value=value) + cmdVars[name + 'Value'] = str(v) + if v: + cmdVars[group] = cmdVars.get(group, '') + ' ' + cmdVars[name] + # Evaluate input params - for name, attr in self._attributes.objects.items(): + for _, attr in self._attributes.objects.items(): if attr.isOutput: continue # skip outputs - v = attr.getValueStr() - - self._cmdVars[name] = '--{name} {value}'.format(name=name, value=v) - self._cmdVars[name + 'Value'] = str(v) - - if v: - self._cmdVars[attr.attributeDesc.group] = self._cmdVars.get(attr.attributeDesc.group, '') + \ - ' ' + self._cmdVars[name] + group_name_values = attr.format() + for group, name, v in group_name_values: + populate(self._cmdVars, group, name, v) # For updating output attributes invalidation values cmdVarsNoCache = self._cmdVars.copy() @@ -476,14 +477,9 @@ class BaseNode(BaseObject): continue # skip inputs attr.value = attr.attributeDesc.value.format(**self._cmdVars) attr._invalidationValue = attr.attributeDesc.value.format(**cmdVarsNoCache) - v = attr.getValueStr() - - self._cmdVars[name] = '--{name} {value}'.format(name=name, value=v) - self._cmdVars[name + 'Value'] = str(v) - - if v: - self._cmdVars[attr.attributeDesc.group] = self._cmdVars.get(attr.attributeDesc.group, '') + \ - ' ' + self._cmdVars[name] + group_name_values = attr.format() + for group, name, v in group_name_values: + populate(self._cmdVars, group, name, v) @property def isParallelized(self):