Browse Source

Merge branch 'dev0.8.0' of github.com:fastnlp/fastNLP into dev0.8.0

tags/v1.0.0alpha
x54-729 2 years ago
parent
commit
5acaeabae4
20 changed files with 339 additions and 121 deletions
  1. +129
    -13
      fastNLP/core/callbacks/callback.py
  2. +1
    -18
      fastNLP/core/callbacks/checkpoint_callback.py
  3. +1
    -1
      fastNLP/core/collators/collator.py
  4. +10
    -9
      fastNLP/core/controllers/evaluator.py
  5. +6
    -3
      fastNLP/core/controllers/loops/train_batch_loop.py
  6. +1
    -1
      fastNLP/core/controllers/trainer.py
  7. +1
    -1
      fastNLP/core/dataloaders/torch_dataloader/fdl.py
  8. +4
    -2
      fastNLP/core/dataset/dataset.py
  9. +3
    -1
      fastNLP/core/drivers/torch_driver/ddp.py
  10. +31
    -30
      fastNLP/core/drivers/torch_driver/dist_utils.py
  11. +6
    -7
      fastNLP/core/drivers/torch_driver/utils.py
  12. +16
    -0
      fastNLP/core/log/logger.py
  13. +2
    -2
      fastNLP/core/samplers/reproducible_batch_sampler.py
  14. +2
    -2
      fastNLP/core/samplers/reproducible_sampler.py
  15. +2
    -2
      fastNLP/core/samplers/unrepeated_sampler.py
  16. +77
    -2
      fastNLP/core/utils/rich_progress.py
  17. +4
    -2
      fastNLP/core/utils/torch_utils.py
  18. +4
    -1
      fastNLP/core/utils/utils.py
  19. +24
    -24
      tests/core/log/test_logger.py
  20. +15
    -0
      tests/helpers/utils.py

+ 129
- 13
fastNLP/core/callbacks/callback.py View File

@@ -32,100 +32,205 @@ class Callback:
def on_sanity_check_end(self, trainer, sanity_check_res):
r"""
在 '预跑'检测 开始后会被触发;

:param trainer:
:param sanity_check_res: 预跑的 evaluate 结果
:return:
"""
pass

def on_train_begin(self, trainer):
r"""
在训练开始前会被触发;

:param trainer:
:return:
"""
pass

def on_train_end(self, trainer):
r"""
在训练完成后会被触发;

:param trainer:
:return:
"""
pass

def on_train_epoch_begin(self, trainer):
r"""
在训练过程中的每一个 epoch 开始前会被触发;

:param trainer:
:return:
"""
pass

def on_train_epoch_end(self, trainer):
r"""
在训练过程中的每一个 epoch 完成后会被触发;
在训练过程中的每一个 epoch 完成后会被触发;此时 trainer.cur_epoch_idx 已经完成加 1 操作。

:param trainer:
:return:
"""
pass

def on_fetch_data_begin(self, trainer):
r"""
在训练过程中拿到当前的具体的一个 batch 前会被触发;
在训练过程中准备取出下一个 batch 的数据时触发

:param trainer:
:return:
"""
pass

def on_fetch_data_end(self, trainer):
r"""
在训练过程中拿到当前的具体的一个 batch 后会被触发;
在训练过程中拿到当前的 batch 数据后会被触发;

:param trainer:
:return:
"""
pass

def on_train_batch_begin(self, trainer, batch, indices):
r"""
在训练过程中开始具体的一个 batch 前会被触发;
在取得数据,执行完 input_mapping (如果 Trainer 传有该参数),并且移动 batch 中的 tensor 到了指定设备。
其中 batch 中的数据格式要么是 Dataloader 返回的每个 batch 的格式;要么是 input_mapping 之后的内容。
如果 batch 是 dict 类型,直接增删其中的 key 或 修改其中的 value 会影响到输入到 model 的中的 batch 数据。

:param trainer: `fastNLP.Trainer`
:param batch: 当前正在运行的一个 batch;
:param indices: 当前的 batch 在一个 epoch 中的位置,用于用户方便地通过该 callback 函数定位具体的数据;
:param batch: batch 的数据,已经经过 input_mapping (如果有) 以及 移动到指定设备 。
:param list[int] indices: 当前的 batch 是 dataset 中的哪些数据
"""
pass

