@@ -32,100 +32,205 @@ class Callback: | |||
def on_sanity_check_end(self, trainer, sanity_check_res): | |||
r""" | |||
在 '预跑'检测 开始后会被触发; | |||
:param trainer: | |||
:param sanity_check_res: 预跑的 evaluate 结果 | |||
:return: | |||
""" | |||
pass | |||
def on_train_begin(self, trainer): | |||
r""" | |||
在训练开始前会被触发; | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_train_end(self, trainer): | |||
r""" | |||
在训练完成后会被触发; | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_train_epoch_begin(self, trainer): | |||
r""" | |||
在训练过程中的每一个 epoch 开始前会被触发; | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_train_epoch_end(self, trainer): | |||
r""" | |||
在训练过程中的每一个 epoch 完成后会被触发; | |||
在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_fetch_data_begin(self, trainer): | |||
r""" | |||
在训练过程中拿到当前的具体的一个 batch 前会被触发; | |||
在训练过程中准备取出下一个 batch 的数据时触发 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_fetch_data_end(self, trainer): | |||
r""" | |||
在训练过程中拿到当前的具体的一个 batch 后会被触发; | |||
在训练过程中拿到当前的 batch 数据后会被触发; | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_train_batch_begin(self, trainer, batch, indices): | |||
r""" | |||
在训练过程中开始具体的一个 batch 前会被触发; | |||
在取得数据,执行完 input_mapping (如果 Trainer 传有该参数),并且移动 batch 中的 tensor 到了指定设备。 | |||
其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。 | |||
如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。 | |||
:param trainer: `fastNLP.Trainer` | |||
:param batch: 当前正在运行的一个 batch; | |||
:param indices: 当前的 batch 在一个 epoch 中的位置,用于用户方便地通过该 callback 函数定位具体的数据; | |||
:param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。 | |||
:param list[int] indices: 当前的 batch 是 dataset 中的哪些数据 | |||
""" | |||
pass | |||
def on_train_batch_end(self, trainer): | |||
""" | |||
完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与 | |||
global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会 | |||
执行。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_exception(self, trainer, exception): | |||
""" | |||
在训练过程遇到异常时调用。 | |||
:param trainer: | |||
:param exception: 遭遇的异常。 | |||
:return: | |||
""" | |||
pass | |||
def on_save_model(self, trainer): | |||
""" | |||
当将要保存模型时调用,此刻模型还未保存。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_load_model(self, trainer): | |||
""" | |||
当将要加载模型时调用,此刻模型还未加载。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_save_checkpoint(self, trainer) -> Dict: | |||
""" | |||
当确定前后两个 callback 是一样的(callback_name 相同,意味着它们所起的职能相同)时,它们在该函数中则应当保存使该 callback 正常 | |||
工作的状态;而不应该让该函数去判断两个 callback 是否一样; | |||
当 Trainer 将要保存 checkpoint 的时候触发,该函数用于保存当前 callback 在恢复需要的相关数据。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_load_checkpoint(self, trainer, states: Optional[Dict]): | |||
r""" | |||
如果一个 callback 在断点重训前没有保存状态,或者其 `callback_name` 与其余的 callback 重名时,`states` 为 None; | |||
当 Trainer 要恢复 checkpoint 的时候触发( Trainer 与 Driver 已经加载好自身的状态),参数 states 为 on_save_checkpoint() | |||
的返回值。 | |||
:param trainer: | |||
:param states: | |||
:return: | |||
""" | |||
pass | |||
def on_before_backward(self, trainer, outputs): | |||
""" | |||
在 backward 前执行。 | |||
:param trainer: | |||
:param outputs: model 的返回内容。如果有 output_mapping ,则 outputs 中的内容为已经执行了 output_mapping 后的结果。 | |||
:return: | |||
""" | |||
pass | |||
def on_after_backward(self, trainer): | |||
""" | |||
在 backward 后执行。在多卡场景下,由于 accumulation_steps 的影响,仅在需要真正 update 参数那次梯度回传才会触发梯度同步, | |||
因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_before_optimizer_step(self, trainer, optimizers): | |||
""" | |||
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
:param trainer: | |||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
:return: | |||
""" | |||
pass | |||
def on_before_zero_grad(self, trainer, optimizers): | |||
""" | |||
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。 | |||
:param trainer: | |||
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。 | |||
:return: | |||
""" | |||
pass | |||
def on_validate_begin(self, trainer): | |||
""" | |||
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 | |||
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 | |||
:param trainer: | |||
:return: | |||
""" | |||
pass | |||
def on_validate_end(self, trainer, results): | |||
""" | |||
结束 validate 时调用,并把 validate 的结果传入。 | |||
:param trainer: | |||
:param results: | |||
:return: | |||
""" | |||
pass | |||
@property | |||
def callback_name(self): | |||
""" | |||
callback 的名称,我们会使用该名称从 checkpoint 中读取的相应的 state 并传递给 on_load_checkpoint() 函数。 | |||
:return: | |||
""" | |||
return self.__class__.__name__ | |||
@@ -226,10 +331,21 @@ class HasMonitorCallback(Callback): | |||
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。 | |||
:return: | |||
""" | |||
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value) | |||
if keep_if_better and better: | |||
self.monitor_value = monitor_value | |||
return better | |||
def is_former_monitor_value_better(self, monitor_value1, monitor_value2): | |||
""" | |||
传入的两个值中,是否monitor_value1的结果更好。 | |||
:param monitor_value1: | |||
:param monitor_value2: | |||
:return: | |||
""" | |||
better = False | |||
if (self.larger_better and monitor_value > self.monitor_value) or \ | |||
(not self.larger_better and monitor_value < self.monitor_value): | |||
if (self.larger_better and monitor_value1 > monitor_value2) or \ | |||
(not self.larger_better and monitor_value1 < monitor_value2): | |||
better = True | |||
if keep_if_better: | |||
self.monitor_value = monitor_value | |||
return better |
@@ -15,7 +15,6 @@ from fastNLP.core.callbacks.utils import _get_monitor_value | |||
from fastNLP.core.log import logger | |||
from fastNLP.envs import FASTNLP_LAUNCH_TIME | |||
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir | |||
from fastNLP.core.utils import apply_to_collection | |||
class CheckpointCallback(HasMonitorCallback): | |||
@@ -178,8 +177,7 @@ class CheckpointCallback(HasMonitorCallback): | |||
else: | |||
_least_valuable_model = (min if self.larger_better else max)(self._topk_model, | |||
key=lambda x: self._topk_model[x]) | |||
if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \ | |||
(self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]): | |||
if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]): | |||
self._topk_model[folder_name] = monitor_value | |||
_should_save = True | |||
self._topk_model.pop(_least_valuable_model) | |||
@@ -208,21 +206,6 @@ class CheckpointCallback(HasMonitorCallback): | |||
**self.kwargs | |||
) | |||
def _get_validate_metric(self, res: Dict): | |||
""" | |||
该函数用于从 `Evaluator` 的结果中找到属于当前 CheckpointCallback 的 metric result(根据 monitor); | |||
如果用户输入在 res 中没有找到,我们会查询所有的 validate 结果字典的键值,根据 最长公共字符串 匹配,使用最长匹配的结果值; | |||
:param res: | |||
:return: | |||
""" | |||
use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res) | |||
if self._real_monitor != use_monitor: | |||
logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), " | |||
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.") | |||
self._real_monitor = use_monitor | |||
return value | |||
@property | |||
def folder_prefix(self): | |||
raise NotImplementedError("The `folder_prefix` is not specified") | |||
@@ -197,7 +197,7 @@ class _MultiCollator: | |||
collator.set_input(*field_names) | |||
flag = False | |||
if flag: | |||
warnings.warn("AutoCollator is remove, set_input is unavailable!!") | |||
warnings.warn("AutoCollator is removed, set_input is unavailable!!") | |||
return self | |||
@@ -223,7 +223,6 @@ class Evaluator: | |||
def remove_progress_bar(self, dataloader_name): | |||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||
f_rich_progress.destroy_task(self._rich_task_id) | |||
f_rich_progress.refresh() # 使得最终的bar可以消失 | |||
delattr(self, '_rich_task_id') | |||
elif self.progress_bar == 'raw': | |||
desc = 'Evaluation ends' | |||
@@ -234,7 +233,6 @@ class Evaluator: | |||
def finally_progress_bar(self): | |||
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'): | |||
f_rich_progress.destroy_task(self._rich_task_id) | |||
f_rich_progress.refresh() | |||
delattr(self, '_rich_task_id') | |||
@property | |||
@@ -359,20 +357,23 @@ class _MetricsWrapper: | |||
if is_dataclass(outputs): | |||
outputs = dataclass_to_dict(outputs) | |||
for metric in self._metrics: | |||
args = [] | |||
if not isinstance(batch, dict): | |||
raise RuntimeError(f"When the output of the DataLoader is of type:`{type(batch)}`, please either directly" | |||
f" return a dict from your DataLoader or use `input_mapping` to convert it into dict type.") | |||
logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on " | |||
f"the output of model to update metric.") | |||
else: | |||
args.append(batch) | |||
if not isinstance(outputs, dict): | |||
raise RuntimeError(f"When the output of your model is of type:`{type(batch)}`, please either directly" | |||
raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly" | |||
f" return a dict from your model or use `output_mapping` to convert it into dict type.") | |||
if isinstance(metric, Metric): | |||
auto_param_call(metric.update, batch, outputs) | |||
auto_param_call(metric.update, batch, *args) | |||
elif _is_torchmetrics_metric(metric): | |||
auto_param_call(metric.update, batch, outputs) | |||
auto_param_call(metric.update, batch, *args) | |||
elif _is_allennlp_metric(metric): | |||
auto_param_call(metric.__call__, batch, outputs) | |||
auto_param_call(metric.__call__, batch, *args) | |||
elif _is_paddle_metric(metric): | |||
res = auto_param_call(metric.compute, batch, outputs) | |||
res = auto_param_call(metric.compute, batch, *args) | |||
metric.update(res) | |||
def reset(self): | |||
@@ -7,6 +7,7 @@ from typing import Optional, Callable | |||
from .loop import Loop | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.utils import match_and_substitute_params | |||
from fastNLP.core.utils.exceptions import EarlyStopException | |||
class TrainBatchLoop(Loop): | |||
@@ -23,13 +24,15 @@ class TrainBatchLoop(Loop): | |||
try: | |||
trainer.on_fetch_data_begin() | |||
batch = next(dataloader) | |||
batch = match_and_substitute_params(trainer.input_mapping, batch) | |||
indices = get_batch_indices() | |||
batch = trainer.move_data_to_device(batch) | |||
trainer.on_fetch_data_end() | |||
batch = match_and_substitute_params(trainer.input_mapping, batch) | |||
batch = trainer.move_data_to_device(batch) | |||
except StopIteration: | |||
break | |||
except BaseException as e: # TODO 把这里的信息写入进去 | |||
except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception | |||
break | |||
except BaseException as e: | |||
if indices: | |||
logger.debug(f"The following exception happens when running on samples: {indices}") | |||
raise e | |||
@@ -677,7 +677,7 @@ class Trainer(TrainerEventTrigger): | |||
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') | |||
# 5. 恢复所有 callback 的状态; | |||
self.on_load_checkpoint(states["callback_states"]) | |||
self.train_stepeckpoint(states["callback_states"]) | |||
self.driver.barrier() | |||
@@ -54,7 +54,7 @@ class TorchDataLoader(DataLoader): | |||
pin_memory: bool = False, drop_last: bool = False, | |||
timeout: float = 0, worker_init_fn: Optional[Callable] = None, | |||
multiprocessing_context=None, generator=None, prefetch_factor: int = 2, | |||
persistent_workers: bool = False, as_numpy: bool = False) -> None: | |||
persistent_workers: bool = False, as_numpy: bool = False, **kwargs) -> None: | |||
""" | |||
:param dataset: 实现了__getitem__和__len__的数据容器 | |||
@@ -788,13 +788,14 @@ class DataSet: | |||
def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None: | |||
""" | |||
设置每个field_name的padding值,默认为0,只有当Auto_collate存在时该方法有效 | |||
设置每个field_name的padding值,默认为0,只有当AutoCollator存在时该方法有效 | |||
当val=None时,意味着给定的field_names都不需要尝试padding | |||
:param field_names: dataset存在的field_name | |||
:param val: 默认为0 | |||
:param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。 | |||
:return: | |||
""" | |||
# TODO 需要去重复 | |||
for field_name in field_names: | |||
self.collate_fns.set_pad_val(field_name, val=val) | |||
@@ -805,6 +806,7 @@ class DataSet: | |||
:param field_names: | |||
:return: | |||
""" | |||
# | |||
self.collate_fns.set_input(*field_names) | |||
def get_collator(self) -> _MultiCollator: | |||
@@ -12,6 +12,7 @@ if _NEED_IMPORT_TORCH: | |||
import torch | |||
import torch.distributed as dist | |||
from torch.nn.parallel import DistributedDataParallel | |||
from torch.utils.data import BatchSampler | |||
__all__ = [ | |||
'TorchDDPDriver' | |||
@@ -524,7 +525,8 @@ class TorchDDPDriver(TorchDriver): | |||
num_replicas=self.world_size, | |||
rank=self.global_rank | |||
) | |||
return replace_sampler(dataloader, sampler) | |||
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False) | |||
return replace_batch_sampler(dataloader, batch_sampler) | |||
else: | |||
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).") | |||
@@ -3,28 +3,20 @@ import pickle | |||
_pickler = pickle.Pickler | |||
_unpickler = pickle.Unpickler | |||
from typing import Any, List | |||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8 | |||
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
from torch import distributed as dist | |||
try: | |||
from torch._C._distributed_c10d import ProcessGroupMPI | |||
except ImportError: | |||
_MPI_AVAILABLE = False | |||
try: | |||
from torch._C._distributed_c10d import ProcessGroupNCCL | |||
except ImportError: | |||
_NCCL_AVAILABLE = False | |||
try: | |||
from torch._C._distributed_c10d import ProcessGroupGloo | |||
from torch._C._distributed_c10d import _ProcessGroupWrapper | |||
except ImportError: | |||
_GLOO_AVAILABLE = False | |||
if _TORCH_GREATER_EQUAL_1_8: | |||
try: | |||
from torch._C._distributed_c10d import ProcessGroupGloo | |||
from torch._C._distributed_c10d import _ProcessGroupWrapper | |||
except ImportError: | |||
pass | |||
from fastNLP.core.utils import apply_to_collection | |||
@@ -42,7 +34,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list): | |||
) | |||
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None): | |||
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP): | |||
""" | |||
从其它 rank gather 东西到 dst rank 。 | |||
@@ -91,6 +83,9 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None) | |||
>>> output | |||
['foo', 12, {1: 2}] | |||
""" | |||
if group is None: | |||
group = DEFAULT_TORCH_GROUP | |||
if dist.distributed_c10d._rank_not_in_group(group): | |||
return | |||
@@ -193,7 +188,7 @@ def _to_device(tensor, device): | |||
return tensor.contiguous().to(device) | |||
def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||
def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List: | |||
""" | |||
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。 | |||
@@ -217,7 +212,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||
:param group: | |||
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。 | |||
""" | |||
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 | |||
if group is None: | |||
group = DEFAULT_TORCH_GROUP | |||
if isinstance(obj, torch.Tensor): | |||
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))] | |||
dist.all_gather(objs, obj, group=group) | |||
@@ -232,7 +228,7 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List: | |||
return objs | |||
def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): | |||
def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP): | |||
""" | |||
将 src 上的 obj 对象广播到其它 rank 上。 | |||
@@ -242,6 +238,8 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=None): | |||
:param group: | |||
:return: | |||
""" | |||
if group is None: | |||
group = DEFAULT_TORCH_GROUP | |||
cur_rank = dist.get_rank(group) | |||
if cur_rank == src: | |||
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里 | |||
@@ -339,15 +337,18 @@ def all_gather_object(object_list, obj, group=None): | |||
return | |||
input_tensor, local_size = _object_to_tensor(obj) | |||
current_device = torch.device("cpu") | |||
is_nccl_backend = _check_for_nccl_backend(group) | |||
if is_nccl_backend: | |||
# See note about using torch.cuda.current_device() here in docstring. | |||
# We cannot simply use my_rank since rank == device is not necessarily | |||
# true. | |||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||
input_tensor = input_tensor.to(current_device) | |||
local_size = local_size.to(current_device) | |||
if _TORCH_GREATER_EQUAL_1_8: | |||
current_device = torch.device("cpu") | |||
is_nccl_backend = _check_for_nccl_backend(group) | |||
if is_nccl_backend: | |||
# See note about using torch.cuda.current_device() here in docstring. | |||
# We cannot simply use my_rank since rank == device is not necessarily | |||
# true. | |||
current_device = torch.device("cuda", torch.cuda.current_device()) | |||
input_tensor = input_tensor.to(current_device) | |||
local_size = local_size.to(current_device) | |||
else: | |||
current_device = torch.cuda.current_device() | |||
# Gather all local sizes. This is so that we can find the max size, and index | |||
# until the correct size when deserializing the tensors. | |||
group_size = dist.get_world_size(group=group) | |||
@@ -8,6 +8,7 @@ import numpy as np | |||
import inspect | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.core.samplers import re_instantiate_sampler | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
@@ -295,7 +296,6 @@ def replace_sampler(dataloader: "DataLoader", sampler): | |||
"manually add the `DistributedSampler` as: " | |||
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`." | |||
) | |||
return type(dataloader)(**reconstruct_args) | |||
@@ -307,12 +307,8 @@ def _dataloader_init_kwargs_resolve_sampler( | |||
""" | |||
batch_sampler = getattr(dataloader, "batch_sampler") | |||
# checking the batch sampler type is different than PyTorch default. | |||
if batch_sampler is not None and type(batch_sampler) is not BatchSampler: | |||
batch_sampler = type(batch_sampler)( | |||
sampler, | |||
batch_size=batch_sampler.batch_size, | |||
drop_last=batch_sampler.drop_last, | |||
) | |||
if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler): | |||
batch_sampler = re_instantiate_sampler(batch_sampler) | |||
return { | |||
"sampler": None, | |||
@@ -343,6 +339,9 @@ def replace_batch_sampler(dataloader, new_batch_sampler): | |||
params = {k: getattr(dataloader, k) for k in params_keys} | |||
params["batch_sampler"] = new_batch_sampler | |||
return type(dataloader)(**params) | |||
# TODO 这里是否可以auto_param_call一下 | |||
# return auto_param_call(type(dataloader), params, {'self': type(dataloader).__new__()}, | |||
# signature_fn=type(dataloader).__init__) | |||
def optimizer_state_to_device(state, device): | |||
@@ -51,6 +51,7 @@ class LoggerSingleton(type): | |||
class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
def __init__(self, name): | |||
super().__init__(name) | |||
self._warning_msgs = set() | |||
def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False, | |||
mode: str = "w"): | |||
@@ -108,6 +109,21 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): | |||
kwargs = self._add_rank_info(kwargs) | |||
self._log(WARNING, msg, args, **kwargs) | |||
def warning_once(self, msg, *args, **kwargs): | |||
""" | |||
通过 warning 内容只会 warning 一次 | |||
:param msg: | |||
:param args: | |||
:param kwargs: | |||
:return: | |||
""" | |||
if msg not in self._warning_msgs: | |||
if self.isEnabledFor(WARNING): | |||
kwargs = self._add_rank_info(kwargs) | |||
self._log(WARNING, msg, args, **kwargs) | |||
self._warning_msgs.add(msg) | |||
def warn(self, msg, *args, **kwargs): | |||
warnings.warn("The 'warn' method is deprecated, " | |||
"use 'warning' instead", DeprecationWarning, 2) | |||
@@ -166,8 +166,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler): | |||
:param kwargs: fastNLP 保留使用 | |||
""" | |||
super().__init__() | |||
if isinstance(dataset, DataSet): | |||
length = dataset.get_field(length) | |||
if isinstance(dataset, DataSet) and isinstance(length, str): | |||
length = dataset.get_field(length).content | |||
if not isinstance(length[0], int): | |||
length = list(map(len, length)) | |||
else: | |||
@@ -295,8 +295,8 @@ class SortedSampler(SequentialSampler): | |||
:param kwargs: fastNLP 保留使用 | |||
""" | |||
super().__init__(dataset=dataset, **kwargs) | |||
if isinstance(dataset, DataSet): | |||
length = dataset.get_field(length) | |||
if isinstance(dataset, DataSet) and isinstance(length, str): | |||
length = dataset.get_field(length).content | |||
if not isinstance(length[0], int): | |||
length = list(map(len, length)) | |||
else: | |||
@@ -105,8 +105,8 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler): | |||
:param kwargs: fastNLP 保留使用 | |||
""" | |||
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) | |||
if isinstance(dataset, DataSet): | |||
length = dataset.get_field(length) | |||
if isinstance(dataset, DataSet) and isinstance(length, str): | |||
length = dataset.get_field(length).content | |||
if not isinstance(length[0], int): | |||
length = list(map(len, length)) | |||
else: | |||
@@ -6,7 +6,7 @@ | |||
import sys | |||
from typing import Any, Union, Optional | |||
from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live | |||
from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live, Text, ProgressSample | |||
from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn | |||
__all__ = [ | |||
@@ -146,24 +146,99 @@ class FRichProgress(Progress, metaclass=Singleton): | |||
if task_id in self._tasks: | |||
super().stop_task(task_id) | |||
super().remove_task(task_id) | |||
self.refresh() # 使得bar不残留 | |||
def start(self) -> None: | |||
super().start() | |||
self.console.show_cursor(show=True) | |||
def update( | |||
self, | |||
task_id: TaskID, | |||
*, | |||
total: Optional[float] = None, | |||
completed: Optional[float] = None, | |||
advance: Optional[float] = None, | |||
description: Optional[str] = None, | |||
visible: Optional[bool] = None, | |||
refresh: bool = False, | |||
**fields: Any, | |||
) -> None: | |||
"""Update information associated with a task. | |||
Args: | |||
task_id (TaskID): Task id (returned by add_task). | |||
total (float, optional): Updates task.total if not None. | |||
completed (float, optional): Updates task.completed if not None. | |||
advance (float, optional): Add a value to task.completed if not None. | |||
description (str, optional): Change task description if not None. | |||
visible (bool, optional): Set visible flag if not None. | |||
refresh (bool): Force a refresh of progress information. Default is False. | |||
**fields (Any): Additional data fields required for rendering. | |||
""" | |||
with self._lock: | |||
task = self._tasks[task_id] | |||
completed_start = task.completed | |||
if total is not None and total != task.total: | |||
task.total = total | |||
task._reset() | |||
if advance is not None: | |||
task.completed += advance | |||
if completed is not None: | |||
task.completed = completed | |||
if description is not None: | |||
task.description = description | |||
if visible is not None: | |||
task.visible = visible | |||
task.fields.update(fields) | |||
update_completed = task.completed - completed_start | |||
current_time = self.get_time() | |||
old_sample_time = current_time - self.speed_estimate_period | |||
_progress = task._progress | |||
popleft = _progress.popleft | |||
# 这里修改为至少保留一个,防止超长时间的迭代影响判断 | |||
while len(_progress)>1 and _progress[0].timestamp < old_sample_time: | |||
popleft() | |||
if update_completed > 0: | |||
_progress.append(ProgressSample(current_time, update_completed)) | |||
if task.completed >= task.total and task.finished_time is None: | |||
task.finished_time = task.elapsed | |||
if refresh: | |||
self.refresh() | |||
class SpeedColumn(ProgressColumn): | |||
""" | |||
显示 task 的速度。 | |||
""" | |||
def render(self, task: "Task"): | |||
speed = task.speed | |||
if speed is None: | |||
return Text('-- it./s', style='progress.data.speed') | |||
if speed > 0.1: | |||
return Text(str(round(speed, 2))+' it./s', style='progress.data.speed') | |||
else: | |||
return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') | |||
if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: | |||
f_rich_progress = FRichProgress().new_progess( | |||
"[progress.description]{task.description}", | |||
"[progress.percentage]{task.percentage:>3.0f}%", | |||
BarColumn(), | |||
SpeedColumn(), | |||
TimeElapsedColumn(), | |||
"/", | |||
TimeRemainingColumn(), | |||
TextColumn("{task.fields[post_desc]}", justify="right"), | |||
transient=True, | |||
disable=False, | |||
speed_estimate_period=1 | |||
speed_estimate_period=30 | |||
) | |||
else: | |||
f_rich_progress = DummyFRichProgress() | |||
@@ -1,9 +1,11 @@ | |||
from abc import ABC | |||
from typing import Any, Union, Optional | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH | |||
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_8 | |||
DEFAULT_TORCH_GROUP = None | |||
if _NEED_IMPORT_TORCH: | |||
import torch | |||
if not _TORCH_GREATER_EQUAL_1_8: | |||
DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD | |||
__all__ = [ | |||
'torch_move_data_to_device' | |||
@@ -81,7 +81,10 @@ def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional | |||
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None, | |||
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any: | |||
r""" | |||
1.该函数用来提供给用户根据字符串匹配从而实现自动计算; | |||
该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping | |||
参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。 | |||
1.该函数用来提供给用户根据字符串匹配从而实现自动调用; | |||
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来; | |||
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性; | |||
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值; | |||
@@ -6,13 +6,16 @@ import logging | |||
import re | |||
from fastNLP.envs.env import FASTNLP_LAUNCH_TIME | |||
from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP.core import synchronize_safe_rm | |||
from fastNLP.core.log.logger import logger | |||
from tests.helpers.utils import magic_argv_env_context, recover_logger | |||
# 测试 TorchDDPDriver; | |||
@magic_argv_env_context | |||
def test_add_file_ddp_1(): | |||
@recover_logger | |||
def test_add_file_ddp_1_torch(): | |||
""" | |||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||
@@ -56,11 +59,11 @@ def test_add_file_ddp_1(): | |||
synchronize_safe_rm(filepath) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
logger.removeHandler(handler) | |||
@magic_argv_env_context | |||
def test_add_file_ddp_2(): | |||
@recover_logger | |||
def test_add_file_ddp_2_torch(): | |||
""" | |||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||
""" | |||
@@ -103,14 +106,14 @@ def test_add_file_ddp_2(): | |||
assert len(pattern.findall(line)) == 1 | |||
finally: | |||
synchronize_safe_rm(path) | |||
logger.removeHandler(handler) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
@magic_argv_env_context | |||
def test_add_file_ddp_3(): | |||
@recover_logger | |||
def test_add_file_ddp_3_torch(): | |||
""" | |||
path = None; | |||
@@ -155,10 +158,10 @@ def test_add_file_ddp_3(): | |||
synchronize_safe_rm(file) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
logger.removeHandler(handler) | |||
@magic_argv_env_context | |||
def test_add_file_ddp_4(): | |||
@recover_logger | |||
def test_add_file_ddp_4_torch(): | |||
""" | |||
测试 path 是文件夹; | |||
""" | |||
@@ -200,7 +203,6 @@ def test_add_file_ddp_4(): | |||
assert len(pattern.findall(line)) == 1 | |||
finally: | |||
synchronize_safe_rm(path) | |||
logger.removeHandler(handler) | |||
dist.barrier() | |||
dist.destroy_process_group() | |||
@@ -209,12 +211,11 @@ def test_add_file_ddp_4(): | |||
class TestLogger: | |||
msg = 'some test log msg' | |||
@recover_logger | |||
def test_add_file_1(self): | |||
""" | |||
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在; | |||
""" | |||
from fastNLP.core.log.logger import logger | |||
path = Path(tempfile.mkdtemp()) | |||
try: | |||
filepath = path.joinpath('log.txt') | |||
@@ -225,14 +226,12 @@ class TestLogger: | |||
assert self.msg in line | |||
finally: | |||
synchronize_safe_rm(path) | |||
logger.removeHandler(handler) | |||
@recover_logger | |||
def test_add_file_2(self): | |||
""" | |||
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在; | |||
""" | |||
from fastNLP.core.log.logger import logger | |||
origin_path = Path(tempfile.mkdtemp()) | |||
try: | |||
@@ -245,14 +244,12 @@ class TestLogger: | |||
assert self.msg in line | |||
finally: | |||
synchronize_safe_rm(origin_path) | |||
logger.removeHandler(handler) | |||
@recover_logger | |||
def test_add_file_3(self): | |||
""" | |||
测试 path 是 None; | |||
""" | |||
from fastNLP.core.log.logger import logger | |||
handler = logger.add_file() | |||
logger.info(self.msg) | |||
@@ -264,14 +261,12 @@ class TestLogger: | |||
line = ''.join([l for l in f]) | |||
assert self.msg in line | |||
file.unlink() | |||
logger.removeHandler(handler) | |||
@recover_logger | |||
def test_add_file_4(self): | |||
""" | |||
测试 path 是文件夹; | |||
""" | |||
from fastNLP.core.log.logger import logger | |||
path = Path(tempfile.mkdtemp()) | |||
try: | |||
handler = logger.add_file(path) | |||
@@ -285,16 +280,21 @@ class TestLogger: | |||
assert self.msg in line | |||
finally: | |||
synchronize_safe_rm(path) | |||
logger.removeHandler(handler) | |||
@recover_logger | |||
def test_stdout(self, capsys): | |||
from fastNLP.core.log.logger import logger | |||
handler = logger.set_stdout(stdout="raw") | |||
logger.info(self.msg) | |||
logger.debug('aabbc') | |||
captured = capsys.readouterr() | |||
assert "some test log msg\n" == captured.out | |||
logger.removeHandler(handler) | |||
@recover_logger | |||
def test_warning_once(self, capsys): | |||
logger.warning_once('#') | |||
logger.warning_once('#') | |||
logger.warning_once('@') | |||
captured = capsys.readouterr() | |||
assert captured.out.count('#') == 1 | |||
assert captured.out.count('@') == 1 | |||
@@ -13,6 +13,7 @@ import numpy as np | |||
from fastNLP.envs.env import FASTNLP_GLOBAL_RANK | |||
from fastNLP.core.drivers.utils import distributed_open_proc | |||
from fastNLP.core.log import logger | |||
def get_class_that_defined_method(meth): | |||
@@ -32,6 +33,20 @@ def get_class_that_defined_method(meth): | |||
return getattr(meth, '__objclass__', None) # handle special descriptor objects | |||
def recover_logger(fn): | |||
@wraps(fn) | |||
def wrapper(*args, **kwargs): | |||
# 保存logger的状态 | |||
handlers = [handler for handler in logger.handlers] | |||
level = logger.level | |||
res = fn(*args, **kwargs) | |||
logger.handlers = handlers | |||
logger.setLevel(level) | |||
return res | |||
return wrapper | |||
def magic_argv_env_context(fn): | |||
@wraps(fn) | |||