@@ -919,16 +919,19 @@ class Trainer(TrainerEventTrigger): | |||||
_not_called_callback_fns.append(each_callback_fn) | _not_called_callback_fns.append(each_callback_fn) | ||||
if check_mode: | if check_mode: | ||||
logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these " | |||||
if len(_not_called_callback_fns) != 0: | |||||
logger.rank_zero_warning("You have customized your 'batch_step_fn' in the 'train_batch_loop' and also use these " | |||||
f"callback_fns: {_not_called_callback_fns}, but it seems that" | f"callback_fns: {_not_called_callback_fns}, but it seems that" | ||||
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.") | |||||
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.", | |||||
once=True) | |||||
# 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass | # 对于 'batch_step_fn' 来讲,其只需要在第一次的 step 后进行检测即可,因此在第一次检测后将 check_batch_step_fn 置为 pass | ||||
# 函数; | # 函数; | ||||
self.check_batch_step_fn = lambda *args, **kwargs: ... | self.check_batch_step_fn = lambda *args, **kwargs: ... | ||||
else: | |||||
logger.warning("You have customized your 'TrainBatchLoop' and also use these callback_fns: " | |||||
elif len(_not_called_callback_fns)!=0: | |||||
logger.rank_zero_warning("You have customized your 'TrainBatchLoop' and also use these callback_fns: " | |||||
f"{_not_called_callback_fns}, but it seems that" | f"{_not_called_callback_fns}, but it seems that" | ||||
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.") | |||||
"you don't call the corresponding callback hook explicitly in your 'batch_step_fn'.", | |||||
once=True) | |||||
def _check_train_batch_loop_legality(self): | def _check_train_batch_loop_legality(self): | ||||
r""" | r""" | ||||
@@ -405,7 +405,7 @@ class DataSet: | |||||
if isinstance(item, str) and item in self.field_arrays: | if isinstance(item, str) and item in self.field_arrays: | ||||
return self.field_arrays[item] | return self.field_arrays[item] | ||||
else: | else: | ||||
raise AttributeError | |||||
raise AttributeError(f"Dataset has no attribute named:{item}.") | |||||
def __setstate__(self, state): | def __setstate__(self, state): | ||||
self.__dict__ = state | self.__dict__ = state | ||||
@@ -136,15 +136,13 @@ class Element: | |||||
if self.value is None: | if self.value is None: | ||||
prefix = f'Element:`{self.name}`' | prefix = f'Element:`{self.name}`' | ||||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | ||||
"element, or use it after it being used by the `Metric.compute()` method.") | |||||
"element, or use it after it being used by the `Metric.update()` method.") | |||||
def __add__(self, other): | def __add__(self, other): | ||||
self._check_value_when_call() | self._check_value_when_call() | ||||
if isinstance(other, Element): | if isinstance(other, Element): | ||||
self.value += other.value | |||||
else: | |||||
self.value += other | |||||
return self | |||||
other = other.value | |||||
return self.value + other | |||||
def __radd__(self, other): | def __radd__(self, other): | ||||
self._check_value_when_call() | self._check_value_when_call() | ||||
@@ -314,7 +312,7 @@ class Element: | |||||
if self._value is None: | if self._value is None: | ||||
prefix = f'Element:`{self.name}`' | prefix = f'Element:`{self.name}`' | ||||
raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " | ||||
"element, or use it after it being used by the `Metric.compute()` method.") | |||||
"element, or use it after it being used by the `Metric.update()` method.") | |||||
return getattr(self._value, item) | return getattr(self._value, item) | ||||
except AttributeError as e: | except AttributeError as e: | ||||
logger.error(f"Element:{self.name} has no `{item}` attribute.") | logger.error(f"Element:{self.name} has no `{item}` attribute.") | ||||
@@ -99,8 +99,14 @@ class Metric: | |||||
def __setattr__(self, key, value): | def __setattr__(self, key, value): | ||||
if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: | if hasattr(self, '_cannot_change_element') and self._cannot_change_element is True: | ||||
if key in self.elements and value is not self.elements[key]: | |||||
raise RuntimeError(f"self.`{key}` is an element, cannot assign to a new value:{value}") | |||||
if key in self.elements and isinstance(value, (float, int, bool)): | |||||
self.elements[key].fill_value(value) | |||||
return | |||||
elif key in self.elements: | |||||
raise TypeError(f"self.{key} is an Element, only float/int/bool type value can be assigned to it, " | |||||
f"instead of {type(value)}.") | |||||
if isinstance(value, Element) and key not in self.elements: | |||||
raise RuntimeError("Please use register_element() function to add Element.") | |||||
object.__setattr__(self, key, value) | object.__setattr__(self, key, value) | ||||
def _wrap_update(self, update): | def _wrap_update(self, update): | ||||
@@ -163,13 +169,6 @@ class Metric: | |||||
""" | """ | ||||
self.aggregate_when_get_metric = flag | self.aggregate_when_get_metric = flag | ||||
def __getattr__(self, name: str) -> Element: | |||||
if 'elements' in self.__dict__: | |||||
elements = self.__dict__['elements'] | |||||
if name in elements: | |||||
return elements[name] | |||||
raise AttributeError("`{}` object has no attribute `{}`".format(type(self).__name__, name)) | |||||
def tensor2numpy(self, tensor) -> np.array: | def tensor2numpy(self, tensor) -> np.array: | ||||
""" | """ | ||||
将tensor向量转为numpy类型变量 | 将tensor向量转为numpy类型变量 | ||||