def on_train_batch_end(self, trainer):
"""
完成一个 batch 的训练(forward)、梯度回传(backward)、梯度更新(step)、梯度置零、batch_idx_in_epoch与
global_forward_batches累计加1操作。其中梯度更新】梯度置零操作会考虑 accumulation_steps ,所以不一定在当前 batch 会
执行。

:param trainer:
:return:
"""
pass

def on_exception(self, trainer, exception):
"""
在训练过程遇到异常时调用。

:param trainer:
:param exception: 遭遇的异常。
:return:
"""
pass

def on_save_model(self, trainer):
"""
当将要保存模型时调用,此刻模型还未保存。

:param trainer:
:return:
"""
pass

def on_load_model(self, trainer):
"""
当将要加载模型时调用,此刻模型还未加载。

:param trainer:
:return:
"""
pass

def on_save_checkpoint(self, trainer) -> Dict:
"""
当确定前后两个 callback 是一样的(callback_name 相同,意味着它们所起的职能相同)时,它们在该函数中则应当保存使该 callback 正常
工作的状态;而不应该让该函数去判断两个 callback 是否一样;
当 Trainer 将要保存 checkpoint 的时候触发,该函数用于保存当前 callback 在恢复需要的相关数据。

:param trainer:
:return:
"""
pass

def on_load_checkpoint(self, trainer, states: Optional[Dict]):
r"""
如果一个 callback 在断点重训前没有保存状态,或者其 `callback_name` 与其余的 callback 重名时,`states` 为 None;
当 Trainer 要恢复 checkpoint 的时候触发( Trainer 与 Driver 已经加载好自身的状态),参数 states 为 on_save_checkpoint()
的返回值。

:param trainer:
:param states:
:return:
"""
pass

def on_before_backward(self, trainer, outputs):
"""
在 backward 前执行。

:param trainer:
:param outputs: model 的返回内容。如果有 output_mapping ,则 outputs 中的内容为已经执行了 output_mapping 后的结果。
:return:
"""
pass

def on_after_backward(self, trainer):
"""
在 backward 后执行。在多卡场景下,由于 accumulation_steps 的影响,仅在需要真正 update 参数那次梯度回传才会触发梯度同步,
因此在多卡且使用 accumulation_steps 时,可能存在某些 step 各卡上梯度不一致的问题。

:param trainer:
:return:
"""
pass

def on_before_optimizer_step(self, trainer, optimizers):
"""
在进行 optimizer 优化进行前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。

:param trainer:
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。
:return:
"""
pass

def on_before_zero_grad(self, trainer, optimizers):
"""
在进行模型梯度置零前调用。该接口不一定每次前向计算都会触发,实际调用会受到 accumulation_steps 的影响。

:param trainer:
:param optimizers: 优化器,内容为在 Trainer 初始化时传入的值。
:return:
"""
pass

def on_validate_begin(self, trainer):
"""
在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后
进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。

:param trainer:
:return:
"""
pass

def on_validate_end(self, trainer, results):
"""
结束 validate 时调用,并把 validate 的结果传入。

:param trainer:
:param results:
:return:
"""
pass

@property
def callback_name(self):
"""
callback 的名称,我们会使用该名称从 checkpoint 中读取的相应的 state 并传递给 on_load_checkpoint() 函数。

:return:
"""
return self.__class__.__name__


@@ -226,10 +331,21 @@ class HasMonitorCallback(Callback):
:param keep_if_better: 如果传入的 monitor_value 值更好,则将其保存下来。
:return:
"""
better = self.is_former_monitor_value_better(monitor_value, self.monitor_value)
if keep_if_better and better:
self.monitor_value = monitor_value
return better

def is_former_monitor_value_better(self, monitor_value1, monitor_value2):
"""
传入的两个值中,是否monitor_value1的结果更好。

