diff --git a/README.md b/README.md index 2fd27048..74090646 100644 --- a/README.md +++ b/README.md @@ -6,4 +6,133 @@ ![Hex.pm](https://img.shields.io/hexpm/l/plug.svg) [![Documentation Status](https://readthedocs.org/projects/fastnlp/badge/?version=latest)](http://fastnlp.readthedocs.io/?badge=latest) -dev0.8.0正在开发中 \ No newline at end of file +fastNLP是一款轻量级的自然语言处理(NLP)工具包,目标是快速实现NLP任务以及构建复杂模型。 + +fastNLP具有如下的特性: + +- 统一的Tabular式数据容器,简化数据预处理过程; +- 内置多种数据集的Loader和Pipe,省去预处理代码; +- 各种方便的NLP工具,例如Embedding加载(包括ELMo和BERT)、中间数据cache等; +- 部分[数据集与预训练模型](https://docs.qq.com/sheet/DVnpkTnF6VW9UeXdh?c=A1A0A0)的自动下载; +- 提供多种神经网络组件以及复现模型(涵盖中文分词、命名实体识别、句法分析、文本分类、文本匹配、指代消解、摘要等任务); +- Trainer提供多种内置Callback函数,方便实验记录、异常捕获等。 + +## 安装指南 + +fastNLP 依赖以下包: + ++ numpy>=1.14.2 ++ torch>=1.0.0 ++ tqdm>=4.28.1 ++ nltk>=3.4.1 ++ requests ++ spacy ++ prettytable>=0.7.2 + +其中torch的安装可能与操作系统及 CUDA 的版本相关,请参见 [PyTorch 官网](https://pytorch.org/) 。 +在依赖包安装完成后,您可以在命令行执行如下指令完成安装 + +```shell +pip install fastNLP +python -m spacy download en +``` + + +## fastNLP教程 +中文[文档](https://fastnlp.readthedocs.io/)、[教程](https://fastnlp.readthedocs.io/zh/latest/user/tutorials.html) + +### 快速入门 + +- [0. 快速入门](https://fastnlp.readthedocs.io/zh/latest/user/quickstart.html) + +### 详细使用教程 + +- [1. 使用DataSet预处理文本](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_1_data_preprocess.html) +- [2. 使用Vocabulary转换文本与index](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_2_vocabulary.html) +- [3. 使用Embedding模块将文本转成向量](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_3_embedding.html) +- [4. 使用Loader和Pipe加载并处理数据集](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_4_load_dataset.html) +- [5. 动手实现一个文本分类器I-使用Trainer和Tester快速训练和测试](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_5_loss_optimizer.html) +- [6. 动手实现一个文本分类器II-使用DataSetIter实现自定义训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_6_datasetiter.html) +- [7. 使用Metric快速评测你的模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_7_metrics.html) +- [8. 使用Modules和Models快速搭建自定义模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_8_modules_models.html) +- [9. 快速实现序列标注模型](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_9_seq_labeling.html) +- [10. 使用Callback自定义你的训练过程](https://fastnlp.readthedocs.io/zh/latest/tutorials/tutorial_10_callback.html) + +### 扩展教程 + +- [Extend-1. BertEmbedding的各种用法](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_1_bert_embedding.html) +- [Extend-2. 分布式训练简介](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_2_dist.html) +- [Extend-3. 使用fitlog 辅助 fastNLP 进行科研](https://fastnlp.readthedocs.io/zh/latest/tutorials/extend_3_fitlog.html) + + +## 内置组件 + +大部分用于的 NLP 任务神经网络都可以看做由词嵌入(embeddings)和两种模块:编码器(encoder)、解码器(decoder)组成。 + +以文本分类任务为例,下图展示了一个BiLSTM+Attention实现文本分类器的模型流程图: + + +![](./docs/source/figures/text_classification.png) + +fastNLP 在 embeddings 模块中内置了几种不同的embedding:静态embedding(GloVe、word2vec)、上下文相关embedding +(ELMo、BERT)、字符embedding(基于CNN或者LSTM的CharEmbedding) + +与此同时,fastNLP 在 modules 模块中内置了两种模块的诸多组件,可以帮助用户快速搭建自己所需的网络。 两种模块的功能和常见组件如下: + + + + + + + + + + + + + + + + +
类型 功能 例子
encoder 将输入编码为具有具有表示能力的向量 Embedding, RNN, CNN, Transformer, ... +
decoder 将具有某种表示意义的向量解码为需要的输出形式 MLP, CRF, ...
+ + +## 项目结构 + +
+ + + +fastNLP的大致工作流程如上图所示,而项目结构如下: + + + + + + + + + + + + + + + + + + + + + + + + + + +
fastNLP 开源的自然语言处理库
fastNLP.core 实现了核心功能,包括数据处理组件、训练器、测试器等
fastNLP.models 实现了一些完整的神经网络模型
fastNLP.modules 实现了用于搭建神经网络模型的诸多组件
fastNLP.embeddings 实现了将序列index转为向量序列的功能,包括读取预训练embedding等
fastNLP.io 实现了读写功能,包括数据读入与预处理,模型读写,数据与模型自动下载等
+ +
+ +*In memory of @FengZiYjun. May his soul rest in peace. We will miss you very very much!* diff --git a/fastNLP/core/callbacks/__init__.py b/fastNLP/core/callbacks/__init__.py index f45cf5e0..a47ab998 100644 --- a/fastNLP/core/callbacks/__init__.py +++ b/fastNLP/core/callbacks/__init__.py @@ -4,7 +4,8 @@ __all__ = [ 'EventsList', 'Filter', 'CallbackManager', - 'CheckpointCallback', + 'ModelCheckpointCallback', + 'TrainerCheckpointCallback', 'choose_progress_callback', 'ProgressCallback', 'RichCallback', @@ -16,7 +17,7 @@ __all__ = [ from .callback import Callback from .callback_events import EventsList, Events, Filter from .callback_manager import CallbackManager -from .checkpoint_callback import CheckpointCallback +from .checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback from .progress_callback import choose_progress_callback, ProgressCallback, RichCallback from .lr_scheduler_callback import LRSchedCallback from .load_best_model_callback import LoadBestModelCallback diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index c239f8b1..8b53c70b 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -8,7 +8,7 @@ __all__ = [ from .callback_events import Events from .callback import Callback -from .checkpoint_callback import CheckpointCallback +from .checkpoint_callback import TrainerCheckpointCallback from .progress_callback import ProgressCallback, choose_progress_callback from fastNLP.core.log import logger @@ -98,7 +98,7 @@ class CallbackManager: :return: """ for each_callback in self.class_callbacks: - if isinstance(each_callback, CheckpointCallback) and each_callback.is_trainer_checkpoint: + if isinstance(each_callback, TrainerCheckpointCallback): self._has_trainer_checkpoint = True self.dissect_one_callback(each_callback) @@ -210,7 +210,7 @@ class CallbackManager: each_callback.on_load_checkpoint(trainer, None) @property - def has_trainer_chechpoint(self) -> bool: + def has_trainer_checkpoint(self) -> bool: return self._has_trainer_checkpoint @_transfer diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 5fcc7e26..d3a3b52d 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -1,12 +1,13 @@ +__all__ = [ + 'ModelCheckpointCallback', + 'TrainerCheckpointCallback' +] import os -from typing import Union, Optional, Callable, Dict, Sequence +from typing import Union, Optional, Callable, Dict, Sequence, Any, Mapping from pathlib import Path -from functools import partial -from time import sleep +from abc import ABC +import sys -__all__ = [ - 'CheckpointCallback' -] import fastNLP from .callback import Callback, Filter @@ -14,35 +15,37 @@ from fastNLP.core.callbacks.utils import _get_monitor_value from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_LAUNCH_TIME from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir +from fastNLP.core.utils import apply_to_collection -class CheckpointCallback(Callback): +class CanItemDataType(ABC): """ - 1. 因为只有 'Trainer' 才有 callback,因此评测 metric 实际上就是 validate 时干的事情; - 2. 默认 'save_last' 为 True,即 model_checkpoint 的默认逻辑是在每一个 epoch 下保存最后的一个模型,模型名字为 last.pth.tar; - 3. 理论上一个 model_checkpoint 的实例只会负责一个 monitor 的监视,如果用户在训练过程中指定了多个 monitor 的监视,例如 "acc1", - "acc2", ... 那么我们会为用户创建多个 model_checkpoint 的实例; - 4. 理论上,在实际保存的过程中,topk 模式和 固定频率保存的模式是完全独立的,我们确实应当采取一些措施至少保证两者的名字不一样; + 检测可以进行传输的对象。 + """ + @classmethod + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: + if cls is CanItemDataType: + item = getattr(subclass, 'item', None) + return callable(item) + return NotImplemented + + + +class CheckpointCallback(Callback): def __init__( self, monitor, - is_trainer_checkpoint: Optional[bool] = False, - save_folder: Optional[Union[str, Path]] = None, - save_every_n_epochs: Optional[int] = None, - save_every_n_global_batches: Optional[int] = None, + save_every_n_batches: Optional[int] = None, save_last: bool = True, save_topk: Optional[int] = None, save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, - larger_better: bool = True, only_state_dict: bool = True, - model_save_fn: Optional[Callable] = None, - **kwargs, ): if monitor is None and save_topk is not None: @@ -51,9 +54,6 @@ class CheckpointCallback(Callback): if monitor is not None and not isinstance(monitor, str): raise ValueError("Parameter `monitor` should be of 'str' type.") - if not isinstance(is_trainer_checkpoint, bool): - raise TypeError("Parameter 'is_trainer_checkpoint' can only be `bool` type.") - if save_folder is None: logger.warning( "Parameter `path` is None, and we will use the current work directory to find and load your model.") @@ -67,15 +67,15 @@ class CheckpointCallback(Callback): if not isinstance(save_every_n_epochs, int) or save_every_n_epochs < 1: raise ValueError("parameter save_after_epoch_num should be an int and greater than or equal to 1.") - # 突然发现有一个骚操作在于 'Filter' 内部记载的状态值例如 'num_called' 是这个类全局的,而每次调用 __call__ 中输入的 - # 函数却是及时传入的,也就是说,我们可以保证 'Filter' 的正常控制频率的逻辑,然后每一次运行的函数都不一样; - self._filter_every_n_epochs = Filter(every=save_every_n_epochs) + else: + save_every_n_epochs = sys.maxsize # 使得没有数字可以整除 - if save_every_n_global_batches is not None: - if not isinstance(save_every_n_global_batches, int) or save_every_n_global_batches < 1: + if save_every_n_batches is not None: + if not isinstance(save_every_n_batches, int) or save_every_n_batches < 1: raise ValueError( - "parameter save_every_n_global_batches should be an int and greater than or equal to 1.") - self._filter_every_n_global_batches = Filter(every=save_every_n_global_batches) + "parameter save_every_n_batches should be an int and greater than or equal to 1.") + else: + save_every_n_batches = sys.maxsize # 使得没有数字可以整除 if save_topk is not None: if not isinstance(save_topk, int) or save_topk < 1: @@ -89,12 +89,12 @@ class CheckpointCallback(Callback): if not issubclass(exception, BaseException): raise TypeError("Each exception in parameter `save_on_exception` can only be " "`BaseException` type.") - + else: + save_on_exception = [] self.monitor = monitor - self.is_trainer_checkpoint = is_trainer_checkpoint self.save_folder = Path(save_folder) self.save_every_n_epochs = save_every_n_epochs - self.save_every_n_global_batches = save_every_n_global_batches + self.save_every_n_batches = save_every_n_batches self.save_last = save_last self.save_topk = save_topk self.larger_better = larger_better @@ -107,7 +107,7 @@ class CheckpointCallback(Callback): self._topk_model = {} self._topn = 0 # 表示目前已经保存了几个最好的模型; - # 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用模糊匹配找到的第一个 + # 因为我们在 `_get_validate_metric` 函数中,当在返回的 `validate_res` 字典中找不到 `monitor` 时,是使用匹配找到的 # key 对应的 value 当做结果;但是这样存在的一个问题在于如果用户传入的 metric 返回的 sub_metric 的名字可能会混淆,并且其在下一次 # 训练的代码中修改了这些 sub_metric 返回的顺序,那么就会导致模糊匹配拿到的 key 和 value 与之前的不是同一个,这显然不是合理的行为; # 因此我们通过该变量来表示我们通过模糊匹配拿到的 key; @@ -115,76 +115,83 @@ class CheckpointCallback(Callback): # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; - self.log_filepath = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) + self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) # 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; - synchronize_mkdir(self.log_filepath) + synchronize_mkdir(self.timestamp_path) def on_validate_end(self, trainer, validate_res): self._save_topk(trainer, validate_res) def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): - self._save_every_n_epochs(trainer) - self._save_last(trainer) + if trainer.cur_epoch_idx % self.save_every_n_epochs == 0: + folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}' + self.save(trainer, folder_name=folder_name) + if self.save_last: + folder_name = f'{self.folder_prefix}-last' + self.save(trainer, folder_name=folder_name) def on_train_batch_end(self, trainer): - self._save_every_n_global_batches(trainer) + if trainer.global_forward_batches % self.save_every_n_batches == 0: + folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}' + self.save(trainer, folder_name=folder_name) def on_exception(self, trainer, exception: BaseException): - if self.save_on_exception is not None and exception.__class__ in self.save_on_exception: - folder = self._get_checkpoint_real_save_folder(trainer=trainer, topk=False, metric=None) - folder = folder + f"_{exception.__class__.__name__}" - self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder=folder) + if exception.__class__ in self.save_on_exception: + folder_name = f'{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}-' \ + f'exception_{exception.__class__.__name__}' + self.save(trainer=trainer, folder_name=folder_name) def on_sanity_check_end(self, trainer, sanity_check_res): + # 主要核对一下 monitor 是否存在。 self._get_validate_metric(sanity_check_res) def on_save_checkpoint(self, trainer) -> Dict: """ - 我们需要保存 CheckpointCallback 内部的几个 filter 的状态; + 保存 timestamp_path 使得之后可以继续训练并保存到该文件夹。 + topk_model的状态 + _real_monitor的值 """ + states = {} - if self.save_every_n_epochs is not None: - states["_filter_every_n_epochs"] = self._filter_every_n_epochs.state_dict() - if self.save_every_n_global_batches is not None: - states["_filter_every_n_global_batches"] = self._filter_every_n_global_batches.state_dict() - states["real_monitor"] = self._real_monitor + states['timestamp_path'] = str(self.timestamp_path.absolute()) + states['_topk_model'] = apply_to_collection(self._topk_model, dtype=CanItemDataType, + function=lambda x:x.item()) + states['save_topk'] = 0 if self.save_topk is None else self.save_topk + states['_real_monitor'] = self._real_monitor return states def on_load_checkpoint(self, trainer, states: Optional[Dict]): - if self.save_every_n_epochs is not None: - self._filter_every_n_epochs.load_state_dict(states["_filter_every_n_epochs"]) - if self.save_every_n_global_batches is not None: - self._filter_every_n_global_batches.load_state_dict(states["_filter_every_n_global_batches"]) + timestamp_path = states['timestamp_path'] + if not os.path.exists(timestamp_path): + logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to " + f" {self.timestamp_path.absolute()}.") + else: + logger.info(f"Resume to save in path: {timestamp_path}.") + self.timestamp_path = Path(timestamp_path) + _topk_model = states['_topk_model'] + save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) + if save_topk is not None and self.save_topk is not None: + assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ + f"as {save_topk}." + self._topk_model.update(self._topk_model) self._real_monitor = states["real_monitor"] - def _save_every_n_epochs(self, trainer: "fastNLP.Trainer"): - if self.save_every_n_epochs is not None: - if self.is_trainer_checkpoint: - _fn_every_n_epochs = trainer.save - else: - _fn_every_n_epochs = trainer.save_model - _fn_every_n_epochs = partial(self._save_fn, trainer, False, None, _fn_every_n_epochs, None) - _fn_every_n_epochs = self._filter_every_n_epochs(_fn_every_n_epochs) - _fn_every_n_epochs() - - def _save_every_n_global_batches(self, trainer: "fastNLP.Trainer"): - if self.save_every_n_global_batches is not None: - if self.is_trainer_checkpoint: - _fn_every_n_global_batches = trainer.save - else: - _fn_every_n_global_batches = trainer.save_model - _fn_every_n_global_batches = partial(self._save_fn, trainer, False, None, _fn_every_n_global_batches, None) - _fn_every_n_global_batches = self._filter_every_n_global_batches(_fn_every_n_global_batches) - _fn_every_n_global_batches() - def _save_topk(self, trainer: "fastNLP.Trainer", validate_res: Dict): + """ + 根据validate_res决定保存哪些model的函数。会自动移除掉不满足topk的文件夹。 + + :param trainer: + :param validate_res: + :return: + """ if self.save_topk is not None: _metric_value = self._get_validate_metric(validate_res) - _saved_name = self._get_checkpoint_real_save_folder(trainer=trainer, topk=True, metric=_metric_value) + folder_name = f"{self.folder_prefix}-epoch_{trainer.cur_epoch_idx}-batch_{trainer.global_forward_batches}" \ + f"-{self._real_monitor}_{_metric_value}" _should_save = False if self._topn < self.save_topk: - self._topk_model[_saved_name] = _metric_value + self._topk_model[folder_name] = _metric_value self._topn += 1 _should_save = True else: @@ -192,39 +199,27 @@ class CheckpointCallback(Callback): key=lambda x: self._topk_model[x]) if (self.larger_better and _metric_value > self._topk_model[_least_valuable_model]) or \ (self.larger_better is False and _metric_value < self._topk_model[_least_valuable_model]): - self._topk_model[_saved_name] = _metric_value + self._topk_model[folder_name] = _metric_value _should_save = True self._topk_model.pop(_least_valuable_model) - synchronize_safe_rm(self.log_filepath.joinpath(_least_valuable_model)) + synchronize_safe_rm(self.timestamp_path.joinpath(_least_valuable_model)) assert len(self._topk_model) == self.save_topk == self._topn if _should_save: - self._save_fn(trainer=trainer, topk=True, metric=_metric_value, substitute_folder=_saved_name) + self.save(trainer, folder_name=folder_name) - def _save_last(self, trainer: "fastNLP.Trainer"): - if self.save_last: - self._save_fn(trainer=trainer, topk=False, metric=None, substitute_folder="last") - - def _save_fn(self, trainer, topk: bool = False, metric: Optional[Union[int, float]] = None, - substitute_fn: Optional[Callable] = None, substitute_folder: Optional[str] = None): - # 首先根据当前的 epoch 和 batch 在 parent_path/FASTNLP_LAUNCH_TIME 下创建子文件夹 epoch-batch-monitor 或者 - # epoch-batch-monitor-monitor_value; - if substitute_folder is None: - folder = self.log_filepath.joinpath(self._get_checkpoint_real_save_folder(trainer, topk, metric)) - else: - folder = self.log_filepath.joinpath(substitute_folder) + def save(self, trainer, folder_name): + """ + 执行保存的函数,将数据保存在 save_folder/timestamp/folder_name 下。 + :param trainer: + :param folder_name: + :return: + """ + folder = self.timestamp_path.joinpath(folder_name) synchronize_mkdir(folder) - - # 然后再调用 trainer 的 save_model(用于保存模型)或者 save(用于断点重训)函数; - if substitute_fn is not None: - _fn = substitute_fn - else: - if self.is_trainer_checkpoint: - _fn = trainer.save - else: - _fn = trainer.save_model + _fn = getattr(trainer, self.save_fn_name) _fn( folder=folder, only_state_dict=self.only_state_dict, @@ -243,18 +238,95 @@ class CheckpointCallback(Callback): self._real_monitor = use_monitor return value - def _get_checkpoint_real_save_folder(self, trainer: "fastNLP.Trainer", topk: bool = False, - metric: Optional[Union[int, float]] = None) -> str: + @property + def folder_prefix(self): + raise NotImplementedError("The `folder_prefix` is not specified") + + @property + def save_fn_name(self): + raise NotImplementedError("The `save_fn_name` is not specified.") + + +class ModelCheckpointCallback(CheckpointCallback): + """ + 保存模型 checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 + + - save_folder/ + - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 + - model-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 + - model-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 + - model-last/ # 最后一个 epoch 的保存 + - model-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 + - model-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 + + model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 + 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 + + :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。 + :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 + 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 + :param save_every_n_epochs: 多少个 epoch 保存一次。 + :param save_every_n_batches: 多少个 batch 保存一次。 + :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 + :param save_topk: 保存 monitor 结果 topK 个。 + :param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 + :param larger_better: monitor 的值是否时越大越好。 + :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 + :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 + 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 + :param kwargs: + """ + @property + def save_fn_name(self): + return 'save_model' + + @property + def callback_name(self): """ - 获取当前保存模型的真正地名字; - metric 参数仅当 mode 为 'topk' 时起作用; + 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; + :return: """ - cur_epoch_idx = trainer.cur_epoch_idx - global_forward_batches = trainer.global_forward_batches - _other = "" - if topk: - _other = f"_{metric}" - return f"epoch_{cur_epoch_idx}-global_batch_{global_forward_batches}-{self._real_monitor}{_other}" + return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" + + @property + def folder_prefix(self): + return 'model' + + +class TrainerCheckpointCallback(CheckpointCallback): + """ + 保存 Trainer checkpoint 的 callback ,其保存的文件目录以及文件名命名规则如下 + + - save_folder/ + - YYYY-mm-dd-HH_MM_SS_fffff/ # 自动根据当前脚本的启动时间创建的 + - trainer-epoch_{epoch_idx}/ # 满足 save_every_n_epochs 条件保存的模型 + - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}/ # 满足 save_every_n_batches 保存的模型 + - trainer-last/ # 最后一个 epoch 的保存 + - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-exception_{exception_type}/ # exception时保存。 + - trainer-epoch_{epoch_idx}-batch_{global_batch_idx}-{monitor}_{monitor_value}/ # 满足topk条件存储文件名 + + model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 + 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 + + :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。 + :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 + 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 + :param save_every_n_epochs: 多少个 epoch 保存一次。 + :param save_every_n_batches: 多少个 batch 保存一次。 + :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 + :param save_topk: 保存 monitor 结果 topK 个。 + :param save_on_exception: 在出异常信息时,是否保存。 + :param larger_better: monitor 的值是否时越大越好。 + :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无意义。 + :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 + 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 + :param kwargs: + """ + @property + def save_fn_name(self): + return 'save' @property def callback_name(self): @@ -262,6 +334,8 @@ class CheckpointCallback(Callback): 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; :return: """ - return f"monitor-{self.monitor}#trainer_checkpoint-{self.is_trainer_checkpoint}#only_state_dict-{self.only_state_dict}" - + return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" + @property + def folder_prefix(self): + return 'trainer' diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index b4ef4e62..e7b94f8c 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -31,7 +31,7 @@ class LoadBestModelCallback(Callback): 请在函数内完成对模型的保存。 :param model_load_fn: 加载 model 的函数,与 model_save_fn 必须同时不为空。本函数的输入为一个已经创建好的文件夹,没有输出, 请在函数内完成对模型的加载。 - :param delete_after_train: 在加载了最佳模型之后是否删掉模型。 + :param delete_after_train: 在训练结束后是否删掉模型。 """ if model_load_fn is not None: assert callable(model_load_fn), "`model_load_fn` must be a callable object." diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index f58a7faf..bd66d0a0 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -133,17 +133,18 @@ class Evaluator: self.driver.barrier() - def run(self, num_eval_batch_per_dl: int = -1) -> Dict: + def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: """ 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 - 如果存在多个metric,一个dataloader的情况,key的命名规则是 - metric_indicator_name#metric_name - 如果存在多个数据集,一个metric的情况,key的命名规则是 - metric_indicator_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 - 如果存在多个metric,多个dataloader的情况,key的命名规则是 - metric_indicator_name#metric_name#dataloader_name - :param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 + 如果存在多个metric,一个dataloader的情况,key的命名规则是 + metric_indicator_name#metric_name + 如果存在多个数据集,一个metric的情况,key的命名规则是 + metric_indicator_name#metric_name#dataloader_name (其中 # 是默认的 separator ,可以通过 Evaluator 初始化参数修改)。 + 如果存在多个metric,多个dataloader的情况,key的命名规则是 + metric_indicator_name#metric_name#dataloader_name + 其中 metric_indicator_name 可能不存在。 + :param num_eval_batch_per_dl: 每个 dataloader 测试多少个 batch 的数据,-1 为测试所有数据。 :return: """ assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." @@ -157,7 +158,6 @@ class Evaluator: assert self.driver.has_test_dataloaders() metric_results = {} - self.reset() evaluate_context = self.driver.get_evaluate_context() self.driver.set_model_mode(mode='eval' if self.model_use_eval_mode else 'train') diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 9e1ccfbf..11697bdc 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -23,7 +23,7 @@ from fastNLP.core.drivers import Driver from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.utils import check_fn_not_empty_params, get_fn_arg_names, match_and_substitute_params, nullcontext from fastNLP.envs import rank_zero_call -from fastNLP.core.samplers import ReproducibleIterator, ReproducibleBatchSampler +from fastNLP.core.samplers import ReproducibleSampler, RandomBatchSampler from fastNLP.core.log import logger from fastNLP.envs import FASTNLP_MODEL_FILENAME @@ -251,7 +251,7 @@ class Trainer(TrainerEventTrigger): self.driver.set_deterministic_dataloader(self.dataloader) self.dataloader = self.driver.set_dist_repro_dataloader(dataloader=self.train_dataloader, dist=_dist_sampler, - reproducible=self.callback_manager.has_trainer_chechpoint) + reproducible=self.callback_manager.has_trainer_checkpoint) self.set_grad_to_none = kwargs.get("set_grad_to_none", True) self.on_after_trainer_initialized(self.driver) @@ -291,6 +291,7 @@ class Trainer(TrainerEventTrigger): raise FileNotFoundError("You are using `resume_from`, but we can not find your specific file.") if self.evaluator is not None and num_eval_sanity_batch > 0: + logger.info(f"Running evaluator sanity check for {num_eval_sanity_batch} batches.") self.on_sanity_check_begin() sanity_check_res = self.evaluator.run(num_eval_batch_per_dl=num_eval_sanity_batch) self.on_sanity_check_end(sanity_check_res) @@ -509,7 +510,7 @@ class Trainer(TrainerEventTrigger): :param folder: 保存模型的地址; :param only_state_dict: 是否只保存模型的 `state_dict`; - :param save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; + :param model_save_fn: 用户自己定制的用来替换该保存函数本身保存逻辑的函数; :param kwargs: 一些 driver 的保存模型的函数的参数另有其它; """ @@ -534,7 +535,16 @@ class Trainer(TrainerEventTrigger): def load_model(self, folder: Union[str, Path, BinaryIO, io.BytesIO], only_state_dict: bool = False, model_load_fn: Optional[Callable] = None, **kwargs): + """ + 加载模型 + :param folder: 读取 model 的文件夹,默认会尝试读取该文件夹下的 fastnlp_model.pkl.tar 文件。在 model_load_fn 不为空时, + 直接将该 folder 传递到 model_load_fn 中。 + :param only_state_dict: 要读取的文件中是否仅包含模型权重。在 model_load_fn 不为 None 时,该参数无意义。 + :param model_load_fn: callable 的函数,接受一个 folder 作为参数,不返回任何内容。 + :param kwargs: + :return: + """ self.on_load_model() self.driver.barrier() if not isinstance(folder, (io.BytesIO, BinaryIO)): @@ -555,7 +565,13 @@ class Trainer(TrainerEventTrigger): def save(self, folder: Union[str, Path], only_state_dict: bool = True, model_save_fn: Optional[Callable] = None, **kwargs): r""" - 用于断点重训的保存函数; + 用于断点重训 Trainer 的保存函数; + + :param folder: + :param only_state_dict: + :param model_save_fn: + :param kwargs: + :return: """ self.driver.barrier() @@ -594,7 +610,7 @@ class Trainer(TrainerEventTrigger): r""" 用于断点重训的加载函数; 注意在 fastNLP 中断点重训的保存和加载逻辑是分开的,因此可能存在一种情况:用户只希望加载一个断点重训的状态,而在之后不再进行断点重训的 - 保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleIterator; + 保存;在这种情况下,dataloader 的 sampler 就不一定会被替换成我们的 ReproducibleSampler; 注意我们目前不支持单卡到多卡的断点重训; diff --git a/fastNLP/core/dataloaders/torch_dataloader/fdl.py b/fastNLP/core/dataloaders/torch_dataloader/fdl.py index 0cae39ac..d56dbac9 100644 --- a/fastNLP/core/dataloaders/torch_dataloader/fdl.py +++ b/fastNLP/core/dataloaders/torch_dataloader/fdl.py @@ -24,6 +24,7 @@ class _FDataSet: 对Dataset的封装,主要是修改dataset的__getitem__函数,增加返回下标idx,值得注意的是dataset需要实现__getattribute__函数才能在_FDataset 中调用dataset的方法 """ + def __init__(self, dataset) -> None: self.dataset = dataset @@ -45,6 +46,7 @@ class TorchDataLoader(DataLoader): 提供给使用pytorch框架的DataLoader函数,若是配套使用FastNLP的dataset则可以自动使用AutoCollate函数对数据进行自动padding操作,用户也可以通过 提供的方法调节设置collate_fn的若干参数。 """ + def __init__(self, dataset, batch_size: int = 1, shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, batch_sampler: Optional["Sampler[Sequence[int]]"] = None, @@ -175,17 +177,17 @@ class TorchDataLoader(DataLoader): def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataSet], Mapping[str, DataSet]], - batch_size: int = 1, - shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, - batch_sampler: Optional["Sampler[Sequence[int]]"] = None, - num_workers: int = 0, collate_fn: Optional[Callable] = None, - pin_memory: bool = False, drop_last: bool = False, - timeout: float = 0, worker_init_fn: Optional[Callable] = None, - multiprocessing_context=None, generator=None, prefetch_factor: int = 2, - persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, - non_train_batch_size: int = 16, as_numpy: bool = False, - input_fields: Union[List, str] = None)\ - -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: + batch_size: int = 1, + shuffle: bool = False, sampler: Optional["Sampler[int]"] = None, + batch_sampler: Optional["Sampler[Sequence[int]]"] = None, + num_workers: int = 0, collate_fn: Optional[Callable] = None, + pin_memory: bool = False, drop_last: bool = False, + timeout: float = 0, worker_init_fn: Optional[Callable] = None, + multiprocessing_context=None, generator=None, prefetch_factor: int = 2, + persistent_workers: bool = False, non_train_sampler: Optional["Sampler[int]"] = None, + non_train_batch_size: int = 16, as_numpy: bool = False, + input_fields: Union[List, str, None] = None) \ + -> Union[TorchDataLoader, Dict[str, TorchDataLoader], Sequence[TorchDataLoader]]: """ 传入dataset或者data_bundle后,将其处理返回相对应的FdataLoader实例化对象 @@ -221,7 +223,8 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS multiprocessing_context=multiprocessing_context, generator=generator, prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, as_numpy=as_numpy) - dl.set_input(*input_fields) + if input_fields: + dl.set_input(*input_fields) return dl elif isinstance(ds_or_db, DataBundle): @@ -233,17 +236,21 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, as_numpy=as_numpy) else: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, - shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, + shuffle=shuffle, sampler=non_train_sampler, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, as_numpy=as_numpy) - dl_bundle[name].set_input(*input_fields) + if input_fields: + dl_bundle[name].set_input(*input_fields) return dl_bundle elif isinstance(ds_or_db, Sequence): @@ -269,8 +276,9 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, as_numpy=as_numpy) ) - for dl in dl_bundle: - dl.set_input(*input_fields) + if input_fields: + for dl in dl_bundle: + dl.set_input(*input_fields) return dl_bundle elif isinstance(ds_or_db, Mapping): @@ -282,18 +290,22 @@ def prepare_torch_dataloader(ds_or_db: Union[DataSet, DataBundle, Sequence[DataS num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, as_numpy=as_numpy) else: dl_bundle[name] = TorchDataLoader(dataset=ds, batch_size=non_train_batch_size, - shuffle=shuffle, sampler=non_train_sampler, batch_sampler=batch_sampler, + shuffle=shuffle, sampler=non_train_sampler, + batch_sampler=batch_sampler, num_workers=num_workers, collate_fn=collate_fn, pin_memory=pin_memory, drop_last=drop_last, timeout=timeout, worker_init_fn=worker_init_fn, multiprocessing_context=multiprocessing_context, generator=generator, - prefetch_factor=prefetch_factor, persistent_workers=persistent_workers, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, as_numpy=as_numpy) - dl_bundle[name].set_input(*input_fields) + if input_fields: + dl_bundle[name].set_input(*input_fields) return dl_bundle else: diff --git a/fastNLP/core/dataset/dataset.py b/fastNLP/core/dataset/dataset.py index 037fde00..9630a3a0 100644 --- a/fastNLP/core/dataset/dataset.py +++ b/fastNLP/core/dataset/dataset.py @@ -8,9 +8,8 @@ __all__ = [ import _pickle as pickle from copy import deepcopy -from typing import Optional, List, Callable, Union, Dict, Any +from typing import Optional, List, Callable, Union, Dict, Any, Mapping from functools import partial -import warnings import numpy as np from threading import Thread @@ -197,6 +196,20 @@ class DataSet: else: raise KeyError("Unrecognized type {} for idx in __getitem__ method".format(type(idx))) + def __setitem__(self, key, value): + assert isinstance(key, int) and keyList: """ # # 首先将所有的都移动到cpu上并且连续,防止有 pickle 出问题 # obj = apply_to_collection(obj, torch.Tensor, _to_device, device=torch.device('cpu')) + if device is None: + device = torch.cuda.current_device() if _TORCH_GREATER_EQUAL_1_8: objs = [None for _ in range(dist.get_world_size(group))] dist.all_gather_object(objs, obj) + objs = apply_to_collection(objs, torch.Tensor, _to_device, device=device) # 保证如果有tensor的话,所有tensor都在当前卡上 return objs - if device is None: - device = torch.cuda.current_device() group = group if group is not None else torch.distributed.group.WORLD data = convert_to_tensors(obj, device=device) data = apply_to_collection(data, (torch.Tensor, tuple), _all_gather, group=group) diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index 35b20b72..2c9c5162 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -27,7 +27,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic # world_size 和 rank if FASTNLP_BACKEND_LAUNCH in os.environ: if device is not None: - logger.warning("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " + logger.info("Parameter `device` would be ignored when you are using `torch.distributed.run` to pull " "up your script. And we will directly get the local device via " "`os.environ['LOCAL_RANK']`.") return TorchDDPDriver(model, torch.device(f"cuda:{os.environ['LOCAL_RANK']}"), True, **kwargs) @@ -39,11 +39,14 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic if isinstance(device, str): device = torch.device(device) elif isinstance(device, int): - if device < 0 and device != -1: - raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") - if device >= _could_use_device_num: + if device < 0: + if device != -1: + raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") + device = [torch.device(f"cuda:{w}") for w in range(_could_use_device_num)] + elif device >= _could_use_device_num: raise ValueError("The gpu device that parameter `device` specifies is not existed.") - device = torch.device(f"cuda:{device}") + else: + device = torch.device(f"cuda:{device}") elif isinstance(device, Sequence): device = list(set(device)) for each in device: @@ -62,7 +65,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic if not isinstance(device, List): return TorchSingleDriver(model, device, **kwargs) else: - logger.warning("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use " + logger.info("Notice you are using `torch` driver but your chosen `device` are multi gpus, we will use " "`TorchDDPDriver` by default. But if you mean using `TorchDDPDriver`, you should choose parameter" "`driver` as `TorchDDPDriver`.") return TorchDDPDriver(model, device, **kwargs) diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 952712be..cf8c19a8 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -13,9 +13,8 @@ __all__ = [ from .torch_driver import TorchDriver from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler from fastNLP.core.utils import auto_param_call -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleIterator +from fastNLP.core.samplers import RandomBatchSampler, ReproducibleSampler, re_instantiate_sampler from fastNLP.core.log import logger -from fastNLP.core.samplers import re_instantiate_sampler class TorchSingleDriver(TorchDriver): @@ -130,25 +129,31 @@ class TorchSingleDriver(TorchDriver): else: return self._test_step(batch) - def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleIterator]=None, + def set_dist_repro_dataloader(self, dataloader, dist: Union[str, RandomBatchSampler, ReproducibleSampler]=None, reproducible: bool = False): - if isinstance(dist, ReproducibleBatchSampler): + + # 如果 dist 为 RandomBatchSampler, ReproducibleIterator 说明是在断点重训时 driver.load 函数调用; + if isinstance(dist, RandomBatchSampler): return replace_batch_sampler(dataloader, dist) - elif isinstance(dist, ReproducibleIterator): + elif isinstance(dist, ReproducibleSampler): return replace_sampler(dataloader, dist) + # 如果 dist 为 str 或者 None,说明是在 trainer 初试化时调用; + args = self.get_dataloader_args(dataloader) + if isinstance(args.batch_sampler, RandomBatchSampler): + batch_sampler = re_instantiate_sampler(args.batch_sampler) + return replace_batch_sampler(dataloader, batch_sampler) + elif isinstance(args.sampler, ReproducibleSampler): + sampler = re_instantiate_sampler(args.sampler) + return replace_sampler(dataloader, sampler) + if reproducible: - args = self.get_dataloader_args(dataloader) - if isinstance(args.sampler, ReproducibleIterator): - sampler = re_instantiate_sampler(args.sampler) - return replace_sampler(dataloader, sampler) - else: - batch_sampler = ReproducibleBatchSampler( - batch_sampler=args.batch_sampler, - batch_size=args.batch_size, - drop_last=args.drop_last - ) - return replace_batch_sampler(dataloader, batch_sampler) + batch_sampler = RandomBatchSampler( + batch_sampler=args.batch_sampler, + batch_size=args.batch_size, + drop_last=args.drop_last + ) + return replace_batch_sampler(dataloader, batch_sampler) else: return dataloader diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 96d11761..b3386f5a 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler +from fastNLP.core.samplers import RandomBatchSampler, ReproducibleIterator class TorchDriver(Driver): @@ -143,8 +143,6 @@ class TorchDriver(Driver): :param filepath: 保存到哪个文件夹; :param only_state_dict: 是否只保存权重; - :param model_save_fn: - :return: """ model = self.unwrap_model() @@ -184,10 +182,10 @@ class TorchDriver(Driver): # trainer.dataloader 来改变 dataloader 的状态,从而适配训练或者评测环境; # 1. sampler 的状态,因为我们支持 resume training,即精确恢复到具体的一个 batch; - # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `replace_sampler` 中将 dataloader 的 - # sampler 替换为 `ReproducibleIterator`;否则就是在单卡情况下将 batch_sampler 替换为 `ReproducibleBatchSampler`; + # 首先 pytorch 的 DataLoader 一定会有 sampler;另一方面,我们在断点重训的时候一定会在 `set_` 中将 dataloader 的 + # sampler 替换为 `ReproducibleSampler`;否则就是在单卡情况下将 batch_sampler 替换为 `RandomBatchSampler`; dataloader_args = self.get_dataloader_args(dataloader) - if isinstance(dataloader_args.batch_sampler, ReproducibleBatchSampler): + if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): sampler = dataloader_args.batch_sampler elif dataloader_args.sampler: sampler = dataloader_args.sampler @@ -247,25 +245,25 @@ class TorchDriver(Driver): # 3. 恢复 sampler 的状态; dataloader_args = self.get_dataloader_args(dataloader) - - sampler = dataloader_args.sampler - if not (hasattr(sampler, 'load_state_dict') and callable(sampler.load_state_dict)): - # 说明这里需要使用 ReproduceSampler 来弄一下了 - if self.is_distributed(): - raise RuntimeError( - "It is not allowed to use single device checkpoint retraining before but ddp now.") - sampler = ReproducibleBatchSampler( - batch_sampler=sampler, + if isinstance(dataloader_args.batch_sampler, RandomBatchSampler): + sampler = dataloader_args.batch_sampler + elif isinstance(dataloader_args.sampler, ReproducibleIterator): + sampler = dataloader_args.sampler + elif self.is_distributed(): + raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our " + "`RandomBatchSampler` or `ReproducibleIterator`.") + else: + sampler = RandomBatchSampler( + batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last ) sampler.load_state_dict(states['sampler_states']) - states["dataloader"] = self.set_dist_repro_dataloader(dataloader, sampler) # 4. 修改 trainer_state.batch_idx_in_epoch # sampler 是类似 RandomSampler 的sampler,不是 batch_sampler; - if not isinstance(sampler, ReproducibleBatchSampler): + if not isinstance(sampler, RandomBatchSampler): if dataloader_args.drop_last: batch_idx_in_epoch = len( sampler) // dataloader_args.batch_size - sampler.num_left_samples // dataloader_args.batch_size @@ -293,7 +291,7 @@ class TorchDriver(Driver): @staticmethod def worker_init_function(worker_id: int, rank: Optional[int] = None) -> None: # pragma: no cover - """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set set the seed + """The worker_init_fn that Lightning automatically adds to your dataloader if you previously set the seed with ``seed_everything(seed, workers=True)``. See also the PyTorch documentation on diff --git a/fastNLP/core/metrics/backend/torch_backend/backend.py b/fastNLP/core/metrics/backend/torch_backend/backend.py index 06304a98..8945ab01 100644 --- a/fastNLP/core/metrics/backend/torch_backend/backend.py +++ b/fastNLP/core/metrics/backend/torch_backend/backend.py @@ -33,7 +33,7 @@ class TorchBackend(Backend): if dist.is_initialized(): if method is None: raise AggregateMethodError(should_have_aggregate_method=True) - tensor = self._gather_all(tensor) + tensor = fastnlp_torch_all_gather(tensor) if isinstance(tensor[0], torch.Tensor): tensor = torch.stack(tensor) # 第一步, aggregate结果 @@ -68,59 +68,6 @@ class TorchBackend(Backend): def get_scalar(self, tensor) -> float: return tensor.item() - @staticmethod - def _gather_all(result, group: Optional[Any] = None) -> List: - """Function to gather all tensors from several ddp processes onto a list that is broadcasted to all processes. - Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case - tensors are padded, gathered and then trimmed to secure equal workload for all processes. - - Args: - result: the value to sync - group: the process group to gather results from. Defaults to all processes (world) - - Return: - gathered_result: list with size equal to the process group where - gathered_result[i] corresponds to result tensor from process i - """ - - if group is None: - group = dist.group.WORLD - - # convert tensors to contiguous format - result = result.contiguous() - - world_size = dist.get_world_size(group) - dist.barrier(group=group) - - # if the tensor is scalar, things are easy - if result.ndim == 0: - return _simple_gather_all_tensors(result, group, world_size) - - # 1. Gather sizes of all tensors - local_size = torch.tensor(result.shape, device=result.device) - local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)] - dist.all_gather(local_sizes, local_size, group=group) - max_size = torch.stack(local_sizes).max(dim=0).values - all_sizes_equal = all(all(ls == max_size) for ls in local_sizes) - - # 2. If shapes are all the same, then do a simple gather: - if all_sizes_equal: - return _simple_gather_all_tensors(result, group, world_size) - - # 3. If not, we need to pad each local tensor to maximum size, gather and then truncate - pad_dims = [] - pad_by = (max_size - local_size).detach().cpu() - for val in reversed(pad_by): - pad_dims.append(0) - pad_dims.append(val.item()) - result_padded = torch.nn.functional.pad(result, pad_dims) - gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)] - dist.all_gather(gathered_result, result_padded, group) - for idx, item_size in enumerate(local_sizes): - slice_param = [slice(dim_size) for dim_size in item_size] - gathered_result[idx] = gathered_result[idx][slice_param] - return gathered_result - def tensor2numpy(self, tensor) -> np.array: """ 将对应的tensor转为numpy对象 diff --git a/fastNLP/core/metrics/element.py b/fastNLP/core/metrics/element.py index b3a496bf..22ba2635 100644 --- a/fastNLP/core/metrics/element.py +++ b/fastNLP/core/metrics/element.py @@ -11,12 +11,12 @@ from fastNLP.envs.env import FASTNLP_GLOBAL_RANK class Element: - def __init__(self, value: float, aggregate_method, backend: Backend, name=None): + def __init__(self, name, value: float, aggregate_method, backend: Backend): + self.name = name self.init_value = value self.aggregate_method = aggregate_method - self.name = name if backend == 'auto': - raise RuntimeError("You have to specify the backend.") + raise RuntimeError(f"You have to specify the backend for Element:{self.name}.") elif isinstance(backend, AutoBackend): self.backend = backend else: @@ -34,20 +34,16 @@ class Element: 自动aggregate对应的元素 """ + self._check_value_initialized() try: self._value = self.backend.aggregate(self._value, self.aggregate_method) except AggregateMethodError as e: msg = 'If you see this message, please report a bug.' if self.name and e.should_have_aggregate_method: msg = f"Element:{self.name} has no specified `aggregate_method`." - elif e.should_have_aggregate_method: - msg = "Element has no specified `aggregate_method`." elif self.name and not e.should_have_aggregate_method: msg = f"Element:{self.name}'s backend:{self.backend.__class__.__name__} does not support " \ f'aggregate_method:{self.aggregate_method}.' - elif not e.should_have_aggregate_method: - msg = f"Element's backend:{self.backend.__class__.__name__} does not support " \ - f'aggregate_method:{self.aggregate_method}.' if e.only_warn: if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: logger.warning(msg) @@ -74,6 +70,7 @@ class Element: return self._value def get_scalar(self) -> float: + self._check_value_initialized() return self.backend.get_scalar(self._value) def fill_value(self, value): @@ -95,7 +92,7 @@ class Element: def _check_value_when_call(self): if self.value is None: - prefix = f'Element:`{self.name}`' if self.name else 'Element' + prefix = f'Element:`{self.name}`' raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " "element, or use it after it being used by the `Metric.compute()` method.") @@ -273,9 +270,10 @@ class Element: """ try: if self._value is None: - prefix = f'Element:`{self.name}`' if self.name else 'Element' + prefix = f'Element:`{self.name}`' raise RuntimeError(prefix + " is not initialized. Please either specify backend when creating this " "element, or use it after it being used by the `Metric.compute()` method.") return getattr(self._value, item) except AttributeError as e: + logger.error(f"Element:{self.name} has no `{item}` attribute.") raise e diff --git a/fastNLP/core/metrics/metric.py b/fastNLP/core/metrics/metric.py index 097671da..2fb575fc 100644 --- a/fastNLP/core/metrics/metric.py +++ b/fastNLP/core/metrics/metric.py @@ -35,7 +35,7 @@ class Metric: def elements(self) -> dict: return self._elements - def register_element(self, name=None, value: float = 0, aggregate_method=None, backend='auto') -> Element: + def register_element(self, name, value: float = 0, aggregate_method=None, backend='auto') -> Element: """ 注册一个 element 对象,注册之后便可以通过在 Metric 中直接通过 self.{name} 进行调用,可以认为该对象即为对应 backend 的 tensor 直接进行加减乘除计算即可。 @@ -57,11 +57,9 @@ class Metric: else: backend = AutoBackend(backend) - # 当name为None,默认为变量取得变量名 - if name is None: - name = f'ele_var_{len(self._elements)}' + assert name is not None and name not in self.elements - element = Element(value=value, aggregate_method=aggregate_method, backend=backend, name=name) + element = Element(name=name, value=value, aggregate_method=aggregate_method, backend=backend) self.elements[name] = element setattr(self, name, element) return element diff --git a/fastNLP/core/metrics/span_f1_pre_rec_metric.py b/fastNLP/core/metrics/span_f1_pre_rec_metric.py index 45b412c8..716cea30 100644 --- a/fastNLP/core/metrics/span_f1_pre_rec_metric.py +++ b/fastNLP/core/metrics/span_f1_pre_rec_metric.py @@ -216,9 +216,26 @@ def _compute_f_pre_rec(beta_square, tp, fn, fp): class SpanFPreRecMetric(Metric): - def __init__(self, backend: Union[str, Backend, None] = 'auto', tag_vocab: Vocabulary = None, - encoding_type: str = None, ignore_labels: List[str] = None, only_gross: bool = True, f_type='micro', - beta=1, aggregate_when_get_metric: bool = True,) -> None: + def __init__(self, tag_vocab: Vocabulary, encoding_type: str = None, ignore_labels: List[str] = None, + only_gross: bool = True, f_type='micro', + beta=1, backend: Union[str, Backend, None] = 'auto', aggregate_when_get_metric: bool = True,) -> None: + r""" + + :param tag_vocab: 标签的 :class:`~fastNLP.Vocabulary` 。支持的标签为"B"(没有label);或"B-xxx"(xxx为某种label,比如POS中的NN), + 在解码时,会将相同xxx的认为是同一个label,比如['B-NN', 'E-NN']会被合并为一个'NN'. + :param str pred: 用该key在evaluate()时从传入dict中取出prediction数据。 为None,则使用 `pred` 取数据 + :param str target: 用该key在evaluate()时从传入dict中取出target数据。 为None,则使用 `target` 取数据 + :param str seq_len: 用该key在evaluate()时从传入dict中取出sequence length数据。为None,则使用 `seq_len` 取数据。 + :param str encoding_type: 目前支持bio, bmes, bmeso, bioes。默认为None,通过tag_vocab自动判断. + :param list ignore_labels: str 组成的list. 这个list中的class不会被用于计算。例如在POS tagging时传入['NN'],则不会计算'NN'个label + :param bool only_gross: 是否只计算总的f1, precision, recall的值;如果为False,不仅返回总的f1, pre, rec, 还会返回每个label的f1, pre, rec + :param str f_type: `micro` 或 `macro` . `micro` :通过先计算总体的TP,FN和FP的数量,再计算f, precision, recall; `macro` : 分布计算每个类别的f, precision, recall,然后做平均(各类别f的权重相同) + :param float beta: f_beta分数, :math:`f_{beta} = \frac{(1 + {beta}^{2})*(pre*rec)}{({beta}^{2}*pre + rec)}` . 常用为 `beta=0.5, 1, 2` 若为0.5则精确率的权重高于召回率;若为1,则两者平等;若为2,则召回率权重高于精确率。 + :param str backend: 目前支持四种类型的backend, ['auto', 'torch', 'paddle', 'jittor']。其中 auto 表示根据实际调用 Metric.update() + 函数时传入的参数决定具体的 backend ,一般情况下直接使用 'auto' 即可。 + :param bool aggregate_when_get_metric: 在计算 metric 的时候是否自动将各个进程上的相同的 element 的数字聚合后再得到metric, + 当 backend 不支持分布式时,该参数无意义。 + """ super(SpanFPreRecMetric, self).__init__(backend=backend, aggregate_when_get_metric=aggregate_when_get_metric) if f_type not in ('micro', 'macro'): raise ValueError("f_type only supports `micro` or `macro`', got {}.".format(f_type)) @@ -249,16 +266,25 @@ class SpanFPreRecMetric(Metric): self.only_gross = only_gross self.tag_vocab = tag_vocab - self._true_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) - self._false_positives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) - self._false_negatives = defaultdict(partial(self.register_element, aggregate_method='sum', name=None)) + self._true_positives = {} + self._false_positives = {} + self._false_negatives = {} + for word, _ in tag_vocab: + word = word.lower() + if word != 'o': + word = word[2:] + if word in self._true_positives: + continue + self._true_positives[word] = self.register_element(name=f'tp_{word}', aggregate_method='sum', backend=backend) + self._false_negatives[word] = self.register_element(name=f'fn_{word}', aggregate_method='sum', backend=backend) + self._false_positives[word] = self.register_element(name=f'fp_{word}', aggregate_method='sum', backend=backend) def get_metric(self) -> dict: evaluate_result = {} if not self.only_gross or self.f_type == 'macro': tags = set(self._false_negatives.keys()) - tags.update(set(self._false_positives.keys())) - tags.update(set(self._true_positives.keys())) + tags.update(self._false_positives.keys()) + tags.update(self._true_positives.keys()) f_sum = 0 pre_sum = 0 rec_sum = 0 @@ -266,6 +292,9 @@ class SpanFPreRecMetric(Metric): tp = self._true_positives[tag].get_scalar() fn = self._false_negatives[tag].get_scalar() fp = self._false_positives[tag].get_scalar() + if tp == fn == fp == 0: + continue + f, pre, rec = _compute_f_pre_rec(self.beta_square, tp, fn, fp) f_sum += f pre_sum += pre @@ -284,10 +313,17 @@ class SpanFPreRecMetric(Metric): evaluate_result['rec'] = rec_sum / len(tags) if self.f_type == 'micro': + tp, fn, fp = [], [], [] + for val in self._true_positives.values(): + tp.append(val.get_scalar()) + for val in self._false_negatives.values(): + fn.append(val.get_scalar()) + for val in self._false_positives.values(): + fp.append(val.get_scalar()) f, pre, rec = _compute_f_pre_rec(self.beta_square, - sum(val.get_scalar() for val in self._true_positives.values()), - sum(val.get_scalar() for val in self._false_negatives.values()), - sum(val.get_scalar() for val in self._false_positives.values())) + sum(tp), + sum(fn), + sum(fp)) evaluate_result['f'] = f evaluate_result['pre'] = pre evaluate_result['rec'] = rec diff --git a/fastNLP/core/samplers/__init__.py b/fastNLP/core/samplers/__init__.py index 68928b66..3d6813f7 100644 --- a/fastNLP/core/samplers/__init__.py +++ b/fastNLP/core/samplers/__init__.py @@ -3,19 +3,30 @@ __all__ = [ 'SortedSampler', 'ConstTokenNumSampler', 'ConstantTokenNumSampler', - 'UnrepeatedDistributedSampler', + 'MixSampler', - 'InnerSampler', 'DopedSampler', 'MixSequentialSampler', 'PollingSampler', - 'ReproducibleIterator', + + 'ReproducibleSampler', 'RandomSampler', - 're_instantiate_sampler' + "SequentialSampler", + "SortedSampler", + + 'UnrepeatedSampler', + 'UnrepeatedRandomSampler', + "UnrepeatedSortedSampler", + "UnrepeatedSequentialSampler", + + "re_instantiate_sampler", + "conversion_between_reproducible_and_unrepeated_sampler" ] -from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler, UnrepeatedDistributedSampler -from .mix_sampler import MixSampler, InnerSampler, DopedSampler, MixSequentialSampler, PollingSampler -from .reproducible_sampler import ReproducibleIterator, RandomSampler, re_instantiate_sampler -from .reproducible_batch_sampler import ReproducibleBatchSampler, BucketedBatchSampler +from .sampler import BucketSampler, SortedSampler, ConstTokenNumSampler, ConstantTokenNumSampler +from .unrepeated_sampler import UnrepeatedSampler, UnrepeatedRandomSampler, UnrepeatedSortedSampler, UnrepeatedSequentialSampler +from .mix_sampler import MixSampler, DopedSampler, MixSequentialSampler, PollingSampler +from .reproducible_sampler import ReproducibleSampler, RandomSampler, SequentialSampler, SortedSampler +from .utils import re_instantiate_sampler, conversion_between_reproducible_and_unrepeated_sampler +from .reproducible_batch_sampler import RandomBatchSampler, BucketedBatchSampler diff --git a/fastNLP/core/samplers/mix_sampler.py b/fastNLP/core/samplers/mix_sampler.py index e219b6e2..f53c06a5 100644 --- a/fastNLP/core/samplers/mix_sampler.py +++ b/fastNLP/core/samplers/mix_sampler.py @@ -4,7 +4,6 @@ from typing import Union, List, Iterable, Dict __all__ = [ 'MixSampler', - 'InnerSampler', 'DopedSampler', 'MixSequentialSampler', 'PollingSampler' diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index 3e39aca5..5a25110b 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -1,6 +1,6 @@ __all__ = [ 'BucketedBatchSampler', - "ReproducibleBatchSampler" + "RandomBatchSampler" ] import math @@ -16,7 +16,7 @@ from fastNLP.core.log import logger from abc import abstractmethod -class ReproducibleBatchIterator: +class ReproducibleBatchSampler: @abstractmethod def set_distributed(self, num_replicas, rank, pad=True): raise NotImplementedError("Each specific batch_sampler should implement its own `set_distributed` method.") @@ -42,13 +42,13 @@ class ReproducibleBatchIterator: pass -class ReproducibleBatchSampler(ReproducibleBatchIterator): +class RandomBatchSampler(ReproducibleBatchSampler): # 这两个参数的值应当交给 driver 的 get_dataloader_args 函数去拿; def __init__(self, batch_sampler, batch_size: int, drop_last: bool, **kwargs): """ 可以使得 batch_sampler 对象状态恢复的 wrapper 。 - :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。ReproducibleBatchSampler 将首先遍历一边该对象,然后将迭代 + :param batch_sampler: 可迭代出 数字 或 数字列表 的可迭代对象。RandomBatchSampler 将首先遍历一边该对象,然后将迭代 出来的序号暂存起来,使用时按照 batch_size 的 batch 大小吐出序号列表。 :param batch_size: 每个 batch 的大小是多少。 :param drop_last: 如果最后一个 batch 无法构成 batch_size 那么多个 sample ,是否丢掉。 @@ -138,7 +138,7 @@ class ReproducibleBatchSampler(ReproducibleBatchIterator): (len(self.index_list) - self.data_idx + self.batch_size - 1) // self.batch_size -class BucketedBatchSampler(ReproducibleBatchIterator): +class BucketedBatchSampler(ReproducibleBatchSampler): def __init__(self, dataset, length: Union[List[int], str], batch_size:int = 32, num_batch_per_bucket:int = 10, shuffle: bool = True, drop_last: bool = False, seed: int = 0, **kwargs): """ diff --git a/fastNLP/core/samplers/reproducible_sampler.py b/fastNLP/core/samplers/reproducible_sampler.py index 0a4ac7bf..1dc226a5 100644 --- a/fastNLP/core/samplers/reproducible_sampler.py +++ b/fastNLP/core/samplers/reproducible_sampler.py @@ -1,25 +1,21 @@ -from typing import Dict, List +from typing import Dict, List, Union import math import numpy as np from fastNLP.core.log import logger +from fastNLP.core.dataset import DataSet __all__ = [ - 'ReproducibleIterator', + 'ReproducibleSampler', 'RandomSampler', - 're_instantiate_sampler' + "SortedSampler", + "SequentialSampler" ] -def re_instantiate_sampler(sampler): - all_attributes = vars(sampler) - return type(sampler)(**all_attributes) - - - -class ReproducibleIterator: +class ReproducibleSampler: """ - 注意所有继承 `ReproducibleIterator` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler + 注意所有继承 `ReproducibleSampler` 的类的 `__init__` 方法中都需要加入参数 `**kwargs`,用来使我们再断点重训时重新实例化这个 sampler 或者 batch_sampler;注意,所有在 init 中初始化的变量,都不能含有 _ 下横线作为开头;所有不在 init 中设置的变量都必须以下横线开头。 """ @@ -47,7 +43,7 @@ class ReproducibleIterator: pass -class RandomSampler(ReproducibleIterator): +class RandomSampler(ReproducibleSampler): def __init__(self, dataset, shuffle: bool = True, seed: int = 0, **kwargs): """ @@ -157,8 +153,8 @@ class RandomSampler(ReproducibleIterator): f"we cannot use {self.__class__.__name__} to load it." length = states['length'] - assert length == len(self.dataset), "The number of samples is different between the checkpoint record " \ - "and current dataset." + assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ + f"and current dataset({len(self.dataset)})." self.seed = states['seed'] self.epoch = states['epoch'] self.num_consumed_samples = states['num_consumed_samples'] @@ -215,9 +211,132 @@ class RandomSampler(ReproducibleIterator): self.pad else math.floor(((len(self.dataset) - num_consumed_samples) / self.num_replicas)) +class SequentialSampler(RandomSampler): + def __init__(self, dataset, dist_mode:str='interval', **kwargs): + """ + 按照顺序读取 dataset 。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 + + :param dataset: 实现了 __len__ 方法的数据容器。 + :param kwargs: + """ + super().__init__(dataset=dataset, shuffle=False, seed=0, **kwargs) + + def __iter__(self): + if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 + self.num_consumed_samples = 0 + self.during_iter = True + indices = self.generate_indices() + + if self.pad: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + + assert len(indices) == self.total_size + + # subsample + indices = indices[self.num_consumed_samples:] + indices = indices[self.rank:len(indices):self.num_replicas] + assert len(indices) == self.num_left_samples + + for index in indices: + self.num_consumed_samples += self.num_replicas + yield index + self.during_iter = False + self.num_consumed_samples = 0 + + def generate_indices(self) -> List[int]: + """ + 生成随机序列 + + :return: + """ + return list(range(len(self.dataset))) + + def state_dict(self) -> Dict: + states = { + 'num_consumed_samples': self.num_consumed_samples, # 注意该值是计算所有 rank 上训练的所有数据; + 'sampler_type': self.__class__.__name__, + 'length': len(self.dataset), + } + return states + + def load_state_dict(self, states: Dict): + # 如果 self.during_iter 是 True,那么 data_idx 一定是 0; + assert self.during_iter is False, "Cannot call load_state_dict() when it is " \ + "during an unfinished iteration." + + assert states['sampler_type'] == self.__class__.__name__, f"The sampler type in checkpoint is {states['sampler_type']}," \ + f"we cannot use {self.__class__.__name__} to load it." + + length = states['length'] + assert length == len(self.dataset), f"The number of samples is different between the checkpoint record({length}) " \ + f"and current dataset({len(self.dataset)})." + self.num_consumed_samples = states['num_consumed_samples'] + if self.num_consumed_samples >= length: # 如果保存的时候已经到达了最后一个sample了,则直接将结果重置为0 + self.num_consumed_samples = 0 + + +class SortedSampler(SequentialSampler): + def __init__(self, dataset, length:Union[str, List], **kwargs): + """ + 将 dataset 中的数据根据 length 从长到短进行迭代。在多卡情况下,由于padding 最后一个 sample 可能是最长的那个 sample。 + + :param dataset: 实现了 __len__ 方法的数据容器。 + :param length: 如果为 List,应当与 dataset 有一样的长度,表示 dataset 中每个元素的数量;仅当传入的 dataset 为 fastNLP 的 + DataSet 时支持传入 str,会将该str理解为 dataset 的 field 名称,若 field 中的元素为 int,则认为该值是 sample 的长度。 + :param seed: 设置的随机数种子 + :param kwargs: fastNLP 保留使用 + """ + super().__init__(dataset=dataset, **kwargs) + if isinstance(dataset, DataSet): + length = dataset.get_field(length) + if not isinstance(length[0], int): + length = list(map(len, length)) + else: + assert len(length) == len(dataset), "When the dataset is not fastNLP.DataSet, " \ + "the length parameter can only be List[int]" + + assert len(length) == len(dataset), "The length of `data` and `length` should be equal." + + self.length = np.array(length, dtype=int) # 按照长到短排列的序号。 + self.sorted_indices = np.argsort(self.length)[::-1].tolist() # 按长度从高到低排序的 + + def generate_indices(self) -> List[int]: + return self.sorted_indices + def __iter__(self): + if self.during_iter: # 如果发现_during_iter为True,说明之前的还没结束,只有强制重新初始化了 + self.num_consumed_samples = 0 + self.during_iter = True + indices = self.generate_indices() + if self.pad: + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[:self.total_size] + assert len(indices) == self.total_size + # subsample + indices = indices[self.num_consumed_samples:] + indices = indices[self.rank:len(indices):self.num_replicas] + assert len(indices) == self.num_left_samples + for index in indices: + self.num_consumed_samples += self.num_replicas + yield index + self.during_iter = False + self.num_consumed_samples = 0 diff --git a/fastNLP/core/samplers/sampler.py b/fastNLP/core/samplers/sampler.py index e41472bf..89751884 100644 --- a/fastNLP/core/samplers/sampler.py +++ b/fastNLP/core/samplers/sampler.py @@ -7,7 +7,6 @@ __all__ = [ "SortedSampler", 'ConstTokenNumSampler', "ConstantTokenNumSampler", - "UnrepeatedDistributedSampler", ] from itertools import chain @@ -18,7 +17,7 @@ import numpy as np from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: - from torch.utils.data import SequentialSampler, Sampler, RandomSampler + from torch.utils.data import Sampler else: from fastNLP.core.utils.dummy_class import DummyClass as Sampler @@ -727,87 +726,3 @@ def k_means_bucketing(lengths, buckets): if buckets[bucket_id] is None or lengths[idx] <= buckets[bucket_id]: bucket_data[bucket_id].append(idx) return bucket_data - - -class UnrepeatedDistributedSampler: - def __init__(self, dataset, shuffle: bool = False, seed: int = 0): - """ - 考虑在多卡evaluate的场景下,不能重复sample。 - - :param dataset: - :param shuffle: - :param seed: - """ - self.dataset = dataset - self.shuffle = shuffle - self.seed = seed - - # 多卡的相关的参数 - self.num_replicas = 1 - self.rank = 0 - self.epoch = -1 - - def __len__(self): - """ - 返回 sampler 一次完整的迭代过程会产生多少个index。多卡的情况下,只考虑当前rank; - :return: - """ - num_common = len(self.dataset)//self.num_replicas - self.num_samples = num_common + int(self.rank < (len(self.dataset)-num_common*self.num_replicas)) - return self.num_samples - - def __iter__(self): - r""" - 当前使用num_consumed_samples做法会在交替使用的时候遇到问题; - Example: - >>> sampler = RandomSampler() - >>> iter1 = iter(sampler) - >>> iter2 = iter(sampler) - >>> next(iter1) - >>> next(iter2) # 当前num_consumed_samples的数量会发生变化 - """ - - indices = self.generate_indices() - - # subsample - indices = indices[self.rank:len(indices):self.num_replicas] - assert len(indices) == len(self) - - for index in indices: - yield index - - def generate_indices(self) -> List[int]: - """ - 生成随机序列 - - :return: - """ - if self.shuffle: - indices = list(range(len(self.dataset))) - seed = self.seed + self.epoch - rng = np.random.default_rng(abs(seed)) - rng.shuffle(indices) - if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 - self.epoch -= 1 - else: - indices = list(range(len(self.dataset))) - return indices - - def set_epoch(self, epoch: int) -> None: - self.epoch = epoch - - def set_distributed(self, num_replicas, rank): - """ - 该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用; - - :param num_replicas: - :param rank: - :return: - """ - assert num_replicas>0 and isinstance(num_replicas, int) - assert isinstance(rank, int) and 0<=rank List[int]: + """ + 生成随机序列 + + :return: + """ + if self.shuffle: + indices = list(range(len(self.dataset))) + seed = self.seed + self.epoch + rng = np.random.default_rng(abs(seed)) + rng.shuffle(indices) + if self.epoch < 0: # 防止用户忘记调用 set_epoch,至少这样可以保证每次epoch出来的index顺序不同。 + self.epoch -= 1 + else: + indices = list(range(len(self.dataset))) + return indices + + def set_epoch(self, epoch: int) -> None: + self.epoch = epoch + + def set_distributed(self, num_replicas, rank): + """ + 该方法本质上等同于 ddp 情形下的没有完成的初始化,应当在初始化该 sampler 本身后立即被调用; + + :param num_replicas: + :param rank: + :return: + """ + assert num_replicas>0 and isinstance(num_replicas, int) + assert isinstance(rank, int) and 0<=rank List[int]: + return self.sorted_indices + + +class UnrepeatedSequentialSampler(UnrepeatedRandomSampler): + def __init__(self, dataset, **kwargs): + """ + 按照顺序读取 dataset。在多卡情况下,间隔读取,例如,在两卡情况下,卡0取 [0,2,4,..], 卡1取 [1,3,5...]。 + + :param dataset: 实现了 __len__ 方法的数据容器。 + :param kwargs: + """ + super(UnrepeatedSequentialSampler, self).__init__(dataset, shuffle=False, seed=0, **kwargs) + + def __iter__(self): + indices = self.generate_indices() + indices = indices[self.rank:len(indices):self.num_replicas] + for index in indices: + yield index + + def generate_indices(self) -> List[int]: + return list(range(len(self.dataset))) + diff --git a/fastNLP/core/samplers/utils.py b/fastNLP/core/samplers/utils.py new file mode 100644 index 00000000..dd90fe7c --- /dev/null +++ b/fastNLP/core/samplers/utils.py @@ -0,0 +1,42 @@ +__all__ = [ + 're_instantiate_sampler', + 'conversion_between_reproducible_and_unrepeated_sampler' +] + +from fastNLP.core.samplers.unrepeated_sampler import * +from fastNLP.core.samplers.reproducible_sampler import * + + +def conversion_between_reproducible_and_unrepeated_sampler(sampler): + """ + 将 sampler 替换成其对应的 reproducible 版本或 unrepeated 版本。如果输入是 UnrepeatedSampler 但是没找到对应的 + ReproducibleSampler, + + :param sampler: + :return: + """ + assert isinstance(sampler, UnrepeatedSampler) or isinstance(sampler, ReproducibleSampler), \ + "The sampler must be UnrepeatedSampler or ReproducibleSampler" + if isinstance(sampler, UnrepeatedSampler): + if isinstance(sampler, UnrepeatedRandomSampler): + return re_instantiate_sampler(sampler, new_sampler_class=RandomSampler) + elif isinstance(sampler, UnrepeatedSequentialSampler): + return re_instantiate_sampler(sampler, new_sampler_class=SequentialSampler) + elif isinstance(sampler, UnrepeatedSortedSampler): + return re_instantiate_sampler(sampler, new_sampler_class=SortedSampler) + raise TypeError(f"{sampler.__class__} has no unrepeated version.") + else: + if isinstance(sampler, RandomSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedRandomSampler) + elif isinstance(sampler, SequentialSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSequentialSampler) + elif isinstance(sampler, SortedSampler): + return re_instantiate_sampler(sampler, new_sampler_class=UnrepeatedSortedSampler) + raise TypeError(f"{sampler.__class__} has no reproducible version.") + + +def re_instantiate_sampler(sampler, new_sampler_class=None): + all_attributes = vars(sampler) + if new_sampler_class is not None: + return new_sampler_class(**all_attributes) + return type(sampler)(**all_attributes) \ No newline at end of file diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index 20330d02..256cc906 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -96,6 +96,7 @@ class FRichProgress(Progress, metaclass=Singleton): # start new self.start() + self.console.show_cursor(show=True) return self def set_transient(self, transient: bool = True): @@ -149,6 +150,9 @@ class FRichProgress(Progress, metaclass=Singleton): super().stop_task(task_id) super().remove_task(task_id) + def start(self) -> None: + super().start() + self.console.show_cursor(show=True) if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: f_rich_progress = FRichProgress().new_progess( @@ -161,7 +165,7 @@ if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: TextColumn("{task.fields[post_desc]}", justify="right"), transient=True, disable=False, - speed_estimate_period=10 + speed_estimate_period=1 ) else: f_rich_progress = DummyFRichProgress() diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 73267e7f..1a7e0ee5 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -44,6 +44,9 @@ __all__ = [ ] + + + def get_fn_arg_names(fn: Callable) -> List[str]: r""" 返回一个函数的所有参数的名字; diff --git a/fastNLP/envs/set_backend.py b/fastNLP/envs/set_backend.py index a9e82c74..6da62334 100644 --- a/fastNLP/envs/set_backend.py +++ b/fastNLP/envs/set_backend.py @@ -153,7 +153,7 @@ def seed_jittor_global_seed(global_seed): pass -def dump_fastnlp_backend(default:bool = False): +def dump_fastnlp_backend(default:bool = False, backend=None): """ 将 fastNLP 的设置写入到 ~/.fastNLP/envs/ 文件夹下, 若 default 为 True,则保存的文件为 ~/.fastNLP/envs/default.json 。 @@ -165,6 +165,7 @@ def dump_fastnlp_backend(default:bool = False): 会保存的环境变量为 FASTNLP_BACKEND 。 :param default: + :param backend: 保存使用的 backend 为哪个值,允许的值有 ['torch', 'paddle', 'jittor']。如果为 None ,则使用环境变量中的值。 :return: """ if int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: @@ -179,10 +180,16 @@ def dump_fastnlp_backend(default:bool = False): os.makedirs(os.path.dirname(env_path), exist_ok=True) envs = {} - if FASTNLP_BACKEND in os.environ: - envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND] + assert backend in SUPPORT_BACKENDS, f"fastNLP only supports {SUPPORT_BACKENDS} right now." + if backend is None: + if FASTNLP_BACKEND in os.environ: + envs[FASTNLP_BACKEND] = os.environ[FASTNLP_BACKEND] + else: + envs[FASTNLP_BACKEND] = backend if len(envs): with open(env_path, 'w', encoding='utf8') as f: json.dump(fp=f, obj=envs) print(f"Writing the default fastNLP backend:{envs[FASTNLP_BACKEND]} to {env_path}.") + else: + raise RuntimeError("No backend specified.") \ No newline at end of file diff --git a/fastNLP/envs/set_env_on_import.py b/fastNLP/envs/set_env_on_import.py index f94bef50..8b5f6394 100644 --- a/fastNLP/envs/set_env_on_import.py +++ b/fastNLP/envs/set_env_on_import.py @@ -47,7 +47,8 @@ def set_env_on_import_paddle(): # TODO jittor may need set this def set_env_on_import_jittor(): # todo 需要设置 FASTNLP_GLOBAL_RANK 和 FASTNLP_BACKEND_LAUNCH - pass + if 'log_silent' not in os.environ: + os.environ['log_silent'] = '1' def set_env_on_import(): @@ -63,7 +64,7 @@ def set_env_on_import(): # fastNLP 内部使用的一些变量 if FASTNLP_LAUNCH_TIME not in os.environ: - cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%M_%f')}" + cur_time = f"{datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S_%f')}" os.environ[FASTNLP_LAUNCH_TIME] = cur_time # 设置对应的值 diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index f7cc6e5f..1f404bb8 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -8,7 +8,7 @@ import torch.distributed as dist from pathlib import Path import re -from fastNLP.core.callbacks.checkpoint_callback import CheckpointCallback +from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback from fastNLP.core.controllers.trainer import Trainer from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK @@ -80,16 +80,23 @@ def test_model_checkpoint_callback_1( version, only_state_dict ): +# def test_model_checkpoint_callback_1( +# model_and_optimizers: TrainerParameters, +# driver='torch_ddp', +# device=[0, 1], +# version=1, +# only_state_dict=True +# ): path = Path.cwd().joinpath(f"test_model_checkpoint") path.mkdir(exist_ok=True, parents=True) if version == 0: callbacks = [ - CheckpointCallback( + ModelCheckpointCallback( monitor="acc", save_folder=path, save_every_n_epochs=1, - save_every_n_global_batches=123, # 避免和 epoch 的保存重复; + save_every_n_batches=123, # 避免和 epoch 的保存重复; save_topk=None, save_last=False, save_on_exception=None, @@ -98,11 +105,11 @@ def test_model_checkpoint_callback_1( ] elif version == 1: callbacks = [ - CheckpointCallback( + ModelCheckpointCallback( monitor="acc", save_folder=path, save_every_n_epochs=3, - save_every_n_global_batches=None, + save_every_n_batches=None, save_topk=2, save_last=True, save_on_exception=None, @@ -121,7 +128,6 @@ def test_model_checkpoint_callback_1( input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, - n_epochs=10, callbacks=callbacks, output_from_new_proc="all" @@ -134,31 +140,31 @@ def test_model_checkpoint_callback_1( if version == 0: if driver == "torch": - assert "epoch_10-global_batch_250-acc" in all_saved_model_paths - assert "epoch_4-global_batch_123-acc" in all_saved_model_paths + assert "model-epoch_10" in all_saved_model_paths + assert "model-epoch_4-batch_123" in all_saved_model_paths - epoch_save_path = all_saved_model_paths["epoch_10-global_batch_250-acc"] - step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"] + epoch_save_path = all_saved_model_paths["model-epoch_10"] + step_save_path = all_saved_model_paths["model-epoch_4-batch_123"] assert len(all_saved_model_paths) == 12 # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "epoch_6-global_batch_78-acc" in all_saved_model_paths - assert "epoch_9-global_batch_123-acc" in all_saved_model_paths + assert "model-epoch_6" in all_saved_model_paths + assert "model-epoch_9-batch_123" in all_saved_model_paths - epoch_save_path = all_saved_model_paths["epoch_6-global_batch_78-acc"] - step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"] + epoch_save_path = all_saved_model_paths["model-epoch_6"] + step_save_path = all_saved_model_paths["model-epoch_9-batch_123"] assert len(all_saved_model_paths) == 11 all_state_dicts = [epoch_save_path, step_save_path] elif version == 1: - pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") + pattern = re.compile("model-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") if driver == "torch": - assert "epoch_9-global_batch_225-acc" in all_saved_model_paths - assert "last" in all_saved_model_paths + assert "model-epoch_9" in all_saved_model_paths + assert "model-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: each_folder_name = pattern.findall(each_folder_name) @@ -166,15 +172,15 @@ def test_model_checkpoint_callback_1( aLL_topk_folders.append(each_folder_name[0]) assert len(aLL_topk_folders) == 2 - epoch_save_path = all_saved_model_paths["epoch_9-global_batch_225-acc"] - last_save_path = all_saved_model_paths["last"] + epoch_save_path = all_saved_model_paths["model-epoch_9"] + last_save_path = all_saved_model_paths["model-last"] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] assert len(all_saved_model_paths) == 6 # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "epoch_9-global_batch_117-acc" in all_saved_model_paths - assert "last" in all_saved_model_paths + assert "model-epoch_9" in all_saved_model_paths + assert "model-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: @@ -183,8 +189,8 @@ def test_model_checkpoint_callback_1( aLL_topk_folders.append(each_folder_name[0]) assert len(aLL_topk_folders) == 2 - epoch_save_path = all_saved_model_paths["epoch_9-global_batch_117-acc"] - last_save_path = all_saved_model_paths["last"] + epoch_save_path = all_saved_model_paths["model-epoch_9"] + last_save_path = all_saved_model_paths["model-last"] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] assert len(all_saved_model_paths) == 6 @@ -212,7 +218,7 @@ def test_model_checkpoint_callback_1( finally: synchronize_safe_rm(path) - # pass + pass if dist.is_initialized(): dist.destroy_process_group() @@ -238,11 +244,11 @@ def test_model_checkpoint_callback_2( raise NotImplementedError callbacks = [ - CheckpointCallback( + ModelCheckpointCallback( monitor="acc1", save_folder=path, save_every_n_epochs=None, - save_every_n_global_batches=None, + save_every_n_batches=None, save_topk=None, save_last=False, save_on_exception=NotImplementedError, @@ -279,12 +285,12 @@ def test_model_checkpoint_callback_2( all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} if driver == "torch": - assert "epoch_4-global_batch_100-acc_NotImplementedError" in all_saved_model_paths - exception_model_path = all_saved_model_paths["epoch_4-global_batch_100-acc_NotImplementedError"] + assert "model-epoch_4-batch_100-exception_NotImplementedError" in all_saved_model_paths + exception_model_path = all_saved_model_paths["model-epoch_4-batch_100-exception_NotImplementedError"] # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "epoch_4-global_batch_52-acc_NotImplementedError" in all_saved_model_paths - exception_model_path = all_saved_model_paths["epoch_4-global_batch_52-acc_NotImplementedError"] + assert "model-epoch_4-batch_52-exception_NotImplementedError" in all_saved_model_paths + exception_model_path = all_saved_model_paths["model-epoch_4-batch_52-exception_NotImplementedError"] assert len(all_saved_model_paths) == 1 all_state_dicts = [exception_model_path] @@ -332,12 +338,11 @@ def test_trainer_checkpoint_callback_1( if version == 0: callbacks = [ - CheckpointCallback( + TrainerCheckpointCallback( monitor="acc", - is_trainer_checkpoint=True, save_folder=path, save_every_n_epochs=7, - save_every_n_global_batches=123, # 避免和 epoch 的保存重复; + save_every_n_batches=123, # 避免和 epoch 的保存重复; save_topk=None, save_last=False, save_on_exception=None, @@ -346,12 +351,11 @@ def test_trainer_checkpoint_callback_1( ] elif version == 1: callbacks = [ - CheckpointCallback( + TrainerCheckpointCallback( monitor="acc", - is_trainer_checkpoint=True, save_folder=path, save_every_n_epochs=None, - save_every_n_global_batches=None, + save_every_n_batches=None, save_topk=2, save_last=True, save_on_exception=None, @@ -383,31 +387,31 @@ def test_trainer_checkpoint_callback_1( if version == 0: if driver == "torch": - assert "epoch_7-global_batch_175-acc" in all_saved_model_paths - assert "epoch_4-global_batch_123-acc" in all_saved_model_paths + assert "trainer-epoch_7" in all_saved_model_paths + assert "trainer-epoch_4-batch_123" in all_saved_model_paths - epoch_save_path = all_saved_model_paths["epoch_7-global_batch_175-acc"] - step_save_path = all_saved_model_paths["epoch_4-global_batch_123-acc"] + epoch_save_path = all_saved_model_paths["trainer-epoch_7"] + step_save_path = all_saved_model_paths["trainer-epoch_4-batch_123"] assert len(all_saved_model_paths) == 3 # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "epoch_7-global_batch_91-acc" in all_saved_model_paths - assert "epoch_9-global_batch_123-acc" in all_saved_model_paths + assert "trainer-epoch_7" in all_saved_model_paths + assert "trainer-epoch_9-batch_123" in all_saved_model_paths - epoch_save_path = all_saved_model_paths["epoch_7-global_batch_91-acc"] - step_save_path = all_saved_model_paths["epoch_9-global_batch_123-acc"] + epoch_save_path = all_saved_model_paths["trainer-epoch_7"] + step_save_path = all_saved_model_paths["trainer-epoch_9-batch_123"] assert len(all_saved_model_paths) == 2 all_state_dicts = [epoch_save_path, step_save_path] elif version == 1: - pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") + pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} if driver == "torch": - assert "last" in all_saved_model_paths + assert "trainer-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: each_folder_name = pattern.findall(each_folder_name) @@ -415,13 +419,13 @@ def test_trainer_checkpoint_callback_1( aLL_topk_folders.append(each_folder_name[0]) assert len(aLL_topk_folders) == 2 - last_save_path = all_saved_model_paths["last"] + last_save_path = all_saved_model_paths["trainer-last"] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] assert len(all_saved_model_paths) == 3 # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "last" in all_saved_model_paths + assert "trainer-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: @@ -430,7 +434,7 @@ def test_trainer_checkpoint_callback_1( aLL_topk_folders.append(each_folder_name[0]) assert len(aLL_topk_folders) == 2 - last_save_path = all_saved_model_paths["last"] + last_save_path = all_saved_model_paths["trainer-last"] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] assert len(all_saved_model_paths) == 3 @@ -474,10 +478,11 @@ def test_trainer_checkpoint_callback_2( device, version ): + pytest.skip("Skip transformers test for now.") path = Path.cwd().joinpath(f"test_model_checkpoint") path.mkdir(exist_ok=True, parents=True) - import transformers + import transformers # 版本4.16.2 import torch from torchmetrics import Accuracy from transformers import AutoModelForSequenceClassification @@ -587,12 +592,11 @@ def test_trainer_checkpoint_callback_2( if version == 0: callbacks = [ - CheckpointCallback( + TrainerCheckpointCallback( monitor="acc", - is_trainer_checkpoint=True, save_folder=path, save_every_n_epochs=None, - save_every_n_global_batches=50, + save_every_n_batches=50, save_topk=None, save_last=False, save_on_exception=None, @@ -601,12 +605,11 @@ def test_trainer_checkpoint_callback_2( ] elif version == 1: callbacks = [ - CheckpointCallback( + TrainerCheckpointCallback( monitor="acc", - is_trainer_checkpoint=True, save_folder=path, save_every_n_epochs=None, - save_every_n_global_batches=None, + save_every_n_batches=None, save_topk=1, save_last=True, save_on_exception=None, @@ -638,27 +641,27 @@ def test_trainer_checkpoint_callback_2( if version == 0: if driver == "torch": - assert "epoch_1-global_batch_200-acc" in all_saved_model_paths + assert "trainer-epoch_1-batch_200" in all_saved_model_paths - epoch_save_path = all_saved_model_paths["epoch_1-global_batch_200-acc"] + epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_200"] assert len(all_saved_model_paths) == 4 # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "epoch_1-global_batch_100-acc" in all_saved_model_paths + assert "trainer-epoch_1-batch_100" in all_saved_model_paths - epoch_save_path = all_saved_model_paths["epoch_1-global_batch_100-acc"] + epoch_save_path = all_saved_model_paths["trainer-epoch_1-batch_100"] assert len(all_saved_model_paths) == 2 all_state_dicts = [epoch_save_path] elif version == 1: - pattern = re.compile("epoch_[0-9]+-global_batch_[0-9]+-[a-z|A-Z]+_[0-9]*.?[0-9]*") + pattern = re.compile("trainer-epoch_[0-9]+-batch_[0-9]+-[a-zA-Z#]+_[0-9]*.?[0-9]*") # all_saved_model_paths = {w.name: w for w in path.joinpath(os.environ[FASTNLP_LAUNCH_TIME]).iterdir()} if driver == "torch": - assert "last" in all_saved_model_paths + assert "trainer-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: each_folder_name = pattern.findall(each_folder_name) @@ -666,13 +669,13 @@ def test_trainer_checkpoint_callback_2( aLL_topk_folders.append(each_folder_name[0]) assert len(aLL_topk_folders) == 1 - last_save_path = all_saved_model_paths["last"] + last_save_path = all_saved_model_paths["trainer-last"] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] assert len(all_saved_model_paths) == 2 # ddp 下的文件名不同,因为同样的数据,ddp 用了更少的步数跑完; else: - assert "last" in all_saved_model_paths + assert "trainer-last" in all_saved_model_paths aLL_topk_folders = [] for each_folder_name in all_saved_model_paths: @@ -681,7 +684,7 @@ def test_trainer_checkpoint_callback_2( aLL_topk_folders.append(each_folder_name[0]) assert len(aLL_topk_folders) == 1 - last_save_path = all_saved_model_paths["last"] + last_save_path = all_saved_model_paths["trainer-last"] topk_save_path = all_saved_model_paths[aLL_topk_folders[0]] assert len(all_saved_model_paths) == 2 diff --git a/tests/core/dataloaders/paddle_dataloader/test_fdl.py b/tests/core/dataloaders/paddle_dataloader/test_fdl.py index dbca394b..20795166 100644 --- a/tests/core/dataloaders/paddle_dataloader/test_fdl.py +++ b/tests/core/dataloaders/paddle_dataloader/test_fdl.py @@ -1,4 +1,4 @@ -import unittest +import pytest from fastNLP.core.dataloaders.paddle_dataloader.fdl import PaddleDataLoader from fastNLP.core.dataset import DataSet @@ -17,7 +17,7 @@ class RandomDataset(Dataset): return 10 -class TestPaddle(unittest.TestCase): +class TestPaddle: def test_init(self): # ds = DataSet({'x': [[1, 2], [2, 3, 4], [1]] * 10, 'y': [0, 1, 1] * 10}) diff --git a/tests/core/dataloaders/torch_dataloader/test_fdl.py b/tests/core/dataloaders/torch_dataloader/test_fdl.py index 0cd17ddd..baa3781a 100644 --- a/tests/core/dataloaders/torch_dataloader/test_fdl.py +++ b/tests/core/dataloaders/torch_dataloader/test_fdl.py @@ -1,25 +1,25 @@ -import unittest +import pytest -from fastNLP.core.dataloaders.torch_dataloader import FDataLoader, prepare_dataloader +from fastNLP.core.dataloaders.torch_dataloader import TorchDataLoader, prepare_torch_dataloader from fastNLP.core.dataset import DataSet from fastNLP.io.data_bundle import DataBundle -class TestFdl(unittest.TestCase): +class TestFdl: def test_init_v1(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) - fdl = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) + fdl = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True) # for batch in fdl: # print(batch) - fdl1 = FDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True) + fdl1 = TorchDataLoader(ds, batch_size=3, shuffle=True, drop_last=True, as_numpy=True) # for batch in fdl1: # print(batch) def test_set_padding(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) ds.set_pad_val("x", val=-1) - fdl = FDataLoader(ds, batch_size=3) + fdl = TorchDataLoader(ds, batch_size=3) fdl.set_input("x", "y") for batch in fdl: print(batch) @@ -36,7 +36,7 @@ class TestFdl(unittest.TestCase): _dict["Y"].append(ins['y']) return _dict - fdl = FDataLoader(ds, batch_size=3, as_numpy=True) + fdl = TorchDataLoader(ds, batch_size=3, as_numpy=True) fdl.set_input("x", "y") fdl.add_collator(collate_fn) for batch in fdl: @@ -44,7 +44,7 @@ class TestFdl(unittest.TestCase): def test_get_batch_indices(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) - fdl = FDataLoader(ds, batch_size=3, shuffle=True) + fdl = TorchDataLoader(ds, batch_size=3, shuffle=True) fdl.set_input("y", "x") for batch in fdl: print(fdl.get_batch_indices()) @@ -67,30 +67,30 @@ class TestFdl(unittest.TestCase): return object.__getattribute__(self, item) dataset = _DataSet() - dl = FDataLoader(dataset, batch_size=2, shuffle=True) + dl = TorchDataLoader(dataset, batch_size=2, shuffle=True) # dl.set_inputs('data', 'labels') # dl.set_pad_val('labels', val=None) for batch in dl: print(batch) print(dl.get_batch_indices()) - def test_prepare_dataloader(self): + def test_prepare_torch_dataloader(self): ds = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) - dl = prepare_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) - assert isinstance(dl, FDataLoader) + dl = prepare_torch_dataloader(ds, batch_size=8, shuffle=True, num_workers=2) + assert isinstance(dl, TorchDataLoader) ds1 = DataSet({"x": [[1, 2], [2, 3, 4], [4, 5, 6, 7]] * 10, "y": [1, 0, 1] * 10}) dbl = DataBundle(datasets={'train': ds, 'val': ds1}) - dl_bundle = prepare_dataloader(dbl) - assert isinstance(dl_bundle['train'], FDataLoader) - assert isinstance(dl_bundle['val'], FDataLoader) + dl_bundle = prepare_torch_dataloader(dbl) + assert isinstance(dl_bundle['train'], TorchDataLoader) + assert isinstance(dl_bundle['val'], TorchDataLoader) ds_dict = {'train_1': ds, 'val': ds1} - dl_dict = prepare_dataloader(ds_dict) - assert isinstance(dl_dict['train_1'], FDataLoader) - assert isinstance(dl_dict['val'], FDataLoader) + dl_dict = prepare_torch_dataloader(ds_dict) + assert isinstance(dl_dict['train_1'], TorchDataLoader) + assert isinstance(dl_dict['val'], TorchDataLoader) sequence = [ds, ds1] - seq_ds = prepare_dataloader(sequence) - assert isinstance(seq_ds[0], FDataLoader) - assert isinstance(seq_ds[1], FDataLoader) + seq_ds = prepare_torch_dataloader(sequence) + assert isinstance(seq_ds[0], TorchDataLoader) + assert isinstance(seq_ds[1], TorchDataLoader) diff --git a/tests/core/dataset/test_dataset.py b/tests/core/dataset/test_dataset.py index 78c48c54..8ff64d04 100644 --- a/tests/core/dataset/test_dataset.py +++ b/tests/core/dataset/test_dataset.py @@ -1,12 +1,12 @@ import os -import unittest +import pytest import numpy as np from fastNLP.core.dataset import DataSet, FieldArray, Instance, ApplyResultException -class TestDataSetInit(unittest.TestCase): +class TestDataSetInit: """初始化DataSet的办法有以下几种: 1) 用dict: 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) @@ -24,46 +24,46 @@ class TestDataSetInit(unittest.TestCase): def test_init_v1(self): # 一维list ds = DataSet([Instance(x=[1, 2, 3, 4], y=[5, 6])] * 40) - self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) - self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) - self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) + assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True + assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40 + assert ds.field_arrays["y"].content == [[5, 6], ] * 40 def test_init_v2(self): # 用dict ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) - self.assertTrue("x" in ds.field_arrays and "y" in ds.field_arrays) - self.assertEqual(ds.field_arrays["x"].content, [[1, 2, 3, 4], ] * 40) - self.assertEqual(ds.field_arrays["y"].content, [[5, 6], ] * 40) + assert ("x" in ds.field_arrays and "y" in ds.field_arrays) == True + assert ds.field_arrays["x"].content == [[1, 2, 3, 4], ] * 40 + assert ds.field_arrays["y"].content == [[5, 6], ] * 40 def test_init_assert(self): - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): _ = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 100}) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): _ = DataSet([[1, 2, 3, 4]] * 10) - with self.assertRaises(ValueError): + with pytest.raises(ValueError): _ = DataSet(0.00001) -class TestDataSetMethods(unittest.TestCase): +class TestDataSetMethods: def test_append(self): dd = DataSet() for _ in range(3): dd.append(Instance(x=[1, 2, 3, 4], y=[5, 6])) - self.assertEqual(len(dd), 3) - self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3, 4]] * 3) - self.assertEqual(dd.field_arrays["y"].content, [[5, 6]] * 3) + assert len(dd) == 3 + assert dd.field_arrays["x"].content == [[1, 2, 3, 4]] * 3 + assert dd.field_arrays["y"].content == [[5, 6]] * 3 def test_add_field(self): dd = DataSet() dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("y", [[1, 2, 3, 4]] * 10) dd.add_field("z", [[5, 6]] * 10) - self.assertEqual(len(dd), 10) - self.assertEqual(dd.field_arrays["x"].content, [[1, 2, 3]] * 10) - self.assertEqual(dd.field_arrays["y"].content, [[1, 2, 3, 4]] * 10) - self.assertEqual(dd.field_arrays["z"].content, [[5, 6]] * 10) + assert len(dd) == 10 + assert dd.field_arrays["x"].content == [[1, 2, 3]] * 10 + assert dd.field_arrays["y"].content == [[1, 2, 3, 4]] * 10 + assert dd.field_arrays["z"].content == [[5, 6]] * 10 - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): dd.add_field("??", [[1, 2]] * 40) def test_delete_field(self): @@ -71,8 +71,8 @@ class TestDataSetMethods(unittest.TestCase): dd.add_field("x", [[1, 2, 3]] * 10) dd.add_field("y", [[1, 2, 3, 4]] * 10) dd.delete_field("x") - self.assertFalse("x" in dd.field_arrays) - self.assertTrue("y" in dd.field_arrays) + assert ("x" in dd.field_arrays) == False + assert "y" in dd.field_arrays def test_delete_instance(self): dd = DataSet() @@ -80,99 +80,113 @@ class TestDataSetMethods(unittest.TestCase): dd.add_field("x", [[1, 2, 3]] * old_length) dd.add_field("y", [[1, 2, 3, 4]] * old_length) dd.delete_instance(0) - self.assertEqual(len(dd), old_length - 1) + assert len(dd) == old_length - 1 dd.delete_instance(0) - self.assertEqual(len(dd), old_length - 2) + assert len(dd) == old_length - 2 def test_getitem(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ins_1, ins_0 = ds[0], ds[1] - self.assertTrue(isinstance(ins_1, Instance) and isinstance(ins_0, Instance)) - self.assertEqual(ins_1["x"], [1, 2, 3, 4]) - self.assertEqual(ins_1["y"], [5, 6]) - self.assertEqual(ins_0["x"], [1, 2, 3, 4]) - self.assertEqual(ins_0["y"], [5, 6]) + assert isinstance(ins_1, Instance) and isinstance(ins_0, Instance) == True + assert ins_1["x"] == [1, 2, 3, 4] + assert ins_1["y"] == [5, 6] + assert ins_0["x"] == [1, 2, 3, 4] + assert ins_0["y"] == [5, 6] sub_ds = ds[:10] - self.assertTrue(isinstance(sub_ds, DataSet)) - self.assertEqual(len(sub_ds), 10) + assert isinstance(sub_ds, DataSet) == True + assert len(sub_ds) == 10 sub_ds_1 = ds[[10, 0, 2, 3]] - self.assertTrue(isinstance(sub_ds_1, DataSet)) - self.assertEqual(len(sub_ds_1), 4) + assert isinstance(sub_ds_1, DataSet) == True + assert len(sub_ds_1) == 4 field_array = ds['x'] - self.assertTrue(isinstance(field_array, FieldArray)) - self.assertEqual(len(field_array), 40) + assert isinstance(field_array, FieldArray) == True + assert len(field_array) == 40 + + def test_setitem(self): + ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) + ds.add_field('i', list(range(len(ds)))) + assert ds.get_field('i').content == list(range(len(ds))) + import random + random.shuffle(ds) + import numpy as np + np.random.shuffle(ds) + assert ds.get_field('i').content != list(range(len(ds))) + + ins1 = ds[1] + ds[2] = ds[1] + assert ds[2]['x'] == ins1['x'] and ds[2]['y'] == ins1['y'] def test_get_item_error(self): - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) _ = ds[40:] - with self.assertRaises(KeyError): + with pytest.raises(KeyError): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) _ = ds["kom"] def test_len_(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) - self.assertEqual(len(ds), 40) + assert len(ds) == 40 ds = DataSet() - self.assertEqual(len(ds), 0) + assert len(ds) == 0 def test_add_fieldarray(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) - ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*40)) - self.assertEqual(ds['z'].content, [[7, 8]]*40) + ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 40)) + assert ds['z'].content == [[7, 8]] * 40 - with self.assertRaises(RuntimeError): - ds.add_fieldarray('z', FieldArray('z', [[7, 8]]*10)) + with pytest.raises(RuntimeError): + ds.add_fieldarray('z', FieldArray('z', [[7, 8]] * 10)) - with self.assertRaises(TypeError): + with pytest.raises(TypeError): ds.add_fieldarray('z', [1, 2, 4]) def test_copy_field(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) ds.copy_field('x', 'z') - self.assertEqual(ds['x'].content, ds['z'].content) + assert ds['x'].content == ds['z'].content def test_has_field(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) - self.assertTrue(ds.has_field('x')) - self.assertFalse(ds.has_field('z')) + assert ds.has_field('x') == True + assert ds.has_field('z') == False def test_get_field(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) - with self.assertRaises(KeyError): + with pytest.raises(KeyError): ds.get_field('z') x_array = ds.get_field('x') - self.assertEqual(x_array.content, [[1, 2, 3, 4]] * 40) + assert x_array.content == [[1, 2, 3, 4]] * 40 def test_get_all_fields(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) field_arrays = ds.get_all_fields() - self.assertEqual(field_arrays["x"], [[1, 2, 3, 4]] * 40) - self.assertEqual(field_arrays['y'], [[5, 6]] * 40) + assert field_arrays["x"].content == [[1, 2, 3, 4]] * 40 + assert field_arrays['y'].content == [[5, 6]] * 40 def test_get_field_names(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) field_names = ds.get_field_names() - self.assertTrue('x' in field_names) - self.assertTrue('y' in field_names) + assert 'x' in field_names + assert 'y' in field_names def test_apply(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 4000, "y": [[5, 6]] * 4000}) ds.apply(lambda ins: ins["x"][::-1], new_field_name="rx", progress_desc='rx') - self.assertTrue("rx" in ds.field_arrays) - self.assertEqual(ds.field_arrays["rx"].content[0], [4, 3, 2, 1]) + assert ("rx" in ds.field_arrays) == True + assert ds.field_arrays["rx"].content[0] == [4, 3, 2, 1] ds.apply(lambda ins: len(ins["y"]), new_field_name="y", show_progress_bar=False) - self.assertEqual(ds.field_arrays["y"].content[0], 2) + assert ds.field_arrays["y"].content[0] == 2 res = ds.apply(lambda ins: len(ins["x"]), num_proc=0, progress_desc="len") - self.assertTrue(isinstance(res, list) and len(res) > 0) - self.assertTrue(res[0], 4) + assert (isinstance(res, list) and len(res) > 0) == True + assert res[0] == 4 ds.apply(lambda ins: (len(ins["x"]), "hahaha"), new_field_name="k") # expect no exception raised @@ -192,6 +206,7 @@ class TestDataSetMethods(unittest.TestCase): def modify_inplace(instance): instance['words'] = 1 + ds.apply(modify_inplace) # with self.assertRaises(TypeError): # ds.apply(modify_inplace) @@ -216,48 +231,48 @@ class TestDataSetMethods(unittest.TestCase): T.apply_more(func_1) # print(T['c'][0, 1, 2]) - self.assertEqual(list(T["c"].content), [2, 4, 6]) - self.assertEqual(list(T["d"].content), [1, 4, 9]) + assert list(T["c"].content) == [2, 4, 6] + assert list(T["d"].content) == [1, 4, 9] res = T.apply_field_more(func_2, "a", modify_fields=False) - self.assertEqual(list(T["c"].content), [2, 4, 6]) - self.assertEqual(list(T["d"].content), [1, 4, 9]) - self.assertEqual(list(res["c"]), [3, 6, 9]) - self.assertEqual(list(res["d"]), [1, 8, 27]) + assert list(T["c"].content) == [2, 4, 6] + assert list(T["d"].content) == [1, 4, 9] + assert list(res["c"]) == [3, 6, 9] + assert list(res["d"]) == [1, 8, 27] - with self.assertRaises(ApplyResultException) as e: + with pytest.raises(ApplyResultException) as e: T.apply_more(func_err_1) print(e) - with self.assertRaises(ApplyResultException) as e: + with pytest.raises(ApplyResultException) as e: T.apply_field_more(func_err_2, "a") print(e) def test_drop(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6], [7, 8, 9, 0]] * 20}) ds.drop(lambda ins: len(ins["y"]) < 3, inplace=True) - self.assertEqual(len(ds), 20) + assert len(ds) == 20 def test_contains(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 40, "y": [[5, 6]] * 40}) - self.assertTrue("x" in ds) - self.assertTrue("y" in ds) - self.assertFalse("z" in ds) + assert ("x" in ds) == True + assert ("y" in ds) == True + assert ("z" in ds) == False def test_rename_field(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds.rename_field("x", "xx") - self.assertTrue("xx" in ds) - self.assertFalse("x" in ds) + assert ("xx" in ds) == True + assert ("x" in ds) == False - with self.assertRaises(KeyError): + with pytest.raises(KeyError): ds.rename_field("yyy", "oo") def test_split(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) d1, d2 = ds.split(0.1) - self.assertEqual(len(d1), len(ds)*0.9) - self.assertEqual(len(d2), len(ds)*0.1) + assert len(d2) == (len(ds) * 0.9) + assert len(d1) == (len(ds) * 0.1) def test_add_field_v2(self): ds = DataSet({"x": [3, 4]}) @@ -268,14 +283,14 @@ class TestDataSetMethods(unittest.TestCase): def test_save_load(self): ds = DataSet({"x": [[1, 2, 3, 4]] * 10, "y": [[5, 6]] * 10}) ds.save("./my_ds.pkl") - self.assertTrue(os.path.exists("./my_ds.pkl")) + assert os.path.exists("./my_ds.pkl") == True ds_1 = DataSet.load("./my_ds.pkl") os.remove("my_ds.pkl") def test_add_null(self): ds = DataSet() - with self.assertRaises(RuntimeError) as RE: + with pytest.raises(RuntimeError) as RE: ds.add_field('test', []) def test_concat(self): @@ -287,16 +302,16 @@ class TestDataSetMethods(unittest.TestCase): ds2 = DataSet({"x": [[4, 3, 2, 1] for _ in range(10)], "y": [[6, 5] for _ in range(10)]}) ds3 = ds1.concat(ds2) - self.assertEqual(len(ds3), 20) + assert len(ds3) == 20 - self.assertListEqual(ds1[9]['x'], [1, 2, 3, 4]) - self.assertListEqual(ds1[10]['x'], [4, 3, 2, 1]) + assert ds1[9]['x'] == [1, 2, 3, 4] + assert ds1[10]['x'] == [4, 3, 2, 1] ds2[0]['x'][0] = 100 - self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 + assert ds3[10]['x'][0] == 4 # 不改变copy后的field了 ds3[10]['x'][0] = -100 - self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 + assert ds2[0]['x'][0] == 100 # 不改变copy前的field了 # 测试inplace ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) @@ -304,19 +319,19 @@ class TestDataSetMethods(unittest.TestCase): ds3 = ds1.concat(ds2, inplace=True) ds2[0]['x'][0] = 100 - self.assertEqual(ds3[10]['x'][0], 4) # 不改变copy后的field了 + assert ds3[10]['x'][0] == 4 # 不改变copy后的field了 ds3[10]['x'][0] = -100 - self.assertEqual(ds2[0]['x'][0], 100) # 不改变copy前的field了 + assert ds2[0]['x'][0] == 100 # 不改变copy前的field了 ds3[0]['x'][0] = 100 - self.assertEqual(ds1[0]['x'][0], 100) # 改变copy前的field了 + assert ds1[0]['x'][0] == 100 # 改变copy前的field了 # 测试mapping ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)], "Y": [[6, 5] for i in range(10)]}) ds3 = ds1.concat(ds2, field_mapping={'X': 'x', 'Y': 'y'}) - self.assertEqual(len(ds3), 20) + assert len(ds3) == 20 # 测试忽略掉多余的 ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) @@ -326,7 +341,7 @@ class TestDataSetMethods(unittest.TestCase): # 测试报错 ds1 = DataSet({"x": [[1, 2, 3, 4] for i in range(10)], "y": [[5, 6] for i in range(10)]}) ds2 = DataSet({"X": [[4, 3, 2, 1] for i in range(10)]}) - with self.assertRaises(RuntimeError): + with pytest.raises(RuntimeError): ds3 = ds1.concat(ds2, field_mapping={'X': 'x'}) def test_instance_field_disappear_bug(self): @@ -334,7 +349,7 @@ class TestDataSetMethods(unittest.TestCase): data.copy_field(field_name='raw_chars', new_field_name='chars') _data = data[:1] for field_name in ['raw_chars', 'target', 'chars']: - self.assertTrue(_data.has_field(field_name)) + assert _data.has_field(field_name) == True def test_from_pandas(self): import pandas as pd @@ -342,8 +357,8 @@ class TestDataSetMethods(unittest.TestCase): df = pd.DataFrame({'x': [1, 2, 3], 'y': [4, 5, 6]}) ds = DataSet.from_pandas(df) print(ds) - self.assertEqual(ds['x'].content, [1, 2, 3]) - self.assertEqual(ds['y'].content, [4, 5, 6]) + assert ds['x'].content == [1, 2, 3] + assert ds['y'].content == [4, 5, 6] def test_to_pandas(self): ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) @@ -352,7 +367,7 @@ class TestDataSetMethods(unittest.TestCase): def test_to_csv(self): ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) ds.to_csv("1.csv") - self.assertTrue(os.path.exists("1.csv")) + assert os.path.exists("1.csv") == True os.remove("1.csv") def test_add_collate_fn(self): @@ -360,27 +375,26 @@ class TestDataSetMethods(unittest.TestCase): def collate_fn(item): return item - ds.add_collate_fn(collate_fn) - self.assertEqual(len(ds.collate_fns.collators), 2) + ds.add_collate_fn(collate_fn) def test_get_collator(self): from typing import Callable ds = DataSet({'x': [1, 2, 3], 'y': [4, 5, 6]}) collate_fn = ds.get_collator() - self.assertEqual(isinstance(collate_fn, Callable), True) + assert isinstance(collate_fn, Callable) == True def test_add_seq_len(self): - ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]}) + ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]}) ds.add_seq_len('x') print(ds) def test_set_target(self): - ds = DataSet({'x': [[1, 2], [2, 3 , 4], [3]], 'y': [4, 5, 6]}) + ds = DataSet({'x': [[1, 2], [2, 3, 4], [3]], 'y': [4, 5, 6]}) ds.set_target('x') -class TestFieldArrayInit(unittest.TestCase): +class TestFieldArrayInit: """ 1) 如果DataSet使用dict初始化,那么在add_field中会构造FieldArray: 1.1) 二维list DataSet({"x": [[1, 2], [3, 4]]}) @@ -428,7 +442,6 @@ class TestFieldArrayInit(unittest.TestCase): # list of array fa = FieldArray("x", [np.array([[1, 2], [3, 4]]), np.array([[1, 2], [3, 4]])]) - def test_init_v8(self): # 二维list val = np.array([[1, 2], [3, 4]]) @@ -436,78 +449,78 @@ class TestFieldArrayInit(unittest.TestCase): fa.append(val) -class TestFieldArray(unittest.TestCase): +class TestFieldArray: def test_main(self): fa = FieldArray("x", [1, 2, 3, 4, 5]) - self.assertEqual(len(fa), 5) + assert len(fa) == 5 fa.append(6) - self.assertEqual(len(fa), 6) + assert len(fa) == 6 - self.assertEqual(fa[-1], 6) - self.assertEqual(fa[0], 1) + assert fa[-1] == 6 + assert fa[0] == 1 fa[-1] = 60 - self.assertEqual(fa[-1], 60) + assert fa[-1] == 60 - self.assertEqual(fa.get(0), 1) - self.assertTrue(isinstance(fa.get([0, 1, 2]), np.ndarray)) - self.assertListEqual(list(fa.get([0, 1, 2])), [1, 2, 3]) + assert fa.get(0) == 1 + assert isinstance(fa.get([0, 1, 2]), np.ndarray) == True + assert list(fa.get([0, 1, 2])) == [1, 2, 3] def test_getitem_v1(self): fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) - self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) + assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5] ans = fa[[0, 1]] - self.assertTrue(isinstance(ans, np.ndarray)) - self.assertTrue(isinstance(ans[0], np.ndarray)) - self.assertEqual(ans[0].tolist(), [1.1, 2.2, 3.3, 4.4, 5.5]) - self.assertEqual(ans[1].tolist(), [1, 2, 3, 4, 5]) - self.assertEqual(ans.dtype, np.float64) + assert isinstance(ans, np.ndarray) == True + assert isinstance(ans[0], np.ndarray) == True + assert ans[0].tolist() == [1.1, 2.2, 3.3, 4.4, 5.5] + assert ans[1].tolist() == [1, 2, 3, 4, 5] + assert ans.dtype == np.float64 def test_getitem_v2(self): x = np.random.rand(10, 5) fa = FieldArray("my_field", x) indices = [0, 1, 3, 4, 6] for a, b in zip(fa[indices], x[indices]): - self.assertListEqual(a.tolist(), b.tolist()) + assert a.tolist() == b.tolist() def test_append(self): fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) fa.append([1.2, 2.3, 3.4, 4.5, 5.6]) - self.assertEqual(len(fa), 3) - self.assertEqual(fa[2], [1.2, 2.3, 3.4, 4.5, 5.6]) + assert len(fa) == 3 + assert fa[2] == [1.2, 2.3, 3.4, 4.5, 5.6] def test_pop(self): fa = FieldArray("y", [[1.1, 2.2, 3.3, 4.4, 5.5], [1.0, 2.0, 3.0, 4.0, 5.0]]) fa.pop(0) - self.assertEqual(len(fa), 1) - self.assertEqual(fa[0], [1.0, 2.0, 3.0, 4.0, 5.0]) + assert len(fa) == 1 + assert fa[0] == [1.0, 2.0, 3.0, 4.0, 5.0] fa[0] = [1.1, 2.2, 3.3, 4.4, 5.5] - self.assertEqual(fa[0], [1.1, 2.2, 3.3, 4.4, 5.5]) + assert fa[0] == [1.1, 2.2, 3.3, 4.4, 5.5] -class TestCase(unittest.TestCase): +class TestCase: def test_init(self): fields = {"x": [1, 2, 3], "y": [4, 5, 6]} ins = Instance(x=[1, 2, 3], y=[4, 5, 6]) - self.assertTrue(isinstance(ins.fields, dict)) - self.assertEqual(ins.fields, fields) + assert isinstance(ins.fields, dict) == True + assert ins.fields == fields ins = Instance(**fields) - self.assertEqual(ins.fields, fields) + assert ins.fields == fields def test_add_field(self): fields = {"x": [1, 2, 3], "y": [4, 5, 6]} ins = Instance(**fields) ins.add_field("z", [1, 1, 1]) fields.update({"z": [1, 1, 1]}) - self.assertEqual(ins.fields, fields) + assert ins.fields == fields def test_get_item(self): fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} ins = Instance(**fields) - self.assertEqual(ins["x"], [1, 2, 3]) - self.assertEqual(ins["y"], [4, 5, 6]) - self.assertEqual(ins["z"], [1, 1, 1]) + assert ins["x"] == [1, 2, 3] + assert ins["y"] == [4, 5, 6] + assert ins["z"] == [1, 1, 1] def test_repr(self): fields = {"x": [1, 2, 3], "y": [4, 5, 6], "z": [1, 1, 1]} diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index 791b1203..8e21c20f 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -6,7 +6,7 @@ from fastNLP.core.drivers.paddle_driver.single_device import PaddleSingleDriver from fastNLP.core.samplers.reproducible_sampler import RandomSampler from fastNLP.core.samplers import ReproducibleBatchSampler from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 -from tests.helpers.datasets.paddle_data import PaddleRandomMaxDataset +from tests.helpers.datasets.paddle_data import PaddleNormalDataset, PaddleRandomMaxDataset from tests.helpers.datasets.torch_data import TorchNormalDataset from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 from fastNLP.core import synchronize_safe_rm diff --git a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py index 81d693fc..161bbfe8 100644 --- a/tests/core/drivers/torch_driver/test_torch_replace_sampler.py +++ b/tests/core/drivers/torch_driver/test_torch_replace_sampler.py @@ -30,7 +30,7 @@ class SequenceDataSet: def check_replace_sampler(driver): - # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,ReproducibleBatchSampler + # dist_sampler 可以选择的有['dist', 'unrepeatdist', None]或者是ReproducibleSampler,RandomBatchSampler # reproducible 是 True 和 False # 需要 check 返回的 sampler 和 dataloader 都不同了 diff --git a/tests/core/metrics/test_accuracy_torch.py b/tests/core/metrics/test_accuracy_torch.py index 33fc791a..b62200db 100644 --- a/tests/core/metrics/test_accuracy_torch.py +++ b/tests/core/metrics/test_accuracy_torch.py @@ -118,7 +118,6 @@ class TestAccuracy: def test_v1(self, is_ddp: bool, dataset: DataSet, metric_class: Type['Metric'], metric_kwargs: Dict[str, Any]) -> None: global pool - print(pool) if is_ddp: if sys.platform == "win32": pytest.skip("DDP not supported on windows") diff --git a/tests/core/metrics/test_f1_rec_acc_torch.py b/tests/core/metrics/test_span_f1_rec_acc_torch.py similarity index 72% rename from tests/core/metrics/test_f1_rec_acc_torch.py rename to tests/core/metrics/test_span_f1_rec_acc_torch.py index 121f9530..bc711a54 100644 --- a/tests/core/metrics/test_f1_rec_acc_torch.py +++ b/tests/core/metrics/test_span_f1_rec_acc_torch.py @@ -1,5 +1,5 @@ import pytest -import unittest + from collections import Counter import os, sys import copy @@ -14,6 +14,7 @@ from torch.multiprocessing import Pool, set_start_method from fastNLP.core.vocabulary import Vocabulary from fastNLP.core.metrics import SpanFPreRecMetric from fastNLP.core.dataset import DataSet + set_start_method("spawn", force=True) @@ -45,7 +46,6 @@ def setup_ddp(rank: int, world_size: int, master_port: int) -> None: os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(master_port) - print(torch.cuda.device_count()) if torch.distributed.is_available() and sys.platform not in ("win32", "cygwin"): torch.distributed.init_process_group("gloo", rank=rank, world_size=world_size) @@ -64,15 +64,15 @@ def find_free_network_port() -> int: return port -@pytest.fixture(scope='class', autouse=True) -def pre_process(): - global pool - pool = Pool(processes=NUM_PROCESSES) - master_port = find_free_network_port() - pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) - yield - pool.close() - pool.join() +# @pytest.fixture(scope='class', autouse=True) +# def pre_process(): +# global pool +# pool = Pool(processes=NUM_PROCESSES) +# master_port = find_free_network_port() +# pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) +# yield +# pool.close() +# pool.join() def _test(local_rank: int, @@ -87,18 +87,19 @@ def _test(local_rank: int, # dataset 也类似(每个进程有自己的一个) dataset = copy.deepcopy(dataset) metric.to(device) - print(os.environ.get("MASTER_PORT", "xx")) # 把数据拆到每个 GPU 上,有点模仿 DistributedSampler 的感觉,但这里数据单位是一个 batch(即每个 i 取了一个 batch 到自己的 GPU 上) for i in range(local_rank, len(dataset), world_size): pred, tg, seq_len = dataset[i]['pred'].to(device), dataset[i]['tg'].to(device), dataset[i]['seq_len'] + print(tg, seq_len) metric.update(pred, tg, seq_len) my_result = metric.get_metric() + print(my_result) + print(sklearn_metric) assert my_result == sklearn_metric -class SpanFPreRecMetricTest(unittest.TestCase): - global pool +class TestSpanFPreRecMetric: def test_case1(self): from fastNLP.core.metrics.span_f1_pre_rec_metric import _bmes_tag_to_spans @@ -135,38 +136,36 @@ class SpanFPreRecMetricTest(unittest.TestCase): fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) - bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, - -0.3782, 0.8240], - [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, + bio_sequence = torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, + -0.3782, 0.8240], + [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, -0.3562, -1.4116], - [1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, - 2.0023, 0.7075], - [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, + [ 1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, + 2.0023, 0.7075], + [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, 0.3832, -0.1540], - [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, + [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, -1.3508, -0.9513], - [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, - -0.0842, -0.4294]], + [ 1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, + -0.0842, -0.4294]], - [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, + [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, -1.4138, -0.8853], - [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, - -1.0726, 0.0364], - [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, - -0.8836, -0.9320], - [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, - -1.6857, 1.1571], - [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, - -0.5837, 1.0184], - [1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, - -0.9025, 0.0864]]]) - bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], - [4, 1, 7, 0, 4, 7]]) + [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, + -1.0726, 0.0364], + [ 0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, + -0.8836, -0.9320], + [ 0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, + -1.6857, 1.1571], + [ 1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, + -0.5837, 1.0184], + [ 1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, + -0.9025, 0.0864]]]) + bio_target = torch.LongTensor([[3, 6, 0, 8, 2, 4], [4, 1, 7, 0, 4, 7]]) fastnlp_bio_metric.update(bio_sequence, bio_target, [6, 6]) expect_bio_res = {'pre-1': 0.333333, 'rec-1': 0.333333, 'f-1': 0.333333, 'pre-2': 0.5, 'rec-2': 0.5, 'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} - assert expect_bio_res == fastnlp_bio_metric.get_metric() # print(fastnlp_bio_metric.get_metric()) @@ -253,7 +252,7 @@ class SpanFPreRecMetricTest(unittest.TestCase): # print(expected_metric) metric_value = metric.get_metric() for key, value in expected_metric.items(): - self.assertAlmostEqual(value, metric_value[key], places=5) + np.allclose(value, metric_value[key]) def test_auto_encoding_type_infer(self): # 检查是否可以自动check encode的类型 @@ -270,9 +269,8 @@ class SpanFPreRecMetricTest(unittest.TestCase): vocab.add_word('o') vocabs[encoding_type] = vocab for e in ['bio', 'bioes', 'bmeso']: - with self.subTest(e=e): - metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) - assert metric.encoding_type == e + metric = SpanFPreRecMetric(tag_vocab=vocabs[e]) + assert metric.encoding_type == e bmes_vocab = _generate_tags('bmes') vocab = Vocabulary() @@ -285,7 +283,7 @@ class SpanFPreRecMetricTest(unittest.TestCase): vocab = Vocabulary() for i in range(10): vocab.add_word(str(i)) - with self.assertRaises(Exception): + with pytest.raises(Exception): metric = SpanFPreRecMetric(vocab) def test_encoding_type(self): @@ -304,65 +302,72 @@ class SpanFPreRecMetricTest(unittest.TestCase): vocab.add_word('o') vocabs[encoding_type] = vocab for e1, e2 in product(['bio', 'bioes', 'bmeso'], ['bio', 'bioes', 'bmeso']): - with self.subTest(e1=e1, e2=e2): - if e1 == e2: + if e1 == e2: + metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) + else: + s2 = set(e2) + s2.update(set(e1)) + if s2 == set(e2): + continue + with pytest.raises(AssertionError): metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) - else: - s2 = set(e2) - s2.update(set(e1)) - if s2 == set(e2): - continue - with self.assertRaises(AssertionError): - metric = SpanFPreRecMetric(tag_vocab=vocabs[e1], encoding_type=e2) for encoding_type in ['bio', 'bioes', 'bmeso']: - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): metric = SpanFPreRecMetric(tag_vocab=vocabs[encoding_type], encoding_type='bmes') - with self.assertWarns(Warning): + with pytest.warns(Warning): vocab = Vocabulary(unknown=None, padding=None).add_word_lst(list('bmes')) metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') vocab = Vocabulary().add_word_lst(list('bmes')) metric = SpanFPreRecMetric(tag_vocab=vocab, encoding_type='bmeso') def test_case5(self): - global pool - # pool = Pool(NUM_PROCESSES) - # master_port = find_free_network_port() - # pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) + # global pool + pool = Pool(NUM_PROCESSES) + master_port = find_free_network_port() + pool.starmap(setup_ddp, [(rank, NUM_PROCESSES, master_port) for rank in range(NUM_PROCESSES)]) number_labels = 4 # bio tag fastnlp_bio_vocab = Vocabulary(unknown=None, padding=None) fastnlp_bio_vocab.word_count = Counter(_generate_tags('BIO', number_labels)) # fastnlp_bio_metric = SpanFPreRecMetric(tag_vocab=fastnlp_bio_vocab, only_gross=False) - dataset = DataSet({'pred': [torch.FloatTensor( - [[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, - -0.3782, 0.8240], - [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, - -0.3562, -1.4116], - [1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, - 2.0023, 0.7075], - [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, - 0.3832, -0.1540], - [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, - -1.3508, -0.9513], - [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, - -0.0842, -0.4294]], - - [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, - -1.4138, -0.8853], - [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, - -1.0726, 0.0364], - [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, - -0.8836, -0.9320], - [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, - -1.6857, 1.1571], - [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, - -0.5837, 1.0184], - [1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, - -0.9025, 0.0864]]])] * 100, - 'tg': [torch.LongTensor([[3, 6, 0, 8, 2, 4], - [4, 1, 7, 0, 4, 7]])] * 100, - 'seq_len': [[6, 6]] * 100}) + dataset = DataSet({'pred': [ + torch.FloatTensor([[[-0.4424, -0.4579, -0.7376, 1.8129, 0.1316, 1.6566, -1.2169, + -0.3782, 0.8240], + [-1.2348, -0.1876, -0.1462, -0.4834, -0.6692, -0.9735, 1.1563, + -0.3562, -1.4116], + [1.6550, -0.9555, 0.3782, -1.3160, -1.5835, -0.3443, -1.7858, + 2.0023, 0.7075], + [-0.3772, -0.5447, -1.5631, 1.1614, 1.4598, -1.2764, 0.5186, + 0.3832, -0.1540], + [-0.1011, 0.0600, 1.1090, -0.3545, 0.1284, 1.1484, -1.0120, + -1.3508, -0.9513], + [1.8948, 0.8627, -2.1359, 1.3740, -0.7499, 1.5019, 0.6919, + -0.0842, -0.4294]] + + ]), + torch.FloatTensor([ + [[-0.2802, 0.6941, -0.4788, -0.3845, 1.7752, 1.2950, -1.9490, + -1.4138, -0.8853], + [-1.3752, -0.5457, -0.5305, 0.4018, 0.2934, 0.7931, 2.3845, + -1.0726, 0.0364], + [0.3621, 0.2609, 0.1269, -0.5950, 0.7212, 0.5959, 1.6264, + -0.8836, -0.9320], + [0.2003, -1.0758, -1.1560, -0.6472, -1.7549, 0.1264, 0.6044, + -1.6857, 1.1571], + [1.4277, -0.4915, 0.4496, 2.2027, 0.0730, -3.1792, -0.5125, + -0.5837, 1.0184], + [1.9495, 1.7145, -0.2143, -0.1230, -0.2205, 0.8250, 0.4943, + -0.9025, 0.0864]] + ]) + ], + 'tg': [ + torch.LongTensor([[3, 6, 0, 8, 2, 4]]), + torch.LongTensor([[4, 1, 7, 0, 4, 7]]) + ], + 'seq_len': [ + [6], [6] + ]}) metric_kwargs = { 'tag_vocab': fastnlp_bio_vocab, 'only_gross': False, @@ -372,7 +377,6 @@ class SpanFPreRecMetricTest(unittest.TestCase): 'f-2': 0.5, 'pre-0': 0.0, 'rec-0': 0.0, 'f-0': 0.0, 'pre-3': 0.0, 'rec-3': 0.0, 'f-3': 0.0, 'pre': 0.222222, 'rec': 0.181818, 'f': 0.2} processes = NUM_PROCESSES - print(torch.cuda.device_count()) pool.starmap( partial( @@ -384,3 +388,5 @@ class SpanFPreRecMetricTest(unittest.TestCase): ), [(rank, processes, torch.device(f'cuda:{rank}')) for rank in range(processes)] ) + pool.close() + pool.join() diff --git a/tests/core/samplers/test_reproducible_batch_sampler.py b/tests/core/samplers/test_reproducible_batch_sampler.py index 42b86dcd..d51dd912 100644 --- a/tests/core/samplers/test_reproducible_batch_sampler.py +++ b/tests/core/samplers/test_reproducible_batch_sampler.py @@ -4,7 +4,7 @@ import numpy as np import pytest from itertools import chain -from fastNLP.core.samplers import ReproducibleBatchSampler, BucketedBatchSampler +from fastNLP.core.samplers import RandomBatchSampler, BucketedBatchSampler from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler from tests.helpers.datasets.torch_data import TorchNormalDataset @@ -18,7 +18,7 @@ class TestReproducibleBatchSampler: before_batch_size = 7 dataset = TorchNormalDataset(num_of_data=100) dataloader = DataLoader(dataset, batch_size=before_batch_size) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) dataloader = replace_batch_sampler(dataloader, re_batchsampler) forward_steps = 3 @@ -28,15 +28,15 @@ class TestReproducibleBatchSampler: # 1. 保存状态 _get_re_batchsampler = dataloader.batch_sampler - assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) + assert isinstance(_get_re_batchsampler, RandomBatchSampler) state = _get_re_batchsampler.state_dict() assert state == {"index_list": array("I", list(range(100))), "data_idx": forward_steps*before_batch_size, - "sampler_type": "ReproducibleBatchSampler"} + "sampler_type": "RandomBatchSampler"} # 2. 断点重训,重新生成一个 dataloader; # 不改变 batch_size; dataloader = DataLoader(dataset, batch_size=before_batch_size) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) re_batchsampler.load_state_dict(state) dataloader = replace_batch_sampler(dataloader, re_batchsampler) @@ -53,7 +53,7 @@ class TestReproducibleBatchSampler: # 改变 batch_size; after_batch_size = 3 dataloader = DataLoader(dataset, batch_size=after_batch_size) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) re_batchsampler.load_state_dict(state) dataloader = replace_batch_sampler(dataloader, re_batchsampler) @@ -99,7 +99,7 @@ class TestReproducibleBatchSampler: dataset = TorchNormalDataset(num_of_data=100) # 开启 shuffle,来检验断点重训后的第二轮的 index list 是不是重新生成的; dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) dataloader = replace_batch_sampler(dataloader, re_batchsampler) # 将一轮的所有数据保存下来,看是否恢复的是正确的; @@ -111,13 +111,13 @@ class TestReproducibleBatchSampler: # 1. 保存状态 _get_re_batchsampler = dataloader.batch_sampler - assert isinstance(_get_re_batchsampler, ReproducibleBatchSampler) + assert isinstance(_get_re_batchsampler, RandomBatchSampler) state = _get_re_batchsampler.state_dict() # 2. 断点重训,重新生成一个 dataloader; # 不改变 batch_size; dataloader = DataLoader(dataset, batch_size=before_batch_size, shuffle=True) - re_batchsampler = ReproducibleBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) + re_batchsampler = RandomBatchSampler(dataloader.batch_sampler, dataloader.batch_size, drop_last=False) re_batchsampler.load_state_dict(state) dataloader = replace_batch_sampler(dataloader, re_batchsampler) @@ -416,7 +416,6 @@ class TestBucketedBatchSampler: @pytest.mark.parametrize('num_replica', [2, 3]) def test_multi_same_bucket(self, shuffle, drop_last, pad, num_samples, num_replica): # def test_multi_same_bucket(self, shuffle=True, drop_last=True, pad=True, num_samples=623, num_replica=2): - # TODO 两个 rank 上的长度是要在同一个bucket的 dataset = DatasetWithVaryLength(num_of_data=num_samples) batch_size = 6 if num_replica*batch_size > num_samples: diff --git a/tests/core/samplers/test_reproducible_sampler.py b/tests/core/samplers/test_reproducible_sampler.py index 0a3697d3..981d6a03 100644 --- a/tests/core/samplers/test_reproducible_sampler.py +++ b/tests/core/samplers/test_reproducible_sampler.py @@ -1,18 +1,14 @@ -import unittest - -from itertools import product import numpy as np +import pytest from functools import partial -from array import array +from itertools import chain -from fastNLP.core.samplers.reproducible_sampler import RandomSampler -from fastNLP.core.drivers.torch_driver.utils import replace_batch_sampler +from fastNLP.core.samplers.reproducible_sampler import RandomSampler, SortedSampler, SequentialSampler from tests.helpers.datasets.torch_data import TorchNormalDataset - -class TestRandomSamplerYh(unittest.TestCase): +class TestRandomSamplerYh: def test_init(self): # 测试能否正确初始化 dataset = TorchNormalDataset(num_of_data=100) @@ -24,7 +20,7 @@ class TestRandomSamplerYh(unittest.TestCase): dataset = TorchNormalDataset(num_of_data=100) sampler = RandomSampler(dataset) for i in sampler: - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): sampler.set_distributed(1, 0) break @@ -37,39 +33,39 @@ class TestRandomSamplerYh(unittest.TestCase): dataset = TorchNormalDataset(num_of_data=100) sampler = RandomSampler(dataset, shuffle=False) sampler.set_distributed(num_replicas=2, rank=0, pad=False) - self.assertEqual(len(sampler), 50) + assert len(sampler)==50 count = 0 for i in sampler: - self.assertEqual(i%2, 0) + assert i%2==0 count += 1 - self.assertEqual(count, 50) + assert count == 50 sampler.set_distributed(num_replicas=2, rank=1, pad=False) - self.assertEqual(len(sampler), 50) + assert len(sampler)==50 count = 0 for i in sampler: - self.assertEqual(i%2, 1) + assert i%2==1 count += 1 - self.assertEqual(count, 50) + assert count==50 dataset = TorchNormalDataset(num_of_data=101) sampler = RandomSampler(dataset, shuffle=False) sampler.set_distributed(num_replicas=2, rank=0, pad=True) - self.assertEqual(len(sampler), 51) + assert len(sampler)==51 count = 0 for i in sampler: - self.assertEqual(i%2, 0) + assert i%2==0 count += 1 - self.assertEqual(count, 51) + assert count == 51 sampler.set_distributed(num_replicas=2, rank=1, pad=True) - self.assertEqual(len(sampler), 51) + assert len(sampler) == 51 count = 0 for i in sampler: if i!=0: - self.assertEqual(i%2, 1) + assert i%2==1 count += 1 - self.assertEqual(count, 51) + assert count == 51 def test_state_dict_check_length(self): dataset = TorchNormalDataset(num_of_data=100) @@ -77,7 +73,7 @@ class TestRandomSamplerYh(unittest.TestCase): states = sampler.state_dict() new_ds = TorchNormalDataset(num_of_data=10) - with self.assertRaises(AssertionError): + with pytest.raises(AssertionError): new_sampler = RandomSampler(new_ds) new_sampler.load_state_dict(states) @@ -85,99 +81,107 @@ class TestRandomSamplerYh(unittest.TestCase): new_sampler = RandomSampler(new_ds) new_sampler.load_state_dict(states) - def test_state_dict(self): + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('pre_shuffle', [True, False]) + @pytest.mark.parametrize('post_shuffle', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) + def test_state_dict(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): num_samples = 100 dataset = TorchNormalDataset(num_of_data=num_samples) # 测试使用 前后shuffle不一致的load操作 - lst = [0]+np.random.randint(1, num_samples, size=3).tolist() - for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], - lst): - with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): - sampler = RandomSampler(dataset, shuffle=pre_shuffle) - sampler.set_epoch(0) - already_numbers = set() - if num_consumed_samples>0: - for i, j in enumerate(sampler, start=1): - already_numbers.add(j) - if i == num_consumed_samples: - break - self.assertEqual(len(already_numbers), num_consumed_samples) - - states = sampler.state_dict() - - new_sampler = RandomSampler(dataset, shuffle=post_shuffle) - new_sampler.load_state_dict(states) - new_sampler.set_epoch(0) - for i in new_sampler: - self.assertNotIn(i, already_numbers) - - # 测试切换成多卡也没有问题 - other_rank_number = set() - for rank in range(3): - new_sampler = RandomSampler(dataset, shuffle=post_shuffle) - new_sampler.load_state_dict(states) - new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) - new_sampler.set_epoch(0) - count = 0 - for i in new_sampler: - self.assertNotIn(i, other_rank_number) - other_rank_number.add(i) - self.assertNotIn(i, already_numbers) - count += 1 - - def test_state_dict_2(self): + sampler = RandomSampler(dataset, shuffle=pre_shuffle) + sampler.set_epoch(0) + already_numbers = set() + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples + + states = sampler.state_dict() + + new_sampler = RandomSampler(dataset, shuffle=post_shuffle) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = RandomSampler(dataset, shuffle=post_shuffle) + new_sampler.load_state_dict(states) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + new_sampler.set_epoch(0) + count = 0 + seen = 0 + seen_in_other_rank = 0 + for i in new_sampler: + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=1 # 因为pad可能重复 + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('pre_shuffle', [True, False]) + @pytest.mark.parametrize('post_shuffle', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) + def test_state_dict_2(self, pad, pre_shuffle, post_shuffle, num_consumed_samples): # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 num_samples = 100 dataset = TorchNormalDataset(num_of_data=num_samples) # 测试使用 前后shuffle不一致的load操作 - lst = [0]+np.random.randint(1, num_samples//2, size=3).tolist() # lst = [30] - for pre_shuffle, post_shuffle, num_consumed_samples in product([True, False], [True, False], - lst): - with self.subTest(pre_shuffle=pre_shuffle, post_shuffle=post_shuffle, num_consumed_samples=num_consumed_samples): - already_numbers = set() - sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) - sampler.set_distributed(num_replicas=2, rank=0) - sampler.set_epoch(0) - if num_consumed_samples>0: - for i, j in enumerate(sampler, start=1): - already_numbers.add(j) - if i == num_consumed_samples: - break - sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) - sampler.set_epoch(0) - sampler.set_distributed(num_replicas=2, rank=1) - if num_consumed_samples>0: - for i, j in enumerate(sampler, start=1): - already_numbers.add(j) - if i == num_consumed_samples: - break - self.assertEqual(len(already_numbers), num_consumed_samples*2) - - states = sampler.state_dict() - - new_sampler = RandomSampler(dataset, shuffle=post_shuffle) - new_sampler.load_state_dict(states) - new_sampler.set_epoch(0) - for i in new_sampler: - self.assertNotIn(i, already_numbers) - - # 测试切换成多卡也没有问题 - other_rank_number = set() - for rank in range(3): - new_sampler = RandomSampler(dataset, shuffle=post_shuffle) - new_sampler.load_state_dict(states) - new_sampler.set_epoch(0) - new_sampler.set_distributed(num_replicas=3, rank=rank, pad=False) - count = 0 - for i in new_sampler: - self.assertNotIn(i, other_rank_number) - other_rank_number.add(i) - self.assertNotIn(i, already_numbers) - count += 1 - - -class TestRandomSampler(unittest.TestCase): + already_numbers = set() + sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) + sampler.set_distributed(num_replicas=2, rank=0) + sampler.set_epoch(0) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + sampler = RandomSampler(dataset, shuffle=pre_shuffle, seed=0) + sampler.set_epoch(0) + sampler.set_distributed(num_replicas=2, rank=1) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples*2 + + states = sampler.state_dict() + + new_sampler = RandomSampler(dataset, shuffle=post_shuffle) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = RandomSampler(dataset, shuffle=post_shuffle) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + count = 0 + seen = 0 + seen_in_other_rank = 0 + for i in new_sampler: + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=1 # 因为pad可能重复 + + +class TestRandomSampler: # 测试单卡; def test_seed_work_when_shuffle_is_true(self): data_length = 100 @@ -360,4 +364,324 @@ class TestRandomSampler(unittest.TestCase): ... +class DatasetWithVaryLength: + def __init__(self, num_of_data=100, reverse=False): + self.data = np.arange(num_of_data) + if reverse: + self.data = self.data[::-1] + + def __getitem__(self, item): + return self.data[item] + + def __len__(self): + return len(self.data) + + +class TestSortedSampler: + def test_single(self): + num_of_data = 100 + data = DatasetWithVaryLength(num_of_data) + sampler = SortedSampler(data, length=data.data) + indexes = list(sampler) + assert indexes==list(range(num_of_data-1, -1, -1)) + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) + def test_multi(self, pad, num_replica, num_of_data): + data = DatasetWithVaryLength(num_of_data=num_of_data) + samplers = [] + for i in range(num_replica): + sampler = SortedSampler(dataset=data, length=data.data) + sampler.set_distributed(num_replica, rank=i, pad=pad) + samplers.append(sampler) + + # 保证顺序是没乱的 + already_seen_index = set() + for sampler in samplers: + larger_count = 0 # 这里为 0 就可以,因为最后补充的index一定是比较大的数。 + prev_index = float('inf') + cur_set = set() + seen_in_other_rank = 0 + for index in sampler: + seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 + cur_set.add(index) + larger_count += int(index <= prev_index) + prev_index = index + assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 + assert seen_in_other_rank <= 1 if pad else seen_in_other_rank == 0 + already_seen_index.update(cur_set) + + indexes = list(chain(*samplers)) + indexes = set(indexes) + if pad: + assert indexes == set(range(num_of_data)) + else: + assert len(indexes) <= num_of_data + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) + def test_state_dict(self, pad, num_consumed_samples): + num_samples = 100 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + # 测试使用 前后shuffle不一致的load操作 + sampler = SortedSampler(dataset, length=dataset.data) + sampler.set_epoch(0) + already_numbers = set() + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + if already_numbers: + assert j= max(already_numbers)) + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=1 # 因为pad可能重复 + assert smaller<=1 if pad else smaller==0 + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) + def test_state_dict_2(self, pad, num_consumed_samples): + # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 + num_samples = 100 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + # 测试使用 前后shuffle不一致的load操作 + # lst = [30] + already_numbers = set() + sampler = SortedSampler(dataset, length=dataset.data) + sampler.set_distributed(num_replicas=2, rank=0) + sampler.set_epoch(0) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + if already_numbers: + assert j<=max(already_numbers) + already_numbers.add(j) + if i == num_consumed_samples: + break + sampler = SortedSampler(dataset, length=dataset.data) + sampler.set_epoch(0) + sampler.set_distributed(num_replicas=2, rank=1) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples*2 + + states = sampler.state_dict() + + new_sampler = SortedSampler(dataset, length=dataset.data) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + if already_numbers: + assert i < max(already_numbers) + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = SortedSampler(dataset, length=dataset.data) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + count = 0 + seen = 0 + seen_in_other_rank = 0 + smaller = 0 + for i in new_sampler: + if already_numbers: + smaller += int(i>=max(already_numbers)) + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=1 # 因为pad可能重复 + assert smaller <= 1 if pad else smaller == 0 + + +class TestSequentialSampler: + def test_single(self): + num_of_data = 100 + data = DatasetWithVaryLength(num_of_data) + sampler = SequentialSampler(data) + indexes = list(sampler) + assert indexes==list(range(num_of_data)) + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_replica', [2, 3]) + @pytest.mark.parametrize('num_of_data', [2, 3, 4, 100]) + def test_multi(self, pad, num_replica, num_of_data): + data = DatasetWithVaryLength(num_of_data=num_of_data) + samplers = [] + for i in range(num_replica): + sampler = SequentialSampler(dataset=data) + sampler.set_distributed(num_replica, rank=i, pad=pad) + samplers.append(sampler) + + # 保证顺序是没乱的 + already_seen_index = set() + for idx, sampler in enumerate(samplers): + larger_count = 1 + prev_index = float('inf') + cur_set = set() + seen_in_other_rank = 0 + for index in sampler: + seen_in_other_rank += int(index in already_seen_index) # 不同的卡不交叉 + cur_set.add(index) + larger_count += int(index >= prev_index) + prev_index = index + assert larger_count+1 >= len(sampler) # 除了最后一个可能乱掉,其它都必须要保持这个顺序 + assert seen_in_other_rank <= idx if pad else seen_in_other_rank == 0 + already_seen_index.update(cur_set) + + indexes = list(chain(*samplers)) + indexes = set(indexes) + if pad: + assert indexes == set(range(num_of_data)) + else: + assert len(indexes) <= num_of_data + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100, size=3).tolist()) + def test_state_dict(self, pad, num_consumed_samples): + num_samples = 100 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + # 测试使用 前后shuffle不一致的load操作 + sampler = SequentialSampler(dataset=dataset) + sampler.set_epoch(0) + already_numbers = set() + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + if already_numbers: + assert j>max(already_numbers) + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples + + states = sampler.state_dict() + + new_sampler = SequentialSampler(dataset=dataset) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + if already_numbers: + assert i > max(already_numbers) + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = SequentialSampler(dataset=dataset) + new_sampler.load_state_dict(states) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + new_sampler.set_epoch(0) + count = 0 + seen = 0 + seen_in_other_rank = 0 + smaller = 0 + for i in new_sampler: + if already_numbers: + smaller += int(i <= max(already_numbers)) + seen_in_other_rank += int(i in other_rank_number) + other_rank_number.add(i) + seen += int(i in already_numbers) + count += 1 + assert seen <= 1 if pad else seen == 0 + assert seen_in_other_rank<=rank # 因为pad可能重复 + assert smaller<=1 if pad else smaller==0 + + @pytest.mark.parametrize('pad', [True, False]) + @pytest.mark.parametrize('num_consumed_samples', [0]+np.random.randint(1, 100//2, size=3).tolist()) + def test_state_dict_2(self, pad, num_consumed_samples): + # 测试一下从多卡切换到单卡,或者切换到不同卡数量的多卡 + num_samples = 100 + dataset = DatasetWithVaryLength(num_of_data=num_samples) + # 测试使用 前后shuffle不一致的load操作 + # lst = [30] + already_numbers = set() + sampler = SequentialSampler(dataset=dataset) + sampler.set_distributed(num_replicas=2, rank=0) + sampler.set_epoch(0) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + if already_numbers: + assert j>max(already_numbers) + already_numbers.add(j) + if i == num_consumed_samples: + break + sampler = SequentialSampler(dataset=dataset) + sampler.set_epoch(0) + sampler.set_distributed(num_replicas=2, rank=1) + if num_consumed_samples>0: + for i, j in enumerate(sampler, start=1): + already_numbers.add(j) + if i == num_consumed_samples: + break + assert len(already_numbers) == num_consumed_samples*2 + + states = sampler.state_dict() + + new_sampler = SequentialSampler(dataset=dataset) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + for i in new_sampler: + if already_numbers: + assert i > max(already_numbers) + assert i not in already_numbers + + # 测试切换成多卡也没有问题 + other_rank_number = set() + for rank in range(3): + new_sampler = SequentialSampler(dataset=dataset) + new_sampler.load_state_dict(states) + new_sampler.set_epoch(0) + new_sampler.set_distributed(num_replicas=3, rank=rank, pad=pad) + count = 0 + seen = 0 + seen_in_other_rank = 0 + smaller = 0 + for i in new_sampler: + if already_numbers: + smaller += int(i=prev_index + prev_index = index + + indexes = list(chain(*samplers)) + assert len(indexes) == num_of_data + indexes = set(indexes) + assert indexes == set(range(num_of_data))