| @@ -919,16 +919,19 @@ class Trainer(TrainerEventTrigger): | |||||
| _not_called_callback_fns.append(each_callback_fn) | _not_called_callback_fns.append(each_callback_fn) | ||||
| if check_mode: | if check_mode: | ||||
| logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these " | |||||
| if len(_not_called_callback_fns) != 0: | |||||
| logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these " | |||||
| f"callback_fns: {_not_called_callback_fns}, but it seems that" | f"callback_fns: {_not_called_callback_fns}, but it seems that" | ||||
| "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.") | |||||
| "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.", | |||||
| once=True) | |||||
| # 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass | # 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass | ||||
| # 函数; | # 函数; | ||||
| self.check_batch_step_fn = lambda *args, **kwargs: ... | self.check_batch_step_fn = lambda *args, **kwargs: ... | ||||
| else: | |||||
| logger.warning("You have customized your 'TrainBatchLoop' and also use these callback_fns: " | |||||
| elif len(_not_called_callback_fns)!=0: | |||||
| logger.rank_zero_warning("You have customized your 'TrainBatchLoop' and also use these callback_fns: " | |||||
| f"{_not_called_callback_fns}, but it seems that" | f"{_not_called_callback_fns}, but it seems that" | ||||
| "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.") | |||||
| "you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.", | |||||
| once=True) | |||||
| def _check_train_batch_loop_legality(self): | def _check_train_batch_loop_legality(self): | ||||
| r""" | r""" | ||||
| @@ -405,7 +405,7 @@ class DataSet: | |||||
| if isinstance(item, str) and item in self.field_arrays: | if isinstance(item, str) and item in self.field_arrays: | ||||
| return self.field_arrays[item] | return self.field_arrays[item] | ||||
| else: | else: | ||||
| raise AttributeError | |||||
| raise AttributeError(f"Dataset has no attribute named:{item}.") | |||||
| def __setstate__(self, state): | def __setstate__(self, state): | ||||
| self.__dict__ = state | self.__dict__ = state | ||||
| @@ -136,15 +136,13 @@ class Element: | |||||
| if self.value is None: | if self.value is None: | ||||
| prefix = f'Element:`{self.name}`' | prefix = f'Element:`{self.name}`' | ||||
| raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | ||||
| "element, or use it after it being used by the `Metric.compute()` method.") | |||||
| "element, or use it after it being used by the `Metric.update()` method.") | |||||
| def __add__(self, other): | def __add__(self, other): | ||||
| self._check_value_when_call() | self._check_value_when_call() | ||||
| if isinstance(other, Element): | if isinstance(other, Element): | ||||
| self.value += other.value | |||||
| else: | |||||
| self.value += other | |||||
| return self | |||||
| other = other.value | |||||
| return self.value + other | |||||
| def __radd__(self, other): | def __radd__(self, other): | ||||
| self._check_value_when_call() | self._check_value_when_call() | ||||
| @@ -314,7 +312,7 @@ class Element: | |||||
| if self._value is None: | if self._value is None: | ||||
| prefix = f'Element:`{self.name}`' | prefix = f'Element:`{self.name}`' | ||||
| raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | ||||
| "element, or use it after it being used by the `Metric.compute()` method.") | |||||
| "element, or use it after it being used by the `Metric.update()` method.") | |||||
| return getattr(self._value, item) | return getattr(self._value, item) | ||||
| except AttributeError as e: | except AttributeError as e: | ||||
| logger.error(f"Element:{self.name} has no `{item}` attribute.") | logger.error(f"Element:{self.name} has no `{item}` attribute.") | ||||
| @@ -99,8 +99,14 @@ class Metric: | |||||
| def __setattr__(self, key, value): | def __setattr__(self, key, value): | ||||
| if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: | if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: | ||||
| if key in self.elements and value is not self.elements[key]: | |||||
| raise RuntimeError(f"self.`{key}` is an element, cannot assign to a new value:{value}") | |||||
| if key in self.elements and isinstance(value, (float, int, bool)): | |||||
| self.elements[key].fill_value(value) | |||||
| return | |||||
| elif key in self.elements: | |||||
| raise TypeError(f"self.{key} is an Element, only float/int/bool type value can be assigned to it, " | |||||
| f"instead of {type(value)}.") | |||||
| if isinstance(value, Element) and key not in self.elements: | |||||
| raise RuntimeError("Please use register_element() function to add Element.") | |||||
| object.__setattr__(self, key, value) | object.__setattr__(self, key, value) | ||||
| def _wrap_update(self, update): | def _wrap_update(self, update): | ||||
| @@ -163,13 +169,6 @@ class Metric: | |||||
| """ | """ | ||||
| self.aggregate_when_get_metric = flag | self.aggregate_when_get_metric = flag | ||||
| def __getattr__(self, name: str) -> Element: | |||||
| if 'elements' in self.__dict__: | |||||
| elements = self.__dict__['elements'] | |||||
| if name in elements: | |||||
| return elements[name] | |||||
| raise AttributeError("`{}` object has no attribute `{}`".format(type(self).__name__, name)) | |||||
| def tensor2numpy(self, tensor) -> np.array: | def tensor2numpy(self, tensor) -> np.array: | ||||
| """ | """ | ||||
| 将tensor向量转为numpy类型变量 | 将tensor向量转为numpy类型变量 | ||||