:param monitor_value1:
:param monitor_value2:
:return:
"""
better = False
if (self.larger_better and monitor_value > self.monitor_value) or \
(not self.larger_better and monitor_value < self.monitor_value):
if (self.larger_better and monitor_value1 > monitor_value2) or \
(not self.larger_better and monitor_value1 < monitor_value2):
better = True
if keep_if_better:
self.monitor_value = monitor_value
return better

+ 1
- 18
fastNLP/core/callbacks/checkpoint_callback.py View File

@@ -15,7 +15,6 @@ from fastNLP.core.callbacks.utils import _get_monitor_value
from fastNLP.core.log import logger
from fastNLP.envs import FASTNLP_LAUNCH_TIME
from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir
from fastNLP.core.utils import apply_to_collection


class CheckpointCallback(HasMonitorCallback):
@@ -178,8 +177,7 @@ class CheckpointCallback(HasMonitorCallback):
else:
_least_valuable_model = (min if self.larger_better else max)(self._topk_model,
key=lambda x: self._topk_model[x])
if (self.larger_better and monitor_value > self._topk_model[_least_valuable_model]) or \
(self.larger_better is False and monitor_value < self._topk_model[_least_valuable_model]):
if self.is_former_monitor_value_better(monitor_value, self._topk_model[_least_valuable_model]):
self._topk_model[folder_name] = monitor_value
_should_save = True
self._topk_model.pop(_least_valuable_model)
@@ -208,21 +206,6 @@ class CheckpointCallback(HasMonitorCallback):
**self.kwargs
)

def _get_validate_metric(self, res: Dict):
"""
该函数用于从 `Evaluator` 的结果中找到属于当前 CheckpointCallback 的 metric result(根据 monitor);
如果用户输入在 res 中没有找到,我们会查询所有的 validate 结果字典的键值,根据 最长公共字符串 匹配,使用最长匹配的结果值;
:param res:
:return:
"""
use_monitor, value = _get_monitor_value(monitor=self.monitor, real_monitor=self._real_monitor, res=res)
if self._real_monitor != use_monitor:
logger.warning(f"We can not find `{self._real_monitor}` in the evaluation result (with keys as {list(res.keys())}), "
f"we use the `{use_monitor}` as the monitor for {self.__class__.__name__}.")
self._real_monitor = use_monitor

return value

@property
def folder_prefix(self):
raise NotImplementedError("The `folder_prefix` is not specified")


+ 1
- 1
fastNLP/core/collators/collator.py View File

@@ -197,7 +197,7 @@ class _MultiCollator:
collator.set_input(*field_names)
flag = False
if flag:
warnings.warn("AutoCollator is remove, set_input is unavailable!!")
warnings.warn("AutoCollator is removed, set_input is unavailable!!")
return self




+ 10
- 9
fastNLP/core/controllers/evaluator.py View File

@@ -223,7 +223,6 @@ class Evaluator:
def remove_progress_bar(self, dataloader_name):
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'):
f_rich_progress.destroy_task(self._rich_task_id)
f_rich_progress.refresh() # 使得最终的bar可以消失
delattr(self, '_rich_task_id')
elif self.progress_bar == 'raw':
desc = 'Evaluation ends'
@@ -234,7 +233,6 @@ class Evaluator:
def finally_progress_bar(self):
if self.progress_bar == 'rich' and hasattr(self, '_rich_task_id'):
f_rich_progress.destroy_task(self._rich_task_id)
f_rich_progress.refresh()
delattr(self, '_rich_task_id')

@property
@@ -359,20 +357,23 @@ class _MetricsWrapper:
if is_dataclass(outputs):
outputs = dataclass_to_dict(outputs)
for metric in self._metrics:
args = []
if not isinstance(batch, dict):
raise RuntimeError(f"When the output of the DataLoader is of type:`{type(batch)}`, please either directly"
f" return a dict from your DataLoader or use `input_mapping` to convert it into dict type.")
logger.warning_once(f"The output of the DataLoader is of type:`{type(batch)}`, fastNLP will only depend on "
f"the output of model to update metric.")
else:
args.append(batch)
if not isinstance(outputs, dict):
raise RuntimeError(f"When the output of your model is of type:`{type(batch)}`, please either directly"
raise RuntimeError(f"The output of your model is of type:`{type(batch)}`, please either directly"
f" return a dict from your model or use `output_mapping` to convert it into dict type.")
if isinstance(metric, Metric):
auto_param_call(metric.update, batch, outputs)
auto_param_call(metric.update, batch, *args)
elif _is_torchmetrics_metric(metric):
auto_param_call(metric.update, batch, outputs)
auto_param_call(metric.update, batch, *args)
elif _is_allennlp_metric(metric):
auto_param_call(metric.__call__, batch, outputs)
auto_param_call(metric.__call__, batch, *args)
elif _is_paddle_metric(metric):
res = auto_param_call(metric.compute, batch, outputs)
res = auto_param_call(metric.compute, batch, *args)
metric.update(res)

def reset(self):


+ 6
- 3
fastNLP/core/controllers/loops/train_batch_loop.py View File

@@ -7,6 +7,7 @@ from typing import Optional, Callable
from .loop import Loop
from fastNLP.core.log import logger
from fastNLP.core.utils import match_and_substitute_params
from fastNLP.core.utils.exceptions import EarlyStopException


class TrainBatchLoop(Loop):
@@ -23,13 +24,15 @@ class TrainBatchLoop(Loop):
try:
trainer.on_fetch_data_begin()
batch = next(dataloader)
batch = match_and_substitute_params(trainer.input_mapping, batch)
indices = get_batch_indices()
batch = trainer.move_data_to_device(batch)
trainer.on_fetch_data_end()
batch = match_and_substitute_params(trainer.input_mapping, batch)
batch = trainer.move_data_to_device(batch)
except StopIteration:
break
except BaseException as e: # TODO 把这里的信息写入进去
except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception
break
except BaseException as e:
if indices:
logger.debug(f"The following exception happens when running on samples: {indices}")
raise e


+ 1
- 1
fastNLP/core/controllers/trainer.py View File

@@ -677,7 +677,7 @@ class Trainer(TrainerEventTrigger):
self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch')

# 5. 恢复所有 callback 的状态;
self.on_load_checkpoint(states["callback_states"])
self.train_stepeckpoint(states["callback_states"])

self.driver.barrier()



+ 1
- 1
fastNLP/core/dataloaders/torch_dataloader/fdl.py View File

@@ -54,7 +54,7 @@ class TorchDataLoader(DataLoader):
pin_memory: bool = False, drop_last: bool = False,
timeout: float = 0, worker_init_fn: Optional[Callable] = None,
multiprocessing_context=None, generator=None, prefetch_factor: int = 2,
persistent_workers: bool = False, as_numpy: bool = False) -> None:
persistent_workers: bool = False, as_numpy: bool = False, **kwargs) -> None:
"""

