Browse Source

修改element无法与element运算的问题

tags/v1.0.0alpha
yhcc 3 years ago
parent
commit
fecd82aadf
4 changed files with 21 additions and 21 deletions
  1. +8
    -5
      fastNLP/core/controllers/trainer.py
  2. +1
    -1
      fastNLP/core/dataset/dataset.py
  3. +4
    -6
      fastNLP/core/metrics/element.py
  4. +8
    -9
      fastNLP/core/metrics/metric.py

+ 8
- 5
fastNLP/core/controllers/trainer.py View File

@@ -919,16 +919,19 @@ class Trainer(TrainerEventTrigger):
_not_called_callback_fns.append(each_callback_fn) _not_called_callback_fns.append(each_callback_fn)


if check_mode: if check_mode:
logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these "
if len(_not_called_callback_fns) != 0:
logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these "
f"callback_fns: {_not_called_callback_fns}, but it seems that" f"callback_fns: {_not_called_callback_fns}, but it seems that"
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.")
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.",
once=True)
# 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass # 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass
# 函数; # 函数;
self.check_batch_step_fn = lambda *args, **kwargs: ... self.check_batch_step_fn = lambda *args, **kwargs: ...
else:
logger.warning("You have customized your 'TrainBatchLoop' and also use these callback_fns: "
elif len(_not_called_callback_fns)!=0:
logger.rank_zero_warning("You have customized your 'TrainBatchLoop' and also use these callback_fns: "
f"{_not_called_callback_fns}, but it seems that" f"{_not_called_callback_fns}, but it seems that"
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.")
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.",
once=True)


def _check_train_batch_loop_legality(self): def _check_train_batch_loop_legality(self):
r""" r"""


+ 1
- 1
fastNLP/core/dataset/dataset.py View File

@@ -405,7 +405,7 @@ class DataSet:
if isinstance(item, str) and item in self.field_arrays: if isinstance(item, str) and item in self.field_arrays:
return self.field_arrays[item] return self.field_arrays[item]
else: else:
raise AttributeError
raise AttributeError(f"Dataset has no attribute named:{item}.")


def __setstate__(self, state): def __setstate__(self, state):
self.__dict__ = state self.__dict__ = state


+ 4
- 6
fastNLP/core/metrics/element.py View File

@@ -136,15 +136,13 @@ class Element:
if self.value is None: if self.value is None:
prefix = f'Element:`{self.name}`' prefix = f'Element:`{self.name}`'
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
"element, or use it after it being used by the `Metric.compute()` method.")
"element, or use it after it being used by the `Metric.update()` method.")


def __add__(self, other): def __add__(self, other):
self._check_value_when_call() self._check_value_when_call()
if isinstance(other, Element): if isinstance(other, Element):
self.value += other.value
else:
self.value += other
return self
other = other.value
return self.value + other


def __radd__(self, other): def __radd__(self, other):
self._check_value_when_call() self._check_value_when_call()
@@ -314,7 +312,7 @@ class Element:
if self._value is None: if self._value is None:
prefix = f'Element:`{self.name}`' prefix = f'Element:`{self.name}`'
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this "
"element, or use it after it being used by the `Metric.compute()` method.")
"element, or use it after it being used by the `Metric.update()` method.")
return getattr(self._value, item) return getattr(self._value, item)
except AttributeError as e: except AttributeError as e:
logger.error(f"Element:{self.name} has no `{item}` attribute.") logger.error(f"Element:{self.name} has no `{item}` attribute.")


+ 8
- 9
fastNLP/core/metrics/metric.py View File

@@ -99,8 +99,14 @@ class Metric:


def __setattr__(self, key, value): def __setattr__(self, key, value):
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True:
if key in self.elements and value is not self.elements[key]:
raise RuntimeError(f"self.`{key}` is an element, cannot assign to a new value:{value}")
if key in self.elements and isinstance(value, (float, int, bool)):
self.elements[key].fill_value(value)
return
elif key in self.elements:
raise TypeError(f"self.{key} is an Element, only float/int/bool type value can be assigned to it, "
f"instead of {type(value)}.")
if isinstance(value, Element) and key not in self.elements:
raise RuntimeError("Please use register_element() function to add Element.")
object.__setattr__(self, key, value) object.__setattr__(self, key, value)


def _wrap_update(self, update): def _wrap_update(self, update):
@@ -163,13 +169,6 @@ class Metric:
""" """
self.aggregate_when_get_metric = flag self.aggregate_when_get_metric = flag


def __getattr__(self, name: str) -> Element:
if 'elements' in self.__dict__:
elements = self.__dict__['elements']
if name in elements:
return elements[name]
raise AttributeError("`{}` object has no attribute `{}`".format(type(self).__name__, name))

def tensor2numpy(self, tensor) -> np.array: def tensor2numpy(self, tensor) -> np.array:
""" """
将tensor向量转为numpy类型变量 将tensor向量转为numpy类型变量


Loading…
Cancel
Save