@@ -32,100 +32,205 @@ class Callback: | |||||
def on_sanity_check_end(self, trainer, sanity_check_res): | def on_sanity_check_end(self, trainer, sanity_check_res): | ||||
r""" | r""" | ||||
在 '预跑'检测 开始后会被触发; | 在 '预跑'检测 开始后会被触发; | ||||
:param trainer: | |||||
:param sanity_check_res: 预跑的 evaluate 结果 | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_begin(self, trainer): | def on_train_begin(self, trainer): | ||||
r""" | r""" | ||||
在训练开始前会被触发; | 在训练开始前会被触发; | ||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_end(self, trainer): | def on_train_end(self, trainer): | ||||
r""" | r""" | ||||
在训练完成后会被触发; | 在训练完成后会被触发; | ||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_epoch_begin(self, trainer): | def on_train_epoch_begin(self, trainer): | ||||
r""" | r""" | ||||
在训练过程中的每一个 epoch 开始前会被触发; | 在训练过程中的每一个 epoch 开始前会被触发; | ||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_epoch_end(self, trainer): | def on_train_epoch_end(self, trainer): | ||||
r""" | r""" | ||||
在训练过程中的每一个 epoch 完成后会被触发; | |||||
在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 | |||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_fetch_data_begin(self, trainer): | def on_fetch_data_begin(self, trainer): | ||||
r""" | r""" | ||||
在训练过程中拿到当前的具体的一个 batch 前会被触发; | |||||
在训练过程中准备取出下一个 batch 的数据时触发 | |||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_fetch_data_end(self, trainer): | def on_fetch_data_end(self, trainer): | ||||
r""" | r""" | ||||
在训练过程中拿到当前的具体的一个 batch 后会被触发; | |||||
在训练过程中拿到当前的 batch 数据后会被触发; | |||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_batch_begin(self, trainer, batch, indices): | def on_train_batch_begin(self, trainer, batch, indices): | ||||
r""" | r""" | ||||
在训练过程中开始具体的一个 batch 前会被触发; | |||||
在取得数据,执行完 input_mapping (如果 Trainer 传有该参数),并且移动 batch 中的 tensor 到了指定设备。 | |||||
其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。 | |||||
如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。 | |||||
:param trainer: `fastNLP.Trainer` | :param trainer: `fastNLP.Trainer` | ||||
:param batch: 当前正在运行的一个 batch; | |||||
:param indices: 当前的 batch 在一个 epoch 中的位置,用于用户方便地通过该 callback 函数定位具体的数据; | |||||
:param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。 | |||||
:param list[int] indices: 当前的 batch 是 dataset 中的哪些数据 | |||||
""" | """ | ||||
pass | pass | ||||
def on_train_batch_end(self, trainer): | def on_train_batch_end(self, trainer): | ||||
""" | |||||
完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与 | |||||
global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 | |||||
执行。 | |||||
:param trainer: | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_exception(self, trainer, exception): | def on_exception(self, trainer, exception): | ||||
""" | |||||
在训练过程遇到异常时调用。 | |||||
:param trainer: | |||||
:param exception: 遭遇的异常。 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_save_model(self, trainer): | def on_save_model(self, trainer): | ||||
""" | |||||
当将要保存模型时调用,此刻模型还未保存。 | |||||
:param trainer: | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_load_model(self, trainer): | def on_load_model(self, trainer): | ||||
""" | |||||
当将要加载模型时调用,此刻模型还未加载。 | |||||
:param trainer: | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_save_checkpoint(self, trainer) -> Dict: | def on_save_checkpoint(self, trainer) -> Dict: | ||||
""" | """ | ||||
当确定前后两个 callback 是一样的(callback_name 相同,意味着它们所起的职能相同)时,它们在该函数中则应当保存使该 callback 正常 | |||||
工作的状态;而不应该让该函数去判断两个 callback 是否一样; | |||||
当 Trainer 将要保存 checkpoint 的时候触发,该函数用于保存当前 callback 在恢复需要的相关数据。 | |||||
:param trainer: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | def on_load_checkpoint(self, trainer, states: Optional[Dict]): | ||||
r""" | r""" | ||||
如果一个 callback 在断点重训前没有保存状态,或者其 `callback_name` 与其余的 callback 重名时,`states` 为 None; | |||||
当 Trainer 要恢复 checkpoint 的时候触发( Trainer 与 Driver 已经加载好自身的状态),参数 states 为 on_save_checkpoint() | |||||
的返回值。 | |||||
:param trainer: | |||||
:param states: | |||||
:return: | |||||
""" | """ | ||||
pass | pass | ||||
def on_before_backward(self, trainer, outputs): | def on_before_backward(self, trainer, outputs): | ||||
""" | |||||
在 backward 前执行。 | |||||
:param trainer: | |||||
:param outputs: model 的返回内容。如果有 output_mapping ,则 outputs 中的内容为已经执行了 output_mapping 后的结果。 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_after_backward(self, trainer): | def on_after_backward(self, trainer): | ||||
""" | |||||
在 backward 后执行。在多卡场景下,由于 accumulation_steps 的影响,仅在需要真正 update 参数那次梯度回传才会触发梯度同步, | |||||
因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。 | |||||
:param trainer: | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_before_optimizer_step(self, trainer, optimizers): | def on_before_optimizer_step(self, trainer, optimizers): | ||||
""" | |||||
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_before_zero_grad(self, trainer, optimizers): | def on_before_zero_grad(self, trainer, optimizers): | ||||
""" | |||||
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||||
:param trainer: | |||||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_validate_begin(self, trainer): | def on_validate_begin(self, trainer): | ||||
""" | |||||
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | |||||
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | |||||
:param trainer: | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
def on_validate_end(self, trainer, results): | def on_validate_end(self, trainer, results): | ||||
""" | |||||
结束 validate 时调用,并把 validate 的结果传入。 | |||||
:param trainer: | |||||
:param results: | |||||
:return: | |||||
""" | |||||
pass | pass | ||||
@property | @property | ||||
def callback_name(self): | def callback_name(self): | ||||
""" | |||||
callback 的名称,我们会使用该名称从 checkpoint 中读取的相应的 state 并传递给 on_load_checkpoint() 函数。 | |||||
:return: | |||||
""" | |||||
return self.__class__.__name__ | return self.__class__.__name__ | ||||
@@ -226,10 +331,21 @@ class HasMonitorCallback(Callback): | |||||
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | ||||
:return: | :return: | ||||
""" | """ | ||||
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) | |||||
if keep_if_better and better: | |||||
self.monitor_value = monitor_value | |||||
return better | |||||
def is_former_monitor_value_better(self, monitor_value1, monitor_value2): | |||||
""" | |||||
传入的两个值中,是否monitor_value1的结果更好。 | |||||
:param monitor_value1: | |||||
:param monitor_value2: | |||||
:return: | |||||
""" | |||||
better = False | better = False | ||||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||||
(not self.larger_better and monitor_value < self.monitor_value): | |||||
if (self.larger_better and monitor_value1 > monitor_value2) or \ | |||||
(not self.larger_better and monitor_value1 < monitor_value2): | |||||
better = True | better = True | ||||
if keep_if_better: | |||||
self.monitor_value = monitor_value | |||||
return better | return better |
@@ -15,7 +15,6 @@ from fastNLP.core.callbacks.utils import _get_monitor_value | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | from fastNLP.envs import FASTNLP_LAUNCH_TIME | ||||
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | ||||
from fastNLP.core.utils import apply_to_collection | |||||
class CheckpointCallback(HasMonitorCallback): | class CheckpointCallback(HasMonitorCallback): | ||||
@@ -178,8 +177,7 @@ class CheckpointCallback(HasMonitorCallback): | |||||
else: | else: | ||||
_least_valuable_model = (min if self.larger_better else max)(self._topk_model, | _least_valuable_model = (min if self.larger_better else max)(self._topk_model, | ||||
key=lambda x: self._topk_model[x]) | key=lambda x: self._topk_model[x]) | ||||
if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \ | |||||
(self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]): | |||||
if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]): | |||||
self._topk_model[folder_name] = monitor_value | self._topk_model[folder_name] = monitor_value | ||||
_should_save = True | _should_save = True | ||||
self._topk_model.pop(_least_valuable_model) | self._topk_model.pop(_least_valuable_model) | ||||
@@ -208,21 +206,6 @@ class CheckpointCallback(HasMonitorCallback): | |||||
**self.kwargs | **self.kwargs | ||||
) | ) | ||||
def _get_validate_metric(self, res: Dict): | |||||
""" | |||||
该函数用于从 `Evaluator` 的结果中找到属于当前 CheckpointCallback 的 metric result(根据 monitor); | |||||
如果用户输入在 res 中没有找到,我们会查询所有的 validate 结果字典的键值,根据 最长公共字符串 匹配,使用最长匹配的结果值; | |||||
:param res: | |||||
:return: | |||||
""" | |||||
use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) | |||||
if self._real_monitor != use_monitor: | |||||
logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), " | |||||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||||
self._real_monitor = use_monitor | |||||
return value | |||||
@property | @property | ||||
def folder_prefix(self): | def folder_prefix(self): | ||||
raise NotImplementedError("The `folder_prefix` is not specified") | raise NotImplementedError("The `folder_prefix` is not specified") | ||||
@@ -197,7 +197,7 @@ class _MultiCollator: | |||||
collator.set_input(*field_names) | collator.set_input(*field_names) | ||||
flag = False | flag = False | ||||
if flag: | if flag: | ||||
warnings.warn("AutoCollator is remove, set_input is unavailable!!") | |||||
warnings.warn("AutoCollator is removed, set_input is unavailable!!") | |||||
return self | return self | ||||
@@ -223,7 +223,6 @@ class Evaluator: | |||||
def remove_progress_bar(self, dataloader_name): | def remove_progress_bar(self, dataloader_name): | ||||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | ||||
f_rich_progress.destroy_task(self._rich_task_id) | f_rich_progress.destroy_task(self._rich_task_id) | ||||
f_rich_progress.refresh() # 使得最终的bar可以消失 | |||||
delattr(self, '_rich_task_id') | delattr(self, '_rich_task_id') | ||||
elif self.progress_bar == 'raw': | elif self.progress_bar == 'raw': | ||||
desc = 'Evaluation ends' | desc = 'Evaluation ends' | ||||
@@ -234,7 +233,6 @@ class Evaluator: | |||||
def finally_progress_bar(self): | def finally_progress_bar(self): | ||||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | ||||
f_rich_progress.destroy_task(self._rich_task_id) | f_rich_progress.destroy_task(self._rich_task_id) | ||||
f_rich_progress.refresh() | |||||
delattr(self, '_rich_task_id') | delattr(self, '_rich_task_id') | ||||
@property | @property | ||||
@@ -359,20 +357,23 @@ class _MetricsWrapper: | |||||
if is_dataclass(outputs): | if is_dataclass(outputs): | ||||
outputs = dataclass_to_dict(outputs) | outputs = dataclass_to_dict(outputs) | ||||
for metric in self._metrics: | for metric in self._metrics: | ||||
args = [] | |||||
if not isinstance(batch, dict): | if not isinstance(batch, dict): | ||||
raise RuntimeError(f"When the output of the DataLoader is of type:`{type(batch)}`, please either directly" | |||||
f" return a dict from your DataLoader or use `input_mapping` to convert it into dict type.") | |||||
logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " | |||||
f"the output of model to update metric.") | |||||
else: | |||||
args.append(batch) | |||||
if not isinstance(outputs, dict): | if not isinstance(outputs, dict): | ||||
raise RuntimeError(f"When the output of your model is of type:`{type(batch)}`, please either directly" | |||||
raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly" | |||||
f" return a dict from your model or use `output_mapping` to convert it into dict type.") | f" return a dict from your model or use `output_mapping` to convert it into dict type.") | ||||
if isinstance(metric, Metric): | if isinstance(metric, Metric): | ||||
auto_param_call(metric.update, batch, outputs) | |||||
auto_param_call(metric.update, batch, *args) | |||||
elif _is_torchmetrics_metric(metric): | elif _is_torchmetrics_metric(metric): | ||||
auto_param_call(metric.update, batch, outputs) | |||||
auto_param_call(metric.update, batch, *args) | |||||
elif _is_allennlp_metric(metric): | elif _is_allennlp_metric(metric): | ||||
auto_param_call(metric.__call__, batch, outputs) | |||||
auto_param_call(metric.__call__, batch, *args) | |||||
elif _is_paddle_metric(metric): | elif _is_paddle_metric(metric): | ||||
res = auto_param_call(metric.compute, batch, outputs) | |||||
res = auto_param_call(metric.compute, batch, *args) | |||||
metric.update(res) | metric.update(res) | ||||
def reset(self): | def reset(self): | ||||
@@ -7,6 +7,7 @@ from typing import Optional, Callable | |||||
from .loop import Loop | from .loop import Loop | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.utils import match_and_substitute_params | from fastNLP.core.utils import match_and_substitute_params | ||||
from fastNLP.core.utils.exceptions import EarlyStopException | |||||
class TrainBatchLoop(Loop): | class TrainBatchLoop(Loop): | ||||
@@ -23,13 +24,15 @@ class TrainBatchLoop(Loop): | |||||
try: | try: | ||||
trainer.on_fetch_data_begin() | trainer.on_fetch_data_begin() | ||||
batch = next(dataloader) | batch = next(dataloader) | ||||
batch = match_and_substitute_params(trainer.input_mapping, batch) | |||||
indices = get_batch_indices() | indices = get_batch_indices() | ||||
batch = trainer.move_data_to_device(batch) | |||||
trainer.on_fetch_data_end() | trainer.on_fetch_data_end() | ||||
batch = match_and_substitute_params(trainer.input_mapping, batch) | |||||
batch = trainer.move_data_to_device(batch) | |||||
except StopIteration: | except StopIteration: | ||||
break | break | ||||
except BaseException as e: # TODO 把这里的信息写入进去 | |||||
except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception | |||||
break | |||||
except BaseException as e: | |||||
if indices: | if indices: | ||||
logger.debug(f"The following exception happens when running on samples: {indices}") | logger.debug(f"The following exception happens when running on samples: {indices}") | ||||
raise e | raise e | ||||
@@ -677,7 +677,7 @@ class Trainer(TrainerEventTrigger): | |||||
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | ||||
# 5. 恢复所有 callback 的状态; | # 5. 恢复所有 callback 的状态; | ||||
self.on_load_checkpoint(states["callback_states"]) | |||||
self.train_stepeckpoint(states["callback_states"]) | |||||
self.driver.barrier() | self.driver.barrier() | ||||
@@ -54,7 +54,7 @@ class TorchDataLoader(DataLoader): | |||||
pin_memory: bool = False, drop_last: bool = False, | pin_memory: bool = False, drop_last: bool = False, | ||||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | timeout: float = 0, worker_init_fn: Optional[Callable] = None, | ||||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | ||||
persistent_workers: bool = False, as_numpy: bool = False) -> None: | |||||
persistent_workers: bool = False, as_numpy: bool = False, **kwargs) -> None: | |||||
""" | """ | ||||
:param dataset: 实现了__getitem__和__len__的数据容器 | :param dataset: 实现了__getitem__和__len__的数据容器 | ||||
@@ -788,13 +788,14 @@ class DataSet: | |||||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | ||||
""" | """ | ||||
设置每个field_name的padding值,默认为0,只有当Auto_collate存在时该方法有效 | |||||
设置每个field_name的padding值,默认为0,只有当AutoCollator存在时该方法有效 | |||||
当val=None时,意味着给定的field_names都不需要尝试padding | 当val=None时,意味着给定的field_names都不需要尝试padding | ||||
:param field_names: dataset存在的field_name | :param field_names: dataset存在的field_name | ||||
:param val: 默认为0 | |||||
:param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 | |||||
:return: | :return: | ||||
""" | """ | ||||
# TODO 需要去重复 | |||||
for field_name in field_names: | for field_name in field_names: | ||||
self.collate_fns.set_pad_val(field_name, val=val) | self.collate_fns.set_pad_val(field_name, val=val) | ||||
@@ -805,6 +806,7 @@ class DataSet: | |||||
:param field_names: | :param field_names: | ||||
:return: | :return: | ||||
""" | """ | ||||
# | |||||
self.collate_fns.set_input(*field_names) | self.collate_fns.set_input(*field_names) | ||||
def get_collator(self) -> _MultiCollator: | def get_collator(self) -> _MultiCollator: | ||||
@@ -12,6 +12,7 @@ if _NEED_IMPORT_TORCH: | |||||
import torch | import torch | ||||
import torch.distributed as dist | import torch.distributed as dist | ||||
from torch.nn.parallel import DistributedDataParallel | from torch.nn.parallel import DistributedDataParallel | ||||
from torch.utils.data import BatchSampler | |||||
__all__ = [ | __all__ = [ | ||||
'TorchDDPDriver' | 'TorchDDPDriver' | ||||
@@ -524,7 +525,8 @@ class TorchDDPDriver(TorchDriver): | |||||
num_replicas=self.world_size, | num_replicas=self.world_size, | ||||
rank=self.global_rank | rank=self.global_rank | ||||
) | ) | ||||
return replace_sampler(dataloader, sampler) | |||||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | |||||
return replace_batch_sampler(dataloader, batch_sampler) | |||||
else: | else: | ||||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | ||||
@@ -3,28 +3,20 @@ import pickle | |||||
_pickler = pickle.Pickler | _pickler = pickle.Pickler | ||||
_unpickler = pickle.Unpickler | _unpickler = pickle.Unpickler | ||||
from typing import Any, List | from typing import Any, List | ||||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||||
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
from torch import distributed as dist | from torch import distributed as dist | ||||
try: | |||||
from torch._C._distributed_c10d import ProcessGroupMPI | |||||
except ImportError: | |||||
_MPI_AVAILABLE = False | |||||
try: | |||||
from torch._C._distributed_c10d import ProcessGroupNCCL | |||||
except ImportError: | |||||
_NCCL_AVAILABLE = False | |||||
try: | |||||
from torch._C._distributed_c10d import ProcessGroupGloo | |||||
from torch._C._distributed_c10d import _ProcessGroupWrapper | |||||
except ImportError: | |||||
_GLOO_AVAILABLE = False | |||||
if _TORCH_GREATER_EQUAL_1_8: | |||||
try: | |||||
from torch._C._distributed_c10d import ProcessGroupGloo | |||||
from torch._C._distributed_c10d import _ProcessGroupWrapper | |||||
except ImportError: | |||||
pass | |||||
from fastNLP.core.utils import apply_to_collection | from fastNLP.core.utils import apply_to_collection | ||||
@@ -42,7 +34,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||||
) | ) | ||||
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None): | |||||
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): | |||||
""" | """ | ||||
从其它 rank gather 东西到 dst rank 。 | 从其它 rank gather 东西到 dst rank 。 | ||||
@@ -91,6 +83,9 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None) | |||||
>>> output | >>> output | ||||
['foo', 12, {1: 2}] | ['foo', 12, {1: 2}] | ||||
""" | """ | ||||
if group is None: | |||||
group = DEFAULT_TORCH_GROUP | |||||
if dist.distributed_c10d._rank_not_in_group(group): | if dist.distributed_c10d._rank_not_in_group(group): | ||||
return | return | ||||
@@ -193,7 +188,7 @@ def _to_device(tensor, device): | |||||
return tensor.contiguous().to(device) | return tensor.contiguous().to(device) | ||||
def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||||
def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List: | |||||
""" | """ | ||||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | ||||
@@ -217,7 +212,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||||
:param group: | :param group: | ||||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | ||||
""" | """ | ||||
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 | |||||
if group is None: | |||||
group = DEFAULT_TORCH_GROUP | |||||
if isinstance(obj, torch.Tensor): | if isinstance(obj, torch.Tensor): | ||||
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] | objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] | ||||
dist.all_gather(objs, obj, group=group) | dist.all_gather(objs, obj, group=group) | ||||
@@ -232,7 +228,7 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||||
return objs | return objs | ||||
def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): | |||||
def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP): | |||||
""" | """ | ||||
将 src 上的 obj 对象广播到其它 rank 上。 | 将 src 上的 obj 对象广播到其它 rank 上。 | ||||
@@ -242,6 +238,8 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): | |||||
:param group: | :param group: | ||||
:return: | :return: | ||||
""" | """ | ||||
if group is None: | |||||
group = DEFAULT_TORCH_GROUP | |||||
cur_rank = dist.get_rank(group) | cur_rank = dist.get_rank(group) | ||||
if cur_rank == src: | if cur_rank == src: | ||||
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 | # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 | ||||
@@ -339,15 +337,18 @@ def all_gather_object(object_list, obj, group=None): | |||||
return | return | ||||
input_tensor, local_size = _object_to_tensor(obj) | input_tensor, local_size = _object_to_tensor(obj) | ||||
current_device = torch.device("cpu") | |||||
is_nccl_backend = _check_for_nccl_backend(group) | |||||
if is_nccl_backend: | |||||
# See note about using torch.cuda.current_device() here in docstring. | |||||
# We cannot simply use my_rank since rank == device is not necessarily | |||||
# true. | |||||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||||
input_tensor = input_tensor.to(current_device) | |||||
local_size = local_size.to(current_device) | |||||
if _TORCH_GREATER_EQUAL_1_8: | |||||
current_device = torch.device("cpu") | |||||
is_nccl_backend = _check_for_nccl_backend(group) | |||||
if is_nccl_backend: | |||||
# See note about using torch.cuda.current_device() here in docstring. | |||||
# We cannot simply use my_rank since rank == device is not necessarily | |||||
# true. | |||||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||||
input_tensor = input_tensor.to(current_device) | |||||
local_size = local_size.to(current_device) | |||||
else: | |||||
current_device = torch.cuda.current_device() | |||||
# Gather all local sizes. This is so that we can find the max size, and index | # Gather all local sizes. This is so that we can find the max size, and index | ||||
# until the correct size when deserializing the tensors. | # until the correct size when deserializing the tensors. | ||||
group_size = dist.get_world_size(group=group) | group_size = dist.get_world_size(group=group) | ||||
@@ -8,6 +8,7 @@ import numpy as np | |||||
import inspect | import inspect | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | from fastNLP.envs.imports import _NEED_IMPORT_TORCH | ||||
from fastNLP.core.samplers import re_instantiate_sampler | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
@@ -295,7 +296,6 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||||
"manually add the `DistributedSampler` as: " | "manually add the `DistributedSampler` as: " | ||||
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | ||||
) | ) | ||||
return type(dataloader)(**reconstruct_args) | return type(dataloader)(**reconstruct_args) | ||||
@@ -307,12 +307,8 @@ def _dataloader_init_kwargs_resolve_sampler( | |||||
""" | """ | ||||
batch_sampler = getattr(dataloader, "batch_sampler") | batch_sampler = getattr(dataloader, "batch_sampler") | ||||
# checking the batch sampler type is different than PyTorch default. | # checking the batch sampler type is different than PyTorch default. | ||||
if batch_sampler is not None and type(batch_sampler) is not BatchSampler: | |||||
batch_sampler = type(batch_sampler)( | |||||
sampler, | |||||
batch_size=batch_sampler.batch_size, | |||||
drop_last=batch_sampler.drop_last, | |||||
) | |||||
if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler): | |||||
batch_sampler = re_instantiate_sampler(batch_sampler) | |||||
return { | return { | ||||
"sampler": None, | "sampler": None, | ||||
@@ -343,6 +339,9 @@ def replace_batch_sampler(dataloader, new_batch_sampler): | |||||
params = {k: getattr(dataloader, k) for k in params_keys} | params = {k: getattr(dataloader, k) for k in params_keys} | ||||
params["batch_sampler"] = new_batch_sampler | params["batch_sampler"] = new_batch_sampler | ||||
return type(dataloader)(**params) | return type(dataloader)(**params) | ||||
# TODO 这里是否可以auto_param_call一下 | |||||
# return auto_param_call(type(dataloader), params, {'self': type(dataloader).__new__()}, | |||||
# signature_fn=type(dataloader).__init__) | |||||
def optimizer_state_to_device(state, device): | def optimizer_state_to_device(state, device): | ||||
@@ -51,6 +51,7 @@ class LoggerSingleton(type): | |||||
class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | ||||
def __init__(self, name): | def __init__(self, name): | ||||
super().__init__(name) | super().__init__(name) | ||||
self._warning_msgs = set() | |||||
def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False, | def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False, | ||||
mode: str = "w"): | mode: str = "w"): | ||||
@@ -108,6 +109,21 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||||
kwargs = self._add_rank_info(kwargs) | kwargs = self._add_rank_info(kwargs) | ||||
self._log(WARNING, msg, args, **kwargs) | self._log(WARNING, msg, args, **kwargs) | ||||
def warning_once(self, msg, *args, **kwargs): | |||||
""" | |||||
通过 warning 内容只会 warning 一次 | |||||
:param msg: | |||||
:param args: | |||||
:param kwargs: | |||||
:return: | |||||
""" | |||||
if msg not in self._warning_msgs: | |||||
if self.isEnabledFor(WARNING): | |||||
kwargs = self._add_rank_info(kwargs) | |||||
self._log(WARNING, msg, args, **kwargs) | |||||
self._warning_msgs.add(msg) | |||||
def warn(self, msg, *args, **kwargs): | def warn(self, msg, *args, **kwargs): | ||||
warnings.warn("The 'warn' method is deprecated, " | warnings.warn("The 'warn' method is deprecated, " | ||||
"use 'warning' instead", DeprecationWarning, 2) | "use 'warning' instead", DeprecationWarning, 2) | ||||
@@ -166,8 +166,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
super().__init__() | super().__init__() | ||||
if isinstance(dataset, DataSet): | |||||
length = dataset.get_field(length) | |||||
if isinstance(dataset, DataSet) and isinstance(length, str): | |||||
length = dataset.get_field(length).content | |||||
if not isinstance(length[0], int): | if not isinstance(length[0], int): | ||||
length = list(map(len, length)) | length = list(map(len, length)) | ||||
else: | else: | ||||
@@ -295,8 +295,8 @@ class SortedSampler(SequentialSampler): | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
super().__init__(dataset=dataset, **kwargs) | super().__init__(dataset=dataset, **kwargs) | ||||
if isinstance(dataset, DataSet): | |||||
length = dataset.get_field(length) | |||||
if isinstance(dataset, DataSet) and isinstance(length, str): | |||||
length = dataset.get_field(length).content | |||||
if not isinstance(length[0], int): | if not isinstance(length[0], int): | ||||
length = list(map(len, length)) | length = list(map(len, length)) | ||||
else: | else: | ||||
@@ -105,8 +105,8 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||||
:param kwargs: fastNLP 保留使用 | :param kwargs: fastNLP 保留使用 | ||||
""" | """ | ||||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | ||||
if isinstance(dataset, DataSet): | |||||
length = dataset.get_field(length) | |||||
if isinstance(dataset, DataSet) and isinstance(length, str): | |||||
length = dataset.get_field(length).content | |||||
if not isinstance(length[0], int): | if not isinstance(length[0], int): | ||||
length = list(map(len, length)) | length = list(map(len, length)) | ||||
else: | else: | ||||
@@ -6,7 +6,7 @@ | |||||
import sys | import sys | ||||
from typing import Any, Union, Optional | from typing import Any, Union, Optional | ||||
from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live | |||||
from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live, Text, ProgressSample | |||||
from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn | from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn | ||||
__all__ = [ | __all__ = [ | ||||
@@ -146,24 +146,99 @@ class FRichProgress(Progress, metaclass=Singleton): | |||||
if task_id in self._tasks: | if task_id in self._tasks: | ||||
super().stop_task(task_id) | super().stop_task(task_id) | ||||
super().remove_task(task_id) | super().remove_task(task_id) | ||||
self.refresh() # 使得bar不残留 | |||||
def start(self) -> None: | def start(self) -> None: | ||||
super().start() | super().start() | ||||
self.console.show_cursor(show=True) | self.console.show_cursor(show=True) | ||||
def update( | |||||
self, | |||||
task_id: TaskID, | |||||
*, | |||||
total: Optional[float] = None, | |||||
completed: Optional[float] = None, | |||||
advance: Optional[float] = None, | |||||
description: Optional[str] = None, | |||||
visible: Optional[bool] = None, | |||||
refresh: bool = False, | |||||
**fields: Any, | |||||
) -> None: | |||||
"""Update information associated with a task. | |||||
Args: | |||||
task_id (TaskID): Task id (returned by add_task). | |||||
total (float, optional): Updates task.total if not None. | |||||
completed (float, optional): Updates task.completed if not None. | |||||
advance (float, optional): Add a value to task.completed if not None. | |||||
description (str, optional): Change task description if not None. | |||||
visible (bool, optional): Set visible flag if not None. | |||||
refresh (bool): Force a refresh of progress information. Default is False. | |||||
**fields (Any): Additional data fields required for rendering. | |||||
""" | |||||
with self._lock: | |||||
task = self._tasks[task_id] | |||||
completed_start = task.completed | |||||
if total is not None and total != task.total: | |||||
task.total = total | |||||
task._reset() | |||||
if advance is not None: | |||||
task.completed += advance | |||||
if completed is not None: | |||||
task.completed = completed | |||||
if description is not None: | |||||
task.description = description | |||||
if visible is not None: | |||||
task.visible = visible | |||||
task.fields.update(fields) | |||||
update_completed = task.completed - completed_start | |||||
current_time = self.get_time() | |||||
old_sample_time = current_time - self.speed_estimate_period | |||||
_progress = task._progress | |||||
popleft = _progress.popleft | |||||
# 这里修改为至少保留一个,防止超长时间的迭代影响判断 | |||||
while len(_progress)>1 and _progress[0].timestamp < old_sample_time: | |||||
popleft() | |||||
if update_completed > 0: | |||||
_progress.append(ProgressSample(current_time, update_completed)) | |||||
if task.completed >= task.total and task.finished_time is None: | |||||
task.finished_time = task.elapsed | |||||
if refresh: | |||||
self.refresh() | |||||
class SpeedColumn(ProgressColumn): | |||||
""" | |||||
显示 task 的速度。 | |||||
""" | |||||
def render(self, task: "Task"): | |||||
speed = task.speed | |||||
if speed is None: | |||||
return Text('-- it./s', style='progress.data.speed') | |||||
if speed > 0.1: | |||||
return Text(str(round(speed, 2))+' it./s', style='progress.data.speed') | |||||
else: | |||||
return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') | |||||
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | ||||
f_rich_progress = FRichProgress().new_progess( | f_rich_progress = FRichProgress().new_progess( | ||||
"[progress.description]{task.description}", | "[progress.description]{task.description}", | ||||
"[progress.percentage]{task.percentage:>3.0f}%", | "[progress.percentage]{task.percentage:>3.0f}%", | ||||
BarColumn(), | BarColumn(), | ||||
SpeedColumn(), | |||||
TimeElapsedColumn(), | TimeElapsedColumn(), | ||||
"/", | "/", | ||||
TimeRemainingColumn(), | TimeRemainingColumn(), | ||||
TextColumn("{task.fields[post_desc]}", justify="right"), | TextColumn("{task.fields[post_desc]}", justify="right"), | ||||
transient=True, | transient=True, | ||||
disable=False, | disable=False, | ||||
speed_estimate_period=1 | |||||
speed_estimate_period=30 | |||||
) | ) | ||||
else: | else: | ||||
f_rich_progress = DummyFRichProgress() | f_rich_progress = DummyFRichProgress() | ||||
@@ -1,9 +1,11 @@ | |||||
from abc import ABC | from abc import ABC | ||||
from typing import Any, Union, Optional | from typing import Any, Union, Optional | ||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_8 | |||||
DEFAULT_TORCH_GROUP = None | |||||
if _NEED_IMPORT_TORCH: | if _NEED_IMPORT_TORCH: | ||||
import torch | import torch | ||||
if not _TORCH_GREATER_EQUAL_1_8: | |||||
DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD | |||||
__all__ = [ | __all__ = [ | ||||
'torch_move_data_to_device' | 'torch_move_data_to_device' | ||||
@@ -81,7 +81,10 @@ def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional | |||||
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | ||||
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | ||||
r""" | r""" | ||||
1.该函数用来提供给用户根据字符串匹配从而实现自动计算; | |||||
该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping | |||||
参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。 | |||||
1.该函数用来提供给用户根据字符串匹配从而实现自动调用; | |||||
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; | 2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; | ||||
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | 如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | ||||
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | 3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | ||||
@@ -6,13 +6,16 @@ import logging | |||||
import re | import re | ||||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | ||||
from tests.helpers.utils import magic_argv_env_context | |||||
from fastNLP.core import synchronize_safe_rm | from fastNLP.core import synchronize_safe_rm | ||||
from fastNLP.core.log.logger import logger | |||||
from tests.helpers.utils import magic_argv_env_context, recover_logger | |||||
# 测试 TorchDDPDriver; | # 测试 TorchDDPDriver; | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_add_file_ddp_1(): | |||||
@recover_logger | |||||
def test_add_file_ddp_1_torch(): | |||||
""" | """ | ||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | ||||
@@ -56,11 +59,11 @@ def test_add_file_ddp_1(): | |||||
synchronize_safe_rm(filepath) | synchronize_safe_rm(filepath) | ||||
dist.barrier() | dist.barrier() | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
logger.removeHandler(handler) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_add_file_ddp_2(): | |||||
@recover_logger | |||||
def test_add_file_ddp_2_torch(): | |||||
""" | """ | ||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | ||||
""" | """ | ||||
@@ -103,14 +106,14 @@ def test_add_file_ddp_2(): | |||||
assert len(pattern.findall(line)) == 1 | assert len(pattern.findall(line)) == 1 | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
logger.removeHandler(handler) | |||||
dist.barrier() | dist.barrier() | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_add_file_ddp_3(): | |||||
@recover_logger | |||||
def test_add_file_ddp_3_torch(): | |||||
""" | """ | ||||
path = None; | path = None; | ||||
@@ -155,10 +158,10 @@ def test_add_file_ddp_3(): | |||||
synchronize_safe_rm(file) | synchronize_safe_rm(file) | ||||
dist.barrier() | dist.barrier() | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
logger.removeHandler(handler) | |||||
@magic_argv_env_context | @magic_argv_env_context | ||||
def test_add_file_ddp_4(): | |||||
@recover_logger | |||||
def test_add_file_ddp_4_torch(): | |||||
""" | """ | ||||
测试 path 是文件夹; | 测试 path 是文件夹; | ||||
""" | """ | ||||
@@ -200,7 +203,6 @@ def test_add_file_ddp_4(): | |||||
assert len(pattern.findall(line)) == 1 | assert len(pattern.findall(line)) == 1 | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
logger.removeHandler(handler) | |||||
dist.barrier() | dist.barrier() | ||||
dist.destroy_process_group() | dist.destroy_process_group() | ||||
@@ -209,12 +211,11 @@ def test_add_file_ddp_4(): | |||||
class TestLogger: | class TestLogger: | ||||
msg = 'some test log msg' | msg = 'some test log msg' | ||||
@recover_logger | |||||
def test_add_file_1(self): | def test_add_file_1(self): | ||||
""" | """ | ||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | ||||
""" | """ | ||||
from fastNLP.core.log.logger import logger | |||||
path = Path(tempfile.mkdtemp()) | path = Path(tempfile.mkdtemp()) | ||||
try: | try: | ||||
filepath = path.joinpath('log.txt') | filepath = path.joinpath('log.txt') | ||||
@@ -225,14 +226,12 @@ class TestLogger: | |||||
assert self.msg in line | assert self.msg in line | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
logger.removeHandler(handler) | |||||
@recover_logger | |||||
def test_add_file_2(self): | def test_add_file_2(self): | ||||
""" | """ | ||||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | ||||
""" | """ | ||||
from fastNLP.core.log.logger import logger | |||||
origin_path = Path(tempfile.mkdtemp()) | origin_path = Path(tempfile.mkdtemp()) | ||||
try: | try: | ||||
@@ -245,14 +244,12 @@ class TestLogger: | |||||
assert self.msg in line | assert self.msg in line | ||||
finally: | finally: | ||||
synchronize_safe_rm(origin_path) | synchronize_safe_rm(origin_path) | ||||
logger.removeHandler(handler) | |||||
@recover_logger | |||||
def test_add_file_3(self): | def test_add_file_3(self): | ||||
""" | """ | ||||
测试 path 是 None; | 测试 path 是 None; | ||||
""" | """ | ||||
from fastNLP.core.log.logger import logger | |||||
handler = logger.add_file() | handler = logger.add_file() | ||||
logger.info(self.msg) | logger.info(self.msg) | ||||
@@ -264,14 +261,12 @@ class TestLogger: | |||||
line = ''.join([l for l in f]) | line = ''.join([l for l in f]) | ||||
assert self.msg in line | assert self.msg in line | ||||
file.unlink() | file.unlink() | ||||
logger.removeHandler(handler) | |||||
@recover_logger | |||||
def test_add_file_4(self): | def test_add_file_4(self): | ||||
""" | """ | ||||
测试 path 是文件夹; | 测试 path 是文件夹; | ||||
""" | """ | ||||
from fastNLP.core.log.logger import logger | |||||
path = Path(tempfile.mkdtemp()) | path = Path(tempfile.mkdtemp()) | ||||
try: | try: | ||||
handler = logger.add_file(path) | handler = logger.add_file(path) | ||||
@@ -285,16 +280,21 @@ class TestLogger: | |||||
assert self.msg in line | assert self.msg in line | ||||
finally: | finally: | ||||
synchronize_safe_rm(path) | synchronize_safe_rm(path) | ||||
logger.removeHandler(handler) | |||||
@recover_logger | |||||
def test_stdout(self, capsys): | def test_stdout(self, capsys): | ||||
from fastNLP.core.log.logger import logger | |||||
handler = logger.set_stdout(stdout="raw") | handler = logger.set_stdout(stdout="raw") | ||||
logger.info(self.msg) | logger.info(self.msg) | ||||
logger.debug('aabbc') | logger.debug('aabbc') | ||||
captured = capsys.readouterr() | captured = capsys.readouterr() | ||||
assert "some test log msg\n" == captured.out | assert "some test log msg\n" == captured.out | ||||
logger.removeHandler(handler) | |||||
@recover_logger | |||||
def test_warning_once(self, capsys): | |||||
logger.warning_once('#') | |||||
logger.warning_once('#') | |||||
logger.warning_once('@') | |||||
captured = capsys.readouterr() | |||||
assert captured.out.count('#') == 1 | |||||
assert captured.out.count('@') == 1 | |||||
@@ -13,6 +13,7 @@ import numpy as np | |||||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | ||||
from fastNLP.core.drivers.utils import distributed_open_proc | from fastNLP.core.drivers.utils import distributed_open_proc | ||||
from fastNLP.core.log import logger | |||||
def get_class_that_defined_method(meth): | def get_class_that_defined_method(meth): | ||||
@@ -32,6 +33,20 @@ def get_class_that_defined_method(meth): | |||||
return getattr(meth, '__objclass__', None) # handle special descriptor objects | return getattr(meth, '__objclass__', None) # handle special descriptor objects | ||||
def recover_logger(fn): | |||||
@wraps(fn) | |||||
def wrapper(*args, **kwargs): | |||||
# 保存logger的状态 | |||||
handlers = [handler for handler in logger.handlers] | |||||
level = logger.level | |||||
res = fn(*args, **kwargs) | |||||
logger.handlers = handlers | |||||
logger.setLevel(level) | |||||
return res | |||||
return wrapper | |||||
def magic_argv_env_context(fn): | def magic_argv_env_context(fn): | ||||
@wraps(fn) | @wraps(fn) | ||||