:param dataset: 实现了__getitem__和__len__的数据容器


+ 4
- 2
fastNLP/core/dataset/dataset.py View File

@@ -788,13 +788,14 @@ class DataSet:

def set_pad_val(self, *field_names, val: Optional[int] = 0) -> None:
"""
设置每个field_name的padding值,默认为0,只有当Auto_collate存在时该方法有效
设置每个field_name的padding值,默认为0,只有当AutoCollator存在时该方法有效
当val=None时,意味着给定的field_names都不需要尝试padding

:param field_names: dataset存在的field_name
:param val: 默认为0
:param val: 默认为0。如果为 None ,则为不对 field 进行 padding 。
:return:
"""
# TODO 需要去重复
for field_name in field_names:
self.collate_fns.set_pad_val(field_name, val=val)

@@ -805,6 +806,7 @@ class DataSet:
:param field_names:
:return:
"""
#
self.collate_fns.set_input(*field_names)

def get_collator(self) -> _MultiCollator:


+ 3
- 1
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -12,6 +12,7 @@ if _NEED_IMPORT_TORCH:
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import BatchSampler

__all__ = [
'TorchDDPDriver'
@@ -524,7 +525,8 @@ class TorchDDPDriver(TorchDriver):
num_replicas=self.world_size,
rank=self.global_rank
)
return replace_sampler(dataloader, sampler)
batch_sampler = BatchSampler(sampler, args.batch_size, drop_last=False)
return replace_batch_sampler(dataloader, batch_sampler)
else:
raise ValueError("Parameter `dist_sampler` can only be one of three values: ('dist', 'unrepeatdist', None).")



+ 31
- 30
fastNLP/core/drivers/torch_driver/dist_utils.py View File

@@ -3,28 +3,20 @@ import pickle
_pickler = pickle.Pickler
_unpickler = pickle.Unpickler
from typing import Any, List
from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8


from fastNLP.envs.imports import _TORCH_GREATER_EQUAL_1_8
from fastNLP.core.utils.torch_utils import DEFAULT_TORCH_GROUP
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
if _NEED_IMPORT_TORCH:
import torch
from torch import distributed as dist
try:
from torch._C._distributed_c10d import ProcessGroupMPI
except ImportError:
_MPI_AVAILABLE = False

try:
from torch._C._distributed_c10d import ProcessGroupNCCL
except ImportError:
_NCCL_AVAILABLE = False

try:
from torch._C._distributed_c10d import ProcessGroupGloo
from torch._C._distributed_c10d import _ProcessGroupWrapper
except ImportError:
_GLOO_AVAILABLE = False
if _TORCH_GREATER_EQUAL_1_8:
try:
from torch._C._distributed_c10d import ProcessGroupGloo
from torch._C._distributed_c10d import _ProcessGroupWrapper
except ImportError:
pass


from fastNLP.core.utils import apply_to_collection

@@ -42,7 +34,7 @@ def _validate_output_list_for_rank(my_rank, dst, gather_list):
)


def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None):
def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=DEFAULT_TORCH_GROUP):
"""
从其它 rank gather 东西到 dst rank 。

