From e8d11cd5a9ec53bd2c6da16e1c14cb025e489415 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Wed, 13 Apr 2022 12:55:28 +0800 Subject: [PATCH] =?UTF-8?q?1.=20=E4=BF=AE=E5=A4=8Dtorch=20=E5=88=86?= =?UTF-8?q?=E5=B8=83=E5=BC=8F=E5=9C=A8=E4=B8=8D=E5=90=8C=E7=89=88=E6=9C=AC?= =?UTF-8?q?=E4=B8=ADgroup=E5=8F=82=E6=95=B0default=E5=80=BC=E4=B8=8D?= =?UTF-8?q?=E4=B8=80=E6=A0=B7=E7=9A=84=E9=97=AE=E9=A2=98;=202.=20torch?= =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=A4=9A=E5=8D=A1=E6=97=B6=E5=8F=AA=E6=9C=89?= =?UTF-8?q?batchsampler=20evaluate=E4=BC=9A=E9=81=87=E5=88=B0bug=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98;=203=E3=80=82logger=E5=A2=9E=E5=8A=A0warning?= =?UTF-8?q?=5Fonce=E6=8E=A5=E5=8F=A3;4.=E5=A2=9E=E5=8A=A0callback=E7=9B=B8?= =?UTF-8?q?=E5=85=B3=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback.py | 142 ++++++++++++++++-- fastNLP/core/callbacks/checkpoint_callback.py | 19 +-- fastNLP/core/collators/collator.py | 2 +- fastNLP/core/controllers/evaluator.py | 19 +-- .../controllers/loops/train_batch_loop.py | 9 +- fastNLP/core/controllers/trainer.py | 2 +- .../core/dataloaders/torch_dataloader/fdl.py | 2 +- fastNLP/core/dataset/dataset.py | 6 +- fastNLP/core/drivers/torch_driver/ddp.py | 4 +- .../core/drivers/torch_driver/dist_utils.py | 61 ++++---- fastNLP/core/drivers/torch_driver/utils.py | 13 +- fastNLP/core/log/logger.py | 16 ++ .../samplers/reproducible_batch_sampler.py | 4 +- fastNLP/core/samplers/reproducible_sampler.py | 4 +- fastNLP/core/samplers/unrepeated_sampler.py | 4 +- fastNLP/core/utils/rich_progress.py | 79 +++++++++- fastNLP/core/utils/torch_utils.py | 6 +- fastNLP/core/utils/utils.py | 5 +- tests/core/log/test_logger.py | 48 +++--- tests/helpers/utils.py | 15 ++ 20 files changed, 339 insertions(+), 121 deletions(-) diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 99e47dfe..96e4372b 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -32,100 +32,205 @@ class Callback: def on_sanity_check_end(self, trainer, sanity_check_res): r""" 在 '预跑'检测 开始后会被触发; + + :param trainer: + :param sanity_check_res: 预跑的 evaluate 结果 + :return: """ pass def on_train_begin(self, trainer): r""" 在训练开始前会被触发; + + :param trainer: + :return: """ pass def on_train_end(self, trainer): r""" 在训练完成后会被触发; + + :param trainer: + :return: """ pass def on_train_epoch_begin(self, trainer): r""" 在训练过程中的每一个 epoch 开始前会被触发; + + :param trainer: + :return: """ pass def on_train_epoch_end(self, trainer): r""" - 在训练过程中的每一个 epoch 完成后会被触发; + 在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 + + :param trainer: + :return: """ pass def on_fetch_data_begin(self, trainer): r""" - 在训练过程中拿到当前的具体的一个 batch 前会被触发; + 在训练过程中准备取出下一个 batch 的数据时触发 + + :param trainer: + :return: """ pass def on_fetch_data_end(self, trainer): r""" - 在训练过程中拿到当前的具体的一个 batch 后会被触发; + 在训练过程中拿到当前的 batch 数据后会被触发; + + :param trainer: + :return: """ pass def on_train_batch_begin(self, trainer, batch, indices): r""" - 在训练过程中开始具体的一个 batch 前会被触发; + 在取得数据,执行完 input_mapping (如果 Trainer 传有该参数),并且移动 batch 中的 tensor 到了指定设备。 + 其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。 + 如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。 :param trainer: `fastNLP.Trainer` - :param batch: 当前正在运行的一个 batch; - :param indices: 当前的 batch 在一个 epoch 中的位置,用于用户方便地通过该 callback 函数定位具体的数据; + :param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。 + :param list[int] indices: 当前的 batch 是 dataset 中的哪些数据 """ pass def on_train_batch_end(self, trainer): + """ + 完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与 + global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 + 执行。 + + :param trainer: + :return: + """ pass def on_exception(self, trainer, exception): + """ + 在训练过程遇到异常时调用。 + + :param trainer: + :param exception: 遭遇的异常。 + :return: + """ pass def on_save_model(self, trainer): + """ + 当将要保存模型时调用,此刻模型还未保存。 + + :param trainer: + :return: + """ pass def on_load_model(self, trainer): + """ + 当将要加载模型时调用,此刻模型还未加载。 + + :param trainer: + :return: + """ pass def on_save_checkpoint(self, trainer) -> Dict: """ - 当确定前后两个 callback 是一样的(callback_name 相同,意味着它们所起的职能相同)时,它们在该函数中则应当保存使该 callback 正常 - 工作的状态;而不应该让该函数去判断两个 callback 是否一样; + 当 Trainer 将要保存 checkpoint 的时候触发,该函数用于保存当前 callback 在恢复需要的相关数据。 + + :param trainer: + :return: """ pass def on_load_checkpoint(self, trainer, states: Optional[Dict]): r""" - 如果一个 callback 在断点重训前没有保存状态,或者其 `callback_name` 与其余的 callback 重名时,`states` 为 None; + 当 Trainer 要恢复 checkpoint 的时候触发( Trainer 与 Driver 已经加载好自身的状态),参数 states 为 on_save_checkpoint() + 的返回值。 + + :param trainer: + :param states: + :return: """ pass def on_before_backward(self, trainer, outputs): + """ + 在 backward 前执行。 + + :param trainer: + :param outputs: model 的返回内容。如果有 output_mapping ,则 outputs 中的内容为已经执行了 output_mapping 后的结果。 + :return: + """ pass def on_after_backward(self, trainer): + """ + 在 backward 后执行。在多卡场景下,由于 accumulation_steps 的影响,仅在需要真正 update 参数那次梯度回传才会触发梯度同步, + 因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。 + + :param trainer: + :return: + """ pass def on_before_optimizer_step(self, trainer, optimizers): + """ + 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 + + :param trainer: + :param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 + :return: + """ pass def on_before_zero_grad(self, trainer, optimizers): + """ + 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 + + :param trainer: + :param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 + :return: + """ pass def on_validate_begin(self, trainer): + """ + 在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 + 进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 + + :param trainer: + :return: + """ pass def on_validate_end(self, trainer, results): + """ + 结束 validate 时调用,并把 validate 的结果传入。 + + :param trainer: + :param results: + :return: + """ pass @property def callback_name(self): + """ + callback 的名称,我们会使用该名称从 checkpoint 中读取的相应的 state 并传递给 on_load_checkpoint() 函数。 + + :return: + """ return self.__class__.__name__ @@ -226,10 +331,21 @@ class HasMonitorCallback(Callback): :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 :return: """ + better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) + if keep_if_better and better: + self.monitor_value = monitor_value + return better + + def is_former_monitor_value_better(self, monitor_value1, monitor_value2): + """ + 传入的两个值中,是否monitor_value1的结果更好。 + + :param monitor_value1: + :param monitor_value2: + :return: + """ better = False - if (self.larger_better and monitor_value > self.monitor_value) or \ - (not self.larger_better and monitor_value < self.monitor_value): + if (self.larger_better and monitor_value1 > monitor_value2) or \ + (not self.larger_better and monitor_value1 < monitor_value2): better = True - if keep_if_better: - self.monitor_value = monitor_value return better \ No newline at end of file diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 839a9522..82bfe404 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -15,7 +15,6 @@ from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_LAUNCH_TIME from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir -from fastNLP.core.utils import apply_to_collection class CheckpointCallback(HasMonitorCallback): @@ -178,8 +177,7 @@ class CheckpointCallback(HasMonitorCallback): else: _least_valuable_model = (min if self.larger_better else max)(self._topk_model, key=lambda x: self._topk_model[x]) - if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \ - (self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]): + if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]): self._topk_model[folder_name] = monitor_value _should_save = True self._topk_model.pop(_least_valuable_model) @@ -208,21 +206,6 @@ class CheckpointCallback(HasMonitorCallback): **self.kwargs ) - def _get_validate_metric(self, res: Dict): - """ - 该函数用于从 `Evaluator` 的结果中找到属于当前 CheckpointCallback 的 metric result(根据 monitor); - 如果用户输入在 res 中没有找到,我们会查询所有的 validate 结果字典的键值,根据 最长公共字符串 匹配,使用最长匹配的结果值; - :param res: - :return: - """ - use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) - if self._real_monitor != use_monitor: - logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), " - f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") - self._real_monitor = use_monitor - - return value - @property def folder_prefix(self): raise NotImplementedError("The `folder_prefix` is not specified") diff --git a/fastNLP/core/collators/collator.py b/fastNLP/core/collators/collator.py index 78b07751..f468dd4c 100644 --- a/fastNLP/core/collators/collator.py +++ b/fastNLP/core/collators/collator.py @@ -197,7 +197,7 @@ class _MultiCollator: collator.set_input(*field_names) flag = False if flag: - warnings.warn("AutoCollator is remove, set_input is unavailable!!") + warnings.warn("AutoCollator is removed, set_input is unavailable!!") return self diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index b193f877..479686e1 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -223,7 +223,6 @@ class Evaluator: def remove_progress_bar(self, dataloader_name): if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): f_rich_progress.destroy_task(self._rich_task_id) - f_rich_progress.refresh() # 使得最终的bar可以消失 delattr(self, '_rich_task_id') elif self.progress_bar == 'raw': desc = 'Evaluation ends' @@ -234,7 +233,6 @@ class Evaluator: def finally_progress_bar(self): if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): f_rich_progress.destroy_task(self._rich_task_id) - f_rich_progress.refresh() delattr(self, '_rich_task_id') @property @@ -359,20 +357,23 @@ class _MetricsWrapper: if is_dataclass(outputs): outputs = dataclass_to_dict(outputs) for metric in self._metrics: + args = [] if not isinstance(batch, dict): - raise RuntimeError(f"When the output of the DataLoader is of type:`{type(batch)}`, please either directly" - f" return a dict from your DataLoader or use `input_mapping` to convert it into dict type.") + logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " + f"the output of model to update metric.") + else: + args.append(batch) if not isinstance(outputs, dict): - raise RuntimeError(f"When the output of your model is of type:`{type(batch)}`, please either directly" + raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly" f" return a dict from your model or use `output_mapping` to convert it into dict type.") if isinstance(metric, Metric): - auto_param_call(metric.update, batch, outputs) + auto_param_call(metric.update, batch, *args) elif _is_torchmetrics_metric(metric): - auto_param_call(metric.update, batch, outputs) + auto_param_call(metric.update, batch, *args) elif _is_allennlp_metric(metric): - auto_param_call(metric.__call__, batch, outputs) + auto_param_call(metric.__call__, batch, *args) elif _is_paddle_metric(metric): - res = auto_param_call(metric.compute, batch, outputs) + res = auto_param_call(metric.compute, batch, *args) metric.update(res) def reset(self): diff --git a/fastNLP/core/controllers/loops/train_batch_loop.py b/fastNLP/core/controllers/loops/train_batch_loop.py index 5d127359..a3219e6d 100644 --- a/fastNLP/core/controllers/loops/train_batch_loop.py +++ b/fastNLP/core/controllers/loops/train_batch_loop.py @@ -7,6 +7,7 @@ from typing import Optional, Callable from .loop import Loop from fastNLP.core.log import logger from fastNLP.core.utils import match_and_substitute_params +from fastNLP.core.utils.exceptions import EarlyStopException class TrainBatchLoop(Loop): @@ -23,13 +24,15 @@ class TrainBatchLoop(Loop): try: trainer.on_fetch_data_begin() batch = next(dataloader) - batch = match_and_substitute_params(trainer.input_mapping, batch) indices = get_batch_indices() - batch = trainer.move_data_to_device(batch) trainer.on_fetch_data_end() + batch = match_and_substitute_params(trainer.input_mapping, batch) + batch = trainer.move_data_to_device(batch) except StopIteration: break - except BaseException as e: # TODO 把这里的信息写入进去 + except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception + break + except BaseException as e: if indices: logger.debug(f"The following exception happens when running on samples: {indices}") raise e diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 6d154770..5daee856 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -677,7 +677,7 @@ class Trainer(TrainerEventTrigger): self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') # 5. 恢复所有 callback 的状态; - self.on_load_checkpoint(states["callback_states"]) + self.train_stepeckpoint(states["callback_states"]) self.driver.barrier() diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index d56dbac9..13eae93c 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -54,7 +54,7 @@ class TorchDataLoader(DataLoader): pin_memory: bool = False, drop_last: bool = False, timeout: float = 0, worker_init_fn: Optional[Callable] = None, multiprocessing_context=None, generator=None, prefetch_factor: int = 2, - persistent_workers: bool = False, as_numpy: bool = False) -> None: + persistent_workers: bool = False, as_numpy: bool = False, **kwargs) -> None: """ :param dataset: 实现了__getitem__和__len__的数据容器 diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 9630a3a0..5b8ec635 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -788,13 +788,14 @@ class DataSet: def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: """ - 设置每个field_name的padding值,默认为0,只有当Auto_collate存在时该方法有效 + 设置每个field_name的padding值,默认为0,只有当AutoCollator存在时该方法有效 当val=None时,意味着给定的field_names都不需要尝试padding :param field_names: dataset存在的field_name - :param val: 默认为0 + :param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 :return: """ + # TODO 需要去重复 for field_name in field_names: self.collate_fns.set_pad_val(field_name, val=val) @@ -805,6 +806,7 @@ class DataSet: :param field_names: :return: """ + # self.collate_fns.set_input(*field_names) def get_collator(self) -> _MultiCollator: diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 4cf207cd..3537d0b3 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -12,6 +12,7 @@ if _NEED_IMPORT_TORCH: import torch import torch.distributed as dist from torch.nn.parallel import DistributedDataParallel + from torch.utils.data import BatchSampler __all__ = [ 'TorchDDPDriver' @@ -524,7 +525,8 @@ class TorchDDPDriver(TorchDriver): num_replicas=self.world_size, rank=self.global_rank ) - return replace_sampler(dataloader, sampler) + batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) + return replace_batch_sampler(dataloader, batch_sampler) else: raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") diff --git a/fastNLP/core/drivers/torch_driver/dist_utils.py b/fastNLP/core/drivers/torch_driver/dist_utils.py index 5e3819e7..ad9e6794 100644 --- a/fastNLP/core/drivers/torch_driver/dist_utils.py +++ b/fastNLP/core/drivers/torch_driver/dist_utils.py @@ -3,28 +3,20 @@ import pickle _pickler = pickle.Pickler _unpickler = pickle.Unpickler from typing import Any, List -from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 - +from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 +from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch from torch import distributed as dist - try: - from torch._C._distributed_c10d import ProcessGroupMPI - except ImportError: - _MPI_AVAILABLE = False - - try: - from torch._C._distributed_c10d import ProcessGroupNCCL - except ImportError: - _NCCL_AVAILABLE = False - - try: - from torch._C._distributed_c10d import ProcessGroupGloo - from torch._C._distributed_c10d import _ProcessGroupWrapper - except ImportError: - _GLOO_AVAILABLE = False + if _TORCH_GREATER_EQUAL_1_8: + try: + from torch._C._distributed_c10d import ProcessGroupGloo + from torch._C._distributed_c10d import _ProcessGroupWrapper + except ImportError: + pass + from fastNLP.core.utils import apply_to_collection @@ -42,7 +34,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): ) -def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None): +def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): """ 从其它 rank gather 东西到 dst rank 。 @@ -91,6 +83,9 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None) >>> output ['foo', 12, {1: 2}] """ + if group is None: + group = DEFAULT_TORCH_GROUP + if dist.distributed_c10d._rank_not_in_group(group): return @@ -193,7 +188,7 @@ def _to_device(tensor, device): return tensor.contiguous().to(device) -def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: +def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List: """ 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 @@ -217,7 +212,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: :param group: :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 """ - # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 + if group is None: + group = DEFAULT_TORCH_GROUP if isinstance(obj, torch.Tensor): objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] dist.all_gather(objs, obj, group=group) @@ -232,7 +228,7 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: return objs -def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): +def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP): """ 将 src 上的 obj 对象广播到其它 rank 上。 @@ -242,6 +238,8 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): :param group: :return: """ + if group is None: + group = DEFAULT_TORCH_GROUP cur_rank = dist.get_rank(group) if cur_rank == src: # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 @@ -339,15 +337,18 @@ def all_gather_object(object_list, obj, group=None): return input_tensor, local_size = _object_to_tensor(obj) - current_device = torch.device("cpu") - is_nccl_backend = _check_for_nccl_backend(group) - if is_nccl_backend: - # See note about using torch.cuda.current_device() here in docstring. - # We cannot simply use my_rank since rank == device is not necessarily - # true. - current_device = torch.device("cuda", torch.cuda.current_device()) - input_tensor = input_tensor.to(current_device) - local_size = local_size.to(current_device) + if _TORCH_GREATER_EQUAL_1_8: + current_device = torch.device("cpu") + is_nccl_backend = _check_for_nccl_backend(group) + if is_nccl_backend: + # See note about using torch.cuda.current_device() here in docstring. + # We cannot simply use my_rank since rank == device is not necessarily + # true. + current_device = torch.device("cuda", torch.cuda.current_device()) + input_tensor = input_tensor.to(current_device) + local_size = local_size.to(current_device) + else: + current_device = torch.cuda.current_device() # Gather all local sizes. This is so that we can find the max size, and index # until the correct size when deserializing the tensors. group_size = dist.get_world_size(group=group) diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index 4210dac5..cdc6cea9 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -8,6 +8,7 @@ import numpy as np import inspect from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.core.samplers import re_instantiate_sampler if _NEED_IMPORT_TORCH: import torch @@ -295,7 +296,6 @@ def replace_sampler(dataloader: "DataLoader", sampler): "manually add the `DistributedSampler` as: " f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." ) - return type(dataloader)(**reconstruct_args) @@ -307,12 +307,8 @@ def _dataloader_init_kwargs_resolve_sampler( """ batch_sampler = getattr(dataloader, "batch_sampler") # checking the batch sampler type is different than PyTorch default. - if batch_sampler is not None and type(batch_sampler) is not BatchSampler: - batch_sampler = type(batch_sampler)( - sampler, - batch_size=batch_sampler.batch_size, - drop_last=batch_sampler.drop_last, - ) + if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler): + batch_sampler = re_instantiate_sampler(batch_sampler) return { "sampler": None, @@ -343,6 +339,9 @@ def replace_batch_sampler(dataloader, new_batch_sampler): params = {k: getattr(dataloader, k) for k in params_keys} params["batch_sampler"] = new_batch_sampler return type(dataloader)(**params) + # TODO 这里是否可以auto_param_call一下 + # return auto_param_call(type(dataloader), params, {'self': type(dataloader).__new__()}, + # signature_fn=type(dataloader).__init__) def optimizer_state_to_device(state, device): diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index ae89ad3f..9763ab4a 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -51,6 +51,7 @@ class LoggerSingleton(type): class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): def __init__(self, name): super().__init__(name) + self._warning_msgs = set() def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False, mode: str = "w"): @@ -108,6 +109,21 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): kwargs = self._add_rank_info(kwargs) self._log(WARNING, msg, args, **kwargs) + def warning_once(self, msg, *args, **kwargs): + """ + 通过 warning 内容只会 warning 一次 + + :param msg: + :param args: + :param kwargs: + :return: + """ + if msg not in self._warning_msgs: + if self.isEnabledFor(WARNING): + kwargs = self._add_rank_info(kwargs) + self._log(WARNING, msg, args, **kwargs) + self._warning_msgs.add(msg) + def warn(self, msg, *args, **kwargs): warnings.warn("The 'warn' method is deprecated, " "use 'warning' instead", DeprecationWarning, 2) diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index c4116e24..d1041f08 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -166,8 +166,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): :param kwargs: fastNLP 保留使用 """ super().__init__() - if isinstance(dataset, DataSet): - length = dataset.get_field(length) + if isinstance(dataset, DataSet) and isinstance(length, str): + length = dataset.get_field(length).content if not isinstance(length[0], int): length = list(map(len, length)) else: diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index 1dc226a5..f48e2fc6 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -295,8 +295,8 @@ class SortedSampler(SequentialSampler): :param kwargs: fastNLP 保留使用 """ super().__init__(dataset=dataset, **kwargs) - if isinstance(dataset, DataSet): - length = dataset.get_field(length) + if isinstance(dataset, DataSet) and isinstance(length, str): + length = dataset.get_field(length).content if not isinstance(length[0], int): length = list(map(len, length)) else: diff --git a/fastNLP/core/samplers/unrepeated_sampler.py b/fastNLP/core/samplers/unrepeated_sampler.py index d7913d20..02ec1162 100644 --- a/fastNLP/core/samplers/unrepeated_sampler.py +++ b/fastNLP/core/samplers/unrepeated_sampler.py @@ -105,8 +105,8 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): :param kwargs: fastNLP 保留使用 """ super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) - if isinstance(dataset, DataSet): - length = dataset.get_field(length) + if isinstance(dataset, DataSet) and isinstance(length, str): + length = dataset.get_field(length).content if not isinstance(length[0], int): length = list(map(len, length)) else: diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index a865f4c1..82747a01 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -6,7 +6,7 @@ import sys from typing import Any, Union, Optional -from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live +from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live, Text, ProgressSample from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn __all__ = [ @@ -146,24 +146,99 @@ class FRichProgress(Progress, metaclass=Singleton): if task_id in self._tasks: super().stop_task(task_id) super().remove_task(task_id) + self.refresh() # 使得bar不残留 def start(self) -> None: super().start() self.console.show_cursor(show=True) + def update( + self, + task_id: TaskID, + *, + total: Optional[float] = None, + completed: Optional[float] = None, + advance: Optional[float] = None, + description: Optional[str] = None, + visible: Optional[bool] = None, + refresh: bool = False, + **fields: Any, + ) -> None: + """Update information associated with a task. + + Args: + task_id (TaskID): Task id (returned by add_task). + total (float, optional): Updates task.total if not None. + completed (float, optional): Updates task.completed if not None. + advance (float, optional): Add a value to task.completed if not None. + description (str, optional): Change task description if not None. + visible (bool, optional): Set visible flag if not None. + refresh (bool): Force a refresh of progress information. Default is False. + **fields (Any): Additional data fields required for rendering. + """ + with self._lock: + task = self._tasks[task_id] + completed_start = task.completed + + if total is not None and total != task.total: + task.total = total + task._reset() + if advance is not None: + task.completed += advance + if completed is not None: + task.completed = completed + if description is not None: + task.description = description + if visible is not None: + task.visible = visible + task.fields.update(fields) + update_completed = task.completed - completed_start + + current_time = self.get_time() + old_sample_time = current_time - self.speed_estimate_period + _progress = task._progress + + popleft = _progress.popleft + # 这里修改为至少保留一个,防止超长时间的迭代影响判断 + while len(_progress)>1 and _progress[0].timestamp < old_sample_time: + popleft() + if update_completed > 0: + _progress.append(ProgressSample(current_time, update_completed)) + if task.completed >= task.total and task.finished_time is None: + task.finished_time = task.elapsed + + if refresh: + self.refresh() + + +class SpeedColumn(ProgressColumn): + """ + 显示 task 的速度。 + + """ + def render(self, task: "Task"): + speed = task.speed + if speed is None: + return Text('-- it./s', style='progress.data.speed') + if speed > 0.1: + return Text(str(round(speed, 2))+' it./s', style='progress.data.speed') + else: + return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') + if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: f_rich_progress = FRichProgress().new_progess( "[progress.description]{task.description}", "[progress.percentage]{task.percentage:>3.0f}%", BarColumn(), + SpeedColumn(), TimeElapsedColumn(), "/", TimeRemainingColumn(), TextColumn("{task.fields[post_desc]}", justify="right"), transient=True, disable=False, - speed_estimate_period=1 + speed_estimate_period=30 ) else: f_rich_progress = DummyFRichProgress() diff --git a/fastNLP/core/utils/torch_utils.py b/fastNLP/core/utils/torch_utils.py index 9dea93dd..2dfc0802 100644 --- a/fastNLP/core/utils/torch_utils.py +++ b/fastNLP/core/utils/torch_utils.py @@ -1,9 +1,11 @@ from abc import ABC from typing import Any, Union, Optional -from fastNLP.envs.imports import _NEED_IMPORT_TORCH - +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_8 +DEFAULT_TORCH_GROUP = None if _NEED_IMPORT_TORCH: import torch + if not _TORCH_GREATER_EQUAL_1_8: + DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD __all__ = [ 'torch_move_data_to_device' diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 46211581..c402fe11 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -81,7 +81,10 @@ def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: r""" - 1.该函数用来提供给用户根据字符串匹配从而实现自动计算; + 该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping + 参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。 + + 1.该函数用来提供给用户根据字符串匹配从而实现自动调用; 2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; 如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; 3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; diff --git a/tests/core/log/test_logger.py b/tests/core/log/test_logger.py index da9b7b6b..4fe49bef 100644 --- a/tests/core/log/test_logger.py +++ b/tests/core/log/test_logger.py @@ -6,13 +6,16 @@ import logging import re from fastNLP.envs.env import FASTNLP_LAUNCH_TIME -from tests.helpers.utils import magic_argv_env_context from fastNLP.core import synchronize_safe_rm +from fastNLP.core.log.logger import logger + +from tests.helpers.utils import magic_argv_env_context, recover_logger # 测试 TorchDDPDriver; @magic_argv_env_context -def test_add_file_ddp_1(): +@recover_logger +def test_add_file_ddp_1_torch(): """ 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; @@ -56,11 +59,11 @@ def test_add_file_ddp_1(): synchronize_safe_rm(filepath) dist.barrier() dist.destroy_process_group() - logger.removeHandler(handler) @magic_argv_env_context -def test_add_file_ddp_2(): +@recover_logger +def test_add_file_ddp_2_torch(): """ 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; """ @@ -103,14 +106,14 @@ def test_add_file_ddp_2(): assert len(pattern.findall(line)) == 1 finally: synchronize_safe_rm(path) - logger.removeHandler(handler) dist.barrier() dist.destroy_process_group() @magic_argv_env_context -def test_add_file_ddp_3(): +@recover_logger +def test_add_file_ddp_3_torch(): """ path = None; @@ -155,10 +158,10 @@ def test_add_file_ddp_3(): synchronize_safe_rm(file) dist.barrier() dist.destroy_process_group() - logger.removeHandler(handler) @magic_argv_env_context -def test_add_file_ddp_4(): +@recover_logger +def test_add_file_ddp_4_torch(): """ 测试 path 是文件夹; """ @@ -200,7 +203,6 @@ def test_add_file_ddp_4(): assert len(pattern.findall(line)) == 1 finally: synchronize_safe_rm(path) - logger.removeHandler(handler) dist.barrier() dist.destroy_process_group() @@ -209,12 +211,11 @@ def test_add_file_ddp_4(): class TestLogger: msg = 'some test log msg' + @recover_logger def test_add_file_1(self): """ 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; """ - from fastNLP.core.log.logger import logger - path = Path(tempfile.mkdtemp()) try: filepath = path.joinpath('log.txt') @@ -225,14 +226,12 @@ class TestLogger: assert self.msg in line finally: synchronize_safe_rm(path) - logger.removeHandler(handler) + @recover_logger def test_add_file_2(self): """ 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; """ - from fastNLP.core.log.logger import logger - origin_path = Path(tempfile.mkdtemp()) try: @@ -245,14 +244,12 @@ class TestLogger: assert self.msg in line finally: synchronize_safe_rm(origin_path) - logger.removeHandler(handler) + @recover_logger def test_add_file_3(self): """ 测试 path 是 None; """ - from fastNLP.core.log.logger import logger - handler = logger.add_file() logger.info(self.msg) @@ -264,14 +261,12 @@ class TestLogger: line = ''.join([l for l in f]) assert self.msg in line file.unlink() - logger.removeHandler(handler) + @recover_logger def test_add_file_4(self): """ 测试 path 是文件夹; """ - from fastNLP.core.log.logger import logger - path = Path(tempfile.mkdtemp()) try: handler = logger.add_file(path) @@ -285,16 +280,21 @@ class TestLogger: assert self.msg in line finally: synchronize_safe_rm(path) - logger.removeHandler(handler) + @recover_logger def test_stdout(self, capsys): - from fastNLP.core.log.logger import logger - handler = logger.set_stdout(stdout="raw") logger.info(self.msg) logger.debug('aabbc') captured = capsys.readouterr() assert "some test log msg\n" == captured.out - logger.removeHandler(handler) + @recover_logger + def test_warning_once(self, capsys): + logger.warning_once('#') + logger.warning_once('#') + logger.warning_once('@') + captured = capsys.readouterr() + assert captured.out.count('#') == 1 + assert captured.out.count('@') == 1 diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index f4effc1f..b876c289 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -13,6 +13,7 @@ import numpy as np from fastNLP.envs.env import FASTNLP_GLOBAL_RANK from fastNLP.core.drivers.utils import distributed_open_proc +from fastNLP.core.log import logger def get_class_that_defined_method(meth): @@ -32,6 +33,20 @@ def get_class_that_defined_method(meth): return getattr(meth, '__objclass__', None) # handle special descriptor objects +def recover_logger(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + # 保存logger的状态 + handlers = [handler for handler in logger.handlers] + level = logger.level + res = fn(*args, **kwargs) + logger.handlers = handlers + logger.setLevel(level) + return res + + return wrapper + + def magic_argv_env_context(fn): @wraps(fn)