| @@ -32,100 +32,205 @@ class Callback: | |||
| def on_sanity_check_end(self, trainer, sanity_check_res): | |||
| r""" | |||
| 在 '预跑'检测 开始后会被触发; | |||
| :param trainer: | |||
| :param sanity_check_res: 预跑的 evaluate 结果 | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_train_begin(self, trainer): | |||
| r""" | |||
| 在训练开始前会被触发; | |||
| :param trainer: | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_train_end(self, trainer): | |||
| r""" | |||
| 在训练完成后会被触发; | |||
| :param trainer: | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_train_epoch_begin(self, trainer): | |||
| r""" | |||
| 在训练过程中的每一个 epoch 开始前会被触发; | |||
| :param trainer: | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_train_epoch_end(self, trainer): | |||
| r""" | |||
| 在训练过程中的每一个 epoch 完成后会被触发; | |||
| 在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 | |||
| :param trainer: | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_fetch_data_begin(self, trainer): | |||
| r""" | |||
| 在训练过程中拿到当前的具体的一个 batch 前会被触发; | |||
| 在训练过程中准备取出下一个 batch 的数据时触发 | |||
| :param trainer: | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_fetch_data_end(self, trainer): | |||
| r""" | |||
| 在训练过程中拿到当前的具体的一个 batch 后会被触发; | |||
| 在训练过程中拿到当前的 batch 数据后会被触发; | |||
| :param trainer: | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_train_batch_begin(self, trainer, batch, indices): | |||
| r""" | |||
| 在训练过程中开始具体的一个 batch 前会被触发; | |||
| 在取得数据,执行完 input_mapping (如果 Trainer 传有该参数),并且移动 batch 中的 tensor 到了指定设备。 | |||
| 其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。 | |||
| 如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。 | |||
| :param trainer: `fastNLP.Trainer` | |||
| :param batch: 当前正在运行的一个 batch; | |||
| :param indices: 当前的 batch 在一个 epoch 中的位置,用于用户方便地通过该 callback 函数定位具体的数据; | |||
| :param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。 | |||
| :param list[int] indices: 当前的 batch 是 dataset 中的哪些数据 | |||
| """ | |||
| pass | |||
| def on_train_batch_end(self, trainer): | |||
| """ | |||
| 完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与 | |||
| global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 | |||
| 执行。 | |||
| :param trainer: | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_exception(self, trainer, exception): | |||
| """ | |||
| 在训练过程遇到异常时调用。 | |||
| :param trainer: | |||
| :param exception: 遭遇的异常。 | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_save_model(self, trainer): | |||
| """ | |||
| 当将要保存模型时调用,此刻模型还未保存。 | |||
| :param trainer: | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_load_model(self, trainer): | |||
| """ | |||
| 当将要加载模型时调用,此刻模型还未加载。 | |||
| :param trainer: | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_save_checkpoint(self, trainer) -> Dict: | |||
| """ | |||
| 当确定前后两个 callback 是一样的(callback_name 相同,意味着它们所起的职能相同)时,它们在该函数中则应当保存使该 callback 正常 | |||
| 工作的状态;而不应该让该函数去判断两个 callback 是否一样; | |||
| 当 Trainer 将要保存 checkpoint 的时候触发,该函数用于保存当前 callback 在恢复需要的相关数据。 | |||
| :param trainer: | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||
| r""" | |||
| 如果一个 callback 在断点重训前没有保存状态,或者其 `callback_name` 与其余的 callback 重名时,`states` 为 None; | |||
| 当 Trainer 要恢复 checkpoint 的时候触发( Trainer 与 Driver 已经加载好自身的状态),参数 states 为 on_save_checkpoint() | |||
| 的返回值。 | |||
| :param trainer: | |||
| :param states: | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_before_backward(self, trainer, outputs): | |||
| """ | |||
| 在 backward 前执行。 | |||
| :param trainer: | |||
| :param outputs: model 的返回内容。如果有 output_mapping ,则 outputs 中的内容为已经执行了 output_mapping 后的结果。 | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_after_backward(self, trainer): | |||
| """ | |||
| 在 backward 后执行。在多卡场景下,由于 accumulation_steps 的影响,仅在需要真正 update 参数那次梯度回传才会触发梯度同步, | |||
| 因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。 | |||
| :param trainer: | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_before_optimizer_step(self, trainer, optimizers): | |||
| """ | |||
| 在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
| :param trainer: | |||
| :param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_before_zero_grad(self, trainer, optimizers): | |||
| """ | |||
| 在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
| :param trainer: | |||
| :param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_validate_begin(self, trainer): | |||
| """ | |||
| 在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | |||
| 进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | |||
| :param trainer: | |||
| :return: | |||
| """ | |||
| pass | |||
| def on_validate_end(self, trainer, results): | |||
| """ | |||
| 结束 validate 时调用,并把 validate 的结果传入。 | |||
| :param trainer: | |||
| :param results: | |||
| :return: | |||
| """ | |||
| pass | |||
| @property | |||
| def callback_name(self): | |||
| """ | |||
| callback 的名称,我们会使用该名称从 checkpoint 中读取的相应的 state 并传递给 on_load_checkpoint() 函数。 | |||
| :return: | |||
| """ | |||
| return self.__class__.__name__ | |||
| @@ -226,10 +331,21 @@ class HasMonitorCallback(Callback): | |||
| :param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | |||
| :return: | |||
| """ | |||
| better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) | |||
| if keep_if_better and better: | |||
| self.monitor_value = monitor_value | |||
| return better | |||
| def is_former_monitor_value_better(self, monitor_value1, monitor_value2): | |||
| """ | |||
| 传入的两个值中,是否monitor_value1的结果更好。 | |||
| :param monitor_value1: | |||
| :param monitor_value2: | |||
| :return: | |||
| """ | |||
| better = False | |||
| if (self.larger_better and monitor_value > self.monitor_value) or \ | |||
| (not self.larger_better and monitor_value < self.monitor_value): | |||
| if (self.larger_better and monitor_value1 > monitor_value2) or \ | |||
| (not self.larger_better and monitor_value1 < monitor_value2): | |||
| better = True | |||
| if keep_if_better: | |||
| self.monitor_value = monitor_value | |||
| return better | |||
| @@ -15,7 +15,6 @@ from fastNLP.core.callbacks.utils import _get_monitor_value | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.envs import FASTNLP_LAUNCH_TIME | |||
| from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||
| from fastNLP.core.utils import apply_to_collection | |||
| class CheckpointCallback(HasMonitorCallback): | |||
| @@ -178,8 +177,7 @@ class CheckpointCallback(HasMonitorCallback): | |||
| else: | |||
| _least_valuable_model = (min if self.larger_better else max)(self._topk_model, | |||
| key=lambda x: self._topk_model[x]) | |||
| if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \ | |||
| (self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]): | |||
| if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]): | |||
| self._topk_model[folder_name] = monitor_value | |||
| _should_save = True | |||
| self._topk_model.pop(_least_valuable_model) | |||
| @@ -208,21 +206,6 @@ class CheckpointCallback(HasMonitorCallback): | |||
| **self.kwargs | |||
| ) | |||
| def _get_validate_metric(self, res: Dict): | |||
| """ | |||
| 该函数用于从 `Evaluator` 的结果中找到属于当前 CheckpointCallback 的 metric result(根据 monitor); | |||
| 如果用户输入在 res 中没有找到,我们会查询所有的 validate 结果字典的键值,根据 最长公共字符串 匹配,使用最长匹配的结果值; | |||
| :param res: | |||
| :return: | |||
| """ | |||
| use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) | |||
| if self._real_monitor != use_monitor: | |||
| logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), " | |||
| f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||
| self._real_monitor = use_monitor | |||
| return value | |||
| @property | |||
| def folder_prefix(self): | |||
| raise NotImplementedError("The `folder_prefix` is not specified") | |||
| @@ -197,7 +197,7 @@ class _MultiCollator: | |||
| collator.set_input(*field_names) | |||
| flag = False | |||
| if flag: | |||
| warnings.warn("AutoCollator is remove, set_input is unavailable!!") | |||
| warnings.warn("AutoCollator is removed, set_input is unavailable!!") | |||
| return self | |||
| @@ -223,7 +223,6 @@ class Evaluator: | |||
| def remove_progress_bar(self, dataloader_name): | |||
| if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||
| f_rich_progress.destroy_task(self._rich_task_id) | |||
| f_rich_progress.refresh() # 使得最终的bar可以消失 | |||
| delattr(self, '_rich_task_id') | |||
| elif self.progress_bar == 'raw': | |||
| desc = 'Evaluation ends' | |||
| @@ -234,7 +233,6 @@ class Evaluator: | |||
| def finally_progress_bar(self): | |||
| if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||
| f_rich_progress.destroy_task(self._rich_task_id) | |||
| f_rich_progress.refresh() | |||
| delattr(self, '_rich_task_id') | |||
| @property | |||
| @@ -359,20 +357,23 @@ class _MetricsWrapper: | |||
| if is_dataclass(outputs): | |||
| outputs = dataclass_to_dict(outputs) | |||
| for metric in self._metrics: | |||
| args = [] | |||
| if not isinstance(batch, dict): | |||
| raise RuntimeError(f"When the output of the DataLoader is of type:`{type(batch)}`, please either directly" | |||
| f" return a dict from your DataLoader or use `input_mapping` to convert it into dict type.") | |||
| logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " | |||
| f"the output of model to update metric.") | |||
| else: | |||
| args.append(batch) | |||
| if not isinstance(outputs, dict): | |||
| raise RuntimeError(f"When the output of your model is of type:`{type(batch)}`, please either directly" | |||
| raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly" | |||
| f" return a dict from your model or use `output_mapping` to convert it into dict type.") | |||
| if isinstance(metric, Metric): | |||
| auto_param_call(metric.update, batch, outputs) | |||
| auto_param_call(metric.update, batch, *args) | |||
| elif _is_torchmetrics_metric(metric): | |||
| auto_param_call(metric.update, batch, outputs) | |||
| auto_param_call(metric.update, batch, *args) | |||
| elif _is_allennlp_metric(metric): | |||
| auto_param_call(metric.__call__, batch, outputs) | |||
| auto_param_call(metric.__call__, batch, *args) | |||
| elif _is_paddle_metric(metric): | |||
| res = auto_param_call(metric.compute, batch, outputs) | |||
| res = auto_param_call(metric.compute, batch, *args) | |||
| metric.update(res) | |||
| def reset(self): | |||
| @@ -7,6 +7,7 @@ from typing import Optional, Callable | |||
| from .loop import Loop | |||
| from fastNLP.core.log import logger | |||
| from fastNLP.core.utils import match_and_substitute_params | |||
| from fastNLP.core.utils.exceptions import EarlyStopException | |||
| class TrainBatchLoop(Loop): | |||
| @@ -23,13 +24,15 @@ class TrainBatchLoop(Loop): | |||
| try: | |||
| trainer.on_fetch_data_begin() | |||
| batch = next(dataloader) | |||
| batch = match_and_substitute_params(trainer.input_mapping, batch) | |||
| indices = get_batch_indices() | |||
| batch = trainer.move_data_to_device(batch) | |||
| trainer.on_fetch_data_end() | |||
| batch = match_and_substitute_params(trainer.input_mapping, batch) | |||
| batch = trainer.move_data_to_device(batch) | |||
| except StopIteration: | |||
| break | |||
| except BaseException as e: # TODO 把这里的信息写入进去 | |||
| except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception | |||
| break | |||
| except BaseException as e: | |||
| if indices: | |||
| logger.debug(f"The following exception happens when running on samples: {indices}") | |||
| raise e | |||
| @@ -677,7 +677,7 @@ class Trainer(TrainerEventTrigger): | |||
| self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | |||
| # 5. 恢复所有 callback 的状态; | |||
| self.on_load_checkpoint(states["callback_states"]) | |||
| self.train_stepeckpoint(states["callback_states"]) | |||
| self.driver.barrier() | |||
| @@ -54,7 +54,7 @@ class TorchDataLoader(DataLoader): | |||
| pin_memory: bool = False, drop_last: bool = False, | |||
| timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
| multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
| persistent_workers: bool = False, as_numpy: bool = False) -> None: | |||
| persistent_workers: bool = False, as_numpy: bool = False, **kwargs) -> None: | |||
| """ | |||
| :param dataset: 实现了__getitem__和__len__的数据容器 | |||
| @@ -788,13 +788,14 @@ class DataSet: | |||
| def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
| """ | |||
| 设置每个field_name的padding值,默认为0,只有当Auto_collate存在时该方法有效 | |||
| 设置每个field_name的padding值,默认为0,只有当AutoCollator存在时该方法有效 | |||
| 当val=None时,意味着给定的field_names都不需要尝试padding | |||
| :param field_names: dataset存在的field_name | |||
| :param val: 默认为0 | |||
| :param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 | |||
| :return: | |||
| """ | |||
| # TODO 需要去重复 | |||
| for field_name in field_names: | |||
| self.collate_fns.set_pad_val(field_name, val=val) | |||
| @@ -805,6 +806,7 @@ class DataSet: | |||
| :param field_names: | |||
| :return: | |||
| """ | |||
| # | |||
| self.collate_fns.set_input(*field_names) | |||
| def get_collator(self) -> _MultiCollator: | |||
| @@ -12,6 +12,7 @@ if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| import torch.distributed as dist | |||
| from torch.nn.parallel import DistributedDataParallel | |||
| from torch.utils.data import BatchSampler | |||
| __all__ = [ | |||
| 'TorchDDPDriver' | |||
| @@ -524,7 +525,8 @@ class TorchDDPDriver(TorchDriver): | |||
| num_replicas=self.world_size, | |||
| rank=self.global_rank | |||
| ) | |||
| return replace_sampler(dataloader, sampler) | |||
| batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | |||
| return replace_batch_sampler(dataloader, batch_sampler) | |||
| else: | |||
| raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | |||
| @@ -3,28 +3,20 @@ import pickle | |||
| _pickler = pickle.Pickler | |||
| _unpickler = pickle.Unpickler | |||
| from typing import Any, List | |||
| from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||
| from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||
| from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| from torch import distributed as dist | |||
| try: | |||
| from torch._C._distributed_c10d import ProcessGroupMPI | |||
| except ImportError: | |||
| _MPI_AVAILABLE = False | |||
| try: | |||
| from torch._C._distributed_c10d import ProcessGroupNCCL | |||
| except ImportError: | |||
| _NCCL_AVAILABLE = False | |||
| try: | |||
| from torch._C._distributed_c10d import ProcessGroupGloo | |||
| from torch._C._distributed_c10d import _ProcessGroupWrapper | |||
| except ImportError: | |||
| _GLOO_AVAILABLE = False | |||
| if _TORCH_GREATER_EQUAL_1_8: | |||
| try: | |||
| from torch._C._distributed_c10d import ProcessGroupGloo | |||
| from torch._C._distributed_c10d import _ProcessGroupWrapper | |||
| except ImportError: | |||
| pass | |||
| from fastNLP.core.utils import apply_to_collection | |||
| @@ -42,7 +34,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||
| ) | |||
| def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None): | |||
| def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): | |||
| """ | |||
| 从其它 rank gather 东西到 dst rank 。 | |||
| @@ -91,6 +83,9 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None) | |||
| >>> output | |||
| ['foo', 12, {1: 2}] | |||
| """ | |||
| if group is None: | |||
| group = DEFAULT_TORCH_GROUP | |||
| if dist.distributed_c10d._rank_not_in_group(group): | |||
| return | |||
| @@ -193,7 +188,7 @@ def _to_device(tensor, device): | |||
| return tensor.contiguous().to(device) | |||
| def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||
| def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List: | |||
| """ | |||
| 实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | |||
| @@ -217,7 +212,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||
| :param group: | |||
| :return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | |||
| """ | |||
| # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 | |||
| if group is None: | |||
| group = DEFAULT_TORCH_GROUP | |||
| if isinstance(obj, torch.Tensor): | |||
| objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] | |||
| dist.all_gather(objs, obj, group=group) | |||
| @@ -232,7 +228,7 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||
| return objs | |||
| def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): | |||
| def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP): | |||
| """ | |||
| 将 src 上的 obj 对象广播到其它 rank 上。 | |||
| @@ -242,6 +238,8 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): | |||
| :param group: | |||
| :return: | |||
| """ | |||
| if group is None: | |||
| group = DEFAULT_TORCH_GROUP | |||
| cur_rank = dist.get_rank(group) | |||
| if cur_rank == src: | |||
| # 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 | |||
| @@ -339,15 +337,18 @@ def all_gather_object(object_list, obj, group=None): | |||
| return | |||
| input_tensor, local_size = _object_to_tensor(obj) | |||
| current_device = torch.device("cpu") | |||
| is_nccl_backend = _check_for_nccl_backend(group) | |||
| if is_nccl_backend: | |||
| # See note about using torch.cuda.current_device() here in docstring. | |||
| # We cannot simply use my_rank since rank == device is not necessarily | |||
| # true. | |||
| current_device = torch.device("cuda", torch.cuda.current_device()) | |||
| input_tensor = input_tensor.to(current_device) | |||
| local_size = local_size.to(current_device) | |||
| if _TORCH_GREATER_EQUAL_1_8: | |||
| current_device = torch.device("cpu") | |||
| is_nccl_backend = _check_for_nccl_backend(group) | |||
| if is_nccl_backend: | |||
| # See note about using torch.cuda.current_device() here in docstring. | |||
| # We cannot simply use my_rank since rank == device is not necessarily | |||
| # true. | |||
| current_device = torch.device("cuda", torch.cuda.current_device()) | |||
| input_tensor = input_tensor.to(current_device) | |||
| local_size = local_size.to(current_device) | |||
| else: | |||
| current_device = torch.cuda.current_device() | |||
| # Gather all local sizes. This is so that we can find the max size, and index | |||
| # until the correct size when deserializing the tensors. | |||
| group_size = dist.get_world_size(group=group) | |||
| @@ -8,6 +8,7 @@ import numpy as np | |||
| import inspect | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| from fastNLP.core.samplers import re_instantiate_sampler | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| @@ -295,7 +296,6 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||
| "manually add the `DistributedSampler` as: " | |||
| f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | |||
| ) | |||
| return type(dataloader)(**reconstruct_args) | |||
| @@ -307,12 +307,8 @@ def _dataloader_init_kwargs_resolve_sampler( | |||
| """ | |||
| batch_sampler = getattr(dataloader, "batch_sampler") | |||
| # checking the batch sampler type is different than PyTorch default. | |||
| if batch_sampler is not None and type(batch_sampler) is not BatchSampler: | |||
| batch_sampler = type(batch_sampler)( | |||
| sampler, | |||
| batch_size=batch_sampler.batch_size, | |||
| drop_last=batch_sampler.drop_last, | |||
| ) | |||
| if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler): | |||
| batch_sampler = re_instantiate_sampler(batch_sampler) | |||
| return { | |||
| "sampler": None, | |||
| @@ -343,6 +339,9 @@ def replace_batch_sampler(dataloader, new_batch_sampler): | |||
| params = {k: getattr(dataloader, k) for k in params_keys} | |||
| params["batch_sampler"] = new_batch_sampler | |||
| return type(dataloader)(**params) | |||
| # TODO 这里是否可以auto_param_call一下 | |||
| # return auto_param_call(type(dataloader), params, {'self': type(dataloader).__new__()}, | |||
| # signature_fn=type(dataloader).__init__) | |||
| def optimizer_state_to_device(state, device): | |||
| @@ -51,6 +51,7 @@ class LoggerSingleton(type): | |||
| class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
| def __init__(self, name): | |||
| super().__init__(name) | |||
| self._warning_msgs = set() | |||
| def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False, | |||
| mode: str = "w"): | |||
| @@ -108,6 +109,21 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
| kwargs = self._add_rank_info(kwargs) | |||
| self._log(WARNING, msg, args, **kwargs) | |||
| def warning_once(self, msg, *args, **kwargs): | |||
| """ | |||
| 通过 warning 内容只会 warning 一次 | |||
| :param msg: | |||
| :param args: | |||
| :param kwargs: | |||
| :return: | |||
| """ | |||
| if msg not in self._warning_msgs: | |||
| if self.isEnabledFor(WARNING): | |||
| kwargs = self._add_rank_info(kwargs) | |||
| self._log(WARNING, msg, args, **kwargs) | |||
| self._warning_msgs.add(msg) | |||
| def warn(self, msg, *args, **kwargs): | |||
| warnings.warn("The 'warn' method is deprecated, " | |||
| "use 'warning' instead", DeprecationWarning, 2) | |||
| @@ -166,8 +166,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
| :param kwargs: fastNLP 保留使用 | |||
| """ | |||
| super().__init__() | |||
| if isinstance(dataset, DataSet): | |||
| length = dataset.get_field(length) | |||
| if isinstance(dataset, DataSet) and isinstance(length, str): | |||
| length = dataset.get_field(length).content | |||
| if not isinstance(length[0], int): | |||
| length = list(map(len, length)) | |||
| else: | |||
| @@ -295,8 +295,8 @@ class SortedSampler(SequentialSampler): | |||
| :param kwargs: fastNLP 保留使用 | |||
| """ | |||
| super().__init__(dataset=dataset, **kwargs) | |||
| if isinstance(dataset, DataSet): | |||
| length = dataset.get_field(length) | |||
| if isinstance(dataset, DataSet) and isinstance(length, str): | |||
| length = dataset.get_field(length).content | |||
| if not isinstance(length[0], int): | |||
| length = list(map(len, length)) | |||
| else: | |||
| @@ -105,8 +105,8 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||
| :param kwargs: fastNLP 保留使用 | |||
| """ | |||
| super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||
| if isinstance(dataset, DataSet): | |||
| length = dataset.get_field(length) | |||
| if isinstance(dataset, DataSet) and isinstance(length, str): | |||
| length = dataset.get_field(length).content | |||
| if not isinstance(length[0], int): | |||
| length = list(map(len, length)) | |||
| else: | |||
| @@ -6,7 +6,7 @@ | |||
| import sys | |||
| from typing import Any, Union, Optional | |||
| from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live | |||
| from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live, Text, ProgressSample | |||
| from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn | |||
| __all__ = [ | |||
| @@ -146,24 +146,99 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
| if task_id in self._tasks: | |||
| super().stop_task(task_id) | |||
| super().remove_task(task_id) | |||
| self.refresh() # 使得bar不残留 | |||
| def start(self) -> None: | |||
| super().start() | |||
| self.console.show_cursor(show=True) | |||
| def update( | |||
| self, | |||
| task_id: TaskID, | |||
| *, | |||
| total: Optional[float] = None, | |||
| completed: Optional[float] = None, | |||
| advance: Optional[float] = None, | |||
| description: Optional[str] = None, | |||
| visible: Optional[bool] = None, | |||
| refresh: bool = False, | |||
| **fields: Any, | |||
| ) -> None: | |||
| """Update information associated with a task. | |||
| Args: | |||
| task_id (TaskID): Task id (returned by add_task). | |||
| total (float, optional): Updates task.total if not None. | |||
| completed (float, optional): Updates task.completed if not None. | |||
| advance (float, optional): Add a value to task.completed if not None. | |||
| description (str, optional): Change task description if not None. | |||
| visible (bool, optional): Set visible flag if not None. | |||
| refresh (bool): Force a refresh of progress information. Default is False. | |||
| **fields (Any): Additional data fields required for rendering. | |||
| """ | |||
| with self._lock: | |||
| task = self._tasks[task_id] | |||
| completed_start = task.completed | |||
| if total is not None and total != task.total: | |||
| task.total = total | |||
| task._reset() | |||
| if advance is not None: | |||
| task.completed += advance | |||
| if completed is not None: | |||
| task.completed = completed | |||
| if description is not None: | |||
| task.description = description | |||
| if visible is not None: | |||
| task.visible = visible | |||
| task.fields.update(fields) | |||
| update_completed = task.completed - completed_start | |||
| current_time = self.get_time() | |||
| old_sample_time = current_time - self.speed_estimate_period | |||
| _progress = task._progress | |||
| popleft = _progress.popleft | |||
| # 这里修改为至少保留一个,防止超长时间的迭代影响判断 | |||
| while len(_progress)>1 and _progress[0].timestamp < old_sample_time: | |||
| popleft() | |||
| if update_completed > 0: | |||
| _progress.append(ProgressSample(current_time, update_completed)) | |||
| if task.completed >= task.total and task.finished_time is None: | |||
| task.finished_time = task.elapsed | |||
| if refresh: | |||
| self.refresh() | |||
| class SpeedColumn(ProgressColumn): | |||
| """ | |||
| 显示 task 的速度。 | |||
| """ | |||
| def render(self, task: "Task"): | |||
| speed = task.speed | |||
| if speed is None: | |||
| return Text('-- it./s', style='progress.data.speed') | |||
| if speed > 0.1: | |||
| return Text(str(round(speed, 2))+' it./s', style='progress.data.speed') | |||
| else: | |||
| return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') | |||
| if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||
| f_rich_progress = FRichProgress().new_progess( | |||
| "[progress.description]{task.description}", | |||
| "[progress.percentage]{task.percentage:>3.0f}%", | |||
| BarColumn(), | |||
| SpeedColumn(), | |||
| TimeElapsedColumn(), | |||
| "/", | |||
| TimeRemainingColumn(), | |||
| TextColumn("{task.fields[post_desc]}", justify="right"), | |||
| transient=True, | |||
| disable=False, | |||
| speed_estimate_period=1 | |||
| speed_estimate_period=30 | |||
| ) | |||
| else: | |||
| f_rich_progress = DummyFRichProgress() | |||
| @@ -1,9 +1,11 @@ | |||
| from abc import ABC | |||
| from typing import Any, Union, Optional | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
| from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_8 | |||
| DEFAULT_TORCH_GROUP = None | |||
| if _NEED_IMPORT_TORCH: | |||
| import torch | |||
| if not _TORCH_GREATER_EQUAL_1_8: | |||
| DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD | |||
| __all__ = [ | |||
| 'torch_move_data_to_device' | |||
| @@ -81,7 +81,10 @@ def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional | |||
| def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | |||
| mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | |||
| r""" | |||
| 1.该函数用来提供给用户根据字符串匹配从而实现自动计算; | |||
| 该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping | |||
| 参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。 | |||
| 1.该函数用来提供给用户根据字符串匹配从而实现自动调用; | |||
| 2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; | |||
| 如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | |||
| 3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | |||
| @@ -6,13 +6,16 @@ import logging | |||
| import re | |||
| from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | |||
| from tests.helpers.utils import magic_argv_env_context | |||
| from fastNLP.core import synchronize_safe_rm | |||
| from fastNLP.core.log.logger import logger | |||
| from tests.helpers.utils import magic_argv_env_context, recover_logger | |||
| # 测试 TorchDDPDriver; | |||
| @magic_argv_env_context | |||
| def test_add_file_ddp_1(): | |||
| @recover_logger | |||
| def test_add_file_ddp_1_torch(): | |||
| """ | |||
| 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||
| @@ -56,11 +59,11 @@ def test_add_file_ddp_1(): | |||
| synchronize_safe_rm(filepath) | |||
| dist.barrier() | |||
| dist.destroy_process_group() | |||
| logger.removeHandler(handler) | |||
| @magic_argv_env_context | |||
| def test_add_file_ddp_2(): | |||
| @recover_logger | |||
| def test_add_file_ddp_2_torch(): | |||
| """ | |||
| 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||
| """ | |||
| @@ -103,14 +106,14 @@ def test_add_file_ddp_2(): | |||
| assert len(pattern.findall(line)) == 1 | |||
| finally: | |||
| synchronize_safe_rm(path) | |||
| logger.removeHandler(handler) | |||
| dist.barrier() | |||
| dist.destroy_process_group() | |||
| @magic_argv_env_context | |||
| def test_add_file_ddp_3(): | |||
| @recover_logger | |||
| def test_add_file_ddp_3_torch(): | |||
| """ | |||
| path = None; | |||
| @@ -155,10 +158,10 @@ def test_add_file_ddp_3(): | |||
| synchronize_safe_rm(file) | |||
| dist.barrier() | |||
| dist.destroy_process_group() | |||
| logger.removeHandler(handler) | |||
| @magic_argv_env_context | |||
| def test_add_file_ddp_4(): | |||
| @recover_logger | |||
| def test_add_file_ddp_4_torch(): | |||
| """ | |||
| 测试 path 是文件夹; | |||
| """ | |||
| @@ -200,7 +203,6 @@ def test_add_file_ddp_4(): | |||
| assert len(pattern.findall(line)) == 1 | |||
| finally: | |||
| synchronize_safe_rm(path) | |||
| logger.removeHandler(handler) | |||
| dist.barrier() | |||
| dist.destroy_process_group() | |||
| @@ -209,12 +211,11 @@ def test_add_file_ddp_4(): | |||
| class TestLogger: | |||
| msg = 'some test log msg' | |||
| @recover_logger | |||
| def test_add_file_1(self): | |||
| """ | |||
| 测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||
| """ | |||
| from fastNLP.core.log.logger import logger | |||
| path = Path(tempfile.mkdtemp()) | |||
| try: | |||
| filepath = path.joinpath('log.txt') | |||
| @@ -225,14 +226,12 @@ class TestLogger: | |||
| assert self.msg in line | |||
| finally: | |||
| synchronize_safe_rm(path) | |||
| logger.removeHandler(handler) | |||
| @recover_logger | |||
| def test_add_file_2(self): | |||
| """ | |||
| 测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||
| """ | |||
| from fastNLP.core.log.logger import logger | |||
| origin_path = Path(tempfile.mkdtemp()) | |||
| try: | |||
| @@ -245,14 +244,12 @@ class TestLogger: | |||
| assert self.msg in line | |||
| finally: | |||
| synchronize_safe_rm(origin_path) | |||
| logger.removeHandler(handler) | |||
| @recover_logger | |||
| def test_add_file_3(self): | |||
| """ | |||
| 测试 path 是 None; | |||
| """ | |||
| from fastNLP.core.log.logger import logger | |||
| handler = logger.add_file() | |||
| logger.info(self.msg) | |||
| @@ -264,14 +261,12 @@ class TestLogger: | |||
| line = ''.join([l for l in f]) | |||
| assert self.msg in line | |||
| file.unlink() | |||
| logger.removeHandler(handler) | |||
| @recover_logger | |||
| def test_add_file_4(self): | |||
| """ | |||
| 测试 path 是文件夹; | |||
| """ | |||
| from fastNLP.core.log.logger import logger | |||
| path = Path(tempfile.mkdtemp()) | |||
| try: | |||
| handler = logger.add_file(path) | |||
| @@ -285,16 +280,21 @@ class TestLogger: | |||
| assert self.msg in line | |||
| finally: | |||
| synchronize_safe_rm(path) | |||
| logger.removeHandler(handler) | |||
| @recover_logger | |||
| def test_stdout(self, capsys): | |||
| from fastNLP.core.log.logger import logger | |||
| handler = logger.set_stdout(stdout="raw") | |||
| logger.info(self.msg) | |||
| logger.debug('aabbc') | |||
| captured = capsys.readouterr() | |||
| assert "some test log msg\n" == captured.out | |||
| logger.removeHandler(handler) | |||
| @recover_logger | |||
| def test_warning_once(self, capsys): | |||
| logger.warning_once('#') | |||
| logger.warning_once('#') | |||
| logger.warning_once('@') | |||
| captured = capsys.readouterr() | |||
| assert captured.out.count('#') == 1 | |||
| assert captured.out.count('@') == 1 | |||
| @@ -13,6 +13,7 @@ import numpy as np | |||
| from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||
| from fastNLP.core.drivers.utils import distributed_open_proc | |||
| from fastNLP.core.log import logger | |||
| def get_class_that_defined_method(meth): | |||
| @@ -32,6 +33,20 @@ def get_class_that_defined_method(meth): | |||
| return getattr(meth, '__objclass__', None) # handle special descriptor objects | |||
| def recover_logger(fn): | |||
| @wraps(fn) | |||
| def wrapper(*args, **kwargs): | |||
| # 保存logger的状态 | |||
| handlers = [handler for handler in logger.handlers] | |||
| level = logger.level | |||
| res = fn(*args, **kwargs) | |||
| logger.handlers = handlers | |||
| logger.setLevel(level) | |||
| return res | |||
| return wrapper | |||
| def magic_argv_env_context(fn): | |||
| @wraps(fn) | |||