@@ -91,6 +83,9 @@ def fastnlp_torch_gather_object(obj, object_gather_list=None, dst=0, group=None)
>>> output
['foo', 12, {1: 2}]
"""
if group is None:
group = DEFAULT_TORCH_GROUP

if dist.distributed_c10d._rank_not_in_group(group):
return

@@ -193,7 +188,7 @@ def _to_device(tensor, device):
return tensor.contiguous().to(device)


def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List:
def fastnlp_torch_all_gather(obj: Any, device=None, group=DEFAULT_TORCH_GROUP) ->List:
"""
实现任何类型的数据都使用该接口可以进行 all_gather 操作。对于非 tensor 类型的数据,通过 pickle 序列化再反序列化的方式进行传输。

@@ -217,7 +212,8 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List:
:param group:
:return: 返回的结果是 [obj0, obj1, ...],其中 obj_i 即为第 i 个 rank 上的 obj 。
"""
# # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题
if group is None:
group = DEFAULT_TORCH_GROUP
if isinstance(obj, torch.Tensor):
objs = [torch.zeros_like(obj) for _ in range(dist.get_world_size(group))]
dist.all_gather(objs, obj, group=group)
@@ -232,7 +228,7 @@ def fastnlp_torch_all_gather(obj: Any, device=None, group=None) ->List:
return objs


def fastnlp_torch_broadcast_object(obj, src, device=None, group=None):
def fastnlp_torch_broadcast_object(obj, src, device=None, group=DEFAULT_TORCH_GROUP):
"""
将 src 上的 obj 对象广播到其它 rank 上。

@@ -242,6 +238,8 @@ def fastnlp_torch_broadcast_object(obj, src, device=None, group=None):
:param group:
:return:
"""
if group is None:
group = DEFAULT_TORCH_GROUP
cur_rank = dist.get_rank(group)
if cur_rank == src:
# 如果有 tensor 全部移动到 cpu 上,方便 pickle , 不然 unpickle 的时候可能会 pickle 到发送过来的卡那里
@@ -339,15 +337,18 @@ def all_gather_object(object_list, obj, group=None):
return

input_tensor, local_size = _object_to_tensor(obj)
current_device = torch.device("cpu")
is_nccl_backend = _check_for_nccl_backend(group)
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device("cuda", torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
if _TORCH_GREATER_EQUAL_1_8:
current_device = torch.device("cpu")
is_nccl_backend = _check_for_nccl_backend(group)
if is_nccl_backend:
# See note about using torch.cuda.current_device() here in docstring.
# We cannot simply use my_rank since rank == device is not necessarily
# true.
current_device = torch.device("cuda", torch.cuda.current_device())
input_tensor = input_tensor.to(current_device)
local_size = local_size.to(current_device)
else:
current_device = torch.cuda.current_device()
# Gather all local sizes. This is so that we can find the max size, and index
# until the correct size when deserializing the tensors.
group_size = dist.get_world_size(group=group)


+ 6
- 7
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -8,6 +8,7 @@ import numpy as np
import inspect

from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.core.samplers import re_instantiate_sampler

if _NEED_IMPORT_TORCH:
import torch
@@ -295,7 +296,6 @@ def replace_sampler(dataloader: "DataLoader", sampler):
"manually add the `DistributedSampler` as: "
f"`{dataloader_self_name}(dataset, sampler=DistributedSampler(dataset))`."
)

return type(dataloader)(**reconstruct_args)


@@ -307,12 +307,8 @@ def _dataloader_init_kwargs_resolve_sampler(
"""
batch_sampler = getattr(dataloader, "batch_sampler")
# checking the batch sampler type is different than PyTorch default.
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
batch_sampler = type(batch_sampler)(
sampler,
batch_size=batch_sampler.batch_size,
drop_last=batch_sampler.drop_last,
)
if batch_sampler is not None and not isinstance(batch_sampler, BatchSampler):
batch_sampler = re_instantiate_sampler(batch_sampler)

return {
"sampler": None,
@@ -343,6 +339,9 @@ def replace_batch_sampler(dataloader, new_batch_sampler):
params = {k: getattr(dataloader, k) for k in params_keys}
params["batch_sampler"] = new_batch_sampler
return type(dataloader)(**params)
# TODO 这里是否可以auto_param_call一下
# return auto_param_call(type(dataloader), params, {'self': type(dataloader).__new__()},
# signature_fn=type(dataloader).__init__)


def optimizer_state_to_device(state, device):


+ 16
- 0
fastNLP/core/log/logger.py View File

@@ -51,6 +51,7 @@ class LoggerSingleton(type):
class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):
def __init__(self, name):
super().__init__(name)
self._warning_msgs = set()

def add_file(self, path: Optional[Union[str, Path]] = None, level='AUTO', remove_other_handlers: bool = False,
mode: str = "w"):
@@ -108,6 +109,21 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton):
kwargs = self._add_rank_info(kwargs)
self._log(WARNING, msg, args, **kwargs)

def warning_once(self, msg, *args, **kwargs):
"""
通过 warning 内容只会 warning 一次

:param msg:
:param args:
:param kwargs:
:return:
"""
if msg not in self._warning_msgs:
if self.isEnabledFor(WARNING):
kwargs = self._add_rank_info(kwargs)
self._log(WARNING, msg, args, **kwargs)
self._warning_msgs.add(msg)

def warn(self, msg, *args, **kwargs):
warnings.warn("The 'warn' method is deprecated, "
"use 'warning' instead", DeprecationWarning, 2)


+ 2
- 2
fastNLP/core/samplers/reproducible_batch_sampler.py View File

@@ -166,8 +166,8 @@ class BucketedBatchSampler(ReproducibleBatchSampler):
:param kwargs: fastNLP 保留使用
"""
super().__init__()
if isinstance(dataset, DataSet):
length = dataset.get_field(length)
if isinstance(dataset, DataSet) and isinstance(length, str):
length = dataset.get_field(length).content
if not isinstance(length[0], int):
length = list(map(len, length))
else:


+ 2
- 2
fastNLP/core/samplers/reproducible_sampler.py View File

@@ -295,8 +295,8 @@ class SortedSampler(SequentialSampler):
:param kwargs: fastNLP 保留使用
"""
super().__init__(dataset=dataset, **kwargs)
if isinstance(dataset, DataSet):
length = dataset.get_field(length)
if isinstance(dataset, DataSet) and isinstance(length, str):
length = dataset.get_field(length).content
if not isinstance(length[0], int):
length = list(map(len, length))
else:


+ 2
- 2
fastNLP/core/samplers/unrepeated_sampler.py View File

@@ -105,8 +105,8 @@ class UnrepeatedSortedSampler(UnrepeatedRandomSampler):
:param kwargs: fastNLP 保留使用
"""
super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs)
if isinstance(dataset, DataSet):
length = dataset.get_field(length)
if isinstance(dataset, DataSet) and isinstance(length, str):
length = dataset.get_field(length).content
if not isinstance(length[0], int):
length = list(map(len, length))
else:


+ 77
- 2
fastNLP/core/utils/rich_progress.py View File

@@ -6,7 +6,7 @@
import sys
from typing import Any, Union, Optional

from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live
from rich.progress import Progress, Console, GetTimeCallable, get_console, TaskID, Live, Text, ProgressSample
from rich.progress import ProgressColumn, TimeRemainingColumn, BarColumn, TimeElapsedColumn, TextColumn

__all__ = [
@@ -146,24 +146,99 @@ class FRichProgress(Progress, metaclass=Singleton):
if task_id in self._tasks:
super().stop_task(task_id)
super().remove_task(task_id)
self.refresh() # 使得bar不残留

def start(self) -> None:
super().start()
self.console.show_cursor(show=True)

def update(
self,
task_id: TaskID,
*,
total: Optional[float] = None,
completed: Optional[float] = None,
advance: Optional[float] = None,
description: Optional[str] = None,
visible: Optional[bool] = None,
refresh: bool = False,
**fields: Any,
) -> None:
"""Update information associated with a task.

Args:
task_id (TaskID): Task id (returned by add_task).
total (float, optional): Updates task.total if not None.
completed (float, optional): Updates task.completed if not None.
advance (float, optional): Add a value to task.completed if not None.
description (str, optional): Change task description if not None.
visible (bool, optional): Set visible flag if not None.
refresh (bool): Force a refresh of progress information. Default is False.
**fields (Any): Additional data fields required for rendering.
"""
with self._lock:
task = self._tasks[task_id]
completed_start = task.completed

if total is not None and total != task.total:
task.total = total
task._reset()
if advance is not None:
task.completed += advance
if completed is not None:
task.completed = completed
if description is not None:
task.description = description
if visible is not None:
task.visible = visible
task.fields.update(fields)
update_completed = task.completed - completed_start

current_time = self.get_time()
old_sample_time = current_time - self.speed_estimate_period
_progress = task._progress

popleft = _progress.popleft
# 这里修改为至少保留一个,防止超长时间的迭代影响判断
while len(_progress)>1 and _progress[0].timestamp < old_sample_time:
popleft()
if update_completed > 0:
_progress.append(ProgressSample(current_time, update_completed))
if task.completed >= task.total and task.finished_time is None:
task.finished_time = task.elapsed

if refresh:
self.refresh()


class SpeedColumn(ProgressColumn):
"""
显示 task 的速度。

"""
def render(self, task: "Task"):
speed = task.speed
if speed is None:
return Text('-- it./s', style='progress.data.speed')
if speed > 0.1:
return Text(str(round(speed, 2))+' it./s', style='progress.data.speed')
else:
return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed')


if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0:
f_rich_progress = FRichProgress().new_progess(
"[progress.description]{task.description}",
"[progress.percentage]{task.percentage:>3.0f}%",
BarColumn(),
SpeedColumn(),
TimeElapsedColumn(),
"/",
TimeRemainingColumn(),
TextColumn("{task.fields[post_desc]}", justify="right"),
transient=True,
disable=False,
speed_estimate_period=1
speed_estimate_period=30
)
else:
f_rich_progress = DummyFRichProgress()


+ 4
- 2
fastNLP/core/utils/torch_utils.py View File

@@ -1,9 +1,11 @@
from abc import ABC
from typing import Any, Union, Optional
from fastNLP.envs.imports import _NEED_IMPORT_TORCH
from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _TORCH_GREATER_EQUAL_1_8
DEFAULT_TORCH_GROUP = None
if _NEED_IMPORT_TORCH:
import torch
if not _TORCH_GREATER_EQUAL_1_8:
DEFAULT_TORCH_GROUP = torch.distributed.distributed_c10d.group.WORLD

__all__ = [
'torch_move_data_to_device'


+ 4
- 1
fastNLP/core/utils/utils.py View File

@@ -81,7 +81,10 @@ def check_fn_not_empty_params(fn: Optional[Callable] = None, param_num: Optional
def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None,
mapping: Optional[Dict[AnyStr, AnyStr]] = None) -> Any:
r"""
1.该函数用来提供给用户根据字符串匹配从而实现自动计算;
该函数会根据输入函数的形参名从*args(因此都需要是dict类型)中找到匹配的值进行调用,如果传入的数据与fn的形参不匹配,可以通过mapping
参数进行转换。mapping参数中的一对(key,value)表示以这个key在*args中找到值,并将这个值传递给形参名为value的参数。

1.该函数用来提供给用户根据字符串匹配从而实现自动调用;
2.注意 mapping 默认为 None,如果你希望指定输入和运行函数的参数的对应方式,那么你应当让 mapping 为一个这样的字典传入进来;
如果 mapping 不为 None,那么我们一定会先使用 mapping 将输入的字典的 keys 修改过来,因此请务必亲自检查 mapping 的正确性;
3.如果输入的函数的参数有默认值,那么如果之后的输入中没有该参数对应的值,我们就会使用该参数对应的默认值,否则也会使用之后的输入的值;


+ 24
- 24
tests/core/log/test_logger.py View File

@@ -6,13 +6,16 @@ import logging
import re

from fastNLP.envs.env import FASTNLP_LAUNCH_TIME
from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm
from fastNLP.core.log.logger import logger

from tests.helpers.utils import magic_argv_env_context, recover_logger


# 测试 TorchDDPDriver;
@magic_argv_env_context
def test_add_file_ddp_1():
@recover_logger
def test_add_file_ddp_1_torch():
"""
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在;

@@ -56,11 +59,11 @@ def test_add_file_ddp_1():
synchronize_safe_rm(filepath)
dist.barrier()
dist.destroy_process_group()
logger.removeHandler(handler)


@magic_argv_env_context
def test_add_file_ddp_2():
@recover_logger
def test_add_file_ddp_2_torch():
"""
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在;
"""
@@ -103,14 +106,14 @@ def test_add_file_ddp_2():
assert len(pattern.findall(line)) == 1
finally:
synchronize_safe_rm(path)
logger.removeHandler(handler)

dist.barrier()
dist.destroy_process_group()


@magic_argv_env_context
def test_add_file_ddp_3():
@recover_logger
def test_add_file_ddp_3_torch():
"""
path = None;

@@ -155,10 +158,10 @@ def test_add_file_ddp_3():
synchronize_safe_rm(file)
dist.barrier()
dist.destroy_process_group()
logger.removeHandler(handler)

@magic_argv_env_context
def test_add_file_ddp_4():
@recover_logger
def test_add_file_ddp_4_torch():
"""
测试 path 是文件夹;
"""
@@ -200,7 +203,6 @@ def test_add_file_ddp_4():
assert len(pattern.findall(line)) == 1
finally:
synchronize_safe_rm(path)
logger.removeHandler(handler)

dist.barrier()
dist.destroy_process_group()
@@ -209,12 +211,11 @@ def test_add_file_ddp_4():
class TestLogger:
msg = 'some test log msg'

@recover_logger
def test_add_file_1(self):
"""
测试 path 是一个文件的地址,但是这个文件所在的文件夹存在;
"""
from fastNLP.core.log.logger import logger

path = Path(tempfile.mkdtemp())
try:
filepath = path.joinpath('log.txt')
@@ -225,14 +226,12 @@ class TestLogger:
assert self.msg in line
finally:
synchronize_safe_rm(path)
logger.removeHandler(handler)

@recover_logger
def test_add_file_2(self):
"""
测试 path 是一个文件的地址,但是这个文件所在的文件夹不存在;
"""
from fastNLP.core.log.logger import logger

origin_path = Path(tempfile.mkdtemp())

try:
@@ -245,14 +244,12 @@ class TestLogger:
assert self.msg in line
finally:
synchronize_safe_rm(origin_path)
logger.removeHandler(handler)

@recover_logger
def test_add_file_3(self):
"""
测试 path 是 None;
"""
from fastNLP.core.log.logger import logger

handler = logger.add_file()
logger.info(self.msg)

@@ -264,14 +261,12 @@ class TestLogger:
line = ''.join([l for l in f])
assert self.msg in line
file.unlink()
logger.removeHandler(handler)

@recover_logger
def test_add_file_4(self):
"""
测试 path 是文件夹;
"""
from fastNLP.core.log.logger import logger

path = Path(tempfile.mkdtemp())
try:
handler = logger.add_file(path)
@@ -285,16 +280,21 @@ class TestLogger:
assert self.msg in line
finally:
synchronize_safe_rm(path)
logger.removeHandler(handler)

@recover_logger
def test_stdout(self, capsys):
from fastNLP.core.log.logger import logger

handler = logger.set_stdout(stdout="raw")
logger.info(self.msg)
logger.debug('aabbc')
captured = capsys.readouterr()
assert "some test log msg\n" == captured.out

logger.removeHandler(handler)
@recover_logger
def test_warning_once(self, capsys):
logger.warning_once('#')
logger.warning_once('#')
logger.warning_once('@')
captured = capsys.readouterr()
assert captured.out.count('#') == 1
assert captured.out.count('@') == 1


+ 15
- 0
tests/helpers/utils.py View File

@@ -13,6 +13,7 @@ import numpy as np

from fastNLP.envs.env import FASTNLP_GLOBAL_RANK
from fastNLP.core.drivers.utils import distributed_open_proc
from fastNLP.core.log import logger


def get_class_that_defined_method(meth):
@@ -32,6 +33,20 @@ def get_class_that_defined_method(meth):
return getattr(meth, '__objclass__', None) # handle special descriptor objects


def recover_logger(fn):
@wraps(fn)
def wrapper(*args, **kwargs):
# 保存logger的状态
handlers = [handler for handler in logger.handlers]
level = logger.level
res = fn(*args, **kwargs)
logger.handlers = handlers
logger.setLevel(level)
return res

return wrapper


def magic_argv_env_context(fn):

@wraps(fn)


Loading…
Cancel
Save