From 711bbf469c8b5c8357990767fbbd9b0f59e6724c Mon Sep 17 00:00:00 2001 From: YWMditto Date: Sat, 23 Apr 2022 14:44:18 +0800 Subject: [PATCH 01/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86=20trainer.loa?= =?UTF-8?q?d=5Fmodel/load=20=E5=8F=AA=E6=9C=89=E5=8D=95=E5=8D=A1=E5=8A=A0?= =?UTF-8?q?=E8=BD=BD=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index e4cd2817..afd5d06a 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -576,7 +576,7 @@ class Trainer(TrainerEventTrigger): if model_load_fn is not None: if not callable(model_load_fn): raise ValueError("Parameter `model_save_fn` should be `Callable` type when it is not None.") - rank_zero_call(model_load_fn)(folder) + model_load_fn(folder) else: if isinstance(folder, str): folder = Path(folder) @@ -653,7 +653,7 @@ class Trainer(TrainerEventTrigger): if model_load_fn is not None: if not callable(model_load_fn): raise ValueError("Parameter `model_save_fn` should be `Callable`.") - rank_zero_call(model_load_fn)(folder) + model_load_fn(folder) states = self.driver.load(folder=folder, dataloader=dataloader, should_load_model=False, **kwargs) else: states = self.driver.load(folder=folder, dataloader=dataloader, only_state_dict=only_state_dict, should_load_model=True, **kwargs) From e85cbb067eb83a3c739cfd443f589f67806080fc Mon Sep 17 00:00:00 2001 From: yh_cc Date: Sun, 24 Apr 2022 12:09:19 +0800 Subject: [PATCH 02/16] =?UTF-8?q?Rich=E6=94=AF=E6=8C=81jupyter?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback_manager.py | 13 +-------- fastNLP/core/callbacks/progress_callback.py | 32 +++++++++++++++------ fastNLP/core/controllers/evaluator.py | 2 +- fastNLP/core/controllers/trainer.py | 31 ++++++++++++++------ fastNLP/core/utils/rich_progress.py | 23 ++++++++++++++- fastNLP/core/utils/utils.py | 21 +++++++++++++- 6 files changed, 91 insertions(+), 31 deletions(-) diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index 4aa822ad..c5b00e71 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -8,7 +8,6 @@ __all__ = [ from .callback_events import Events from .callback import Callback -from .progress_callback import ProgressCallback, choose_progress_callback from fastNLP.core.log import logger @@ -35,7 +34,7 @@ class CallbackManager: class_callbacks: Optional[List[Callback]] # 用来保留原始的类callback; callback_fns: dict - def __init__(self, callbacks: Optional[List[Callback]], progress_bar='auto'): + def __init__(self, callbacks: Optional[List[Callback]]): r""" 注意 callback 的调用顺序: 1. 通过函数修饰器 `Trainer.on` 添加的 callback 函数; @@ -46,7 +45,6 @@ class CallbackManager: """ self._need_reproducible_sampler = False - _has_progress_callback = False _callbacks = [] if callbacks is not None: if isinstance(callbacks, Callback): @@ -57,16 +55,7 @@ class CallbackManager: for _callback in callbacks: if not isinstance(_callback, Callback): raise TypeError(f"callbacks must be of Callback type, instead of `{type(_callback)}`") - if isinstance(_callback, ProgressCallback): - _has_progress_callback = True _callbacks += callbacks - if not _has_progress_callback: - # 添加 progress callback - progress_callback = choose_progress_callback(progress_bar=progress_bar) - if progress_callback is None: - logger.info("There is no progress bar, Trainer will not output training progress.") - else: - _callbacks.append(progress_callback) self.callback_fns = defaultdict(list) # 因为理论上用户最多只能通过 'trainer.on_train_begin' 或者 'trainer.callback_manager.on_train_begin' 来调用,即其是没办法 # 直接调用具体的某一个 callback 函数,而不调用其余的同名的 callback 函数的,因此我们只需要记录具体 Event 的时机即可; diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 64d72bd0..a6f82896 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -1,6 +1,6 @@ import json import sys - +from typing import Union __all__ = [ 'choose_progress_callback', @@ -11,11 +11,22 @@ __all__ = [ from .has_monitor_callback import HasMonitorCallback from fastNLP.core.utils import f_rich_progress from fastNLP.core.log import logger +from fastNLP.core.utils.utils import is_notebook + + + +class ProgressCallback(HasMonitorCallback): + def on_train_end(self, trainer): + f_rich_progress.stop() + + @property + def name(self): # progress bar的名称 + return 'auto' -def choose_progress_callback(progress_bar:str): +def choose_progress_callback(progress_bar: Union[str, ProgressCallback]) -> ProgressCallback: if progress_bar == 'auto': - if (sys.stdin and sys.stdin.isatty()): + if not f_rich_progress.dummy_rich: progress_bar = 'rich' else: progress_bar = 'raw' @@ -23,15 +34,12 @@ def choose_progress_callback(progress_bar:str): return RichCallback() elif progress_bar == 'raw': return RawTextCallback() + elif isinstance(progress_bar, ProgressCallback): + return progress_bar else: return None -class ProgressCallback(HasMonitorCallback): - def on_train_end(self, trainer): - f_rich_progress.stop() - - class RichCallback(ProgressCallback): def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, format_json=True): @@ -124,6 +132,10 @@ class RichCallback(ProgressCallback): self.task2id = {} self.loss = 0 + @property + def name(self): # progress bar的名称 + return 'rich' + class RawTextCallback(ProgressCallback): def __init__(self, print_every:int = 1, loss_round_ndigit:int = 6, monitor:str=None, larger_better:bool=True, @@ -184,3 +196,7 @@ class RawTextCallback(ProgressCallback): logger.info(json.dumps(trainer.driver.tensor_to_numeric(results))) else: logger.info(results) + + @property + def name(self): # progress bar的名称 + return 'raw' \ No newline at end of file diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 60703ef5..95379302 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -134,7 +134,7 @@ class Evaluator: self.progress_bar = kwargs.get('progress_bar', 'auto') if self.progress_bar == 'auto': - self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' + self.progress_bar = 'raw' if f_rich_progress.dummy_rich else 'rich' self.driver.barrier() diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index e4cd2817..7eb5bbac 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -20,6 +20,7 @@ from fastNLP.core.controllers.utils.utils import TrainerEventTrigger, _Truncated from fastNLP.core.callbacks import Callback, CallbackManager, Events, EventsList from fastNLP.core.callbacks.callback import _CallbackWrapper from fastNLP.core.callbacks.callback_events import _SingleEventState +from fastNLP.core.callbacks.progress_callback import choose_progress_callback from fastNLP.core.drivers import Driver from fastNLP.core.drivers.utils import choose_driver from fastNLP.core.utils import get_fn_arg_names, match_and_substitute_params, nullcontext @@ -125,14 +126,13 @@ class Trainer(TrainerEventTrigger): set_grad_to_none: 是否在训练过程中在每一次 optimizer 更新后将 grad 置为 None; use_dist_sampler: 表示是否使用分布式的 sampler 。在多卡时,分布式 sampler 将自动决定每张卡上读取的 sample ,使得一个epoch 内所有卡的 sample 加起来为一整个数据集的 sample。默认会根据 driver 是否为分布式进行设置。 - eval_use_dist_sampler: 表示在 Evaluator 中在使用 TorchDDPDriver 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; + evaluate_use_dist_sampler: 表示在 Evaluator 中在使用 分布式 的时候是否将 dataloader 的 sampler 替换为分布式的 sampler;默认为 True; output_from_new_proc: 应当为一个字符串,表示在多进程的 driver 中其它进程的输出流应当被做如何处理;其值应当为以下之一: ["all", "ignore", "only_error"];当该参数的值不是以上值时,该值应当表示一个文件夹的名字,我们会将其他 rank 的输出流重定向到 log 文件中,然后将 log 文件保存在通过该参数值设定的文件夹中;默认为 "only_error"; - progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'],默认为 auto 。progress 的实现是通过 - callback 实现的,若在输入的 callback 中检测到了 ProgressCallback 类型的 callback ,则该参数对 Trainer 无效。 - auto 表示如果检测到当前 terminal 为交互型 则使用 rich,否则使用 raw。 - + progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, + 默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果 + 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 """ self.model = model self.marker = marker @@ -195,8 +195,20 @@ class Trainer(TrainerEventTrigger): ) self.driver.set_optimizers(optimizers=optimizers) + # 根据 progress_bar 参数选择 ProgressBarCallback + progress_bar_callback = choose_progress_callback(kwargs.get('progress_bar', 'auto')) + if progress_bar_callback is not None: + if callbacks is None: + callbacks = [] + elif not isinstance(callbacks, Sequence): + callbacks = [callbacks] + + callbacks = list(callbacks) + [progress_bar_callback] + else: + rank_zero_call(logger.warning)("No progress bar is provided, there will have no information output " + "during training.") # 初始化 callback manager; - self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto')) + self.callback_manager = CallbackManager(callbacks) # 添加所有的函数式 callbacks; self._fetch_matched_fn_callbacks() # 添加所有的类 callbacks; @@ -237,6 +249,9 @@ class Trainer(TrainerEventTrigger): self.larger_better = larger_better if metrics is not None and evaluate_dataloaders is not None: check_evaluate_every(evaluate_every) + progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 + if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 + progress_bar = progress_bar.name self.evaluator = Evaluator( model=model, dataloaders=evaluate_dataloaders, @@ -249,8 +264,8 @@ class Trainer(TrainerEventTrigger): output_mapping=output_mapping, fp16=fp16, verbose=0, - use_dist_sampler=kwargs.get("eval_use_dist_sampler", None), - progress_bar=kwargs.get('progress_bar', 'auto') + use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), + progress_bar=progress_bar ) if train_fn is not None and not isinstance(train_fn, str): diff --git a/fastNLP/core/utils/rich_progress.py b/fastNLP/core/utils/rich_progress.py index 82747a01..e7b95d9c 100644 --- a/fastNLP/core/utils/rich_progress.py +++ b/fastNLP/core/utils/rich_progress.py @@ -14,6 +14,7 @@ __all__ = [ ] from fastNLP.envs import get_global_rank +from .utils import is_notebook class Singleton(type): @@ -34,6 +35,14 @@ class DummyFRichProgress: # 防止用户通过 DummyFRichProgress.console.print() 这种调用 return None + @property + def dummy_rich(self)->bool: + """ + 当前对象是否是 dummy 的 rich 对象。 + + :return: + """ + return True class FRichProgress(Progress, metaclass=Singleton): """ @@ -147,6 +156,8 @@ class FRichProgress(Progress, metaclass=Singleton): super().stop_task(task_id) super().remove_task(task_id) self.refresh() # 使得bar不残留 + if len(self._tasks) == 0: + super().stop() def start(self) -> None: super().start() @@ -210,6 +221,15 @@ class FRichProgress(Progress, metaclass=Singleton): if refresh: self.refresh() + @property + def dummy_rich(self) -> bool: + """ + 当前对象是否是 dummy 的 rich 对象。 + + :return: + """ + return False + class SpeedColumn(ProgressColumn): """ @@ -226,7 +246,8 @@ class SpeedColumn(ProgressColumn): return Text(str(round(1/speed, 2))+' s/it.', style='progress.data.speed') -if (sys.stdin and sys.stdin.isatty()) and get_global_rank() == 0: +if ((sys.stdin and sys.stdin.isatty()) or is_notebook()) and \ + get_global_rank() == 0: f_rich_progress = FRichProgress().new_progess( "[progress.description]{task.description}", "[progress.percentage]{task.percentage:>3.0f}%", diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index ff3386fe..c3f57bcf 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -696,4 +696,23 @@ def get_class_that_defined_method(method): None) if isinstance(cls, type): return cls - return getattr(method, '__objclass__', None) # handle special descriptor objects \ No newline at end of file + return getattr(method, '__objclass__', None) # handle special descriptor objects + + +def is_notebook(): + """ + 检查当前运行环境是否为 jupyter + + :return: + """ + try: + from IPython import get_ipython + + if "IPKernelApp" not in get_ipython().config: # pragma: no cover + raise ImportError("console") + if "VSCODE_PID" in os.environ: # pragma: no cover + raise ImportError("vscode") + except: + return False + else: # pragma: no cover + return True \ No newline at end of file From 7e40d984041ad3700d1598863906017887517e91 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 25 Apr 2022 16:47:29 +0800 Subject: [PATCH 03/16] =?UTF-8?q?=E6=96=B0=E5=A2=9Etrain=5Finput=5Fmapping?= =?UTF-8?q?=20=E5=92=8C=20evaluate=5Finput=5Fmapping=20=E7=AD=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/trainer.py | 50 ++++++++++++++++++++++++++--- 1 file changed, 46 insertions(+), 4 deletions(-) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 8a888c2e..9a3c30d5 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -103,10 +103,12 @@ class Trainer(TrainerEventTrigger): value;如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;如果 batch 此时是其它 类型,那么我们将会直接报错;如果 input_mapping 是一个函数,那么对于取出的 batch,我们将不会做任何处理,而是直接将其传入该函数里; 注意该参数会被传进 `Evaluator` 中;因此你可以通过该参数来实现将训练数据 batch 移到对应机器上的工作(例如当参数 `device` 为 None 时); + 如果 train 和 evaluate 需要使用不同的 input_mapping, 请使用 train_input_mapping 与 evaluate_input_mapping 设置。 :param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; + 如果 train 和 evaluate 需要使用不同的 output_mapping, 请使用 train_output_mapping 与 evaluate_output_mapping 设置。 :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; 如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`; @@ -133,6 +135,10 @@ class Trainer(TrainerEventTrigger): progress_bar: 以哪种方式显示 progress ,目前支持[None, 'raw', 'rich', 'auto'] 或者 RichCallback, RawTextCallback对象, 默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 + train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 + train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 + evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 + evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 """ self.model = model self.marker = marker @@ -147,8 +153,18 @@ class Trainer(TrainerEventTrigger): self.evaluate_dataloaders = evaluate_dataloaders self.optimizers = optimizers self.fp16 = fp16 - self.input_mapping = input_mapping - self.output_mapping = output_mapping + + train_input_mapping = kwargs.get('train_input_mapping', None) + train_output_mapping = kwargs.get('train_output_mapping', None) + evaluate_input_mapping = kwargs.get('evaluate_input_mapping', None) + evaluate_output_mapping = kwargs.get('evaluate_output_mapping', None) + + train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping = \ + _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping, + evaluate_input_mapping, evaluate_output_mapping) + + self.input_mapping = train_input_mapping + self.output_mapping = train_output_mapping self.evaluate_fn = evaluate_fn self.batch_step_fn = batch_step_fn @@ -185,8 +201,8 @@ class Trainer(TrainerEventTrigger): callbacks=callbacks, metrics=metrics, evaluate_every=evaluate_every, - input_mapping=input_mapping, - output_mapping=output_mapping, + input_mapping=evaluate_input_mapping, + output_mapping=evaluate_output_mapping, model_wo_auto_param_call=model_wo_auto_param_call, accumulation_steps=accumulation_steps, fp16=fp16, @@ -854,6 +870,32 @@ class Trainer(TrainerEventTrigger): self._evaluate_dataloaders = evaluate_dataloaders +def _get_input_output_mapping(input_mapping, output_mapping, train_input_mapping, train_output_mapping, + evaluate_input_mapping, evaluate_output_mapping): + if train_input_mapping is not None and input_mapping is not None: + raise ValueError("Parameter `input_mapping` and `train_input_mapping` cannot be set simultaneously.") + + if evaluate_input_mapping is not None and input_mapping is not None: + raise ValueError("Parameter `input_mapping` and `evaluate_input_mapping` cannot be set simultaneously.") + + if train_output_mapping is not None and output_mapping is not None: + raise ValueError("Parameter `output_mapping` and `train_output_mapping` cannot be set simultaneously.") + + if evaluate_output_mapping is not None and output_mapping is not None: + raise ValueError("Parameter `output_mapping` and `evaluate_output_mapping` cannot be set simultaneously.") + + if train_input_mapping is None: + train_input_mapping = input_mapping + if evaluate_input_mapping is None: + evaluate_input_mapping = input_mapping + + if train_output_mapping is None: + train_output_mapping = output_mapping + if evaluate_output_mapping is None: + evaluate_output_mapping = output_mapping + + return train_input_mapping, train_output_mapping, evaluate_input_mapping, evaluate_output_mapping + From 6f7bbfabcab444ccdc9233a98f794fda2af49eef Mon Sep 17 00:00:00 2001 From: yh_cc Date: Mon, 25 Apr 2022 19:35:49 +0800 Subject: [PATCH 04/16] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E9=83=A8=E5=88=86?= =?UTF-8?q?=E5=85=B3=E4=BA=8Eevaluate=5Fbatch=5Fstep=5Ffn=E7=9A=84?= =?UTF-8?q?=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/callbacks/more_evaluate_callback.py | 2 +- fastNLP/core/controllers/evaluator.py | 60 ++++++++----------- fastNLP/core/controllers/trainer.py | 33 ++++------ 3 files changed, 38 insertions(+), 57 deletions(-) diff --git a/fastNLP/core/callbacks/more_evaluate_callback.py b/fastNLP/core/callbacks/more_evaluate_callback.py index dbb6505f..6c015bdf 100644 --- a/fastNLP/core/callbacks/more_evaluate_callback.py +++ b/fastNLP/core/callbacks/more_evaluate_callback.py @@ -108,7 +108,7 @@ class MoreEvaluateCallback(HasMonitorCallback): 'metrics': self.metrics, 'driver': self.kwargs.get('driver', trainer.driver), 'device': self.kwargs.get('device', trainer.device), - 'batch_step_fn': self.kwargs.get('batch_step_fn', trainer.evaluate_batch_step_fn), + 'evaluate_batch_step_fn': self.kwargs.get('evaluate_batch_step_fn', trainer.evaluate_batch_step_fn), 'evaluate_fn': self.evaluate_fn, 'input_mapping': self.kwargs.get('input_mapping', trainer.input_mapping), 'output_mapping': self.kwargs.get('output_mapping', trainer.output_mapping), diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 95379302..ada31edb 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -30,22 +30,12 @@ class Evaluator: driver: Driver _evaluate_batch_loop: Loop - def __init__( - self, - model, - dataloaders, - metrics: Optional[Union[Dict, Metric]] = None, - driver: Union[str, Driver] = 'torch', - device: Optional[Union[int, List[int], str]] = None, - batch_step_fn: Optional[callable] = None, - evaluate_fn: Optional[str] = None, - input_mapping: Optional[Union[Callable, Dict]] = None, - output_mapping: Optional[Union[Callable, Dict]] = None, - model_wo_auto_param_call: bool = False, - fp16: bool = False, - verbose: int = 1, - **kwargs - ): + def __init__(self, model, dataloaders, metrics: Optional[Union[Dict, Metric]] = None, + driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, + evaluate_batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, + input_mapping: Optional[Union[Callable, Dict]] = None, + output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, + fp16: bool = False, verbose: int = 1, **kwargs): """ :param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。 @@ -54,13 +44,13 @@ class Evaluator: metric ,torchmetrics,allennlpmetrics等。 :param driver: 使用 driver 。 :param device: 使用的设备。 - :param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 - DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 - batch_step_fn 函数。 + :param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`, + 不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。 :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 `model.forward`;(1) 如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有 找到该方法,则使用 `model.forward` 函数;(2) 如果为 str 类型,则尝试从 model 中寻找该方法,找不到则报错。 - :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 + :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中。如果针对 + model 和 metric 需要不同的 mapping,请考虑使用 evaluate_batch_step_fn 参数定制。 :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 @@ -96,9 +86,9 @@ class Evaluator: self.device = device self.verbose = verbose - if batch_step_fn is not None: - _check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') - self.batch_step_fn = batch_step_fn + if evaluate_batch_step_fn is not None: + _check_valid_parameters_number(evaluate_batch_step_fn, ['evaluator', 'batch'], fn_name='evaluate_batch_step_fn') + self.evaluate_batch_step_fn = evaluate_batch_step_fn self.input_mapping = input_mapping self.output_mapping = output_mapping @@ -106,7 +96,7 @@ class Evaluator: if not isinstance(dataloaders, dict): dataloaders = {None: dataloaders} - self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) + self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=evaluate_batch_step_fn) self.driver.setup() self.driver.barrier() @@ -235,8 +225,8 @@ class Evaluator: @evaluate_batch_loop.setter def evaluate_batch_loop(self, loop: Loop): - if self.batch_step_fn is not None: - logger.warning("`batch_step_fn` was customized in the Evaluator initialization, it will be ignored " + if self.evaluate_batch_step_fn is not None: + logger.warning("`evaluate_batch_step_fn` was customized in the Evaluator initialization, it will be ignored " "when the `evaluate_batch_loop` is also customized.") self._evaluate_batch_loop = loop @@ -249,15 +239,15 @@ class Evaluator: """ self.metrics_wrapper.reset() - def update(self, *args, **kwargs): + def update(self, batch, outputs): """ - 调用所有metric的 update 方法,对当前 batch 的结果进行累积,会根据相应 metric 的参数列表进行匹配传参。 + 自动调用所有 metric 的 update 方法,会根据不同 metric 的参数列表进行匹配传参。 - :param args: - :param kwargs: + :param batch: 一般是来自于 DataLoader 的输出,如果不为 dict 类型的话,该值将被忽略。 + :param outputs: 一般是来自于模型的输出。类别应为 dict 或者 dataclass 类型。 :return: """ - self.metrics_wrapper.update(*args, **kwargs) + self.metrics_wrapper.update(batch, outputs) def get_dataloader_metric(self, dataloader_name: Optional[str] = '') -> Dict: """ @@ -271,7 +261,7 @@ class Evaluator: @property def metrics_wrapper(self): """ - 由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 batch_step_fn )中使用,同时也为了支持 + 由于需要保持 Evaluator 中 metrics 对象与用户传入的 metrics 保持完全一致(方便他在 evaluate_batch_step_fn )中使用,同时也为了支持 不同形式的 metric( fastNLP 的 metric/torchmetrics 等),所以 Evaluator 在进行 metric 操作的时候都调用 metrics_wrapper 进行操作。 @@ -283,11 +273,11 @@ class Evaluator: def evaluate_step(self, batch): """ - 将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再 + 将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 。会将返回结果经过 output_mapping 处理后再 返回。 - :param batch: - :return: + :param batch: {evaluate_fn} 函数支持的输入类型 + :return: {evaluate_fn} 函数的输出结果,如果有设置 output_mapping ,将是 output_mapping 之后的结果。 """ outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn) outputs = match_and_substitute_params(self.output_mapping, outputs) diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 9a3c30d5..cbec1a01 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -83,10 +83,10 @@ class Trainer(TrainerEventTrigger): :param n_epochs: 训练总共的 epoch 的数量,默认为 20; :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 为 None; - :param batch_step_fn: 用来替换 `TrainBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的两个参数必须为 `trainer` 和 - `batch`;默认为 None; - :param evaluate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的 - 两个参数必须为 `evaluator` 和 `batch`;默认为 None; + :param batch_step_fn: 定制每次 train batch 执行的函数。该函数应接受两个参数为 `trainer` 和`batch`,不需要要返回值;可以 + 参考 fastNLP.core.controllers.loops.train_batch_loop.TrainBatchLoop中的batch_step_fn函数。 + :param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`, + 不需要有返回值;可以参考 fastNLP.core.controllers.loops.evaluate_batch_loop.EvaluateBatchLoop中的batch_step_fn函数。 :param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用模型的哪一个函数,例如是 `train_step` 还是 `forward`; 默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法, 则使用模型默认的前向传播函数。 @@ -136,9 +136,9 @@ class Trainer(TrainerEventTrigger): 默认为 auto , auto 表示如果检测到当前 terminal 为交互型 则使用 RichCallback,否则使用 RawTextCallback对象。如果 需要定制 progress bar 的参数,例如打印频率等,可以传入 RichCallback, RawTextCallback 对象。 train_input_mapping: 与 input_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 - train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 input_mapping 互斥。 + train_output_mapping: 与 output_mapping 一致,但是只用于 train 中。与 output_mapping 互斥。 evaluate_input_mapping: 与 input_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 - evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 input_mapping 互斥。 + evaluate_output_mapping: 与 output_mapping 一致,但是只用于 evaluate 中。与 output_mapping 互斥。 """ self.model = model self.marker = marker @@ -268,21 +268,12 @@ class Trainer(TrainerEventTrigger): progress_bar = kwargs.get('progress_bar', 'auto') # 如果不为 if not (isinstance(progress_bar, str) or progress_bar is None): # 应该是ProgressCallback,获取其名称。 progress_bar = progress_bar.name - self.evaluator = Evaluator( - model=model, - dataloaders=evaluate_dataloaders, - metrics=metrics, - driver=self.driver, - device=device, - batch_step_fn=evaluate_batch_step_fn, - evaluate_fn=evaluate_fn, - input_mapping=input_mapping, - output_mapping=output_mapping, - fp16=fp16, - verbose=0, - use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), - progress_bar=progress_bar - ) + self.evaluator = Evaluator(model=model, dataloaders=evaluate_dataloaders, metrics=metrics, + driver=self.driver, device=device, evaluate_batch_step_fn=evaluate_batch_step_fn, + evaluate_fn=evaluate_fn, input_mapping=input_mapping, + output_mapping=output_mapping, fp16=fp16, verbose=0, + use_dist_sampler=kwargs.get("evaluate_use_dist_sampler", None), + progress_bar=progress_bar) if train_fn is not None and not isinstance(train_fn, str): raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") From bb10410ccd4c8b959253d1e715b0f23117c30c50 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 25 Apr 2022 12:57:17 +0000 Subject: [PATCH 05/16] =?UTF-8?q?torch=20=E5=8D=95=E5=8D=A1=E7=9A=84?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../torch_driver/initialize_torch_driver.py | 2 +- .../core/drivers/torch_driver/torch_driver.py | 14 +- .../torch_driver/test_single_device.py | 697 ++++++++++++++++++ tests/core/drivers/torch_driver/test_utils.py | 71 +- 4 files changed, 747 insertions(+), 37 deletions(-) create mode 100644 tests/core/drivers/torch_driver/test_single_device.py diff --git a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py index f149855f..5ee946c4 100644 --- a/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/initialize_torch_driver.py @@ -76,7 +76,7 @@ def initialize_torch_driver(driver: str, device: Optional[Union[str, torch.devic logger.info("Notice you are using `torch_ddp` driver, but your chosen `device` is only one gpu, we will " "still use `TorchDDPDriver` for you, but if you mean using `torch_ddp`, you should " "choose `torch` driver.") - return TorchDDPDriver(model, device, **kwargs) + return TorchDDPDriver(model, [device], **kwargs) else: return TorchDDPDriver(model, device, **kwargs) elif driver == "fairscale": diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index 8e37f550..172a3cf0 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -218,6 +218,8 @@ class TorchDriver(Driver): # 2. 保存模型的状态; if should_save_model: model = self.unwrap_model() + if not os.path.exists(folder): + os.mkdir(folder) if only_state_dict: model_state_dict = {name: param.cpu().detach().clone() for name, param in model.state_dict().items()} # 对于单卡的 driver 来讲,我们实际上(现在)不应该考虑用户在DDP环境下使用单卡模式,从而造成效率损失; @@ -401,7 +403,17 @@ class TorchDriver(Driver): res.sampler = dataloader.batch_sampler.sampler if hasattr(dataloader.batch_sampler.sampler, "shuffle"): res.shuffle = dataloader.batch_sampler.sampler.shuffle - elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): + elif isinstance(dataloader.batch_sampler.sampler, TorchRandomSampler): + res.shuffle = True + else: + res.shuffle = False + # RandomBatchSampler 的情况 + elif hasattr(dataloader.batch_sampler, "batch_sampler"): + batch_sampler = dataloader.batch_sampler.batch_sampler + res.sampler = batch_sampler.sampler + if hasattr(batch_sampler.sampler, "shuffle"): + res.shuffle = dataloader.batch_sampler.sampler.shuffle + elif isinstance(batch_sampler.sampler, TorchRandomSampler): res.shuffle = True else: res.shuffle = False diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py new file mode 100644 index 00000000..4290d02c --- /dev/null +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -0,0 +1,697 @@ +import os +os.environ["FASTNLP_BACKEND"] = "torch" +import pytest +from pathlib import Path + +from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver +from fastNLP.core.samplers import RandomBatchSampler, RandomSampler +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDatset +from tests.helpers.datasets.paddle_data import PaddleNormalDataset +from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 +from fastNLP.core import rank_zero_rm + +import torch +from torch.utils.data import DataLoader, BatchSampler +import paddle + +def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): + """ + 建立一个 batch_samper 为 RandomBatchSampler 的 dataloader + """ + if shuffle: + sampler = torch.utils.data.RandomSampler(dataset) + else: + sampler = torch.utils.data.SequentialSampler(dataset) + dataloader = DataLoader( + dataset=dataset, + batch_sampler=RandomBatchSampler( + BatchSampler( + sampler, batch_size=batch_size, drop_last=drop_last + ), + batch_size=batch_size, + drop_last=drop_last, + ), + ) + + return dataloader + +def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0): + """ + 建立一个 samper 为 RandomSampler 的 dataloader + """ + dataloader = DataLoader( + dataset, + sampler=RandomSampler(dataset, shuffle, seed=seed), + drop_last=drop_last, + batch_size=batch_size + ) + return dataloader + +############################################################################ +# +# 测试基类 TorchDrvier 中的一些简单函数 +# +############################################################################ + +class TestTorchDriverFunctions: + """ + 使用 TorchSingleDriver 测试基类的函数 + """ + + @classmethod + def setup_class(self): + model = TorchNormalModel_Classification_1(10, 32) + self.driver = TorchSingleDriver(model, device="cpu") + + def test_check_single_optimizer_legality(self): + """ + 测试传入单个 optimizer 时的表现 + """ + optimizer = torch.optim.Adam( + params=self.driver.model.parameters(), + lr=0.01 + ) + + self.driver.set_optimizers(optimizer) + + optimizer = paddle.optimizer.Adam( + parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), + learning_rate=0.01, + ) + # 传入 torch 的 optimize r时,应该报错 ValueError + with pytest.raises(ValueError): + self.driver.set_optimizers(optimizer) + + def test_check_optimizers_legality(self): + """ + 测试传入 optimizer list 的表现 + """ + optimizers = [ + torch.optim.Adam( + params=self.driver.model.parameters(), + lr=0.01 + ) for i in range(10) + ] + + self.driver.set_optimizers(optimizers) + + optimizers += [ + paddle.optimizer.Adam( + parameters=PaddleNormalModel_Classification_1(10, 32).parameters(), + learning_rate=0.01, + ) + ] + + with pytest.raises(ValueError): + self.driver.set_optimizers(optimizers) + + def test_check_dataloader_legality_in_train(self): + """ + 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 + """ + dataloader = DataLoader(TorchNormalDataset()) + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + + # 创建 paddle 的 dataloader + dataloader = paddle.io.DataLoader( + PaddleNormalDataset(), + batch_size=32, shuffle=True + ) + with pytest.raises(ValueError): + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) + + def test_check_dataloader_legality_in_test(self): + """ + 测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 + """ + # 此时传入的应该是dict + dataloader = { + "train": DataLoader(TorchNormalDataset()), + "test": DataLoader(TorchNormalDataset()) + } + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + + # 传入的不是 dict,应该报错 + dataloader = DataLoader(TorchNormalDataset()) + with pytest.raises(ValueError): + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + + # 创建 paddle 的 dataloader + train_loader = paddle.io.DataLoader( + PaddleNormalDataset(), + batch_size=32, shuffle=True + ) + test_loader = paddle.io.DataLoader( + PaddleNormalDataset(), + batch_size=32, shuffle=True + ) + dataloader = {"train": train_loader, "test": test_loader} + with pytest.raises(ValueError): + TorchSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) + + def test_tensor_to_numeric(self): + """ + 测试 tensor_to_numeric 函数 + """ + # 单个张量 + tensor = torch.tensor(3) + res = TorchSingleDriver.tensor_to_numeric(tensor) + assert res == 3 + + tensor = torch.rand((3, 4)) + res = TorchSingleDriver.tensor_to_numeric(tensor) + assert res == tensor.tolist() + + # 张量list + tensor_list = [torch.rand((6, 4, 2)) for i in range(10)] + res = TorchSingleDriver.tensor_to_numeric(tensor_list) + assert isinstance(res, list) + tensor_list = [t.tolist() for t in tensor_list] + assert res == tensor_list + + # 张量tuple + tensor_tuple = tuple([torch.rand((6, 4, 2)) for i in range(10)]) + res = TorchSingleDriver.tensor_to_numeric(tensor_tuple) + assert isinstance(res, tuple) + tensor_tuple = tuple([t.tolist() for t in tensor_tuple]) + assert res == tensor_tuple + + # 张量dict + tensor_dict = { + "tensor": torch.rand((3, 4)), + "list": [torch.rand((6, 4, 2)) for i in range(10)], + "dict":{ + "list": [torch.rand((6, 4, 2)) for i in range(10)], + "tensor": torch.rand((3, 4)) + }, + "int": 2, + "string": "test string" + } + + res = TorchSingleDriver.tensor_to_numeric(tensor_dict) + assert isinstance(res, dict) + assert res["tensor"] == tensor_dict["tensor"].tolist() + assert isinstance(res["list"], list) + for r, d in zip(res["list"], tensor_dict["list"]): + assert r == d.tolist() + assert isinstance(res["int"], int) + assert isinstance(res["string"], str) + assert isinstance(res["dict"], dict) + assert isinstance(res["dict"]["list"], list) + for r, d in zip(res["dict"]["list"], tensor_dict["dict"]["list"]): + assert r == d.tolist() + assert res["dict"]["tensor"] == tensor_dict["dict"]["tensor"].tolist() + + def test_set_model_mode(self): + """ + 测试set_model_mode函数 + """ + self.driver.set_model_mode("train") + assert self.driver.model.training + self.driver.set_model_mode("eval") + assert not self.driver.model.training + # 应该报错 + with pytest.raises(AssertionError): + self.driver.set_model_mode("test") + + def test_move_model_to_device_cpu(self): + """ + 测试move_model_to_device函数 + """ + TorchSingleDriver.move_model_to_device(self.driver.model, "cpu") + assert self.driver.model.linear1.weight.device.type == "cpu" + + def test_move_model_to_device_gpu(self): + """ + 测试move_model_to_device函数 + """ + TorchSingleDriver.move_model_to_device(self.driver.model, "cuda") + assert self.driver.model.linear1.weight.device.type == "cuda" + assert self.driver.model.linear1.weight.device.index == 0 + + def test_worker_init_function(self): + """ + 测试worker_init_function + """ + # 先确保不影响运行 + # TODO:正确性 + TorchSingleDriver.worker_init_function(0) + + def test_set_deterministic_dataloader(self): + """ + 测试set_deterministic_dataloader + """ + # 先确保不影响运行 + # TODO:正确性 + dataloader = DataLoader(TorchNormalDataset()) + self.driver.set_deterministic_dataloader(dataloader) + + def test_set_sampler_epoch(self): + """ + 测试set_sampler_epoch + """ + # 先确保不影响运行 + # TODO:正确性 + dataloader = DataLoader(TorchNormalDataset()) + self.driver.set_sampler_epoch(dataloader, 0) + + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + def test_get_dataloader_args(self, batch_size, shuffle, drop_last): + """ + 测试正常情况下 get_dataloader_args 的表现 + """ + dataloader = DataLoader( + TorchNormalDataset(), + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + ) + res = TorchSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, TorchNormalDataset) + assert isinstance(res.batch_sampler, BatchSampler) + if shuffle: + assert isinstance(res.sampler, torch.utils.data.RandomSampler) + else: + assert isinstance(res.sampler, torch.utils.data.SequentialSampler) + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last + + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + def test_get_dataloader_args_with_randombatchsampler(self, batch_size, shuffle, drop_last): + """ + 测试替换了 batch_sampler 后 get_dataloader_args 的表现 + """ + dataset = TorchNormalDataset() + dataloader = dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last) + res = TorchSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, TorchNormalDataset) + assert isinstance(res.batch_sampler, RandomBatchSampler) + if shuffle: + assert isinstance(res.sampler, torch.utils.data.RandomSampler) + else: + assert isinstance(res.sampler, torch.utils.data.SequentialSampler) + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last + + @pytest.mark.parametrize("batch_size", [16]) + @pytest.mark.parametrize("shuffle", [True, False]) + @pytest.mark.parametrize("drop_last", [True, False]) + def test_get_dataloader_args_with_randomsampler(self, batch_size, shuffle, drop_last): + """ + 测试替换了 sampler 后 get_dataloader_args 的表现 + """ + dataset = TorchNormalDataset() + dataloader = dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last) + res = TorchSingleDriver.get_dataloader_args(dataloader) + + assert isinstance(res.dataset, TorchNormalDataset) + assert isinstance(res.batch_sampler, BatchSampler) + assert isinstance(res.sampler, RandomSampler) + assert res.shuffle == shuffle + assert res.batch_size == batch_size + assert res.drop_last == drop_last + + +############################################################################ +# +# 测试 TorchSingleDrvier 中的一些简单函数 +# +############################################################################ + +class TestSingleDeviceFunction: + """ + 测试其它函数的测试例 + """ + + @classmethod + def setup_class(cls): + model = TorchNormalModel_Classification_1(10, 784) + cls.driver = TorchSingleDriver(model, device="cpu") + + def test_unwrap_model(self): + """ + 测试能否运行 + """ + res = self.driver.unwrap_model() + assert res is self.driver.model + + def test_is_distributed(self): + assert self.driver.is_distributed() == False + + def test_move_data_to_device(self): + """ + 这个函数仅调用了 torch_move_data_to_device ,测试例在 tests/core/utils/test_torch_utils.py 中 + 就不重复测试了 + """ + self.driver.move_data_to_device(torch.rand((32, 64))) + + +############################################################################ +# +# 测试 set_dist_repro_dataloader 函数 +# +############################################################################ + +class TestSetDistReproDataloader: + """ + 专门测试 set_dist_repro_dataloader 函数的类 + """ + def setup_method(self): + self.dataset = TorchNormalDataset(20) + model = TorchNormalModel_Classification_1(10, 32) + self.driver = TorchSingleDriver(model, device="cpu") + + def test_with_reproducible_false(self): + """ + 测试 set_dist_repro_dataloader 参数 `reproducible` 为 False 时的表现 + 当dist为字符串时,此时应该返回原来的 dataloader + """ + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=True) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) + + assert replaced_loader is dataloader + + @pytest.mark.parametrize("shuffle", [True, False]) + def test_with_reproducible_true(self, shuffle): + """ + 测试 set_dist_repro_dataloader 参数 `reproducible` 为 True 时的表现 + 当dist为字符串时,此时应该返回新的 dataloader,且如果原 sampler 为 torch.utils.data.RandomSampler(shuffle=True), + 只会替换 Sampler 为 RandomSampler;否则会替换 batch_sampler 为 RandomBatchSampler + """ + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=True) + + assert not (replaced_loader is dataloader) + if shuffle: + # 此时会替换 sampler + assert isinstance(replaced_loader.batch_sampler, torch.utils.data.BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + else: + # 此时会替换 batch_sampler + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.batch_sampler.batch_sampler, BatchSampler) + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + assert replaced_loader.drop_last == dataloader.drop_last + + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_batch_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现,且 dist 是 ReproducibleBatchSampler + 应该返回新的 dataloader,并将 batch_sampler 替换为 dist 对应的 Sampler + """ + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=shuffle) + dist = RandomBatchSampler(BatchSampler(self.dataset, batch_size=4, drop_last=False), 4, False) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert replaced_loader.batch_sampler is dist + + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 参数 dist 不是字符串时的表现 + 应该返回新的 dataloader,并将 batch_sampler.sampler 替换为 dist 对应的 Sampler + """ + dataloader = DataLoader(self.dataset, batch_size=2, shuffle=not shuffle) + dist = RandomSampler(self.dataset, shuffle=shuffle) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist=dist, reproducible=False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert replaced_loader.batch_sampler.sampler is dist + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dataloader_reproducible_batch_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 + 应该返回新的 dataloader,且其余各项设置和原来相同 + """ + dataloader = dataloader_with_randombatchsampler(self.dataset, 4, shuffle, False) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + assert replaced_loader.drop_last == dataloader.drop_last + + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dataloader_reproducible_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 参数 dataloader 已经支持断点重训时的表现 + 应该返回新的 dataloader,且其余各项设置和原来相同 + """ + dataloader = dataloader_with_randomsampler(self.dataset, 2, shuffle, False) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, dist="dist", reproducible=False) + + assert not (replaced_loader is dataloader) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) + assert replaced_loader.batch_sampler.batch_size == 2 + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle + + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): + """ + 测试单卡下 set_dist_repro_dataloader 函数的执行结果是否正确 + """ + # 迭代两个 batch + num_consumed_batches = 2 + already_seen_idx = set() + for idx, batch in enumerate(replaced_loader): + if idx >= num_consumed_batches: + break + already_seen_idx.update(batch) + if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + sampler_states = replaced_loader.batch_sampler.state_dict() + else: + sampler_states = replaced_loader.batch_sampler.sampler.state_dict() + + # 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range + left_idxes = set() + if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): + batch_size = replaced_loader.batch_sampler.batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + # 重新改造 dataloader + new_loader = dataloader_with_randombatchsampler(replaced_loader.dataset, batch_size, shuffle, False) + new_loader.batch_sampler.load_state_dict(sampler_states) + else: + batch_size = replaced_loader.batch_sampler.batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size + # 重新构造 dataloader + new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, False) + new_loader.batch_sampler.sampler.load_state_dict(sampler_states) + for idx, batch in enumerate(new_loader): + left_idxes.update(batch) + + assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) + assert len(left_idxes | already_seen_idx) == len(self.dataset) + +############################################################################ +# +# 测试 save 和 load 相关的功能 +# +############################################################################ + +def generate_random_driver(features, labels, fp16=False, device="cpu"): + """ + 生成driver + """ + model = TorchNormalModel_Classification_1(labels, features) + opt = torch.optim.Adam(params=model.parameters(), lr=0.01) + driver = TorchSingleDriver(model, device=device, fp16=fp16) + driver.set_optimizers(opt) + driver.setup() + + return driver + +@pytest.fixture +def prepare_test_save_load(): + dataset = TorchArgMaxDatset(10, 40) + dataloader = DataLoader(dataset, batch_size=4) + driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) + return driver1, driver2, dataloader + +@pytest.mark.parametrize("only_state_dict", ([True, False])) +def test_save_and_load_model(prepare_test_save_load, only_state_dict): + """ + 测试 save_model 和 load_model 函数 + """ + try: + path = "model" + driver1, driver2, dataloader = prepare_test_save_load + + driver1.save_model(path, only_state_dict) + driver2.load_model(path, only_state_dict) + + for batch in dataloader: + batch = driver1.move_data_to_device(batch) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + + assert torch.equal(res1["preds"], res2["preds"]) + finally: + rank_zero_rm(path) + +@pytest.mark.parametrize("only_state_dict", ([True, False])) +@pytest.mark.parametrize("fp16", ([True, False])) +def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 + """ + + try: + path = "model.ckp" + dataset = TorchArgMaxDatset(10, 40) + dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) + driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") + + num_consumed_batches = 2 + + already_seen_x_set = set() + already_seen_y_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"]) + already_seen_y_set.update(batch["y"]) + + sampler_states = dataloader.batch_sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + # 加载 + # 更改 batch_size + + dataloader = dataloader_with_randombatchsampler(dataset, 2, True, False) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 batch_sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert replaced_loader.batch_sampler is dataloader.batch_sampler + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert replaced_loader.batch_sampler.index_list == sampler_states["index_list"] + assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 + + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + for idx, batch in enumerate(replaced_loader): + + batch = driver2.move_data_to_device(batch) + left_x_batches.update(batch["x"]) + left_y_batches.update(batch["y"]) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + assert torch.equal(res1["preds"], res2["preds"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) + assert len(left_x_batches | already_seen_x_set) == len(dataset) + assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) + assert len(left_y_batches | already_seen_y_set) == len(dataset) + finally: + rank_zero_rm(path) + +@pytest.mark.parametrize("only_state_dict", ([True, False])) +@pytest.mark.parametrize("fp16", ([True, False])) +def test_save_and_load_with_randomsampler(only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 + """ + + try: + path = "model.ckp" + + driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") + dataset = TorchArgMaxDatset(10, 40) + dataloader = dataloader_with_randomsampler(dataset, 4, True, False) + num_consumed_batches = 2 + + already_seen_x_set = set() + already_seen_y_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"]) + already_seen_y_set.update(batch["y"]) + + sampler_states = dataloader.batch_sampler.sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + + # 加载 + # 更改 batch_size + dataloader = dataloader_with_randomsampler(dataset, 2, True, False) + load_states = driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] + assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches + assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] + assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] + + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(driver2.grad_scaler, torch.cuda.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + for idx, batch in enumerate(replaced_loader): + + batch = driver2.move_data_to_device(batch) + left_x_batches.update(batch["x"]) + left_y_batches.update(batch["y"]) + res1 = driver1.model.evaluate_step(**batch) + res2 = driver2.model.evaluate_step(**batch) + assert torch.equal(res1["preds"], res2["preds"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(dataset) + assert len(left_x_batches | already_seen_x_set) == len(dataset) + assert len(left_y_batches) + len(already_seen_y_set) == len(dataset) + assert len(left_y_batches | already_seen_y_set) == len(dataset) + finally: + rank_zero_rm(path) diff --git a/tests/core/drivers/torch_driver/test_utils.py b/tests/core/drivers/torch_driver/test_utils.py index c1e604c9..8f0172e0 100644 --- a/tests/core/drivers/torch_driver/test_utils.py +++ b/tests/core/drivers/torch_driver/test_utils.py @@ -1,35 +1,36 @@ -from torch.utils.data.sampler import SequentialSampler, RandomSampler - -from fastNLP.core.samplers.sampler import ReproduceSampler -from tests.helpers.datasets.normal_data import NormalIterator - - -class TestReproduceSampler: - - def test_sequentialsampler(self): - normal_iterator = NormalIterator(num_of_data=20) - sequential_sampler = SequentialSampler(normal_iterator) - - reproduce_sampler = ReproduceSampler(sequential_sampler) - # iter_seq_sampler = iter(sequential_sampler) - # for each in iter_seq_sampler: - # print(each) - iter_reproduce_sampler = iter(reproduce_sampler) - forward_step = 3 - for _ in range(forward_step): - next(iter_reproduce_sampler) - state = reproduce_sampler.save_state() - assert state["current_batch_idx"] == forward_step - - new_repro_sampler = ReproduceSampler(sequential_sampler) - assert new_repro_sampler.save_state()["current_batch_idx"] == 0 - - new_repro_sampler.load_state(state) - iter_new_repro_sampler = iter(new_repro_sampler) - new_index_list = [] - for each in iter_new_repro_sampler: - new_index_list.append(each) - assert new_index_list == list(range(3, 20)) - - - +import os +import pytest +os.environ["FASTNLP_BACKEND"] = "torch" + +from fastNLP.core.drivers.torch_driver.utils import ( + replace_batch_sampler, + replace_sampler, +) +from fastNLP.core.samplers import RandomBatchSampler, RandomSampler +from torch.utils.data import DataLoader, BatchSampler + +from tests.helpers.datasets.torch_data import TorchNormalDataset + +def test_replace_batch_sampler(): + dataset = TorchNormalDataset(10) + dataloader = DataLoader(dataset, batch_size=32) + batch_sampler = RandomBatchSampler(dataloader.batch_sampler, batch_size=16, drop_last=False) + + replaced_loader = replace_batch_sampler(dataloader, batch_sampler) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, RandomBatchSampler) + assert isinstance(replaced_loader.dataset, TorchNormalDataset) + assert len(replaced_loader.dataset) == len(dataset) + assert replaced_loader.batch_sampler.batch_size == 16 + +def test_replace_sampler(): + dataset = TorchNormalDataset(10) + dataloader = DataLoader(dataset, batch_size=32) + sampler = RandomSampler(dataset) + + replaced_loader = replace_sampler(dataloader, sampler) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) \ No newline at end of file From cf65da13328df3b309eb6eaeaa0ed71ff79139f6 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Mon, 25 Apr 2022 12:58:03 +0000 Subject: [PATCH 06/16] =?UTF-8?q?paddle=20=E5=8D=95=E5=8D=A1=E6=B5=8B?= =?UTF-8?q?=E8=AF=95=E4=BE=8B=E8=B0=83=E6=95=B4?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../paddle_driver/initialize_paddle_driver.py | 9 +- .../drivers/paddle_driver/paddle_driver.py | 16 ++- .../test_initialize_paddle_driver.py | 8 +- .../paddle_driver/test_single_device.py | 46 ++++---- .../test_initialize_torch_driver.py | 103 ++++++++++++++++++ 5 files changed, 142 insertions(+), 40 deletions(-) create mode 100644 tests/core/drivers/torch_driver/test_initialize_torch_driver.py diff --git a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py index 2cba6388..eac2d4a4 100644 --- a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py @@ -47,9 +47,7 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ raise ValueError("Parameter `device` can only be '-1' when it is smaller than 0.") if device >= _could_use_device_num: raise ValueError("The gpu device that parameter `device` specifies is not existed.") - if device != -1: - device = f"gpu:{device}" - else: + if device == -1: device = list(range(_could_use_device_num)) elif isinstance(device, Sequence) and not isinstance(device, str): device = list(set(device)) @@ -61,9 +59,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ elif each >= _could_use_device_num: raise ValueError("When parameter `device` is 'Sequence' type, the value in it should not be bigger than" " the available gpu number.") - if len(device) == 1: - # 传入了 [1] 这样的,视为单卡。 - device = device[0] elif device is not None and not isinstance(device, str): raise ValueError("Parameter `device` is wrong type, please check our documentation for the right use.") @@ -82,6 +77,6 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ logger.warning("Notice you are using `fleet` driver, but your chosen `device` is only one gpu, we will" "still use `PaddleFleetDriver` for you, but if you mean using `PaddleSingleDriver`, you should " "choose `paddle` driver.") - return PaddleFleetDriver(model, device, **kwargs) + return PaddleFleetDriver(model, [device], **kwargs) else: return PaddleFleetDriver(model, device, **kwargs) diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index fe8bf404..ed1aad73 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -19,7 +19,12 @@ from fastNLP.envs import ( rank_zero_call, ) from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler +from fastNLP.core.samplers import ( + ReproducibleBatchSampler, + ReproducibleSampler, + RandomBatchSampler, + RandomSampler, +) if _NEED_IMPORT_PADDLE: import paddle @@ -29,7 +34,7 @@ if _NEED_IMPORT_PADDLE: Dataset, Sampler, BatchSampler, - RandomSampler, + RandomSampler as PaddleRandomSampler, ) from paddle.optimizer import Optimizer @@ -333,6 +338,9 @@ class PaddleDriver(Driver): sampler = dataloader_args.batch_sampler elif isinstance(dataloader_args.sampler, ReproducibleSampler): sampler = dataloader_args.sampler + elif isinstance(dataloader_args.sampler, PaddleRandomSampler): + sampler = RandomSampler(dataloader_args.sampler.data_source) + logger.debug("Replace paddle RandomSampler into fastNLP RandomSampler.") elif self.is_distributed(): raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or " "`ReproducibleSampler`.") @@ -464,7 +472,7 @@ class PaddleDriver(Driver): res.sampler = dataloader.batch_sampler.sampler if hasattr(dataloader.batch_sampler.sampler, "shuffle"): res.shuffle = dataloader.batch_sampler.sampler.shuffle - elif isinstance(dataloader.batch_sampler.sampler, RandomSampler): + elif isinstance(dataloader.batch_sampler.sampler, PaddleRandomSampler): res.shuffle = True else: res.shuffle = False @@ -474,7 +482,7 @@ class PaddleDriver(Driver): res.sampler = batch_sampler.sampler if hasattr(batch_sampler.sampler, "shuffle"): res.shuffle = dataloader.batch_sampler.sampler.shuffle - elif isinstance(batch_sampler.sampler, RandomSampler): + elif isinstance(batch_sampler.sampler, PaddleRandomSampler): res.shuffle = True else: res.shuffle = False diff --git a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py index 54ef22b6..df96d746 100644 --- a/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py +++ b/tests/core/drivers/paddle_driver/test_initialize_paddle_driver.py @@ -19,7 +19,7 @@ def test_incorrect_driver(): @pytest.mark.parametrize( "device", - ["cpu", "gpu:0", 0, [1]] + ["cpu", "gpu:0", 0] ) @pytest.mark.parametrize( "driver", @@ -27,7 +27,7 @@ def test_incorrect_driver(): ) def test_get_single_device(driver, device): """ - 测试正常情况下初始化PaddleSingleDriver的情况 + 测试正常情况下初始化 PaddleSingleDriver 的情况 """ model = PaddleNormalModel_Classification_1(2, 100) @@ -36,7 +36,7 @@ def test_get_single_device(driver, device): @pytest.mark.parametrize( "device", - [0, 1] + [0, 1, [1]] ) @pytest.mark.parametrize( "driver", @@ -45,7 +45,7 @@ def test_get_single_device(driver, device): @magic_argv_env_context def test_get_fleet_2(driver, device): """ - 测试 fleet 多卡的初始化情况 + 测试 fleet 多卡的初始化情况,但传入了单个 gpu """ model = PaddleNormalModel_Classification_1(64, 10) diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index c80bd609..2aa4e0e6 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -34,7 +34,7 @@ class TestPaddleDriverFunctions: def test_check_single_optimizer_legality(self): """ - 测试传入单个optimizer时的表现 + 测试传入单个 optimizer 时的表现 """ optimizer = paddle.optimizer.Adam( parameters=self.driver.model.parameters(), @@ -50,7 +50,7 @@ class TestPaddleDriverFunctions: def test_check_optimizers_legality(self): """ - 测试传入optimizer list的表现 + 测试传入 optimizer list 的表现 """ optimizers = [ paddle.optimizer.Adam( @@ -70,13 +70,13 @@ class TestPaddleDriverFunctions: def test_check_dataloader_legality_in_train(self): """ - 测试is_train参数为True时,_check_dataloader_legality函数的表现 + 测试 `is_train` 参数为 True 时,_check_dataloader_legality 函数的表现 """ - dataloader = paddle.io.DataLoader(PaddleNormalDataset()) + dataloader = DataLoader(PaddleNormalDataset()) PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) # batch_size 和 batch_sampler 均为 None 的情形 - dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) + dataloader = DataLoader(PaddleNormalDataset(), batch_size=None) with pytest.raises(ValueError): PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) @@ -90,29 +90,29 @@ class TestPaddleDriverFunctions: def test_check_dataloader_legality_in_test(self): """ - 测试is_train参数为False时,_check_dataloader_legality函数的表现 + 测试 `is_train` 参数为 False 时,_check_dataloader_legality 函数的表现 """ # 此时传入的应该是dict dataloader = { - "train": paddle.io.DataLoader(PaddleNormalDataset()), - "test":paddle.io.DataLoader(PaddleNormalDataset()) + "train": DataLoader(PaddleNormalDataset()), + "test":DataLoader(PaddleNormalDataset()) } PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) # batch_size 和 batch_sampler 均为 None 的情形 dataloader = { - "train": paddle.io.DataLoader(PaddleNormalDataset()), - "test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) + "train": DataLoader(PaddleNormalDataset()), + "test":DataLoader(PaddleNormalDataset(), batch_size=None) } with pytest.raises(ValueError): PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) - # 传入的不是dict,应该报错 - dataloader = paddle.io.DataLoader(PaddleNormalDataset()) + # 传入的不是 dict ,应该报错 + dataloader = DataLoader(PaddleNormalDataset()) with pytest.raises(ValueError): PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) - # 创建torch的dataloader + # 创建 torch 的 dataloader train_loader = torch.utils.data.DataLoader( TorchNormalDataset(), batch_size=32, shuffle=True @@ -127,7 +127,7 @@ class TestPaddleDriverFunctions: def test_tensor_to_numeric(self): """ - 测试tensor_to_numeric函数 + 测试 tensor_to_numeric 函数 """ # 单个张量 tensor = paddle.to_tensor(3) @@ -180,7 +180,7 @@ class TestPaddleDriverFunctions: def test_set_model_mode(self): """ - 测试set_model_mode函数 + 测试 set_model_mode 函数 """ self.driver.set_model_mode("train") assert self.driver.model.training @@ -192,14 +192,14 @@ class TestPaddleDriverFunctions: def test_move_model_to_device_cpu(self): """ - 测试move_model_to_device函数 + 测试 move_model_to_device 函数 """ PaddleSingleDriver.move_model_to_device(self.driver.model, "cpu") assert self.driver.model.linear1.weight.place.is_cpu_place() def test_move_model_to_device_gpu(self): """ - 测试move_model_to_device函数 + 测试 move_model_to_device 函数 """ PaddleSingleDriver.move_model_to_device(self.driver.model, "gpu") assert self.driver.model.linear1.weight.place.is_gpu_place() @@ -207,7 +207,7 @@ class TestPaddleDriverFunctions: def test_worker_init_function(self): """ - 测试worker_init_function + 测试 worker_init_function """ # 先确保不影响运行 # TODO:正确性 @@ -215,7 +215,7 @@ class TestPaddleDriverFunctions: def test_set_deterministic_dataloader(self): """ - 测试set_deterministic_dataloader + 测试 set_deterministic_dataloader """ # 先确保不影响运行 # TODO:正确性 @@ -224,7 +224,7 @@ class TestPaddleDriverFunctions: def test_set_sampler_epoch(self): """ - 测试set_sampler_epoch + 测试 set_sampler_epoch """ # 先确保不影响运行 # TODO:正确性 @@ -336,7 +336,7 @@ class TestSingleDeviceFunction: def test_move_data_to_device(self): """ - 这个函数仅调用了paddle_move_data_to_device,测试例在tests/core/utils/test_paddle_utils.py中 + 这个函数仅调用了 paddle_move_data_to_device ,测试例在 tests/core/utils/test_paddle_utils.py 中 就不重复测试了 """ self.driver.move_data_to_device(paddle.rand((32, 64))) @@ -490,9 +490,6 @@ class TestSetDistReproDataloader: else: sampler_states = replaced_loader.batch_sampler.sampler.state_dict() - # 加载 num_consumed_samples_array,设置正确取出的 batch 数目 - num_consumed_samples_array = sampler_states.pop('num_consumed_samples_array', None) - # 重新加载,应该可以输出剩下的内容,且对于 PaddleNormalDataset 来说,排序后应该是一个 range left_idxes = set() if isinstance(replaced_loader.batch_sampler, RandomBatchSampler): @@ -510,7 +507,6 @@ class TestSetDistReproDataloader: new_loader.batch_sampler.load_state_dict(sampler_states) else: batch_size = replaced_loader.batch_sampler.batch_size - num_consumed_samples = num_consumed_batches * batch_size sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size # 重新构造 dataloader batch_sampler = BatchSampler(replaced_loader.dataset, shuffle=shuffle, batch_size=batch_size) diff --git a/tests/core/drivers/torch_driver/test_initialize_torch_driver.py b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py new file mode 100644 index 00000000..6c47e30e --- /dev/null +++ b/tests/core/drivers/torch_driver/test_initialize_torch_driver.py @@ -0,0 +1,103 @@ +import os +import pytest + +os.environ["FASTNLP_BACKEND"] = "torch" + +from fastNLP.core.drivers import TorchSingleDriver, TorchDDPDriver +from fastNLP.core.drivers.torch_driver.initialize_torch_driver import initialize_torch_driver +from fastNLP.envs import get_gpu_count +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from tests.helpers.utils import magic_argv_env_context + +import torch + +def test_incorrect_driver(): + + model = TorchNormalModel_Classification_1(2, 100) + with pytest.raises(ValueError): + driver = initialize_torch_driver("paddle", 0, model) + +@pytest.mark.parametrize( + "device", + ["cpu", "cuda:0", 0, torch.device("cuda:0")] +) +@pytest.mark.parametrize( + "driver", + ["torch"] +) +def test_get_single_device(driver, device): + """ + 测试正常情况下初始化TorchSingleDriver的情况 + """ + + model = TorchNormalModel_Classification_1(2, 100) + driver = initialize_torch_driver(driver, device, model) + assert isinstance(driver, TorchSingleDriver) + +@pytest.mark.parametrize( + "device", + [0, 1] +) +@pytest.mark.parametrize( + "driver", + ["torch_ddp"] +) +@magic_argv_env_context +def test_get_ddp_2(driver, device): + """ + 测试 ddp 多卡的初始化情况,但传入了单个 gpu + """ + + model = TorchNormalModel_Classification_1(64, 10) + driver = initialize_torch_driver(driver, device, model) + + assert isinstance(driver, TorchDDPDriver) + +@pytest.mark.parametrize( + "device", + [[0, 2, 3], -1] +) +@pytest.mark.parametrize( + "driver", + ["torch", "torch_ddp"] +) +@magic_argv_env_context +def test_get_ddp(driver, device): + """ + 测试 ddp 多卡的初始化情况 + """ + + model = TorchNormalModel_Classification_1(64, 10) + driver = initialize_torch_driver(driver, device, model) + + assert isinstance(driver, TorchDDPDriver) + +@pytest.mark.parametrize( + ("driver", "device"), + [("torch_ddp", "cpu")] +) +@magic_argv_env_context +def test_get_ddp_cpu(driver, device): + """ + 测试试图在 cpu 上初始化分布式训练的情况 + """ + model = TorchNormalModel_Classification_1(64, 10) + with pytest.raises(ValueError): + driver = initialize_torch_driver(driver, device, model) + +@pytest.mark.parametrize( + "device", + [-2, [0, torch.cuda.device_count() + 1, 3], [-2], torch.cuda.device_count() + 1] +) +@pytest.mark.parametrize( + "driver", + ["torch", "torch_ddp"] +) +@magic_argv_env_context +def test_device_out_of_range(driver, device): + """ + 测试传入的device超过范围的情况 + """ + model = TorchNormalModel_Classification_1(2, 100) + with pytest.raises(ValueError): + driver = initialize_torch_driver(driver, device, model) \ No newline at end of file From 705deeaea964c0f360576f199911f39473da8043 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 26 Apr 2022 03:35:25 +0000 Subject: [PATCH 07/16] small --- .../core/callbacks/test_checkpoint_callback_torch.py | 4 ++-- .../callbacks/test_load_best_model_callback_torch.py | 4 ++-- tests/core/callbacks/test_more_evaluate_callback.py | 4 ++-- .../controllers/test_trainer_w_evaluator_torch.py | 4 ++-- .../core/drivers/torch_driver/test_single_device.py | 12 ++++++------ tests/helpers/datasets/torch_data.py | 2 +- 6 files changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index c700fa79..ca2a3292 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -16,7 +16,7 @@ from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK from tests.helpers.utils import magic_argv_env_context from fastNLP.core import rank_zero_rm from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 -from tests.helpers.datasets.torch_data import TorchArgMaxDatset +from tests.helpers.datasets.torch_data import TorchArgMaxDataset from torchmetrics import Accuracy from fastNLP.core.log import logger @@ -53,7 +53,7 @@ def model_and_optimizers(request): feature_dimension=ArgMaxDatasetConfig.feature_dimension ) trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) - dataset = TorchArgMaxDatset( + dataset = TorchArgMaxDataset( feature_dimension=ArgMaxDatasetConfig.feature_dimension, data_num=ArgMaxDatasetConfig.data_num, seed=ArgMaxDatasetConfig.seed diff --git a/tests/core/callbacks/test_load_best_model_callback_torch.py b/tests/core/callbacks/test_load_best_model_callback_torch.py index 31933347..0bc63bd5 100644 --- a/tests/core/callbacks/test_load_best_model_callback_torch.py +++ b/tests/core/callbacks/test_load_best_model_callback_torch.py @@ -19,7 +19,7 @@ from fastNLP.core import Evaluator from fastNLP.core.utils.utils import safe_rm from fastNLP.core.drivers.torch_driver import TorchSingleDriver from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 -from tests.helpers.datasets.torch_data import TorchArgMaxDatset +from tests.helpers.datasets.torch_data import TorchArgMaxDataset from tests.helpers.utils import magic_argv_env_context @@ -55,7 +55,7 @@ def model_and_optimizers(request): feature_dimension=ArgMaxDatasetConfig.feature_dimension ) trainer_params.optimizers = optim.SGD(trainer_params.model.parameters(), lr=0.01) - dataset = TorchArgMaxDatset( + dataset = TorchArgMaxDataset( feature_dimension=ArgMaxDatasetConfig.feature_dimension, data_num=ArgMaxDatasetConfig.data_num, seed=ArgMaxDatasetConfig.seed diff --git a/tests/core/callbacks/test_more_evaluate_callback.py b/tests/core/callbacks/test_more_evaluate_callback.py index 1c24ea9a..16ee3e17 100644 --- a/tests/core/callbacks/test_more_evaluate_callback.py +++ b/tests/core/callbacks/test_more_evaluate_callback.py @@ -24,7 +24,7 @@ from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK from tests.helpers.utils import magic_argv_env_context from fastNLP.core import rank_zero_rm from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 -from tests.helpers.datasets.torch_data import TorchArgMaxDatset +from tests.helpers.datasets.torch_data import TorchArgMaxDataset from torchmetrics import Accuracy from fastNLP.core.metrics import Metric from fastNLP.core.log import logger @@ -64,7 +64,7 @@ def model_and_optimizers(request): feature_dimension=ArgMaxDatasetConfig.feature_dimension ) trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) - dataset = TorchArgMaxDatset( + dataset = TorchArgMaxDataset( feature_dimension=ArgMaxDatasetConfig.feature_dimension, data_num=ArgMaxDatasetConfig.data_num, seed=ArgMaxDatasetConfig.seed diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 2973e417..94f66403 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -11,7 +11,7 @@ from torchmetrics import Accuracy from fastNLP.core.controllers.trainer import Trainer from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 -from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDatset +from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification, TorchArgMaxDataset from tests.helpers.callbacks.helper_callbacks import RecordLossCallback, RecordMetricCallback from tests.helpers.utils import magic_argv_env_context @@ -80,7 +80,7 @@ def model_and_optimizers(request): feature_dimension=ArgMaxDatasetConfig.feature_dimension ) trainer_params.optimizers = SGD(trainer_params.model.parameters(), lr=0.001) - dataset = TorchArgMaxDatset( + dataset = TorchArgMaxDataset( feature_dimension=ArgMaxDatasetConfig.feature_dimension, data_num=ArgMaxDatasetConfig.data_num, seed=ArgMaxDatasetConfig.seed diff --git a/tests/core/drivers/torch_driver/test_single_device.py b/tests/core/drivers/torch_driver/test_single_device.py index 4290d02c..b8a8def9 100644 --- a/tests/core/drivers/torch_driver/test_single_device.py +++ b/tests/core/drivers/torch_driver/test_single_device.py @@ -6,7 +6,7 @@ from pathlib import Path from fastNLP.core.drivers.torch_driver.single_device import TorchSingleDriver from fastNLP.core.samplers import RandomBatchSampler, RandomSampler from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 -from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDatset +from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset from tests.helpers.datasets.paddle_data import PaddleNormalDataset from tests.helpers.models.paddle_model import PaddleNormalModel_Classification_1 from fastNLP.core import rank_zero_rm @@ -17,7 +17,7 @@ import paddle def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): """ - 建立一个 batch_samper 为 RandomBatchSampler 的 dataloader + 建立一个 batch_sampler 为 RandomBatchSampler 的 dataloader """ if shuffle: sampler = torch.utils.data.RandomSampler(dataset) @@ -38,7 +38,7 @@ def dataloader_with_randombatchsampler(dataset, batch_size, shuffle, drop_last): def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0): """ - 建立一个 samper 为 RandomSampler 的 dataloader + 建立一个 sampler 为 RandomSampler 的 dataloader """ dataloader = DataLoader( dataset, @@ -531,7 +531,7 @@ def generate_random_driver(features, labels, fp16=False, device="cpu"): @pytest.fixture def prepare_test_save_load(): - dataset = TorchArgMaxDatset(10, 40) + dataset = TorchArgMaxDataset(10, 40) dataloader = DataLoader(dataset, batch_size=4) driver1, driver2 = generate_random_driver(10, 10), generate_random_driver(10, 10) return driver1, driver2, dataloader @@ -566,7 +566,7 @@ def test_save_and_load_with_randombatchsampler(only_state_dict, fp16): try: path = "model.ckp" - dataset = TorchArgMaxDatset(10, 40) + dataset = TorchArgMaxDataset(10, 40) dataloader = dataloader_with_randombatchsampler(dataset, 4, True, False) driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") @@ -636,7 +636,7 @@ def test_save_and_load_with_randomsampler(only_state_dict, fp16): path = "model.ckp" driver1, driver2 = generate_random_driver(10, 10, fp16, "cuda"), generate_random_driver(10, 10, False, "cuda") - dataset = TorchArgMaxDatset(10, 40) + dataset = TorchArgMaxDataset(10, 40) dataloader = dataloader_with_randomsampler(dataset, 4, True, False) num_consumed_batches = 2 diff --git a/tests/helpers/datasets/torch_data.py b/tests/helpers/datasets/torch_data.py index 56648adb..9a0af019 100644 --- a/tests/helpers/datasets/torch_data.py +++ b/tests/helpers/datasets/torch_data.py @@ -38,7 +38,7 @@ class TorchNormalDataset_Classification(Dataset): return {"x": self.x[item], "y": self.y[item]} -class TorchArgMaxDatset(Dataset): +class TorchArgMaxDataset(Dataset): def __init__(self, feature_dimension=10, data_num=1000, seed=0): self.num_labels = feature_dimension self.feature_dimension = feature_dimension From 822df062caea2f833d3bb36e8e776b81676441ad Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 26 Apr 2022 05:31:45 +0000 Subject: [PATCH 08/16] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E7=9A=84=E6=B3=A8=E9=87=8A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/drivers/paddle_driver/fleet.py | 235 ++++++++++++++++-- .../drivers/paddle_driver/fleet_launcher.py | 32 ++- .../paddle_driver/initialize_paddle_driver.py | 8 +- .../drivers/paddle_driver/single_device.py | 56 +++++ fastNLP/core/drivers/paddle_driver/utils.py | 24 +- fastNLP/core/utils/paddle_utils.py | 15 +- 6 files changed, 316 insertions(+), 54 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index bde6f37f..a1275bed 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -19,6 +19,7 @@ from fastNLP.core.utils import ( check_user_specific_params, paddle_move_data_to_device, is_in_paddle_dist, + rank_zero_rm ) from fastNLP.core.samplers import ( RandomBatchSampler, @@ -55,20 +56,134 @@ class PaddleFleetDriver(PaddleDriver): fp16: bool = False, **kwargs ): - """ - 采用fleet接口进行并行paddle训练的driver - PaddleFleetDriver 目前考虑支持的三种启动方式: - 1. 用户自己不进行 fleet 的任何操作,直接使用我们的 Trainer,并且只运行一个 main 脚本,这时是由我们自己使用 open_subprocesses 拉起 - 多个进程,然后由 Driver 自己进行初始化 - 2. 其它情况同 1,但是用户自己使用 python -m paddle.distributed.launch 拉起; - 3. 用户自己在外面初始化 Fleet,并且通过 python -m paddle.distributed.launch 拉起; - - 注意多机的启动强制要求用户在每一台机器上使用 python -m paddle.distributed.launch 启动; - - 如果用户自己在外面初始化了 fleet,那么 - parallel_device 为 None; - data_device 为 表示单卡的一个参数; - dist.is_initialized 为 true; + r""" + 通过使用 PaddlePaddle 的 Fleet 框架启动多卡进程的 Driver。 + 需要注意的一点是,由于 PaddlePaddle 框架的特性,如果直接使用在 rank0 拉起其它进程的方法的话,如果不加以任何限制,PaddlePaddle会出现 + 第一次前向传播后卡住或占用所有显卡的现象;为了解决这一问题,我们在引入 FastNLP 时,会使用 `CUDA_VISIBLE_DEVICES` 将设备限制在卡0上, + 而用户如果使用了这一环境变量,我们会将其储存在 `USER_CUDA_VISIBLE_DEVICES` 中,并且通过一定的手段实现了转换(详细的设置请参见: + `fastNLP/envs/set_backend.py`)。在拉起其它进程的时候,我们会如法炮制,将环境限制在对应的设备上。 + + `PaddleFleetDriver` 目前支持的三种启动方式: + 1. 用户自己不进行分布式的任何操作,直接使用我们的 Trainer,这时是由我们自己使用 `FleetLauncher` 拉起多个进程, + 然后 `PaddleFleetDriver` 自己通过调用 `fleet.init` 来初始化 ddp 的通信组;(情况 A) + 2. 用户同样不在 Trainer 之外初始化分布式训练,但是用户自己使用 python -m paddle.distributed.launch 拉起来创建多个进程,这时我们仍旧 + 会通过调用 `fleet.init` 来初始化 ddp 的通信组;(情况 B) + 3. 用户自己在外面初始化分布式,并且通过 python -m paddle.distributed.launch 拉起,这时无论是多个进程的拉起和通信组的建立 + 都由用户自己操作,我们只会在 driver.setup 的时候对 `PaddleFleetDriver` 设置一些必要的属性值;(情况 C) + + 注意多机的启动强制要求用户在每一台机器上使用 python -m paddle.distributed.launch 启动;因此我们不会在 `PaddleFleetDriver` 中保存 + 任何当前有多少台机器的信息; + + Part 1:三种启动方式的具体分析: + (1)对于用户运行的脚本中,如果 `driver.setup` 只会被调用一次(意味着用户的启动脚本中只初始化了一个 trainer/evaluator)时, + `PaddleFleetDriver` 在初始化以及 `setup` 函数中会做的事情分别如下所示: + -> 情况 A:这种情况下用户传入的 model 在一定是普通的 model(没有经 `DataParallel` 包裹的model), + 因为 `Parallel` 的使用一定要求 fleet.init 已经被调用用来建立当前的 ddp 通信组;但是这意味着如果 + 用户需要使用 2 张以上的显卡,那么其必然需要使用 paddle.distributed.launch 来启动,意味着就不是情况 A 了; + 这时我们首先会调用 `FleetLauncher.launch` 函数来拉起多个进程,其中进程的数量等于用户传入给 trainer 的使用的 gpu + 的数量(例如 `Trainer` 中的参数是 device=[0, 1, 6, 7],那么我们就会使用第 0、1、6、7 张 gpu 来拉起 4 个进程); + 接着我们会调用 `fleet.init` 来初始化各个进程之间的通信组; + 这里需要注意拉起的新的进程会从前到后完整地运行一遍用户的启动脚本(例如 main.py),因此也都会运行这两个函数,但是需要注意只有进程 0 + 才会去真正地运行 `FleetLauncher.launch`;进程 0 运行到 `fleet.init`,paddle 会阻塞进程 0 继续 + 向前运行,直到其它进程也运行到这里; + 最后我们会设置这个进程对应的 device,然后将模型迁移到对应的机器上,再使用 `DataParallel` 将模型包裹; + 至此,paddle 分布式的环境配置过程全部完成; + + -> 情况 B:注意这种情况我们直接限定了用户是通过 paddle.distributed.launch 拉起,并且没有自己建立分布式的通信组。这时在 + `PaddleFleetDriver` 的初始化和 setup 函数的调用过程中,与情况 A 首要的不同就在于用户在 trainer 中输入的参数 device 不再有效, + 这时每个进程所使用的 gpu 是我们直接通过 `CUDA_VISIBLE_DEVICE` 来配置的;因此,如果用户想要实现使用特定 gpu + 设备的目的,可以通过自己设置环境变量实现(例如 os.environ["CUDA_VISIBLE_DEVICE"] 来实现,我们会通过一定的手段将其保存起来); + 剩下的操作和情况 A 类似; + + -> 情况 C:注意这种情况我们限定了用户是通过 paddle.distributed.launch 拉起,并且 ddp 的通信组也是由自己建立。这时基本上所有的 + 与操作相关的操作都应当由用户自己完成,包括迁移模型到对应 gpu 上以及将模型用 `DataParallel` 包裹等。 + (2)如果 `driver.setup` 函数在脚本中会被调用两次及以上(意味着用户的启动脚本初始化了两个及以上的 trainer/evaluator)时: + 注意这种情况下我们是会保证前后两个 trainer/evaluator 使用的 `PaddleFleetDriver` 以及其初始化方式的一致性,换句话说,如果 trainer1 + 检测到的启动方式是 '情况 A',那么我们会保证 trainer2 检测到的启动方式同样是 '情况A'(即使这需要一些额外的处理);因此这里我们主要讨论 + 我们是通过怎样的操作来保证 trainer2/3/... 检测到的启动方式是和 trainer1 一致的;简单来说,我们是通过使用环境变量来标记每一种不同的 + 启动方式来实现这一点的: + 我们会使用 `FASTNLP_DISTRIBUTED_CHECK` 来标记 '情况 A',使用 `fastnlp_torch_launch_not_ddp` 来标记 '情况 B',意味着我们在 + 使用 '情况 A' 来启动 `PaddleFleetDriver` 时,我们会将 `FASTNLP_DISTRIBUTED_CHECK` 这一字符串注入到环境变量中,而 '情况 B' 时则 + 会将 `fastnlp_torch_launch_not_ddp` 这一字符串注入到环境变量中。因此在 trainer2 的 `PaddleFleetDriver` 的初始化和 setup 过程中, + 如果检测到这些特殊的环境变量,我们就会将启动方式变更为其对应的启动方式,即使其它的参数特征属于另外的启动方式。 + + Part 2:对应的代码细节: + 1. 如何判断当前的各进程之间的通信组已经被建立(fleet 已经被初始化); + parallel_helper._is_parallel_ctx_initialized(); + 2. 如何判断不同的进程是否是由 `python -m paddle.distributed.launch` 拉起还是由我们的 `FleetLauncher.launch()` + 函数拉起; + 我们会在用户脚本 `import fastNLP` 的时候检测当前的环境变量中是否有 'PADDLE_RANK_IN_NODE'、'PADDLE_TRAINER_ID' + 以及没有 `FASTNLP_DISTRIBUTED_CHECK`, + 如果满足条件,则我们会向环境变量中注入特殊的值 'FASTNLP_BACKEND_LAUNCH' 来标记用户是否使用了 `python -m paddle.distributed.launch` + 来拉起多个进程; + 3. 整体的处理判断流程: + ___________________________________ + |进入 PaddleFleetDriver 的 __init__ 函数| + ——————————————————————————————————— + ↓ + ___________________________________________________ + | 判断不同的进程是否是由 paddle.distributed.launch 拉起 | + |(或者我们自己的 FleetLauncher 函数拉起) | --------------> + ———————————————————————————————————————————————————  | + ↓ 是由 paddle.distributed.launch 拉起 | 我们自己的 FleetLauncher 函数拉起多个进程 +  _____________________________            |  + ←←←←← | 检测用户是否自己初始化了 fleet |              | + ↓ —————————————————————————————                  ↓ + ↓ ↓ 是 ________ + ↓ ______ | 情况 A | + ↓ 否 |情况 C| ————————— + ↓ ——————— + ↓ + ↓ ______ + ↓ -----------> |情况 B| +   ——————— + 4. 为了完成全部的建立分布式所需要的操作,三种情况都需要做的事情,以及每件事情的职责归属: + + 情况 A | 情况 B | 情况 C + ________________________________________________________________________________________________________ + 配置 fleet 所 | FleetLauncher.launch | paddle.distributed.launch| paddle.distributed.launch + 需要的环境变量 | | | + ———————————————————————————————————————————————————————————————————————————————————————————————————————— + 开启多个进程 | FleetLauncher.launch | paddle.distributed.launch| paddle.distributed.launch + ———————————————————————————————————————————————————————————————————————————————————————————————————————— + 调用 fleet.init函数 | PaddleFleetDriver.setup | PaddleFleetDriver.setup | 用户自己调用 + ———————————————————————————————————————————————————————————————————————————————————————————————————————— + 设置 PaddleFleetDriver | | | + 的 world_size 和 | PaddleFleetDriver.setup | PaddleFleetDriver.setup | PaddleFleetDriver.setup + global_rank 属性 | | | + ———————————————————————————————————————————————————————————————————————————————————————————————————————— + + Part 3:其它的处理细节: + 1. 环境变量; + fastNLP 的 `PaddleFleetDriver` 运行时所需要的环境变量分为两种,一种是 paddle fleet 运行所需要的环境变量;另一种是 fastNLP 自己 + 的环境变量。前者的配置情况如上表所示;而后者中的大多数环境变量则是在用户 import fastNLP 时就设置好了; + 2. parallel_device, model_device 和 data_device 的关系; + parallel_device 为 `PaddleFleetDriver` 的参数,model_device 和 data_device 都为 driver 的属性; + 其中 data_device 仅当情况 C 时由用户自己指定;如果其不为 None,那么在模型 forward 的时候,我们就会将数据迁移到 data_device 上; + model_device 永远都为单独的一个 torch.device; + + 情况 A | 情况 B | 情况 C + ________________________________________________________________________________________________________ + parallel_device | 由用户传入trainer的参数 | | + | device 决定,必须是一个list, | 为 CUDA_VISIBLE_DEVICES | 为 CUDA_VISIBLE_DEVICES + | 其中每一个对象都是 int | | + ———————————————————————————————————————————————————————————————————————————————————————————————————————— + model_device | parallel_device[local_rank] | parallel_device | None + ———————————————————————————————————————————————————————————————————————————————————————————————————————— + data_device | model_device | model_device | 由用户传入 trainer 的参数 + | | | data_device 决定 + ———————————————————————————————————————————————————————————————————————————————————————————————————————— + + 3. _DDPWrappingModel 的作用; + 因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DataParallel` 的forward 函数来帮助 + 我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DataParallel` 的 forward 方法, + 然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 forward 函数,还是 + `train_step`、`evaluate_step`、`test_step` 方法。 + + 4. 当某一个进程出现 exception 后,`PaddleFleetDriver` 的处理; + + 不管是什么情况,`PaddleFleetDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, + driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; """ super(PaddleFleetDriver, self).__init__(model, fp16=fp16, **kwargs) @@ -78,6 +193,7 @@ class PaddleFleetDriver(PaddleDriver): "when your value of parameter `device` is `None` in your `Trainer` instance.") # 如果用户自己初始化了 paddle 的分布式训练那么一定是通过 launch 拉起的 + # 这个参数会在 initialize_paddle_drvier 中设置。 self.is_pull_by_paddle_run = is_pull_by_paddle_run self.parallel_device = parallel_device # 在初始化时,如果发现 is_pull_by_paddle_run ,则将 parallel_device 设置成当前进程的gpu @@ -98,7 +214,7 @@ class PaddleFleetDriver(PaddleDriver): self.outside_fleet = True # 用户只有将模型上传到对应机器上后才能用 DataParallel 包裹,因此如果用户在外面初始化了 Fleet,那么在 PaddleFleetDriver 中 - # 我们就直接将 model_device 置为 None; + # 我们就直接将 model_device 置为 None; self._model_device = None # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; @@ -119,9 +235,12 @@ class PaddleFleetDriver(PaddleDriver): self.world_size = None self.global_rank = 0 + self.gloo_rendezvous_dir = None + # 分布式环境的其它参数设置 self._fleet_kwargs = kwargs.get("paddle_fleet_kwargs", {}) check_user_specific_params(self._fleet_kwargs, DataParallel.__init__) + # fleet.init 中对于分布式策略的设置,详情可以参考 PaddlePaddle 的官方文档 self.strategy = self._fleet_kwargs.get("strategy", fleet.DistributedStrategy()) self.is_collective = self._fleet_kwargs.get("is_collective", True) if not self.is_collective: @@ -145,7 +264,10 @@ class PaddleFleetDriver(PaddleDriver): def setup(self): """ - 在主进程拉起其它子进程,将主进程作为rank 0 + 根据不同的情况进行不同的设置。 + 1、如果是通过 paddle.distributed.launch 方法启动时,则根据已经设置好的环境获取 + 分布式的属性。 + 2、否则,调用 FleetLauncher 类启动子进程 """ if self._has_setup: return @@ -174,7 +296,7 @@ class PaddleFleetDriver(PaddleDriver): # 此时 parallel_helper._is_parallel_ctx_initialized() 一定为 False # parallel_device 是 list, if not parallel_helper._is_parallel_ctx_initialized(): - # 没有初始化分布式环境,且是主进程 + # 拉起子进程并设置相应的属性 self.init_fleet_and_set() # 用户在这个 trainer 前面又初始化了一个 trainer,并且使用的是 PaddleFleetDriver; else: @@ -216,12 +338,13 @@ class PaddleFleetDriver(PaddleDriver): # 是 rank0 的话,则拉起其它子进程 launcher = FleetLauncher(self.parallel_device, self.output_from_new_proc) launcher.launch() + self.gloo_rendezvous_dir = launcher.gloo_rendezvous_dir # 设置参数和初始化分布式环境 fleet.init(self.role_maker, self.is_collective, self.strategy) self.global_rank = int(os.getenv("PADDLE_TRAINER_ID")) self.world_size = int(os.getenv("PADDLE_TRAINERS_NUM")) - # 正常情况下不会Assert出问题,但还是保险一下 + # 正常情况下不会 Assert 出问题,但还是保险一下 assert self.global_rank is not None assert self.world_size is not None assert self.world_size == len(self.parallel_device) @@ -235,10 +358,19 @@ class PaddleFleetDriver(PaddleDriver): self.global_rank = paddledist.get_rank() def barrier(self): + r""" + 用于在多进程工作时同步各进程的工作进度,运行快的进程运行到这里会等待运行慢的进程,只有所有进程都运行到此函数时,所有的进程才会继续运行; + 仅在多分布式训练场景中有使用。 + + 注意,该函数的行为会受到 FASTNLP_NO_SYNC 的影响。仅当 FASTNLP_NO_SYNC 在 os.environ 中不存在,或小于 1 时才真的执行 barrier 。 + """ if int(os.environ.get(FASTNLP_NO_SYNC, 0)) < 1: # 当 FASTNLP_NO_SYNC 小于 1 时实际执行 paddledist.barrier() def configure_fleet(self): + """ + 将模型用 DataParallel 和自定义的类型包裹起来 + """ if not self._has_fleetwrapped and not isinstance(self.model, DataParallel): self.model = DataParallel( _FleetWrappingModel(self.model), @@ -247,8 +379,14 @@ class PaddleFleetDriver(PaddleDriver): self._has_fleetwrapped = True def on_exception(self): - if os.path.exists(self.gloo_rendezvous_dir): - shutil.rmtree(self.gloo_rendezvous_dir) + """ + 该函数用于在训练或者预测过程中出现错误时正确地关掉其它的进程,这一点是通过在多进程 driver 调用 open_subprocess 的时候将每一个进程 + 的 pid 记录下来,然后在出现错误后,由出现错误的进程手动地将其它进程 kill 掉; + + 因此,每一个多进程 driver 如果想要该函数能够正确地执行,其需要在自己的 open_subprocess(开启多进程的函数)中正确地记录每一个进程的 + pid 的信息; + """ + rank_zero_rm(self.gloo_rendezvous_dir) super().on_exception() @property @@ -282,6 +420,17 @@ class PaddleFleetDriver(PaddleDriver): return self.model_device def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: + """ + 通过调用 `fn` 来实现训练时的前向传播过程; + 注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的 + 函数; + + :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; + :param fn: 调用该函数进行一次计算。 + :param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call + 函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; + :return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); + """ if self._has_fleetwrapped: return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn, wo_auto_param_call=self.wo_auto_param_call) @@ -292,6 +441,27 @@ class PaddleFleetDriver(PaddleDriver): return fn(batch) def get_model_call_fn(self, fn: str) -> Tuple: + """ + 该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数; + 该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用; + + 之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; + 这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 + `evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 + `evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 + `evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; + + 这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: + 1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` + 函数,然后给出 warning; + 2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; + 注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 + forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 + 可能需要额外标记最初传入 driver 的模型是哪种形式的; + + :param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法; + :return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入; + """ model = self.unwrap_model() if self._has_fleetwrapped: if hasattr(model, fn): @@ -316,7 +486,25 @@ class PaddleFleetDriver(PaddleDriver): return self.model, model.forward def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, RandomBatchSampler]], - reproducible: bool = False, sampler_or_batch_sampler=None): + reproducible: bool = False): + r""" + 根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 + + :param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本 + :param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader + 切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 + 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 + 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; + 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; + 注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; + 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; + + :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 + 可以可以加载。 + :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, + 如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 + dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 + """ # 暂时不支持iterableDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ "FastNLP does not support `IteratorDataset` now." @@ -429,10 +617,7 @@ class PaddleFleetDriver(PaddleDriver): @staticmethod def _check_optimizer_legality(optimizers): - """ - paddle存在设置分布式optimizers的函数,返回值为fleet.meta_optimizers.HybridParallelOptimizer - 重写是为了防止单卡下也传入了分布式的优化器 - """ + # paddle 存在设置分布式 optimizers 的函数,返回值为 fleet.meta_optimizers.HybridParallelOptimizer DistribuedOptimizer = fleet.meta_optimizers.HybridParallelOptimizer for each_optimizer in optimizers: if not isinstance(each_optimizer, (Optimizer, DistribuedOptimizer)): diff --git a/fastNLP/core/drivers/paddle_driver/fleet_launcher.py b/fastNLP/core/drivers/paddle_driver/fleet_launcher.py index 66eccfca..471679a7 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet_launcher.py +++ b/fastNLP/core/drivers/paddle_driver/fleet_launcher.py @@ -20,7 +20,7 @@ from .utils import ( # 记录各个进程信息 class SubTrainer(object): """ - 和fastnlp的Triainer没有关系,仅用于统计节点内不同训练的一些信息 + 用于统计节点内不同训练进程的信息,和 fastnlp 的 Triainer 没有关系 """ def __init__(self, endpoint=None, rank=None): self.devices = [] @@ -30,8 +30,8 @@ class SubTrainer(object): class FleetLauncher: """ - 复原了 paddle 的 launch_collective 函数,将其集成到一个类里 - 仅支持单机多卡的启动 + 复原了 paddle 的 launch_collective 函数,将其简化后集成到一个类里 + 仅支持每个机器单卡的情况。 """ def __init__( self, @@ -45,17 +45,26 @@ class FleetLauncher: self.setup() def setup(self): - + """ + 进行初始化设置的函数,根据传入的设备找到分布式训练使用的端口号 + """ self.set_endpoints() self.sub_trainers = self.get_process_info() - def launch(self) -> int: + def launch(self): + """ + 用于启动分布式进程。 + 首先设置 PaddlePaddle 分布式训练需要设置的环境变量,然后建立新的子进程 + """ # 设置环境变量 self.global_envs = self.get_global_env() self.open_subprocess() reset_seed() def open_subprocess(self): + """ + 从 sub_trainers 中获取各个 rank 的信息,并且使用 subprocess.Popen 建立新的子进程。 + """ if __main__.__spec__ is None: # Script called as `python a/b/c.py` @@ -77,6 +86,7 @@ class FleetLauncher: current_env = copy.copy(self.global_envs) for idx, t in enumerate(self.sub_trainers): + # 根据不同的 rank 设置环境变量 proc_env = { # global_rank "PADDLE_TRAINER_ID": f"{t.rank}", @@ -108,6 +118,14 @@ class FleetLauncher: os.environ.update(current_env) def get_global_env(self): + """ + 设置分布式训练需要的全局变量,包括: + 1、GLOO 相关的设置 + 2、`PADDLE_TRAINERS_NUM` :所有的进程数目 + 3、`PADDLE_TRAINER_ENDPOINTS` :使用的所有地址及其端口 + 4、`PADDLE_WORLD_DEVICE_IDS` :使用的所有设备 + 5、FASTNLP_DISTRIBUTED_CHECK:通过 fastNLP 建立子进程的标志,保存分布式训练使用的设备 + """ global_envs = copy.copy(os.environ.copy()) self.gloo_rendezvous_dir = tempfile.mkdtemp() @@ -137,7 +155,7 @@ class FleetLauncher: def set_endpoints(self): """ - Reference to `get_cluster_from_args` + 寻找用户设置的端口或是空闲端口用于分布式训练,参考了 PaddlePaddle 中的 `get_cluster_from_args` 函数 """ self.node_ip = "127.0.0.1" @@ -157,7 +175,7 @@ class FleetLauncher: def get_process_info(self): """ - Reference to `get_cluster` + 获取各个训练进程的设备、rank 和端口信息,参考 PaddlePaddle 的 `get_cluster` 函数。 """ sub_trainers = [] assert len(self.endpoints) >= len( diff --git a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py index eac2d4a4..9a9d4198 100644 --- a/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/initialize_paddle_driver.py @@ -17,14 +17,16 @@ def initialize_paddle_driver(driver: str, device: Optional[Union[str, int, List[ model: paddle.nn.Layer, **kwargs) -> PaddleDriver: r""" 用来根据参数 `driver` 和 `device` 来确定并且初始化一个具体的 `Driver` 实例然后返回回去; - 注意如果输入的 `device` 如果和 `driver` 对应不上就直接报错; + 1、如果检测到当前进程为用户通过 `python -m paddle.distributed.launch xxx.py` 方式拉起的,则将 + 设备自动设置为用户指定的设备(由于我们在引入 fastNLP 进行了特殊的设置,因此可以通过 `CUDA_VISIBLE_DEVICES` 获取) + 2、如果检测到输入的 `driver` 是 `paddle` 但 `device` 包含了多个设备,那么我们会给出警告并且自动返回多卡的 Driver + 3、如果检测到输入的 `driver` 是 `fleet` 但 `device` 仅有一个设备,那么我们会给出警告但仍旧返回多卡的 Driver :param driver: 该参数的值应为以下之一:["paddle", "fleet"]; :param device: 该参数的格式与 `Trainer` 对参数 `device` 的要求一致; :param model: 训练或者评测的具体的模型; - :return: 返回一个元组,元组的第一个值是具体的基于 pytorch 的 `Driver` 实例,元组的第二个值是该 driver 的名字(用于检测一个脚本中 - 先后 driver 的次序的正确问题); + :return: 返回构造的 `Driver` 实例。 """ if is_in_paddle_launch_dist(): if device is not None: diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index e47360ee..f140ad69 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -31,6 +31,9 @@ __all__ = [ ] class PaddleSingleDriver(PaddleDriver): + """ + 支持 paddle cpu 或单卡 gpu 训练的 driver + """ def __init__(self, model, device: Union[str, int], fp16: Optional[bool] = False, **kwargs): if isinstance(model, DataParallel): raise ValueError("`paddle.DataParallel` is not supported in `PaddleSingleDriver`") @@ -59,18 +62,53 @@ class PaddleSingleDriver(PaddleDriver): self.world_size = 1 def setup(self): + r""" + 该函数用来初始化训练环境,用于设置当前训练的设备,并将模型迁移到对应设备上。 + """ device = self.model_device device = get_device_from_visible(device, output_type=str) paddle.device.set_device(device) self.model.to(device) def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: + """ + 通过调用 `fn` 来实现训练时的前向传播过程; + 注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的 + 函数; + + :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; + :param fn: 调用该函数进行一次计算。 + :param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call + 函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; + :return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); + """ if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(fn, batch, signature_fn=signature_fn) else: return fn(batch) def get_model_call_fn(self, fn: str) -> Tuple: + """ + 该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数; + 该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用; + + 之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; + 这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 + `evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 + `evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 + `evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; + + 这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: + 1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` + 函数,然后给出 warning; + 2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; + 注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 + forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 + 可能需要额外标记最初传入 driver 的模型是哪种形式的; + + :param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法; + :return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入; + """ if hasattr(self.model, fn): fn = getattr(self.model, fn) if not callable(fn): @@ -95,6 +133,24 @@ class PaddleSingleDriver(PaddleDriver): def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, reproducible: bool = False): + r""" + 根据输入的 dataloader 得到一个 支持分布式 (distributed) 与 可复现的 (reproducible) 的 dataloader。 + + :param dataloader: 根据 dataloader 设置其对应的分布式版本以及可复现版本 + :param dist: 应当为一个字符串,其值应当为以下之一:[None, "dist", "unrepeatdist"];为 None 时,表示不需要考虑当前 dataloader + 切换为分布式状态;为 'dist' 时,表示该 dataloader 应该保证每个 gpu 上返回的 batch 的数量是一样多的,允许出现少量 sample ,在 + 不同 gpu 上出现重复;为 'unrepeatdist' 时,表示该 dataloader 应该保证所有 gpu 上迭代出来的数据合并起来应该刚好等于原始的 + 数据,允许不同 gpu 上 batch 的数量不一致。其中 trainer 中 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "dist"; + 否则为 None ,evaluator 中的 kwargs 的参数 `use_dist_sampler` 为 True 时,该值为 "unrepeatdist",否则为 None; + 注意当 dist 为 ReproducibleSampler, ReproducibleBatchSampler 时,是断点重训加载时 driver.load 函数在调用; + 当 dist 为 str 或者 None 时,是 trainer 在初始化时调用该函数; + + :param reproducible: 如果为 False ,不要做任何考虑;如果为 True ,需要保证返回的 dataloader 可以保存当前的迭代状态,使得 + 可以可以加载。 + :return: 应当返回一个被替换 sampler 后的新的 dataloader 对象 (注意此处一定需要返回一个新的 dataloader 对象) ;此外, + 如果传入的 dataloader 中是 ReproducibleSampler 或者 ReproducibleBatchSampler 需要重新初始化一个放入返回的 + dataloader 中。如果 dist 为空,且 reproducible 为 False,可直接返回原对象。 + """ # 暂时不支持iterableDataset assert dataloader.dataset_kind != _DatasetKind.ITER, \ diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 48598a34..6cd7b252 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -69,7 +69,6 @@ def paddle_seed_everything(seed: Optional[int] = None, workers: bool = False) -> os.environ[FASTNLP_SEED_WORKERS] = f"{int(workers)}" return seed - def reset_seed() -> None: """ fleet 会开启多个进程,因此当用户在脚本中指定 seed_everything 时,在开启多个脚本后,会在每个脚本内重新 @@ -80,16 +79,10 @@ def reset_seed() -> None: if seed is not None: paddle_seed_everything(int(seed), workers=bool(int(workers))) -class ForwardState(IntEnum): - TRAIN = 0 - VALIDATE = 1 - TEST = 2 - PREDICT = 3 - class _FleetWrappingModel(Layer): """ - 参考_DDPWrappingModel,paddle的分布式训练也需要用paddle.nn.DataParallel进行包装,采用和 - pytorch相似的处理方式 + 参考 _DDPWrappingModel , paddle 的分布式训练也需要用 paddle.nn.DataParallel 进行包装,采用和 + pytorch 相似的处理方式 """ def __init__(self, model: 'nn.Layer'): super(_FleetWrappingModel, self).__init__() @@ -109,7 +102,6 @@ class _FleetWrappingModel(Layer): class DummyGradScaler: """ 用于仿造的GradScaler对象,防止重复写大量的if判断 - """ def __init__(self, *args, **kwargs): pass @@ -152,6 +144,9 @@ def _build_fp16_env(dummy=False): return auto_cast, GradScaler def find_free_ports(num): + """ + 在空闲的端口中找到 num 个端口 + """ def __free_port(): with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: s.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, @@ -178,18 +173,11 @@ def find_free_ports(num): return None -def get_host_name_ip(): - try: - host_name = socket.gethostname() - host_ip = socket.gethostbyname(host_name) - return host_name, host_ip - except: - return None - def get_device_from_visible(device: Union[str, int], output_type=int): """ 在有 CUDA_VISIBLE_DEVICES 的情况下,获取对应的设备。 如 CUDA_VISIBLE_DEVICES=2,3 ,device=3 ,则返回1。 + :param device: 未转化的设备名 :param output_type: 返回值的类型 :return: 转化后的设备id diff --git a/fastNLP/core/utils/paddle_utils.py b/fastNLP/core/utils/paddle_utils.py index 1f461e0f..e65cd735 100644 --- a/fastNLP/core/utils/paddle_utils.py +++ b/fastNLP/core/utils/paddle_utils.py @@ -22,6 +22,13 @@ from .utils import apply_to_collection def paddle_to(data, device: Union[str, int]): + """ + 将 `data` 迁移到指定的 `device` 上 + + :param data: 要迁移的张量 + :param device: 目标设备,可以是 `str` 或 `int` + :return: 迁移后的张量 + """ if device == "cpu": return data.cpu() @@ -31,6 +38,9 @@ def paddle_to(data, device: Union[str, int]): def get_paddle_gpu_str(device: Union[str, int]): """ 获得 `gpu:x` 类型的设备名 + + :param device: 设备编号或设备名 + :return: 返回对应的 `gpu:x` 格式的设备名 """ if isinstance(device, str): return device.replace("cuda", "gpu") @@ -38,7 +48,10 @@ def get_paddle_gpu_str(device: Union[str, int]): def get_paddle_device_id(device: Union[str, int]): """ - 获得 gpu 的设备id,注意不要传入 `cpu` 。 + 获得 gpu 的设备id + + :param: device: 设备编号或设备名 + :return: 设备对应的编号 """ if isinstance(device, int): return device From d04f49a835a023d10377284172d25843b607a68c Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 26 Apr 2022 05:33:17 +0000 Subject: [PATCH 09/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9BucketedBatchSampler=20?= =?UTF-8?q?batch=5Fidx=5Fin=5Fepoch=20=E7=9A=84=E8=AE=A1=E7=AE=97=E6=96=B9?= =?UTF-8?q?=E5=BC=8F=EF=BC=8C=E4=BD=BF=E5=85=B6=E5=9C=A8=E5=88=86=E5=B8=83?= =?UTF-8?q?=E5=BC=8F=E6=9D=A1=E4=BB=B6=E4=B8=8B=E5=8F=AF=E4=BB=A5=E6=AD=A3?= =?UTF-8?q?=E7=A1=AE=E5=9C=B0=E5=8F=8D=E5=BA=94=E8=BF=AD=E4=BB=A3=E6=AC=A1?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/samplers/reproducible_batch_sampler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/fastNLP/core/samplers/reproducible_batch_sampler.py b/fastNLP/core/samplers/reproducible_batch_sampler.py index e8acc645..2bbf409f 100644 --- a/fastNLP/core/samplers/reproducible_batch_sampler.py +++ b/fastNLP/core/samplers/reproducible_batch_sampler.py @@ -416,7 +416,7 @@ class BucketedBatchSampler(ReproducibleBatchSampler): @property def batch_idx_in_epoch(self): if self.drop_last: - return len(self.dataset) // self.batch_size - (len(self.dataset) - self.num_consumed_samples) // self.batch_size + return len(self.dataset) // self.num_replicas // self.batch_size - self.num_left_samples // self.batch_size else: - return (len(self.dataset) + self.batch_size - 1) // self.batch_size - \ - (len(self.dataset) - self.num_consumed_samples + self.batch_size - 1) // self.batch_size \ No newline at end of file + return (len(self.dataset) // self.num_replicas + self.batch_size - 1) // self.batch_size - \ + (self.num_left_samples + self.batch_size - 1) // self.batch_size \ No newline at end of file From df109316e5e0d1fae8e52df7da336ddfc1f1b508 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 26 Apr 2022 05:33:37 +0000 Subject: [PATCH 10/16] small --- tests/core/drivers/paddle_driver/test_fleet.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/core/drivers/paddle_driver/test_fleet.py b/tests/core/drivers/paddle_driver/test_fleet.py index 52739f53..34c80888 100644 --- a/tests/core/drivers/paddle_driver/test_fleet.py +++ b/tests/core/drivers/paddle_driver/test_fleet.py @@ -527,7 +527,7 @@ class TestSaveLoad: @classmethod def setup_class(cls): # 不在这里 setup 的话会报错 - cls.driver = generate_driver(10, 10) + cls.driver = generate_driver(10, 10, device=[0,1]) def setup_method(self): self.dataset = PaddleRandomMaxDataset(20, 10) @@ -633,7 +633,7 @@ class TestSaveLoad: batch_sampler=BucketedBatchSampler( self.dataset, length=[10 for i in range(len(self.dataset))], - batch_size=4, + batch_size=2, ) ) dataloader.batch_sampler.set_distributed( From f319b5bce149ba52e23e5b10c65be4b296585550 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 26 Apr 2022 05:49:32 +0000 Subject: [PATCH 11/16] =?UTF-8?q?torch=20ddp=20=E7=9A=84=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/core/drivers/torch_driver/test_ddp.py | 788 ++++++++++++++++++++ 1 file changed, 788 insertions(+) create mode 100644 tests/core/drivers/torch_driver/test_ddp.py diff --git a/tests/core/drivers/torch_driver/test_ddp.py b/tests/core/drivers/torch_driver/test_ddp.py new file mode 100644 index 00000000..0e91fe77 --- /dev/null +++ b/tests/core/drivers/torch_driver/test_ddp.py @@ -0,0 +1,788 @@ +import pytest +import os +from pathlib import Path + +os.environ["FASTNLP_BACKEND"] = "torch" +from fastNLP.core.drivers.torch_driver.ddp import TorchDDPDriver +from fastNLP.core.samplers import ( + RandomSampler, + UnrepeatedSampler, + BucketedBatchSampler, + UnrepeatedRandomSampler, + UnrepeatedSequentialSampler, +) +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from tests.helpers.datasets.torch_data import TorchNormalDataset, TorchArgMaxDataset +from tests.helpers.utils import magic_argv_env_context +from fastNLP.core import rank_zero_rm + +import torch +import torch.distributed as dist +from torch.utils.data import DataLoader, BatchSampler + +def generate_driver(num_labels, feature_dimension, device=[0,1], fp16=False, output_from_new_proc="only_error"): + torch_model = TorchNormalModel_Classification_1(num_labels, feature_dimension) + torch_opt = torch.optim.Adam(params=torch_model.parameters(), lr=0.01) + device = [torch.device(i) for i in device] + driver = TorchDDPDriver( + model=torch_model, + parallel_device=device, + fp16=fp16, + output_from_new_proc=output_from_new_proc + ) + driver.set_optimizers(torch_opt) + driver.setup() + + return driver + +def dataloader_with_bucketedbatchsampler(dataset, length, batch_size, shuffle, drop_last): + """ + 建立一个 batch_sampler 为 BucketedBatchSampler 的 dataloader + """ + dataloader = DataLoader( + dataset=dataset, + batch_sampler=BucketedBatchSampler( + dataset, + length, + batch_size, + shuffle=shuffle, + drop_last=drop_last, + ), + ) + + return dataloader + +def dataloader_with_randomsampler(dataset, batch_size, shuffle, drop_last, seed=0, unrepeated=False): + """ + 建立一个 sampler 为 RandomSampler 的 dataloader + """ + if unrepeated: + sampler = UnrepeatedRandomSampler(dataset, shuffle, seed) + else: + sampler = RandomSampler(dataset, shuffle, seed=seed) + dataloader = DataLoader( + dataset, + sampler=sampler, + drop_last=drop_last, + batch_size=batch_size + ) + return dataloader + +############################################################################ +# +# 测试 TorchDDPDriver 的一些函数 +# +############################################################################ + +class TestDDPDriverFunction: + """ + 测试 TorchDDPDriver 一些简单函数的测试类,基本都是测试能否运行、是否存在 import 错误等问题 + """ + + @classmethod + def setup_class(cls): + cls.driver = generate_driver(10, 10) + + @magic_argv_env_context + def test_multi_drivers(self): + """ + 测试使用了多个 TorchDDPDriver 的情况。 + """ + + driver2 = generate_driver(20, 10) + + with pytest.raises(RuntimeError): + # 设备设置不同,应该报错 + driver3 = generate_driver(20, 3, device=[0,1,2]) + assert False + dist.barrier() + + @magic_argv_env_context + def test_move_data_to_device(self): + """ + 这个函数仅调用了torch_move_data_to_device,测试例在tests/core/utils/test_torch_utils.py中 + 就不重复测试了 + """ + self.driver.move_data_to_device(torch.rand((32, 64))) + + dist.barrier() + + @magic_argv_env_context + def test_is_distributed(self): + """ + 测试 is_distributed 函数 + """ + assert self.driver.is_distributed() == True + dist.barrier() + + @magic_argv_env_context + def test_get_no_sync_context(self): + """ + 测试 get_no_sync_context 函数 + """ + res = self.driver.get_model_no_sync_context() + dist.barrier() + + @magic_argv_env_context + def test_is_global_zero(self): + """ + 测试 is_global_zero 函数 + """ + self.driver.is_global_zero() + dist.barrier() + + @magic_argv_env_context + def test_unwrap_model(self): + """ + 测试 unwrap_model 函数 + """ + self.driver.unwrap_model() + dist.barrier() + + @magic_argv_env_context + def test_get_local_rank(self): + """ + 测试 get_local_rank 函数 + """ + self.driver.get_local_rank() + dist.barrier() + + @magic_argv_env_context + def test_all_gather(self): + """ + 测试 all_gather 函数 + 详细的测试在 test_dist_utils.py 中完成 + """ + obj = { + "rank": self.driver.global_rank + } + obj_list = self.driver.all_gather(obj, group=None) + for i, res in enumerate(obj_list): + assert res["rank"] == i + + @magic_argv_env_context + @pytest.mark.parametrize("src_rank", ([0, 1])) + def test_broadcast_object(self, src_rank): + """ + 测试 broadcast_object 函数 + 详细的函数在 test_dist_utils.py 中完成 + """ + if self.driver.global_rank == src_rank: + obj = { + "rank": self.driver.global_rank + } + else: + obj = None + res = self.driver.broadcast_object(obj, src=src_rank) + assert res["rank"] == src_rank + +############################################################################ +# +# 测试 set_dist_repro_dataloader 函数 +# +############################################################################ + +class TestSetDistReproDataloader: + + @classmethod + def setup_class(cls): + cls.device = [0, 1] + cls.driver = generate_driver(10, 10, device=cls.device) + + def setup_method(self): + self.dataset = TorchNormalDataset(40) + + """ + 传入的 `dist` 参数为具体的 ReproducibleSampler 或 ReproducibleBatchSampler 的情况 + 此时对应 driver.load 中的情况 + """ + + @magic_argv_env_context + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_batch_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 BucketedBatchSampler 时的表现 + 此时应该将 batch_sampler 替换为 dist 对应的 BucketedBatchSampler + """ + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) + batch_sampler = BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, batch_sampler, False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + assert replaced_loader.batch_sampler is batch_sampler + self.check_distributed_sampler(replaced_loader.batch_sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + dist.barrier() + + @magic_argv_env_context + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 RandomSampler 时的表现 + 此时应该将 batch_sampler.sampler 替换为 dist 对应的 RandomSampler + """ + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=not shuffle) + sampler = RandomSampler(self.dataset, shuffle=shuffle) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, sampler, False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert replaced_loader.batch_sampler.sampler is sampler + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + dist.barrier() + + """ + 传入的参数 `dist` 为 None 的情况,这种情况出现在 trainer 和 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` + 参数为 False。此时函数会根据 `reproducible` 的设置进行不同的处理。 + 当 `reproducible` 为 False 时,需要根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定 + 是否重新实例化 dataloader + """ + + @magic_argv_env_context + def test_with_dist_none_reproducible_true(self): + """ + 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 True 时的表现 + 当用户在 driver 之外初始化了分布式环境时,fastnlp 不支持进行断点重训,此时应该报错 + """ + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=True) + with pytest.raises(RuntimeError): + # 应当抛出 RuntimeError + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, True) + + dist.barrier() + + @magic_argv_env_context + # @pytest.mark.parametrize("shuffle", ([True, False])) + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_none_reproducible_false_dataloader_reproducible_batch_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 BucketedBatchSampler + 时的表现 + 此时传入的 dataloader 的 batch_sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 batch_sampler + 和原 dataloader 相同 + """ + dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) + dataloader.batch_sampler.set_distributed( + num_replicas=self.driver.world_size, + rank=self.driver.global_rank, + pad=True + ) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + assert replaced_loader.batch_sampler.batch_size == 4 + self.check_distributed_sampler(dataloader.batch_sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + dist.barrier() + + @magic_argv_env_context + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_none_reproducible_false_dataloader_reproducible_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 有 RandomSampler 时的表现 + 此时传入的 dataloader 的 batch_sampler.sampler 应该已经执行了 set_distributed,产生一个新的 dataloader,其 + batch_sampler.sampler 和原 dataloader 相同 + """ + dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) + dataloader.batch_sampler.sampler.set_distributed( + num_replicas=self.driver.world_size, + rank=self.driver.global_rank + ) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.batch_sampler.drop_last == False + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + self.check_set_dist_repro_dataloader(dataloader, replaced_loader, shuffle) + + dist.barrier() + + @magic_argv_env_context + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_none_reproducible_false_dataloader_normal(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 None、reproducible 为 False 、dataloader 为一般情况时的表现 + 此时直接返回原来的 dataloader,不做任何处理。 + """ + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, None, False) + + assert replaced_loader is dataloader + dist.barrier() + + """ + 传入的参数 `dist` 为 'dist' 的情况,这种情况出现在 trainer 的初始化过程中,用户指定了 `use_dist_sampler` 参数 + 为 True。此时函数会根据 dataloader 的 batch_sampler 或 sampler 是否为 Reproducible 来决定如何重新实例化 dataloader + """ + + @magic_argv_env_context + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_dist_dataloader_reproducible_batch_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler 为 ReproducibleBatchSampler + 的表现 + 此时应该返回一个新的 dataloader,其batch_sampler 和原 dataloader 相同,且应该正确地设置了分布式相关的属性 + """ + dataloader = DataLoader( + dataset=self.dataset, + batch_sampler=BucketedBatchSampler(self.dataset, self.dataset._data, batch_size=4, shuffle=shuffle) + ) + dataloader = dataloader_with_bucketedbatchsampler(self.dataset, self.dataset._data, 4, shuffle, False) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.drop_last == dataloader.drop_last + self.check_distributed_sampler(replaced_loader.batch_sampler) + dist.barrier() + + @magic_argv_env_context + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_dist_dataloader_reproducible_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler + 的表现 + 此时应该返回一个新的 dataloader,其 batch_sampler.sampler 和原 dataloader 相同,且应该正确地设置了分布式相关 + 的属性 + """ + dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + + assert not (replaced_loader is dataloader) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + dist.barrier() + + @magic_argv_env_context + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_dist_dataloader_normal(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'dist'、dataloader 为一般情况的表现 + 此时应该返回一个新的 dataloader,并替换其 batch_sampler.sampler 为 RandomSampler,且应该正确设置了分布式相关 + 的属性 + """ + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "dist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert replaced_loader.batch_sampler.batch_size == dataloader.batch_sampler.batch_size + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + dist.barrier() + + """ + 传入的参数 `dist` 为 'unrepeatdist' 的情况,这种情况出现在 evaluator 的初始化过程中,用户指定了 `use_dist_sampler` 参数 + 为 True。此时函数会根据 dataloader 的 sampler 是否为 Unrepeated 和 Reproducible 来决定如何重新实例化 dataloader + """ + + @magic_argv_env_context + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_unrepeat_dataloader_reproducible_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 ReproducibleSampler + 的表现 + 此时应该返回一个新的 dataloader,且将原来的 Sampler 替换为 UnrepeatedRandomSampler,且正确地设置了分布式相关 + 的属性 + """ + dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=False) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.batch_sampler.sampler.shuffle == shuffle + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + dist.barrier() + + @magic_argv_env_context + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_unrepeat_dataloader_unrepreated_sampler(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader.batch_sampler.sampler 为 UnrepeatedSampler + 的表现 + 此时应该返回一个新的 dataloader,且重新实例化了原来的 Sampler + """ + dataloader = dataloader_with_randomsampler(self.dataset, 4, shuffle, False, unrepeated=True) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedRandomSampler) + assert not (replaced_loader.batch_sampler.sampler is dataloader.batch_sampler.sampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.drop_last == dataloader.drop_last + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + dist.barrier() + + @magic_argv_env_context + @pytest.mark.parametrize("shuffle", ([True, False])) + def test_with_dist_unrepeat_dataloader_normal(self, shuffle): + """ + 测试 set_dist_repro_dataloader 中 dist 为 'unrepeatdist'、dataloader 为一般情况的表现 + 此时应该返回一个新的 dataloader,且将 sampler 替换为 UnrepeatedSequentialSampler,并正确地设置了分布式相关 + 的属性 + """ + dataloader = DataLoader(self.dataset, batch_size=4, shuffle=shuffle) + replaced_loader = self.driver.set_dist_repro_dataloader(dataloader, "unrepeatdist", False) + + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler, BatchSampler) + assert not (replaced_loader.batch_sampler is dataloader.batch_sampler) + assert isinstance(replaced_loader.batch_sampler.sampler, UnrepeatedSequentialSampler) + assert replaced_loader.batch_sampler.batch_size == 4 + assert replaced_loader.drop_last == dataloader.drop_last + self.check_distributed_sampler(replaced_loader.batch_sampler.sampler) + dist.barrier() + + def check_distributed_sampler(self, sampler): + """ + 测试替换得到的 sampler 或 batch_sampler 的分布式设置是否正确 + """ + assert sampler.num_replicas == dist.get_world_size() + assert sampler.rank == dist.get_rank() + if not isinstance(sampler, UnrepeatedSampler): + assert sampler.pad == True + + def check_set_dist_repro_dataloader(self, dataloader, replaced_loader, shuffle): + """ + 测试多卡下 set_dist_repro_dataloader 函数的执行结果是否正确 + """ + # 迭代两个 batch + num_replicas = len(self.device) + num_consumed_batches = 2 + already_seen_idx = set() + for idx, batch in enumerate(replaced_loader): + if idx >= num_consumed_batches: + break + already_seen_idx.update(batch) + dist.barrier() + if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): + sampler_states = replaced_loader.batch_sampler.state_dict() + else: + sampler_states = replaced_loader.batch_sampler.sampler.state_dict() + + # 重新加载,应该可以输出剩下的内容,且对于 TorchNormalDataset 来说,排序后应该是一个 range + left_idxes = set() + if isinstance(replaced_loader.batch_sampler, BucketedBatchSampler): + batch_size = replaced_loader.batch_sampler.batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas + # 重新改造 dataloader + new_loader = dataloader_with_bucketedbatchsampler( + replaced_loader.dataset, + length=replaced_loader.dataset._data, + batch_size=batch_size, + shuffle=shuffle, + drop_last=False, + ) + new_loader.batch_sampler.set_distributed( + num_replicas=self.driver.world_size, + rank=self.driver.global_rank, + pad=True + ) + new_loader.batch_sampler.load_state_dict(sampler_states) + else: + batch_size = replaced_loader.batch_sampler.batch_size + sampler_states["num_consumed_samples"] = num_consumed_batches * batch_size * num_replicas + # 重新构造 dataloader + new_loader = dataloader_with_randomsampler(replaced_loader.dataset, batch_size, shuffle, drop_last=False) + new_loader.batch_sampler.sampler.set_distributed( + num_replicas=self.driver.world_size, + rank=self.driver.global_rank + ) + new_loader.batch_sampler.sampler.load_state_dict(sampler_states) + for idx, batch in enumerate(new_loader): + left_idxes.update(batch) + + assert len(left_idxes) + len(already_seen_idx) == len(self.dataset) / num_replicas + assert len(left_idxes | already_seen_idx) == len(self.dataset) / num_replicas + + +############################################################################ +# +# 测试 save 和 load 相关的功能 +# +############################################################################ +class TestSaveLoad: + """ + 测试多卡情况下 save 和 load 相关函数的表现 + """ + + @classmethod + def setup_class(cls): + # 不在这里 setup 的话会报错 + cls.driver = generate_driver(10, 10) + + def setup_method(self): + self.dataset = TorchArgMaxDataset(10, 20) + + @magic_argv_env_context + @pytest.mark.parametrize("only_state_dict", ([True, False])) + def test_save_and_load_model(self, only_state_dict): + """ + 测试 save_model 和 load_model 函数 + """ + try: + path = "model" + + dataloader = DataLoader(self.dataset, batch_size=2) + self.driver1, self.driver2 = generate_driver(10, 10), generate_driver(10, 10) + + self.driver1.save_model(path, only_state_dict) + + # 同步 + dist.barrier() + self.driver2.load_model(path, only_state_dict) + + for idx, batch in enumerate(dataloader): + batch = self.driver1.move_data_to_device(batch) + res1 = self.driver1.model( + batch, + fastnlp_fn=self.driver1.model.module.model.evaluate_step, + # Driver.model -> DataParallel.module -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = self.driver2.model( + batch, + fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + + assert torch.equal(res1["preds"], res2["preds"]) + finally: + rank_zero_rm(path) + + @magic_argv_env_context + @pytest.mark.parametrize("only_state_dict", ([True, False])) + @pytest.mark.parametrize("fp16", ([True, False])) + @pytest.mark.parametrize("device", ([[0,1]])) + def test_save_and_load_with_bucketedbatchsampler(self, device, only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 sampler 之后的情况 + """ + + try: + path = "model.ckp" + num_replicas = len(device) + + self.driver1, self.driver2 = generate_driver(10, 10, device=device, fp16=fp16), \ + generate_driver(10, 10, device=device, fp16=False) + dataloader = dataloader_with_bucketedbatchsampler( + self.dataset, + length=[10 for i in range(len(self.dataset))], + batch_size=4, + shuffle=True, + drop_last=False + ) + dataloader.batch_sampler.set_distributed( + num_replicas=self.driver1.world_size, + rank=self.driver1.global_rank, + pad=True + ) + num_consumed_batches = 2 + + already_seen_x_set = set() + already_seen_y_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"]) + already_seen_y_set.update(batch["y"]) + + # 同步 + dist.barrier() + + # 保存状态 + sampler_states = dataloader.batch_sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + # 加载 + # 更改 batch_size + dataloader = dataloader_with_bucketedbatchsampler( + self.dataset, + length=[10 for i in range(len(self.dataset))], + batch_size=2, + shuffle=True, + drop_last=False + ) + dataloader.batch_sampler.set_distributed( + num_replicas=self.driver2.world_size, + rank=self.driver2.global_rank, + pad=True + ) + load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 batch_sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert replaced_loader.batch_sampler is dataloader.batch_sampler + assert isinstance(replaced_loader.batch_sampler, BucketedBatchSampler) + assert replaced_loader.batch_sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.num_consumed_samples == num_consumed_batches * 4 * num_replicas + + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + for idx, batch in enumerate(replaced_loader): + + left_x_batches.update(batch["x"]) + left_y_batches.update(batch["y"]) + res1 = self.driver1.model( + batch, + fastnlp_fn=self.driver1.model.module.model.evaluate_step, + # Driver.model -> DataParallel.module -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = self.driver2.model( + batch, + fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + assert torch.equal(res1["preds"], res2["preds"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas + assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas + finally: + rank_zero_rm(path) + + @magic_argv_env_context + @pytest.mark.parametrize("only_state_dict", ([True, False])) + @pytest.mark.parametrize("fp16", ([True, False])) + @pytest.mark.parametrize("device", ([[0,1]])) + def test_save_and_load_with_randomsampler(self, device, only_state_dict, fp16): + """ + 测试save和load函数,主要测试 dataloader 被替换了 batch_sampler 的情况 + """ + + try: + path = "model.ckp" + + num_replicas = len(device) + + self.driver1 = generate_driver(10, 10, device=device, fp16=fp16) + self.driver2 = generate_driver(10, 10, device=device, fp16=False) + + dataloader = dataloader_with_randomsampler(self.dataset, 4, True, False, unrepeated=False) + dataloader.batch_sampler.sampler.set_distributed( + num_replicas=self.driver1.world_size, + rank=self.driver1.global_rank, + pad=True + ) + num_consumed_batches = 2 + + already_seen_x_set = set() + already_seen_y_set = set() + for idx, batch in enumerate(dataloader): + if idx >= num_consumed_batches: + break + already_seen_x_set.update(batch["x"]) + already_seen_y_set.update(batch["y"]) + + # 同步 + dist.barrier() + + # 保存状态 + sampler_states = dataloader.batch_sampler.sampler.state_dict() + save_states = {"num_consumed_batches": num_consumed_batches} + if only_state_dict: + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True) + else: + self.driver1.save(Path(path), save_states, dataloader, only_state_dict, should_save_model=True, input_spec=[torch.ones((16, 10))]) + # 加载 + # 更改 batch_size + dataloader = dataloader_with_randomsampler(self.dataset, 2, True, False, unrepeated=False) + dataloader.batch_sampler.sampler.set_distributed( + num_replicas=self.driver2.world_size, + rank=self.driver2.global_rank, + pad=True + ) + load_states = self.driver2.load(Path(path), dataloader, only_state_dict, should_load_model=True) + replaced_loader = load_states.pop("dataloader") + + # 1. 检查 optimizer 的状态 + # TODO optimizer 的 state_dict 总是为空 + + # 2. 检查 sampler 是否被正确地加载和替换 + assert not (replaced_loader is dataloader) + assert isinstance(replaced_loader.batch_sampler.sampler, RandomSampler) + assert replaced_loader.batch_sampler.sampler.seed == sampler_states["seed"] + assert replaced_loader.batch_sampler.sampler.epoch == sampler_states["epoch"] + assert replaced_loader.batch_sampler.sampler.num_consumed_samples == 4 * num_consumed_batches * num_replicas + assert len(replaced_loader.batch_sampler.sampler.dataset) == sampler_states["length"] + assert replaced_loader.batch_sampler.sampler.shuffle == sampler_states["shuffle"] + # 3. 检查 fp16 是否被加载 + if fp16: + assert isinstance(self.driver2.grad_scaler, torch.cuda.amp.GradScaler) + + # 4. 检查 model 的参数是否正确 + # 5. 检查 batch_idx + start_batch = load_states.pop('batch_idx_in_epoch') + assert start_batch == 2 * num_consumed_batches + left_x_batches = set() + left_y_batches = set() + for idx, batch in enumerate(replaced_loader): + + left_x_batches.update(batch["x"]) + left_y_batches.update(batch["y"]) + res1 = self.driver1.model( + batch, + fastnlp_fn=self.driver1.model.module.model.evaluate_step, + # Driver.model -> DataParallel.module -> _FleetWrappingModel.model + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + res2 = self.driver2.model( + batch, + fastnlp_fn=self.driver2.model.module.model.evaluate_step, + fastnlp_signature_fn=None, + wo_auto_param_call=False, + ) + assert torch.equal(res1["preds"], res2["preds"]) + + assert len(left_x_batches) + len(already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_x_batches | already_seen_x_set) == len(self.dataset) / num_replicas + assert len(left_y_batches) + len(already_seen_y_set) == len(self.dataset) / num_replicas + assert len(left_y_batches | already_seen_y_set) == len(self.dataset) / num_replicas + + finally: + rank_zero_rm(path) \ No newline at end of file From 0043e3e89abad9313d26d50b950c72b8444291b0 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 28 Apr 2022 15:12:21 +0800 Subject: [PATCH 12/16] =?UTF-8?q?=E4=BF=AE=E5=A4=8DLoadeBestModelCallback?= =?UTF-8?q?=E7=9A=84bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../core/callbacks/load_best_model_callback.py | 2 ++ fastNLP/core/controllers/evaluator.py | 15 +++++---------- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 0caf22d1..91bdb084 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -95,6 +95,8 @@ class LoadBestModelCallback(HasMonitorCallback): self.buffer.seek(0) trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) + trainer.driver.barrier() + if self.delete_after_after: if self.real_save_folder and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: # 只需要 rank 0 执行删除。 diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index ada31edb..c6bd0f82 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -20,13 +20,6 @@ from fastNLP.core.log import logger class Evaluator: - """ - 1. 我们目前不直接提供每一个 metric 对应一个或者特殊的多个 dataloader 的功能,默认就是所有 metric 处理所有 dataloader,如果用户有这种 - 需求,请使用多个 Tester 进行操作; - 2. Trainer 的 validate dataloader 只允许传进去一个,而 Tester 则可以多个;因为 Trainer 涉及到保存 topk 模型的逻辑,而 Tester - 则只需要给出评测的结果即可; - - """ driver: Driver _evaluate_batch_loop: Loop @@ -37,11 +30,12 @@ class Evaluator: output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, fp16: bool = False, verbose: int = 1, **kwargs): """ + 用于对数据进行评测。 :param model: 待测试的模型,如果传入的 driver 为 Driver 实例,该参数将被忽略。 - :param dataloaders: 待评测的数据集。 + :param dataloaders: 待评测的数据集。如果为多个,请使用 dict 传入。 :param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 - metric ,torchmetrics,allennlpmetrics等。 + metric ,torchmetrics,allennlpmetrics 等。 :param driver: 使用 driver 。 :param device: 使用的设备。 :param evaluate_batch_step_fn: 定制每次 evaluate batch 执行的函数。该函数应接受的两个参数为 `evaluator` 和 `batch`, @@ -59,7 +53,8 @@ class Evaluator: :param verbose: 是否打印 evaluate 的结果。 :param kwargs: bool model_use_eval_mode: 是否在 evaluate 的时候将 model 的状态设置成 eval 状态。在 eval 状态下,model 的dropout - 与 batch normalization 将会关闭。默认为True。 + 与 batch normalization 将会关闭。默认为True。如果为 False,fastNLP 不会对 model 的 evaluate 状态做任何设置。无论 + 该值是什么,fastNLP 都会在 evaluate 接受后将 model 的状态设置为 train 。 TODO 还没完成。 Union[bool] auto_tensor_conversion_for_metric: 是否自动将输出中的 tensor 适配到 metrics 支持的。例如 model 输出是 paddlepaddle 的 tensor ,但是想利用 torchmetrics 的metric对象, From f74b9b6bec391a7dac82b43970ca237b0bbaec8b Mon Sep 17 00:00:00 2001 From: yh_cc Date: Thu, 28 Apr 2022 16:27:17 +0800 Subject: [PATCH 13/16] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=89=80=E6=9C=89?= =?UTF-8?q?=E7=9A=84=20validate=20=E4=B8=BA=20evaluate=20;=20=E7=A7=BB?= =?UTF-8?q?=E5=8A=A8=20callback.on=5Ftrain=5Fend()?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback.py | 36 ++++++++++++++++--- fastNLP/core/callbacks/callback_events.py | 4 +-- fastNLP/core/callbacks/callback_manager.py | 4 +-- fastNLP/core/callbacks/checkpoint_callback.py | 2 +- fastNLP/core/callbacks/early_stop_callback.py | 8 ++--- .../core/callbacks/has_monitor_callback.py | 2 +- .../callbacks/load_best_model_callback.py | 25 ++++--------- .../core/callbacks/more_evaluate_callback.py | 10 +++--- fastNLP/core/callbacks/progress_callback.py | 7 ++-- .../controllers/loops/train_batch_loop.py | 2 +- fastNLP/core/controllers/trainer.py | 31 ++++++++-------- fastNLP/core/controllers/utils/utils.py | 16 ++++----- .../drivers/jittor_driver/jittor_driver.py | 2 +- fastNLP/core/log/logger.py | 13 +++++++ tests/helpers/callbacks/helper_callbacks.py | 10 +++--- 15 files changed, 99 insertions(+), 73 deletions(-) diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 1d3d1f11..982df7da 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -12,6 +12,34 @@ from fastNLP.core.callbacks.callback_events import _SingleEventState class Callback: r""" 实际使用的 callback 类,不管是我们 fastNLP 默认提供的一些 callback 类,还是用户自己定制的 callback 类,都应该继承该基类; + callback 调用时机顺序大概如下 + Trainer.__init__(): + on_after_trainer_initialized() + Trainer.run(): + if num_eval_sanity_batch>0: + on_sanity_check_begin() # 如果设置了num_eval_sanity_batch + on_sanity_check_end() + try: + on_train_begin() + while cur_epoch_idx < n_epochs: + on_train_epoch_begin() + while batch_idx_in_epoch<=num_batches_per_epoch: + on_fetch_data_begin() + on_fetch_data_end() + on_train_batch_begin() + on_before_backward() + on_after_backward() + on_before_zero_grad() # 实际调用受到 accumulation_steps 影响 + on_after_zero_grad() # 实际调用受到 accumulation_steps 影响 + on_before_optimizers_step() # 实际调用受到 accumulation_steps 影响 + on_after_optimizers_step() # 实际调用受到 accumulation_steps 影响 + on_train_batch_end() + on_train_epoch_end() + except BaseException: + self.on_exception() + finally: + on_train_end() + 其它 callback 例如 on_evaluate_begin()/on_evaluate_end()将 """ def on_after_trainer_initialized(self, trainer, driver): @@ -221,9 +249,9 @@ class Callback: """ pass - def on_validate_begin(self, trainer): + def on_evaluate_begin(self, trainer): """ - 在将要进行 validate 时调用。如果是设置的以 step 数量 或 自定义地 决定 validate 的频率,该接口是在 on_train_batch_end 之后 + 在将要进行 evaluate 时调用。如果是设置的以 step 数量 或 自定义地 决定 evaluate 的频率,该接口是在 on_train_batch_end 之后 进行调用。如果是以 epoch 数量决定调用,该接口是在 on_train_epoch_end 之后调用。 :param trainer: @@ -231,9 +259,9 @@ class Callback: """ pass - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): """ - 结束 validate 时调用,并把 validate 的结果传入。 + 结束 evaluate 时调用,并把 evaluate 的结果传入。 :param trainer: :param results: Evaluate 的结果,一般是个 dict 。 diff --git a/fastNLP/core/callbacks/callback_events.py b/fastNLP/core/callbacks/callback_events.py index ef972b35..3f3691e3 100644 --- a/fastNLP/core/callbacks/callback_events.py +++ b/fastNLP/core/callbacks/callback_events.py @@ -96,8 +96,8 @@ class Events(EventEnum): on_after_optimizers_step = "on_after_optimizers_step" on_before_zero_grad = "on_before_zero_grad" on_after_zero_grad = "on_after_zero_grad" - on_validate_begin = "on_validate_begin" - on_validate_end = "on_validate_end" + on_evaluate_begin = "on_evaluate_begin" + on_evaluate_end = "on_evaluate_end" class EventsList: diff --git a/fastNLP/core/callbacks/callback_manager.py b/fastNLP/core/callbacks/callback_manager.py index c5b00e71..90d2e1b1 100644 --- a/fastNLP/core/callbacks/callback_manager.py +++ b/fastNLP/core/callbacks/callback_manager.py @@ -281,9 +281,9 @@ class CallbackManager: pass @_transfer - def on_validate_begin(self, trainer): + def on_evaluate_begin(self, trainer): pass @_transfer - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): pass diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index e12873d3..0f4ed04d 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -114,7 +114,7 @@ class CheckpointCallback(Callback): if self.topk_saver.topk_queue and trainer.evaluator is None: logger.warning(f"You set `topk={self.topk}`, but `evaluate_dataloaders` is not set in Trainer.") - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): # 如果发生了保存,则返回的 folder 不为 None folder = self.topk_saver.save_topk(trainer, results) diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py index 0923eb00..1e867866 100644 --- a/fastNLP/core/callbacks/early_stop_callback.py +++ b/fastNLP/core/callbacks/early_stop_callback.py @@ -16,13 +16,13 @@ class EarlyStopCallback(HasMonitorCallback): 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: monitor 的值是否是越大越好。 - :param patience: 多少次 validate 不没有提升就停止。 + :param patience: 多少次 evaluate 不没有提升就停止。 """ super(EarlyStopCallback, self).__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=True) self.wait = 0 self.patience = patience - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): monitor_value = self.get_monitor_value(results) if monitor_value is None: return @@ -32,13 +32,13 @@ class EarlyStopCallback(HasMonitorCallback): self.wait += 1 def on_fetch_data_begin(self, trainer): - # 当是 step validate 的时候,下一步执行的就是这个, 所以在这里检查。 + # 当是 step evaluate 的时候,下一步执行的就是这个, 所以在这里检查。 if self.wait >= self.patience: raise EarlyStopException(f"After {self.wait} validations, no improvement for " f"metric `{self._real_monitor}`") def on_train_epoch_begin(self, trainer): - # 当是 epoch validate 的时候,下一步执行的就是这个, 所以在这里检查。 + # 当是 epoch evaluate 的时候,下一步执行的就是这个, 所以在这里检查。 if self.wait >= self.patience: raise EarlyStopException(f"After {self.wait} validations, no improvement for " f"metric `{self._real_monitor}`(best value: {self.monitor_value})") diff --git a/fastNLP/core/callbacks/has_monitor_callback.py b/fastNLP/core/callbacks/has_monitor_callback.py index b13f9dd6..52214ff0 100644 --- a/fastNLP/core/callbacks/has_monitor_callback.py +++ b/fastNLP/core/callbacks/has_monitor_callback.py @@ -216,6 +216,6 @@ class ExecuteOnceBetterMonitor(HasMonitorCallback): _check_valid_parameters_number(execute_fn, expected_params=[], fn_name='execute_fn') self.execute_fn = execute_fn - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): if self.is_better_results(results): self.execute_fn() \ No newline at end of file diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index 91bdb084..5addd2e2 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -76,7 +76,7 @@ class LoadBestModelCallback(HasMonitorCallback): super().on_after_trainer_initialized(trainer, driver) - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): if self.is_better_results(results, keep_if_better=True): if self.real_save_folder: trainer.save_model(folder=self.real_save_folder, only_state_dict=self.only_state_dict, @@ -95,27 +95,14 @@ class LoadBestModelCallback(HasMonitorCallback): self.buffer.seek(0) trainer.load_model(folder=self.buffer, only_state_dict=self.only_state_dict) - trainer.driver.barrier() + self._delete_after_after(trainer) + def _delete_after_after(self, trainer): + trainer.driver.barrier() if self.delete_after_after: - if self.real_save_folder and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) == 0: - # 只需要 rank 0 执行删除。 - logger.info(f"Deleting {self.real_save_folder}...") - shutil.rmtree(self.real_save_folder) - try: - # 如果是 emtpy 的,就会被删除掉 - os.rmdir(self.save_folder) - except: - pass - elif hasattr(self, 'buffer'): - self.buffer.close() - del self.buffer - - def on_exception(self, trainer, exception): - if self.delete_after_after: - if self.real_save_folder: # 这里,谁处异常,谁删除 + if self.real_save_folder: logger.info(f"Deleting {self.real_save_folder}...") - shutil.rmtree(self.real_save_folder) + shutil.rmtree(self.real_save_folder, ignore_errors=True) try: # 如果是 emtpy 的,就会被删除掉 os.rmdir(self.save_folder) diff --git a/fastNLP/core/callbacks/more_evaluate_callback.py b/fastNLP/core/callbacks/more_evaluate_callback.py index 6c015bdf..b5800134 100644 --- a/fastNLP/core/callbacks/more_evaluate_callback.py +++ b/fastNLP/core/callbacks/more_evaluate_callback.py @@ -31,8 +31,8 @@ class MoreEvaluateCallback(HasMonitorCallback): :param dataloaders: 需要评估的数据 :param metrics: 使用的 metrics 。 - :param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch validate 一次;(2) 为正整数则表示每隔几个 batch - evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 validate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回 + :param evaluate_every: 可以为负数、正数和函数;(1) 为负整数时表示每隔几个 epoch evaluate 一次;(2) 为正整数则表示每隔几个 batch + evaluate 一次;(3) 为函数时表示用户自己传入的用于控制 evaluate 的频率的函数,该函数的应该接受 trainer 对象作为参数,并返回 一个 bool 值,返回为 True 说明需要进行 evaluate ;将在每个 batch 结束后调用该函数判断是否需要 evaluate 。 :param watch_monitor: 这个值用来表示监控的 Trainer 中的 evaluate 结果的,当该值不为 None ,evaluate_every 失效。本参数的 意义是,当检测到 Trainer 中 evaluate results 的 {watch_monitor} 的结果更好时,则进行一次 evaluate 。该参数有两种 @@ -128,7 +128,7 @@ class MoreEvaluateCallback(HasMonitorCallback): results = self.evaluator.run(num_eval_batch_per_dl=self.num_eval_sanity_batch) self.topk_saver.get_monitor_value(results) - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): if self.is_better_results(results, keep_if_better=True): results = self.evaluator.run() self.topk_saver.save_topk(trainer, results) @@ -137,8 +137,8 @@ class MoreEvaluateCallback(HasMonitorCallback): if self.watch_monitor is not None: return if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: - validate_every = -self.evaluate_every - if trainer.cur_epoch_idx % validate_every == 0: + evaluate_every = -self.evaluate_every + if trainer.cur_epoch_idx % evaluate_every == 0: results = self.evaluator.run() self.topk_saver.save_topk(trainer, results) diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index a6f82896..bacdea48 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -100,7 +100,7 @@ class RichCallback(ProgressCallback): self.progress_bar.update(self.task2id['epoch'], description=f'Epoch:{trainer.cur_epoch_idx}', advance=self.epoch_bar_update_advance, refresh=True) - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): if len(results)==0: return rule_style = '' @@ -122,9 +122,6 @@ class RichCallback(ProgressCallback): else: self.progress_bar.print(results) - def on_exception(self, trainer, exception): - self.clear_tasks() - def clear_tasks(self): for key, taskid in self.task2id.items(): self.progress_bar.destroy_task(taskid) @@ -178,7 +175,7 @@ class RawTextCallback(ProgressCallback): f'finished {round(trainer.global_forward_batches/trainer.total_batches*100, 2)}%.' logger.info(text) - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): if len(results)==0: return base_text = f'Eval. results on Epoch:{trainer.cur_epoch_idx}, Batch:{trainer.batch_idx_in_epoch}' diff --git a/fastNLP/core/controllers/loops/train_batch_loop.py b/fastNLP/core/controllers/loops/train_batch_loop.py index cfb54111..ef05e0c4 100644 --- a/fastNLP/core/controllers/loops/train_batch_loop.py +++ b/fastNLP/core/controllers/loops/train_batch_loop.py @@ -43,7 +43,7 @@ class TrainBatchLoop(Loop): trainer.check_batch_step_fn() trainer.on_train_batch_end() - trainer.step_validate() + trainer.step_evaluate() trainer.batch_idx_in_epoch = 0 @staticmethod diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index cbec1a01..307901b1 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -339,11 +339,11 @@ class Trainer(TrainerEventTrigger): self.num_batches_per_epoch = len(self.dataloader) self.total_batches = self.num_batches_per_epoch * self.n_epochs self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch - self.on_train_begin() - self.driver.barrier() - self.driver.zero_grad(self.set_grad_to_none) try: + self.on_train_begin() + self.driver.barrier() + self.driver.zero_grad(self.set_grad_to_none) while self.cur_epoch_idx < self.n_epochs: # 这个是防止在 Trainer.load 之后还没结束当前 epoch 又继续 save self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch @@ -356,10 +356,8 @@ class Trainer(TrainerEventTrigger): self.cur_epoch_idx += 1 self.on_train_epoch_end() self.driver.barrier() - self.epoch_validate() + self.epoch_evaluate() self.driver.barrier() - self.on_train_end() - self.driver.barrier() except EarlyStopException as e: logger.info(f"Catch early stop exception: {e.msg}.") @@ -373,17 +371,20 @@ class Trainer(TrainerEventTrigger): self.driver.on_exception() self.on_exception(e) raise e + finally: + self.on_train_end() + self.driver.barrier() def _set_num_eval_batch_per_dl(self, num_eval_batch_per_dl): - def _validate_fn(trainer: Trainer, validate_fn: Callable) -> None: - trainer.on_validate_begin() - _validate_res: dict = validate_fn() - trainer.on_validate_end(_validate_res) + def _evaluate_fn(trainer: Trainer, evaluate_fn: Callable) -> None: + trainer.on_evaluate_begin() + _evaluate_res: dict = evaluate_fn() + trainer.on_evaluate_end(_evaluate_res) if self.evaluator is not None: - self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) + self.run_evaluate = partial(_evaluate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) - def step_validate(self): + def step_evaluate(self): """ 在每个 batch 结束后调用,根据设置执行 evaluate 。 @@ -396,7 +397,7 @@ class Trainer(TrainerEventTrigger): elif self.evaluate_every > 0 and self.global_forward_batches % self.evaluate_every == 0: self.run_evaluate() - def epoch_validate(self): + def epoch_evaluate(self): """ 在每个 epoch 结束后调用,根据设置执行 evaluate 。 @@ -404,8 +405,8 @@ class Trainer(TrainerEventTrigger): """ if self.evaluator is not None: if isinstance(self.evaluate_every, int) and self.evaluate_every < 0: - validate_every = -self.evaluate_every - if self.cur_epoch_idx % validate_every == 0: + evaluate_every = -self.evaluate_every + if self.cur_epoch_idx % evaluate_every == 0: self.run_evaluate() def add_callback_fn(self, event: Optional[Union[Events, EventsList]], fn: Callable): diff --git a/fastNLP/core/controllers/utils/utils.py b/fastNLP/core/controllers/utils/utils.py index cc7a1b66..a2b2d5ae 100644 --- a/fastNLP/core/controllers/utils/utils.py +++ b/fastNLP/core/controllers/utils/utils.py @@ -81,12 +81,12 @@ class TrainerEventTrigger: def on_after_zero_grad(self, optimizers): self.callback_manager.on_after_zero_grad(self, optimizers) - def on_validate_begin(self): - self.callback_manager.on_validate_begin(self) + def on_evaluate_begin(self): + self.callback_manager.on_evaluate_begin(self) - def on_validate_end(self, results): + def on_evaluate_end(self, results): self.trainer_state.save_on_this_step = True - self.callback_manager.on_validate_end(self, results) + self.callback_manager.on_evaluate_end(self, results) class _TruncatedDataLoader: @@ -126,8 +126,8 @@ class _TruncatedDataLoader: return getattr(self.dataloader, item) -def check_evaluate_every(validate_every): - if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): +def check_evaluate_every(evaluate_every): + if not callable(evaluate_every) and (not isinstance(evaluate_every, int) or evaluate_every == 0): raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") - if callable(validate_every): - _check_valid_parameters_number(validate_every, expected_params=['trainer']) + if callable(evaluate_every): + _check_valid_parameters_number(evaluate_every, expected_params=['trainer']) diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index 84e3f002..bcebc6d0 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -63,7 +63,7 @@ class JittorDriver(Driver): def check_evaluator_mode(self, mode: str): model = self.unwrap_model() - if mode == "validate": + if mode == "evaluate": if not hasattr(model, "evaluate_step"): if hasattr(model, "test_step"): logger.warning_once( diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index 086089ea..bdfc299f 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -173,6 +173,19 @@ class FastNLPLogger(logging.Logger, metaclass=LoggerSingleton): kwargs["extra"] = extra return kwargs + def setLevel(self, level) -> None: + """ + 设置当前 logger 以及其 handler 的 log 级别 + + :param level: + :return: + """ + if isinstance(level, str): + level = level.upper() + super().setLevel(level) + for handler in self.handlers: + handler.setLevel(level) + def _get_level(level): if not isinstance(level, int): diff --git a/tests/helpers/callbacks/helper_callbacks.py b/tests/helpers/callbacks/helper_callbacks.py index c3a9d4da..4fd5b654 100644 --- a/tests/helpers/callbacks/helper_callbacks.py +++ b/tests/helpers/callbacks/helper_callbacks.py @@ -38,7 +38,7 @@ class RecordMetricCallback(Callback): self.metric_threshold = metric_threshold self.metric_begin_value = None - def on_validate_end(self, trainer, results): + def on_evaluate_end(self, trainer, results): self.metric = results[self.monitor] if self.metric_begin_value is None: self.metric_begin_value = self.metric @@ -113,11 +113,11 @@ class RecordTrainerEventTriggerCallback(Callback): def on_after_zero_grad(self, trainer, optimizers): print("on_after_zero_grad") - def on_validate_begin(self, trainer): - print("on_validate_begin") + def on_evaluate_begin(self, trainer): + print("on_evaluate_begin") - def on_validate_end(self, trainer, results): - print("on_validate_end") + def on_evaluate_end(self, trainer, results): + print("on_evaluate_end") From ba8971fcf39483ad796f59bfc8cc45b7b1fe3129 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Fri, 29 Apr 2022 18:19:07 +0800 Subject: [PATCH 14/16] little change --- fastNLP/core/controllers/evaluator.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index c6bd0f82..4dba8a4c 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -98,7 +98,7 @@ class Evaluator: self.separator = kwargs.get('separator', '#') self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) - use_dist_sampler = kwargs.get("use_dist_sampler", driver.is_distributed()) + use_dist_sampler = kwargs.get("use_dist_sampler", self.driver.is_distributed()) if use_dist_sampler: self._dist_sampler = "unrepeatdist" else: From 29fb454c2e2e78dcfdd7850b86534b9ae1af91b8 Mon Sep 17 00:00:00 2001 From: lxr-tech <1838593642@qq.com> Date: Fri, 29 Apr 2022 22:22:28 +0800 Subject: [PATCH 15/16] modify fastnlp_tutorial_0.py --- tutorials/fastnlp_tutorial_0.ipynb | 1009 ++++++++++++++++++++++++++++ 1 file changed, 1009 insertions(+) create mode 100644 tutorials/fastnlp_tutorial_0.ipynb diff --git a/tutorials/fastnlp_tutorial_0.ipynb b/tutorials/fastnlp_tutorial_0.ipynb new file mode 100644 index 00000000..01913ac0 --- /dev/null +++ b/tutorials/fastnlp_tutorial_0.ipynb @@ -0,0 +1,1009 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "aec0fde7", + "metadata": {}, + "source": [ + "# T0. trainer 和 evaluator 的基本使用\n", + "\n", + "  1   trainer 和 evaluator 的基本关系\n", + " \n", + "    1.1   trainer 和 evaluater 的初始化\n", + "\n", + "    1.2   driver 的含义与使用要求\n", + "\n", + "    1.3   trainer 内部初始化 evaluater\n", + "\n", + "  2   使用 trainer 训练模型\n", + "\n", + "    2.1   argmax 模型实例\n", + "\n", + "    2.2   trainer 的参数匹配\n", + "\n", + "    2.3   trainer 的实际使用 \n", + "\n", + "  3   使用 evaluator 评测模型\n", + " \n", + "    3.1   trainer 外部初始化的 evaluator\n", + "\n", + "    3.2   trainer 内部初始化的 evaluator " + ] + }, + { + "cell_type": "markdown", + "id": "09ea669a", + "metadata": {}, + "source": [ + "## 1. trainer 和 evaluator 的基本关系\n", + "\n", + "### 1.1 trainer 和 evaluator 的初始化\n", + "\n", + "在`fastNLP 0.8`中,**`Trainer`模块和`Evaluator`模块分别表示“训练器”和“评测器”**\n", + "\n", + "  对应于之前的`fastNLP`版本中的`Trainer`模块和`Tester`模块,其定义方法如下所示\n", + "\n", + "在`fastNLP 0.8`中,需要注意,在同个`python`脚本中先使用`Trainer`训练,然后使用`Evaluator`评测\n", + "\n", + "  非常关键的问题在于**如何正确设置二者的`driver`**。这就引入了另一个问题:什么是 `driver`?\n", + "\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + "\t...\n", + "\tdriver=\"torch\",\n", + "\tdevice=0,\n", + "\t...\n", + ")\n", + "...\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()} \n", + " ...\n", + " driver=trainer.driver,\n", + "\tdevice=None,\n", + " ...\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "3c11fe1a", + "metadata": {}, + "source": [ + "### 1.2 driver 的含义与使用要求\n", + "\n", + "在`fastNLP 0.8`中,**`driver`**这一概念被用来表示**控制具体训练的各个步骤的最终执行部分**\n", + "\n", + "  例如神经网络前向、后向传播的具体执行、网络参数的优化和数据在设备间的迁移等\n", + "\n", + "在`fastNLP 0.8`中,**`Trainer`和`Evaluator`都依赖于具体的`driver`来完成整体的工作流程**\n", + "\n", + "  具体`driver`与`Trainer`以及`Evaluator`之间的关系请参考`fastNLP 0.8`的框架设计\n", + "\n", + "注:在同一脚本中,`Trainer`和`Evaluator`使用的`driver`应当保持一致\n", + "\n", + "  一个不能违背的原则在于:**不要将多卡的`driver`前使用单卡的`driver`**(???),这样使用可能会带来很多意想不到的错误。" + ] + }, + { + "cell_type": "markdown", + "id": "2cac4a1a", + "metadata": {}, + "source": [ + "### 1.3 Trainer 内部初始化 Evaluator\n", + "\n", + "在`fastNLP 0.8`中,如果在**初始化`Trainer`时**,**传入参数`evaluator_dataloaders`和`metrics`**\n", + "\n", + "  则在`Trainer`内部,也会初始化单独的`Evaluator`来帮助训练过程中对验证集的评测\n", + "\n", + "```python\n", + "trainer = Trainer(\n", + " model=model,\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + "\t...\n", + "\tdriver=\"torch\",\n", + "\tdevice=0,\n", + "\t...\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()},\n", + "\t...\n", + ")\n", + "```" + ] + }, + { + "cell_type": "markdown", + "id": "0c9c7dda", + "metadata": {}, + "source": [ + "## 2. 使用 trainer 训练模型" + ] + }, + { + "cell_type": "markdown", + "id": "524ac200", + "metadata": {}, + "source": [ + "### 2.1 argmax 模型实例\n", + "\n", + "本节将通过训练`argmax`模型,简单介绍如何`Trainer`模块的使用方式\n", + "\n", + "  使用`pytorch`定义`argmax`模型,输入一组固定维度的向量,输出其中数值最大的数的索引\n", + "\n", + "  除了添加`pytorch`要求的`forward`方法外,还需要添加 **`train_step`** 和 **`evaluate_step`** 这两个方法" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5314482b", + "metadata": { + "pycharm": { + "is_executing": true + } + }, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "\n", + "class ArgMaxModel(nn.Module):\n", + " def __init__(self, num_labels, feature_dimension):\n", + " super(ArgMaxModel, self).__init__()\n", + " self.num_labels = num_labels\n", + "\n", + " self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)\n", + " self.ac1 = nn.ReLU()\n", + " self.linear2 = nn.Linear(in_features=10, out_features=10)\n", + " self.ac2 = nn.ReLU()\n", + " self.output = nn.Linear(in_features=10, out_features=num_labels)\n", + " self.loss_fn = nn.CrossEntropyLoss()\n", + "\n", + " def forward(self, x):\n", + " x = self.ac1(self.linear1(x))\n", + " x = self.ac2(self.linear2(x))\n", + " x = self.output(x)\n", + " return x\n", + "\n", + " def train_step(self, x, y):\n", + " x = self(x)\n", + " return {\"loss\": self.loss_fn(x, y)}\n", + "\n", + " def evaluate_step(self, x, y):\n", + " x = self(x)\n", + " x = torch.max(x, dim=-1)[1]\n", + " return {\"pred\": x, \"target\": y}" + ] + }, + { + "cell_type": "markdown", + "id": "ca897322", + "metadata": {}, + "source": [ + "在`fastNLP 0.8`中,**函数`train_step`是`Trainer`中参数`train_fn`的默认值**\n", + "\n", + "  由于,在`Trainer`训练时,**`Trainer`通过参数`_train_fn_`对应的模型方法获得当前数据批次的损失值**\n", + "\n", + "  因此,在`Trainer`训练时,`Trainer`首先会寻找模型是否定义了`train_step`这一方法\n", + "\n", + "    如果没有找到,那么`Trainer`会默认使用模型的`forward`函数来进行训练的前向传播过程\n", + "\n", + "注:在`fastNLP 0.8`中,`Trainer`要求模型通过`train_step`来返回一个字典,将损失值作为`loss`的键值\n", + "\n", + "  此外,这里也可以通过传入`Trainer`的参数`output_mapping`来实现高度化的定制,具体请见这一note(???)\n", + "\n", + "同样,在`fastNLP 0.8`中,**函数`evaluate_step`是`Evaluator`中参数`evaluate_fn`的默认值**\n", + "\n", + "  在`Evaluator`测试时,**`Evaluator`通过参数`evaluate_fn`对应的模型方法获得当前数据批次的评测结果**\n", + "\n", + "  从用户角度,模型通过`evaluate_step`方法来返回一个字典,内容与传入`Evaluator`的`metrics`一致\n", + "\n", + "" + ] + }, + { + "cell_type": "markdown", + "id": "fb3272eb", + "metadata": {}, + "source": [ + "### 2.2 trainer 的参数匹配\n", + "\n", + "`fastNLP 0.8`中的参数匹配涉及到两个方面,一是在模型训练或者评测的前向传播过程中,如果从`dataloader`中出来一个`batch`的数据是一个字典,那么我们会查看模型的`train_step`和`evaluate_step`方法的参数签名,然后对于每一个参数,我们会根据其名字从 batch 这一字典中选择出对应的数据传入进去。例如在接下来的定义`Dataset`的部分,注意`ArgMaxDatset`的`__getitem__`方法,您可以通过在`Trainer`和`Evaluator`中设置参数 `model_wo_auto_param_call`来关闭这一行为。当您关闭了这一行为后,我们会将`batch`直接传给您的`train_step`、`evaluate_step`或者 `forward`函数。\n", + "\n", + "二是在传入`Trainer`或者`Evaluator metrics`后,我们会在需要评测的时间点主动调用`metrics`来对`evaluate_dataloaders`进行评测,这一功能主要就是通过对`metrics`的`update`方法和一个`batch`的数据进行参数评测实现的。首先需要明确的是一个 metric 的计算通常分为 `update` 和 `get_metric`两步,其中`update`表示更新一个`batch`的评测数据,`get_metric` 表示根据已经得到的评测数据计算出最终的评测值,例如对于 `Accuracy`来说,其在`update`的时候会更新一个`batch`计算正确的数量 right_num 和计算错误的数量 total_num,最终在 `get_metric` 时返回评测值`right_num / total_num`。\n", + "\n", + "因为`fastNLP 0.8`的`metrics`是自动计算的(只需要传给`Trainer`或者`Evaluator`),因此其一定依赖于参数匹配。对于从`evaluate_dataloader`中生成的一个`batch`的数据,我们会查看传给 `Trainer`(最终是传给`Evaluator`)和`Evaluator`的每一个`metric`,然后查看其`update`函数的函数签名,然后根据每一个参数的名字从`batch`字典中选择出对应的数据传入进去。" + ] + }, + { + "cell_type": "markdown", + "id": "f62b7bb1", + "metadata": {}, + "source": [ + "### 2.3 trainer的实际使用\n", + "\n", + "接下来我们创建用于训练的 dataset,其接受三个参数:数据维度、数据量和随机数种子,生成指定数量的维度为 `feature_dimension` 向量,而每一个向量的标签就是该向量中最大值的索引。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "fe612e61", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.utils.data import Dataset\n", + "\n", + "class ArgMaxDatset(Dataset):\n", + " def __init__(self, feature_dimension, data_num=1000, seed=0):\n", + " self.num_labels = feature_dimension\n", + " self.feature_dimension = feature_dimension\n", + " self.data_num = data_num\n", + " self.seed = seed\n", + "\n", + " g = torch.Generator()\n", + " g.manual_seed(1000)\n", + " self.x = torch.randint(low=-100, high=100, size=[data_num, feature_dimension], generator=g).float()\n", + " self.y = torch.max(self.x, dim=-1)[1]\n", + "\n", + " def __len__(self):\n", + " return self.data_num\n", + "\n", + " def __getitem__(self, item):\n", + " return {\"x\": self.x[item], \"y\": self.y[item]}" + ] + }, + { + "cell_type": "markdown", + "id": "2cb96332", + "metadata": {}, + "source": [ + "现在准备好数据和模型。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "76172ef8", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.utils.data import DataLoader\n", + "\n", + "train_dataset = ArgMaxDatset(feature_dimension=10, data_num=1000)\n", + "evaluate_dataset = ArgMaxDatset(feature_dimension=10, data_num=100)\n", + "\n", + "train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)\n", + "evaluate_dataloader = DataLoader(evaluate_dataset, batch_size=8)\n", + "\n", + "# num_labels 设置为 10,与 feature_dimension 保持一致,因为我们是预测十个位置中哪一个的概率最大。\n", + "model = ArgMaxModel(num_labels=10, feature_dimension=10)" + ] + }, + { + "cell_type": "markdown", + "id": "4e7d25ee", + "metadata": {}, + "source": [ + "将优化器也定义好。" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "dc28a2d9", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from torch.optim import SGD\n", + "\n", + "optimizer = SGD(model.parameters(), lr=0.001)" + ] + }, + { + "cell_type": "markdown", + "id": "4f1fba81", + "metadata": {}, + "source": [ + "现在万事俱备,开始使用 Trainer 进行训练!" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "b51b7a2d", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "['__annotations__',\n", + " '__class__',\n", + " '__delattr__',\n", + " '__dict__',\n", + " '__dir__',\n", + " '__doc__',\n", + " '__eq__',\n", + " '__format__',\n", + " '__ge__',\n", + " '__getattribute__',\n", + " '__gt__',\n", + " '__hash__',\n", + " '__init__',\n", + " '__init_subclass__',\n", + " '__le__',\n", + " '__lt__',\n", + " '__module__',\n", + " '__ne__',\n", + " '__new__',\n", + " '__reduce__',\n", + " '__reduce_ex__',\n", + " '__repr__',\n", + " '__setattr__',\n", + " '__sizeof__',\n", + " '__str__',\n", + " '__subclasshook__',\n", + " '__weakref__',\n", + " '_check_callback_called_legality',\n", + " '_check_train_batch_loop_legality',\n", + " '_custom_callbacks',\n", + " '_driver',\n", + " '_evaluate_dataloaders',\n", + " '_fetch_matched_fn_callbacks',\n", + " '_set_num_eval_batch_per_dl',\n", + " '_train_batch_loop',\n", + " '_train_dataloader',\n", + " '_train_step',\n", + " '_train_step_signature_fn',\n", + " 'accumulation_steps',\n", + " 'add_callback_fn',\n", + " 'backward',\n", + " 'batch_idx_in_epoch',\n", + " 'batch_step_fn',\n", + " 'callback_manager',\n", + " 'check_batch_step_fn',\n", + " 'cur_epoch_idx',\n", + " 'data_device',\n", + " 'dataloader',\n", + " 'device',\n", + " 'driver',\n", + " 'driver_name',\n", + " 'epoch_validate',\n", + " 'evaluate_batch_step_fn',\n", + " 'evaluate_dataloaders',\n", + " 'evaluate_every',\n", + " 'evaluate_fn',\n", + " 'evaluator',\n", + " 'extract_loss_from_outputs',\n", + " 'fp16',\n", + " 'get_no_sync_context',\n", + " 'global_forward_batches',\n", + " 'has_checked_train_batch_loop',\n", + " 'input_mapping',\n", + " 'kwargs',\n", + " 'larger_better',\n", + " 'load',\n", + " 'load_model',\n", + " 'marker',\n", + " 'metrics',\n", + " 'model',\n", + " 'model_device',\n", + " 'monitor',\n", + " 'move_data_to_device',\n", + " 'n_epochs',\n", + " 'num_batches_per_epoch',\n", + " 'on',\n", + " 'on_after_backward',\n", + " 'on_after_optimizers_step',\n", + " 'on_after_trainer_initialized',\n", + " 'on_after_zero_grad',\n", + " 'on_before_backward',\n", + " 'on_before_optimizers_step',\n", + " 'on_before_zero_grad',\n", + " 'on_exception',\n", + " 'on_fetch_data_begin',\n", + " 'on_fetch_data_end',\n", + " 'on_load_checkpoint',\n", + " 'on_load_model',\n", + " 'on_sanity_check_begin',\n", + " 'on_sanity_check_end',\n", + " 'on_save_checkpoint',\n", + " 'on_save_model',\n", + " 'on_train_batch_begin',\n", + " 'on_train_batch_end',\n", + " 'on_train_begin',\n", + " 'on_train_end',\n", + " 'on_train_epoch_begin',\n", + " 'on_train_epoch_end',\n", + " 'on_validate_begin',\n", + " 'on_validate_end',\n", + " 'optimizers',\n", + " 'output_mapping',\n", + " 'run',\n", + " 'save',\n", + " 'save_model',\n", + " 'set_grad_to_none',\n", + " 'state',\n", + " 'step',\n", + " 'step_validate',\n", + " 'total_batches',\n", + " 'train_batch_loop',\n", + " 'train_dataloader',\n", + " 'train_fn',\n", + " 'train_step',\n", + " 'trainer_state',\n", + " 'zero_grad']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from fastNLP import Trainer\n", + "\n", + "# 定义一个 Trainer\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver=\"torch\", # 使用 pytorch 进行训练\n", + " device=0, # 使用 GPU:0\n", + " train_dataloader=train_dataloader,\n", + " optimizers=optimizer,\n", + " n_epochs=10, # 训练 40 个 epoch\n", + " progress_bar=\"rich\"\n", + ")\n", + "dir(trainer)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "f8fe9c32", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "FullArgSpec(args=['self', 'num_train_batch_per_epoch', 'num_eval_batch_per_dl', 'num_eval_sanity_batch', 'resume_from', 'resume_training', 'catch_KeyboardInterrupt'], varargs=None, varkw=None, defaults=(-1, -1, 2, None, True, None), kwonlyargs=[], kwonlydefaults=None, annotations={'num_train_batch_per_epoch': , 'num_eval_batch_per_dl': , 'num_eval_sanity_batch': , 'resume_from': , 'resume_training': })\n" + ] + } + ], + "source": [ + "import inspect \n", + "\n", + "print(inspect.getfullargspec(trainer.run))" + ] + }, + { + "cell_type": "markdown", + "id": "6e202d6e", + "metadata": {}, + "source": [ + "没有问题,那么开始真正的训练!" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "ba047ead", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "markdown", + "id": "eb8ca6cf", + "metadata": {}, + "source": [ + "## 3. 使用 evaluator 评测模型" + ] + }, + { + "cell_type": "markdown", + "id": "c16c5fa4", + "metadata": {}, + "source": [ + "模型训练好了我们开始使用 Evaluator 进行评测,查看效果怎么样吧。" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "1c6b6b36", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "from fastNLP import Evaluator\n", + "from fastNLP.core.metrics import Accuracy\n", + "\n", + "evaluator = Evaluator(\n", + " model=model,\n", + " driver=trainer.driver, # 使用 trainer 已经启动的 driver;\n", + " device=None,\n", + " dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()} # 注意这里一定得是一个字典;\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "257061df", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['__annotations__',\n", + " '__class__',\n", + " '__delattr__',\n", + " '__dict__',\n", + " '__dir__',\n", + " '__doc__',\n", + " '__eq__',\n", + " '__format__',\n", + " '__ge__',\n", + " '__getattribute__',\n", + " '__gt__',\n", + " '__hash__',\n", + " '__init__',\n", + " '__init_subclass__',\n", + " '__le__',\n", + " '__lt__',\n", + " '__module__',\n", + " '__ne__',\n", + " '__new__',\n", + " '__reduce__',\n", + " '__reduce_ex__',\n", + " '__repr__',\n", + " '__setattr__',\n", + " '__sizeof__',\n", + " '__str__',\n", + " '__subclasshook__',\n", + " '__weakref__',\n", + " '_dist_sampler',\n", + " '_evaluate_batch_loop',\n", + " '_evaluate_step',\n", + " '_evaluate_step_signature_fn',\n", + " '_metric_wrapper',\n", + " '_metrics',\n", + " 'dataloaders',\n", + " 'device',\n", + " 'driver',\n", + " 'evaluate_batch_loop',\n", + " 'evaluate_batch_step_fn',\n", + " 'evaluate_fn',\n", + " 'evaluate_step',\n", + " 'finally_progress_bar',\n", + " 'get_dataloader_metric',\n", + " 'input_mapping',\n", + " 'metrics',\n", + " 'metrics_wrapper',\n", + " 'model',\n", + " 'model_use_eval_mode',\n", + " 'move_data_to_device',\n", + " 'output_mapping',\n", + " 'progress_bar',\n", + " 'remove_progress_bar',\n", + " 'reset',\n", + " 'run',\n", + " 'separator',\n", + " 'start_progress_bar',\n", + " 'update',\n", + " 'update_progress_bar',\n", + " 'verbose']" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dir(evaluator)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "f7cb0165", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{'acc#acc': 0.3}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.3\u001b[0m\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.3}" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluator.run()" + ] + }, + { + "cell_type": "markdown", + "id": "dd9f68fa", + "metadata": {}, + "source": [ + "## 4. 在 trainer 中加入 metric 来自动评测;" + ] + }, + { + "cell_type": "markdown", + "id": "ca97c9a4", + "metadata": {}, + "source": [ + "现在我们尝试在训练过程中进行评测。" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "183c7d19", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [], + "source": [ + "# 重新定义一个 Trainer\n", + "\n", + "trainer = Trainer(\n", + " model=model,\n", + " driver=trainer.driver, # 因为我们是在同一脚本中,因此这里的 driver 同样需要重用;\n", + " train_dataloader=train_dataloader,\n", + " evaluate_dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()},\n", + " optimizers=optimizer,\n", + " n_epochs=10, # 训练 40 个 epoch;\n", + " evaluate_every=-1, # 表示每一个 epoch 的结束会进行 evaluate;\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "714cc404", + "metadata": {}, + "source": [ + "再次训练。" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "2e4daa2c", + "metadata": { + "pycharm": { + "is_executing": false + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "trainer.run()" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "eabda5eb", + "metadata": {}, + "outputs": [], + "source": [ + "evaluator = Evaluator(\n", + " model=model,\n", + " driver=trainer.driver, # 使用 trainer 已经启动的 driver;\n", + " dataloaders=evaluate_dataloader,\n", + " metrics={'acc': Accuracy()} # 注意这里一定得是一个字典;\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "a310d157", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
{'acc#acc': 0.5}\n",
+       "
\n" + ], + "text/plain": [ + "\u001b[1m{\u001b[0m\u001b[32m'acc#acc'\u001b[0m: \u001b[1;36m0.5\u001b[0m\u001b[1m}\u001b[0m\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/plain": [ + "{'acc#acc': 0.5}" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "evaluator.run()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f1ef78f0", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.4" + }, + "pycharm": { + "stem_cell": { + "cell_type": "raw", + "metadata": { + "collapsed": false + }, + "source": [] + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From d4dd85ed40ce2f4c48b51787f37872962cbc1804 Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 29 Apr 2022 22:24:46 +0800 Subject: [PATCH 16/16] =?UTF-8?q?=E6=96=B0=E8=AE=BE=E8=AE=A1collator?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/collators/new_collator.py | 181 ++++++++++++++ fastNLP/core/collators/padders/__init__.py | 0 fastNLP/core/collators/padders/exceptions.py | 44 ++++ fastNLP/core/collators/padders/get_padder.py | 193 +++++++++++++++ .../core/collators/padders/numpy_padder.py | 72 ++++++ fastNLP/core/collators/padders/padder.py | 21 ++ fastNLP/core/collators/padders/raw_padder.py | 48 ++++ .../core/collators/padders/torch_padder.py | 157 ++++++++++++ fastNLP/core/collators/padders/torch_utils.py | 20 ++ fastNLP/core/collators/padders/utils.py | 173 ++++++++++++++ fastNLP/core/collators/utils.py | 103 ++++++++ tests/core/collators/__init__.py | 0 tests/core/collators/padders/__init__.py | 0 .../core/collators/padders/test_get_padder.py | 139 +++++++++++ .../collators/padders/test_numpy_padder.py | 81 +++++++ .../core/collators/padders/test_raw_padder.py | 29 +++ .../collators/padders/test_torch_padder.py | 105 ++++++++ tests/core/collators/padders/test_utils.py | 90 +++++++ tests/core/collators/test_new_collator.py | 225 ++++++++++++++++++ tests/core/collators/test_utils.py | 37 +++ 20 files changed, 1718 insertions(+) create mode 100644 fastNLP/core/collators/new_collator.py create mode 100644 fastNLP/core/collators/padders/__init__.py create mode 100644 fastNLP/core/collators/padders/exceptions.py create mode 100644 fastNLP/core/collators/padders/get_padder.py create mode 100644 fastNLP/core/collators/padders/numpy_padder.py create mode 100644 fastNLP/core/collators/padders/padder.py create mode 100644 fastNLP/core/collators/padders/raw_padder.py create mode 100644 fastNLP/core/collators/padders/torch_padder.py create mode 100644 fastNLP/core/collators/padders/torch_utils.py create mode 100644 fastNLP/core/collators/padders/utils.py create mode 100644 fastNLP/core/collators/utils.py create mode 100644 tests/core/collators/__init__.py create mode 100644 tests/core/collators/padders/__init__.py create mode 100644 tests/core/collators/padders/test_get_padder.py create mode 100644 tests/core/collators/padders/test_numpy_padder.py create mode 100644 tests/core/collators/padders/test_raw_padder.py create mode 100644 tests/core/collators/padders/test_torch_padder.py create mode 100644 tests/core/collators/padders/test_utils.py create mode 100644 tests/core/collators/test_new_collator.py create mode 100644 tests/core/collators/test_utils.py diff --git a/fastNLP/core/collators/new_collator.py b/fastNLP/core/collators/new_collator.py new file mode 100644 index 00000000..869a60a7 --- /dev/null +++ b/fastNLP/core/collators/new_collator.py @@ -0,0 +1,181 @@ +from typing import List, Union, Dict, Callable, Sequence, Mapping + +from fastNLP.core.log import logger +from .padders.get_padder import get_padder + +import re + +from .utils import unpack_batch_mapping, unpack_batch_nested_mapping, pack_batch_nested_mapping, unpack_batch_sequence, \ + pack_batch_sequence, NESTED_DICT_SEPARATOR + +sequence_idx_str = re.compile(r'^_\d+$') # 形如_0, _1 +SUPPORTED_BACKENDS = ['torch', 'jittor', 'paddle', 'numpy', 'raw', None] + + +class Collator: + def __init__(self, backend='torch'): + """ + 用于 pad 数据的对象。会自动将所有能够 pad (由 fastNLP 根据数据判定能否 pad )的数据都进行 pad 操作,默认 pad 的值为 0。 + 可使用 set_pad() 函数调整。如果有些 field 不想输出,可以使用 set_ignore() 函数进行设置。 + + :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], + 若为 None ,则不进行 padding 。 + """ + self.unpack_batch_func = None + self.pack_batch_func = None + self.ignore_fields = set() + self.padders = {} + self.input_fields = {} + self.batch_data_type = None # 只能是 d ,s ,l 三种,分别对应输入的batch的每个sample为 dict, single,list。 + self.set_backend(backend) + + def __call__(self, batch)->Union[List, Dict]: + """ + batch可能存在三种可能性 + List[Dict], List[List], List[Sample] + + 第一步:使用 unpack_batch_func 将相同 field 的内容打包到一个 list 中。 + 第二步:使用每个 field 各自的 padder 进行 pad 。 + 第三步:根据 batch 中每个 sample 的类型,返回也保证为该类型。 + + 第一次调用会根据当前 batch 数据决定使用哪个 unpack_batch_func ,这个函数的作用是把不同 sample 的同一个 field 的放入到一个 + list 中;同时也会决定 pack_batch_func,这个函数的作用是在返回 pad 好的 batch 之前,将 batch 恢复为 输入时一个 sample + 的类别。 + 第一次调用会根据当前 field 决定对应的 Padder 。 + + """ + if self.unpack_batch_func is None: + # 决定使用哪个unpack_batch_func,让它都 return 回 dict 类型 + if self.batch_data_type is None: + if isinstance(batch[0], Mapping): + self.batch_data_type = 'd' + elif isinstance(batch[0], Sequence): # 这里存在误判的风险 + self.batch_data_type = 'l' + else: + self.batch_data_type = 's' + logger.debug(f"Since batch[0] has type:{type(batch[0])}, so the batch_data_type " + f"is {self.batch_data_type}") + if self.batch_data_type == 's': + self.unpack_batch_func = lambda x:{'_single': x} # 不需要做任何调整 + self.pack_batch_func = lambda x:x['_single'] + elif self.batch_data_type == 'l': + self.unpack_batch_func = unpack_batch_sequence + self.pack_batch_func = pack_batch_sequence + elif self.batch_data_type == 'd': + if any([isinstance(v, Mapping) for v in batch[0].values()]): # 可能存在 nested 的dict。{'a': {'b': xx}}->{'a@@b': value} + self.unpack_batch_func = unpack_batch_nested_mapping + self.pack_batch_func = pack_batch_nested_mapping + else: + self.unpack_batch_func = unpack_batch_mapping + self.pack_batch_func = lambda x:x + + unpack_batch:Dict = self.unpack_batch_func(batch) # 将各自 field 组成 batch 形式。 + + pad_batch = {} + if len(self.padders)==0: # 第一次运行,准备 padder + for key in unpack_batch.keys(): + if key not in self.input_fields and key not in self.ignore_fields: + self.input_fields[key] = {'pad_val': 0, 'dtype': None, 'backend': self.backend} + + for field_name, setting in self.input_fields.items(): + pad_fn = setting.get('pad_fn', None) + if callable(pad_fn): + padder = pad_fn + else: + batch_field = unpack_batch.get(field_name) + padder = get_padder(batch_field=batch_field, pad_val=setting['pad_val'], + dtype=setting['dtype'], backend=setting['backend'], + field_name=field_name) + self.padders[field_name] = padder + if self.batch_data_type == 'l': + self.padders = dict(sorted(self.padders.items(), key=lambda x:int(x[0][1:]))) # sort, 这样 _0, _1 能够保持顺序 + + for key, padder in self.padders.items(): + batch = unpack_batch.get(key) + pad_batch[key] = padder(batch) + + return self.pack_batch_func(pad_batch) # 根据情况恢复成与输入一致的类型 + + def set_pad(self, field_name:str, pad_val:Union[int, float, None]=0, dtype=None, backend=None, + pad_fn:Callable=None) -> "Collator": + """ + 如果需要对某个 field 的内容进行特殊的调整,请使用这个函数。 + + :param field_name: 需要调整的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。如果该 field 在数据中没 + 有找到,则报错;如果 __getitem__ 返回的是就是整体内容,请使用 "_single" 。 + :param pad_val: 这个 field 的默认 pad 值。如果设置为 None,则表示该 field 不需要 pad , fastNLP 默认只会对可以 pad 的 + field 进行 pad,所以如果对应 field 本身就不是可以 pad 的形式,可以不需要主动设置为 None 。 + :param dtype: 对于需要 pad 的 field ,该 field 的数据 dtype 应该是什么。 + :param backend: 可选[None, 'numpy', 'torch', 'paddle', 'jittor'],分别代表,输出为 list, numpy.ndarray, torch.Tensor, + paddle.Tensor, jittor.Var 类型。若 pad_val 为 None ,该值只能为 None 或 numpy 。 + :param pad_fn: 指定当前 field 的 pad 函数,传入该函数则 pad_val, dtype, backend 等参数失效。pad_fn 的输入为当前 field 的 + batch 形式。 Collator 将自动 unbatch 数据,然后将各个 field 组成各自的 batch 。pad_func 的输入即为 field 的 batch + 形式,输出将被直接作为结果输出。 + :return: 返回 Collator 自身 + """ + self.padders.clear() # 重新生成 + + if self.batch_data_type is not None: + if self.batch_data_type == 's': + logger.debug("Set as single field mode.") + self.input_fields.clear() + elif self.batch_data_type == 'd': + assert sequence_idx_str.match(field_name) is None, f"Field name:{field_name} will be recognized as list " \ + f"index, but other field is set as dict mode." + elif self.batch_data_type == 'l': + assert sequence_idx_str.match(field_name) is not None, f"Other field is set as list mode. But the new " \ + f"field name is {field_name}" + + if field_name == '_single': + self.batch_data_type = 's' + elif sequence_idx_str.match(field_name): + self.batch_data_type = 'l' + else: + self.batch_data_type = 'd' + + if field_name in self.ignore_fields: + logger.warning(f"Field:{field_name} has been set as ignored before. It will not be ignored afterwards.") + if backend is None: + backend = self.backend + else: + assert backend in SUPPORTED_BACKENDS + + self.input_fields[field_name] = {'pad_val': pad_val, 'dtype': dtype, 'backend': backend, 'pad_fn': pad_fn} + + return self + + def set_backend(self, backend:str): + """ + 设置可以 pad 的 field 默认 pad 为什么类型的 tensor + + :param backend: 对于可以 pad 的 field,使用哪种 tensor,支持 ['torch','jittor','paddle','numpy','raw',None], + 若为 None ,则不进行 padding 。 + :return: + """ + assert backend in SUPPORTED_BACKENDS + self.padders.clear() + self.backend = backend + + def set_ignore(self, *field_names) -> "Collator": + """ + 如果有的内容不希望输出,可以在此处进行设置,被设置的 field 将在 batch 的输出中被忽略。 + Ex:: + collator.set_ignore('field1', 'field2') + + :param field_names: 需要忽略的 field 的名称。如果 Dataset 的 __getitem__ 方法返回的是 dict 类型的,则可以直接使用对应的 + field 的 key 来表示,如果是 nested 的 dict,可以使用 @@ 来连接不同层次的 key,例如 {'a': {'b': 1}} 中的使用 a@@b; + 如果 __getitem__ 返回的是 Sequence 类型的,则可以使用 '_0', '_1' 表示序列中第 0 或 1 个元素。 + :return: 返回 Collator 自身 + """ + for field_name in field_names: + if field_name in self.input_fields: + self.input_fields.pop(field_name) + logger.warning(f"Field:{field_name} has been set as input before. It will be ignored afterwards.") + self.padders.pop(field_name, None) # 如果由的话,将它的 padder 扔掉。 + self.ignore_fields.add(field_name) + + return self + + diff --git a/fastNLP/core/collators/padders/__init__.py b/fastNLP/core/collators/padders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/fastNLP/core/collators/padders/exceptions.py b/fastNLP/core/collators/padders/exceptions.py new file mode 100644 index 00000000..8b08683d --- /dev/null +++ b/fastNLP/core/collators/padders/exceptions.py @@ -0,0 +1,44 @@ +__all__ = [ + 'InconsistencyError', + 'EleDtypeUnsupportedError', + 'EleDtypeDtypeConversionError', + 'DtypeUnsupportedError', + "DtypeError" +] + + +class InconsistencyError(BaseException): + """ + 当一个 batch 中的数据存在 shape,dtype 之类的不一致时的报错。 + + """ + def __init__(self, msg, *args): + super(InconsistencyError, self).__init__(msg, *args) + + +class DtypeError(BaseException): + def __init__(self, msg, *args): + super(DtypeError, self).__init__(msg, *args) + self.msg = msg + + +class EleDtypeUnsupportedError(DtypeError): + """ + 当 batch 中的 element 的类别本身无法 pad 的时候报错。 + 例如要求 str 类型的数据进行 padding 。 + + """ + + +class EleDtypeDtypeConversionError(DtypeError): + """ + 当 batch 中的 element 的类别无法转换为 dtype 类型时报错。 + + """ + + +class DtypeUnsupportedError(DtypeError): + """ + 当当前 backend 不支持这种类型的 dtype 时报错。 + + """ \ No newline at end of file diff --git a/fastNLP/core/collators/padders/get_padder.py b/fastNLP/core/collators/padders/get_padder.py new file mode 100644 index 00000000..051a0ffc --- /dev/null +++ b/fastNLP/core/collators/padders/get_padder.py @@ -0,0 +1,193 @@ + +from typing import Dict + + + +from typing import Sequence, Any, Union, Dict +from abc import ABC + +from fastNLP.core.log import logger + + +from .padder import Padder, NullPadder +from .numpy_padder import NumpyNumberPadder, NumpySequencePadder, NumpyTensorPadder +from .torch_padder import TorchNumberPadder, TorchSequencePadder, TorchTensorPadder +from .raw_padder import RawNumberPadder, RawSequencePadder +from .exceptions import * + + +def get_padder(batch_field:Sequence[Any], pad_val, dtype, backend, field_name)->Padder: + """ + 根据 参数 与 batch_field ,返回适合于当前 batch_field 的 padder 。 + + :param batch_field: 将某 field 的内容组合成一个 batch 传入。 + :param pad_val: + :param backend: + :param dtype: + :param field_name: 方便报错的。 + :return: + """ + logger.debug(f"The content in the field:`{field_name}` is:\n", str(batch_field)) + if pad_val is None: + logger.debug(f"The pad_val for field:{field_name} is None, not padding this field.") + return NullPadder() + if backend is None: + logger.debug(f"The backend for field:{field_name} is None, not padding this field.") + return NullPadder() + + # 首先判断当前 field 是否是必须要 pad ,根据用户设置的 pad_val、dtype 等判断。 + must_pad = False + if pad_val != 0 or dtype is not None: + must_pad = True + + catalog = _get_element_shape_dtype(batch_field) # 首先获取数据的基本信息。 + + # 根据 catalog 来判定当前是否可以进行 pad 。 + # 首先检查是否所有的 key 是一样长的,表明深度是一致的 + depths = set(map(len, catalog.keys())) + num_depth = len(depths) + if num_depth != 1: + msg = f'Field:`{field_name}` cannot pad, since it has various depths({depths}) of data. To view more ' \ + f"information please set logger's level to DEBUG." + if must_pad: + raise InconsistencyError(msg) + logger.debug(msg) + return NullPadder() + + # 再检查所有的元素 shape 是否一致? + shape_lens = set([len(v[0]) for v in catalog.values()]) + num_shape = len(shape_lens) + if num_shape != 1: + msg = f'Field:`{field_name}` cannot pad, since it has various shape length({shape_lens}) of data. To view more ' \ + f"information please set logger's level to DEBUG." + if must_pad: + raise InconsistencyError(msg) + logger.debug(msg) + return NullPadder() + + # 再检查所有的元素 type 是否一致 + ele_dtypes = set([v[1] for v in catalog.values()]) + num_eletypes = len(ele_dtypes) + if num_eletypes != 1: + msg = f'Field:`{field_name}` cannot pad, since it has various types({ele_dtypes}) of data. To view more ' \ + f"information please set logger's level to DEBUG." + if must_pad: + raise InconsistencyError(msg) + logger.debug(msg) + return NullPadder() + + depth = depths.pop() + shape_len = shape_lens.pop() + ele_dtype = ele_dtypes.pop() + + # 需要由 padder 自己决定是否能够 pad 。 + try: + if depth == 1 and shape_len == 0: # 形如 [0, 1, 2] 或 [True, False, True] + if backend == 'raw': + return RawNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + elif backend == 'numpy': + return NumpyNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + elif backend == 'torch': + return TorchNumberPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + + if depth > 1 and shape_len == 0: # 形如 [[0, 1], [2]] 这种 + if backend == 'raw': + return RawSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + elif backend == 'numpy': + return NumpySequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + elif backend == 'torch': + return TorchSequencePadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + + if depth == 1 and shape_len != 0: + if backend == 'numpy': + return NumpyTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + elif backend == 'torch': + return TorchTensorPadder(ele_dtype=ele_dtype, pad_val=pad_val, dtype=dtype) + + if shape_len != 0 and depth>1: + msg = "Does not support pad tensor under nested list. If you need this, please report." + if must_pad: + raise RuntimeError(msg) + logger.debug(msg) + return NullPadder() + + except DtypeError as e: + msg = f"Fail to get padder for field:{field_name}. " + e.msg + " To view more " \ + "information please set logger's level to DEBUG." + if must_pad: + raise type(e)(msg=msg) + logger.debug(msg) + return NullPadder() + + except BaseException as e: + raise e + + return NullPadder() + + +class HasShapeDtype(ABC): + """ + 检测拥有 shape 和 dtype 属性的对象。一般就是 np.ndarray 或者各类 tensor 。 + + """ + + @classmethod + def __subclasshook__(cls, subclass: Any) -> Union[bool, Any]: + if cls is HasShapeDtype: + if hasattr(subclass, 'shape') and hasattr(subclass, 'dtype'): + return True + return False + return NotImplemented + + +def _get_element_shape_dtype(content, parent=None, catalog=None)->Dict: + """ + 获取对象的中 element 的基本信息,用于判断是否可以 padding。 + + :param content: + :param tuple parent: + :param dict catalog: 记录元素信息的 dict。其中的 index 记录的是每一个元素的 拓扑 结构。 + 例如: [1, 2, 3] -> {(0,): ((), ), (1,): ((), ), (2,): ((), )} + 例如: [1, [2, 3], 4] -> {(0,): ((), ), (1, 0): ((), ), (1, 1): ((), ), (2,): ((), )} + 例如: [[1, 2], [3], [4, 5]] -> {(0, 0): ((), ), (0, 1): ((), ), (1, 0): ((), ), (2, 0): ((), ), (2, 1): ((), )} + 例如: [torch.ones(3, 4), torch.ones(3, 4), torch.ones(3, 4)] + -> {(0,): (torch.Size([3, 4]), torch.float32), (1,): (torch.Size([3, 4]), torch.float32), (2,): (torch.Size([3, 4]), torch.float32)} + + :return: + """ + if catalog is None: + catalog = {} + + if parent is None: + parent = () + + if isinstance(content, HasShapeDtype): # 各类 tensor 或者 np.ndarray + shape = content.shape + dtype = content.dtype + catalog[parent] = (shape, dtype) + elif isinstance(content, (tuple, list)): + for i, c in enumerate(content): + _get_element_shape_dtype(c, parent=parent + (i,), catalog=catalog) + else: # 包括 int/float/bool/dict 以及 其它无法pad 的等 + catalog[parent] = ((), type(content)) # () 表示 shape 的长度为 0,后面表示其类别 + return catalog + + + + +""" +from numbers import Number + +issubclass(type(3), Number) # True +issubclass(type(3.1), Number) # True +issubclass(type('3'), Number) # False +issubclass(type(True), Number) # True +issubclass(type(np.zeros(3)[0]), Number) # True +isinstance(np.zeros(3, dtype=float).dtype, np.dtype) # True +isinstance(np.zeros(3, dtype=int).dtype, np.dtype) # True +isinstance(np.zeros(3, dtype=str).dtype, np.dtype) # True, 需要通过和来判定 +is_torch_tensor_dtype() # 可以通过isinstance(torch.zeros(3).dtype, torch.dtype) +""" + + + diff --git a/fastNLP/core/collators/padders/numpy_padder.py b/fastNLP/core/collators/padders/numpy_padder.py new file mode 100644 index 00000000..0298fd86 --- /dev/null +++ b/fastNLP/core/collators/padders/numpy_padder.py @@ -0,0 +1,72 @@ +__all__ = [ + 'NumpyNumberPadder', + 'NumpySequencePadder', +] + +from numbers import Number +from abc import ABC +from typing import Any, Union +import numpy as np + +from .padder import Padder +from .utils import get_padded_numpy_array, is_number_or_numpy_number +from .exceptions import * + + +def _get_dtype(ele_dtype, dtype, class_name): + if not is_number_or_numpy_number(ele_dtype): + raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " + f"or numpy numbers but get `{ele_dtype}`.") + + if dtype is None: + dtype = ele_dtype + else: + if not is_number_or_numpy_number(dtype): + raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " + f"or numpy numbers but get `{dtype}`.") + dtype = dtype + return dtype + + +class NumpyNumberPadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + return np.array(batch_field, dtype=dtype) + + +class NumpySequencePadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val) + + +class NumpyTensorPadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + """ + pad 类似于 [np.array([3, 4], np.array([1])] 的 field + + :param ele_dtype: + :param pad_val: + :param dtype: + """ + dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + shapes = [field.shape for field in batch_field] + max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] + array = np.full(max_shape, fill_value=pad_val, dtype=dtype) + for i, field in enumerate(batch_field): + slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) + array[slices] = field + return array + diff --git a/fastNLP/core/collators/padders/padder.py b/fastNLP/core/collators/padders/padder.py new file mode 100644 index 00000000..486574af --- /dev/null +++ b/fastNLP/core/collators/padders/padder.py @@ -0,0 +1,21 @@ + +class Padder: + def __init__(self, pad_val, dtype): + self.pad_val = pad_val + self.dtype = dtype + + def __call__(self, batch_field): + return self.pad(batch_field=batch_field, pad_val=self.pad_val, dtype=self.dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + raise NotImplementedError() + + +class NullPadder(Padder): + def __init__(self, ele_dtype=None, pad_val=None, dtype=None): + super().__init__(pad_val=pad_val, dtype=dtype) + + def __call__(self, batch_field): + # 直接返回,不调用 pad() 方法加快速度。 + return batch_field diff --git a/fastNLP/core/collators/padders/raw_padder.py b/fastNLP/core/collators/padders/raw_padder.py new file mode 100644 index 00000000..66393b40 --- /dev/null +++ b/fastNLP/core/collators/padders/raw_padder.py @@ -0,0 +1,48 @@ + + +from .padder import Padder +from .utils import get_padded_nest_list, is_number, get_padded_numpy_array +from .exceptions import * + + +def _get_dtype(ele_dtype, dtype, class_name): + if is_number(ele_dtype): + if dtype is None: + dtype = ele_dtype + elif not is_number(dtype): + raise DtypeUnsupportedError(f"The dtype of `{class_name}` can only be None but " + f"get `{dtype}`.") + else: + raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " + f"but get `{ele_dtype}`.") + return dtype + + +class RawNumberPadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + def __call__(self, batch_field): + return batch_field + + @staticmethod + def pad(batch_field, pad_val, dtype): + raise NotImplementedError() + + +class RawSequencePadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + """ + + :param batch_field: + :param pad_val: + :param dtype: 该参数无意义。 + :return: + """ + return get_padded_numpy_array(batch_field, dtype=dtype, pad_val=pad_val).tolist() diff --git a/fastNLP/core/collators/padders/torch_padder.py b/fastNLP/core/collators/padders/torch_padder.py new file mode 100644 index 00000000..a6768435 --- /dev/null +++ b/fastNLP/core/collators/padders/torch_padder.py @@ -0,0 +1,157 @@ + +from inspect import isclass +import numpy as np + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + numpy_to_torch_dtype_dict = { + np.bool_: torch.bool, + np.uint8: torch.uint8, + np.int8: torch.int8, + np.int16: torch.int16, + np.int32: torch.int32, + np.int64: torch.int64, + np.float16: torch.float16, + np.float32: torch.float32, + np.float64: torch.float32, # 这里都统一为到 float32 吧,这是由于 numpy 大部分时候都默认 float64 了 + np.complex64: torch.complex64, + np.complex128: torch.complex128 + } + number_to_torch_dtype_dict = { + float: torch.float32, # 因为 torch.tensor([1], dtype=float)是torch.float64 + int: torch.int64, + bool: torch.bool + } + +from .padder import Padder +from .utils import is_number_or_numpy_number, is_number, is_numpy_number_dtype, get_shape, is_numpy_generic_class +from .exceptions import * + + +def is_torch_tensor(dtype): + if not isclass(dtype) and isinstance(dtype, torch.dtype): + return True + return False + + +def _get_dtype(ele_dtype, dtype, class_name): + if not (is_number_or_numpy_number(ele_dtype) or is_torch_tensor(ele_dtype)): + raise EleDtypeUnsupportedError(f"`{class_name}` only supports padding python numbers " + f"or numpy numbers or torch.Tensor but get `{ele_dtype}`.") + + if dtype is not None: + if not (is_torch_tensor(dtype) or is_number(dtype)): + raise DtypeUnsupportedError(f"The dtype of `{class_name}` only supports python numbers " + f"or torch.dtype but get `{dtype}`.") + dtype = number_to_torch_dtype_dict.get(dtype, dtype) + else: + if (is_number(ele_dtype) or is_torch_tensor(ele_dtype)): + ele_dtype = number_to_torch_dtype_dict.get(ele_dtype, ele_dtype) + dtype = ele_dtype + elif is_numpy_number_dtype(ele_dtype): # 存在一个转换的问题了 + dtype = numpy_to_torch_dtype_dict.get(ele_dtype.type) + elif is_numpy_generic_class(ele_dtype): + dtype = numpy_to_torch_dtype_dict.get(ele_dtype) + + return dtype + + +class TorchNumberPadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + # 仅当 ele_dtype 是 python number/ numpy number 或者 tensor + dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + return torch.tensor(batch_field, dtype=dtype) + + +class TorchSequencePadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + tensor = get_padded_torch_tensor(batch_field, dtype=dtype, pad_val=pad_val) + return tensor + + +class TorchTensorPadder(Padder): + def __init__(self, ele_dtype, pad_val=0, dtype=None): + """ + 目前仅支持 [torch.tensor([3, 2], torch.tensor([1])] 类似的 + + :param ele_dtype: + :param pad_val: + :param dtype: + """ + dtype = _get_dtype(ele_dtype, dtype, class_name=self.__class__.__name__) + super().__init__(pad_val=pad_val, dtype=dtype) + + @staticmethod + def pad(batch_field, pad_val, dtype): + shapes = [field.shape for field in batch_field] + max_shape = [len(batch_field)] + [max(*_) for _ in zip(*shapes)] + if isinstance(dtype, np.dtype): + print(dtype) + tensor = torch.full(max_shape, fill_value=pad_val, dtype=dtype) + for i, field in enumerate(batch_field): + slices = (i, ) + tuple(slice(0, s) for s in shapes[i]) + if isinstance(field, np.ndarray): + field = torch.from_numpy(field) + tensor[slices] = field + return tensor + + +def fill_tensor(batch_field, padded_batch, dtype): + """ + 将 batch_field 中的值填入到 tensor 中。 + + :param batch_field: 需要填充进入 array 中的内容 + :param padded_batch: 待填充的 tensor + :param dtype: 数据的类别 + + :return: + """ + if padded_batch.ndim == 2: + for i, content_i in enumerate(batch_field): + padded_batch[i, :len(content_i)] = torch.tensor(content_i, dtype=dtype) + elif padded_batch.ndim == 3: + for i, content_i in enumerate(batch_field): + for j, content_ii in enumerate(content_i): + padded_batch[i, j, :len(content_ii)] = torch.tensor(content_ii, dtype=dtype) + elif padded_batch.ndim == 4: + try: # 应该是图像,所以直接应该就 ok 了。 + padded_batch = np.array(batch_field) + except: + for i, content_i in enumerate(batch_field): + for j, content_ii in enumerate(content_i): + for k, content_iii in enumerate(content_ii): + padded_batch[i, j, k, :len(content_iii)] = torch.tensor(content_iii, dtype=dtype) + elif padded_batch.ndim == 1: + padded_batch[:] = torch.tensor(batch_field, dtype=dtype) + else: + raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " + "report.") + return padded_batch + + +def get_padded_torch_tensor(batch_field, dtype=None, pad_val=0): + """ + 例如: + [[1,2], [3]] -> torch.LongTensor([[1, 2], [3, 0]]) + + :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) + /4d(多为图片)。 + :param dtype: 目标类别是什么 + :param pad_val: pad 的 value + :return: + """ + shapes = get_shape(batch_field) + tensor = torch.full(shapes, dtype=dtype, fill_value=pad_val) + tensor = fill_tensor(batch_field, tensor, dtype=dtype) + return tensor diff --git a/fastNLP/core/collators/padders/torch_utils.py b/fastNLP/core/collators/padders/torch_utils.py new file mode 100644 index 00000000..a47bea0e --- /dev/null +++ b/fastNLP/core/collators/padders/torch_utils.py @@ -0,0 +1,20 @@ + + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + + +def is_torch_tensor_dtype(dtype) -> bool: + """ + 返回当前 dtype 是否是 torch 的 dtype 类型 + + + :param dtype: 应该是通过类似与 torch.ones(3).dtype 方式获得结果 + :return: + """ + try: + return isinstance(dtype, torch.dtype) + except: + return False diff --git a/fastNLP/core/collators/padders/utils.py b/fastNLP/core/collators/padders/utils.py new file mode 100644 index 00000000..f6240219 --- /dev/null +++ b/fastNLP/core/collators/padders/utils.py @@ -0,0 +1,173 @@ + +from typing import Sequence, List +from numbers import Number +import re +from inspect import isclass + +import numpy as np +np_str_obj_array_pattern = re.compile(r'[SaUO]') + + +def get_shape(batch_field:List, shape=None): + """ + 给定 field 返回这个 field pad 完成之后的 shape 。 + 例如: [[1, 2, 3], [3]] -> [2, 3] + [[[1], [2], [3, 4]], [[2, 3, 4]]] -> [2, 3, 3] + + :param batch_field: list,第 0 维一般为 batch 维度。 + :param shape: 无需传入。 + :return: + """ + if shape is None: + shape = [] + if isinstance(batch_field, Sequence): + num_ele = len(batch_field) + _shape = shape + [num_ele] + try: + shapes = [] + if isinstance(batch_field[0], Sequence): + for _field in batch_field: + shapes.append(get_shape(_field, _shape)) + max_shape = [max(_) for _ in zip(*shapes)] + return max_shape + except IndexError: # 空的shape + pass + return _shape # 说明是一个空的 sequence + else: + return shape + + +def fill_array(batch_field:List, padded_batch:np.ndarray): + """ + 将 batch_field 中的值填入到 array 中。 + + :param batch_field: 需要填充进入 array 中的内容 + :param padded_batch: 待填充的 np.ndarray + :return: + """ + if padded_batch.ndim == 2: + for i, content_i in enumerate(batch_field): + padded_batch[i, :len(content_i)] = content_i + elif padded_batch.ndim == 3: + for i, content_i in enumerate(batch_field): + for j, content_ii in enumerate(content_i): + padded_batch[i, j, :len(content_ii)] = content_ii + elif padded_batch.ndim == 4: + try: # 应该是图像,所以直接应该就 ok 了。 + padded_batch = np.array(batch_field) + except: + for i, content_i in enumerate(batch_field): + for j, content_ii in enumerate(content_i): + for k, content_iii in enumerate(content_ii): + padded_batch[i, j, k, :len(content_iii)] = content_iii + elif padded_batch.ndim == 1: + padded_batch[:] = batch_field + else: + raise RuntimeError("fastNLP does not support padding for more than 3 dimensions. If you need this, please " + "report.") + return padded_batch + + +def get_padded_numpy_array(batch_field: List, dtype=None, pad_val=0) -> np.ndarray: + """ + 例如: + [[1,2], [3]] -> np.array([[1, 2], [3, 0]]) + + :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) + /4d(多为图片)。 + :param dtype: 目标类别是什么 + :param pad_val: pad 的 value + :return: + """ + shapes = get_shape(batch_field) + array = np.full(shapes, dtype=dtype, fill_value=pad_val) + array = fill_array(batch_field, array) + return array + + +def get_padded_nest_list(batch_field: List, pad_val=0) -> List: + """ + 例如: + [[1,2], [3]] -> [[1, 2], [3, 0]] + + :param batch_field: 需要 pad 的对象。需要保证应该是可以进行 pad 的。支持 1d(多为句子长度)/2d(多为文本序列)/3d(多为字符序列) + /4d(多为图片)。 + :param pad_val: pad 的 value + :return: + """ + + array = get_padded_numpy_array(batch_field, pad_val=pad_val, dtype=None).tolist() + return array + + +def is_number_or_numpy_number(dtype): + """ + 判断 dtype 是否是数字类型,或者 numpy 的数字类型。 + is_number_or_numpy_number(type(3)) # True + is_number_or_numpy_number(type(3.1)) # True + is_number_or_numpy_number(type('3')) # False + is_number_or_numpy_number(type(True)) # True + is_number_or_numpy_number(type(np.zeros(3)[0])) # True + is_number_or_numpy_number(np.zeros(3, dtype=float).dtype) # True + is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) # True + is_number_or_numpy_number(np.zeros(3, dtype=str).dtype) # False + is_number_or_numpy_number(np.array([1, [2]]).dtype) # False + + :param dtype: + :return: + """ + if is_number(dtype): + return True + else: + if isclass(dtype): + return is_numpy_generic_class(dtype) + elif isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None: + return True + return False + + +def is_numpy_number_dtype(dtype): + if not isclass(dtype) and isinstance(dtype, np.dtype) and np_str_obj_array_pattern.search(dtype.str) is None: + return True + return False + + +def is_numpy_generic_class(dtype): + """ + 形如 np.int64,或者 np.zeros(1).dtype.type 的值 + + :param dtype: + :return: + """ + if isclass(dtype) and issubclass(dtype, np.generic): + return True + return False + + +def is_number(dtype): + try: + if dtype in (float, int, complex, bool) and not is_numpy_generic_class(dtype) \ + and not is_numpy_number_dtype(dtype): + return True + except: + return False + + + +if __name__ == '__main__': + # a = [[[1]], [1, 2, 3], [3]] + # a = [[[1], [2], [3, 4]], [[2, 3, 4]]] + # b = get_padded_nest_list(a) + # print(type(b[0])) + # print(b) + # import torch + print(is_number_or_numpy_number(type(3))) # True + print(is_number_or_numpy_number(type(3.1))) # True + print(is_number_or_numpy_number(type('3'))) # False + print(is_number_or_numpy_number(type(True))) # True + print(is_number_or_numpy_number(type(np.zeros(3)[0]))) # True + print(is_number_or_numpy_number(np.zeros(3, dtype=float).dtype)) # True + print(is_number_or_numpy_number(np.zeros(3, dtype=int).dtype)) # True + print(is_number_or_numpy_number(np.zeros(3, dtype=str).dtype)) # False + print(is_number_or_numpy_number(np.array([1, [2]]).dtype)) # False + diff --git a/fastNLP/core/collators/utils.py b/fastNLP/core/collators/utils.py new file mode 100644 index 00000000..9a397c66 --- /dev/null +++ b/fastNLP/core/collators/utils.py @@ -0,0 +1,103 @@ +from collections import defaultdict +from functools import reduce +from typing import Sequence, Mapping, Dict + +NESTED_DICT_SEPARATOR = '@@' + + +def unpack_batch_mapping(batch:Sequence[Mapping])->Dict: + """ + 将 Sequence[Mapping] 转为 Dict 。例如 [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] -> {'a': [[1, 2], [3]], 'b': [1, 2]} + + :param batch: + :return: + """ + dict_batch = defaultdict(list) + for sample in batch: + for key, value in sample.items(): + dict_batch[key].append(value) + return dict_batch + + +def unpack_batch_nested_mapping(batch:Sequence[Mapping], _parent='')->Dict: + """ + 将 nested 的 dict 中的内容展开到一个 flat dict 中 + + :param batch: + :param _parent: 内部使用 + :return: + """ + dict_batch = defaultdict(list) + if _parent != '': + _parent += NESTED_DICT_SEPARATOR + for sample in batch: + for key, value in sample.items(): + if isinstance(value, Mapping): + _dict_batch = _unpack_batch_nested_mapping(value, _parent=_parent + key) + for key, value in _dict_batch.items(): + dict_batch[key].append(value) + else: + dict_batch[_parent + key].append(value) + return dict_batch + + +def _unpack_batch_nested_mapping(value, _parent)->Dict: + _dict = {} + _parent += NESTED_DICT_SEPARATOR + for k, v in value.items(): + if isinstance(v, Mapping): + __dict = _unpack_batch_nested_mapping(v, _parent=_parent + k) + _dict.update(__dict) + else: + _dict[_parent + k] = v + return _dict + + +def pack_batch_nested_mapping(batch:Mapping) -> Dict: + """ + 需要恢复出 nested 的 dict 原来的样式 + + :param batch: + :return: + """ + dicts = [] + + for key, value in batch.items(): + keys = key.split(NESTED_DICT_SEPARATOR) + d = {keys[-1]: value} + for key in keys[:-1:][::-1]: + d = {key: d} + dicts.append(d) + return reduce(_merge_dict, dicts) + + +def _merge_dict(a, b, path=None): + "merges b into a" + if path is None: path = [] + for key in b: + if key in a: + if isinstance(a[key], dict) and isinstance(b[key], dict): + _merge_dict(a[key], b[key], path + [str(key)]) + else: + raise Exception('Conflict at %s' % '.'.join(path + [str(key)])) + else: + a[key] = b[key] + return a + + +def unpack_batch_sequence(batch:Sequence[Sequence])->Dict: + """ + 将 Sequence[Sequence] 转为 Mapping 。例如 [[[1, 2], 2], [[3], 2]] -> {'_0': [[1, 2], [3]], '_1': [1, 2]} + + :param batch: + :return: + """ + dict_batch = defaultdict(list) + for sample in batch: + for i, content in enumerate(sample): + dict_batch[f'_{i}'].append(content) + return dict_batch + + +def pack_batch_sequence(batch:Mapping)->Sequence: + return list(batch.values()) \ No newline at end of file diff --git a/tests/core/collators/__init__.py b/tests/core/collators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/collators/padders/__init__.py b/tests/core/collators/padders/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/core/collators/padders/test_get_padder.py b/tests/core/collators/padders/test_get_padder.py new file mode 100644 index 00000000..38fd4733 --- /dev/null +++ b/tests/core/collators/padders/test_get_padder.py @@ -0,0 +1,139 @@ +import pytest +import numpy as np + +from fastNLP.core.collators.padders.get_padder import get_padder, InconsistencyError, DtypeError, \ + _get_element_shape_dtype +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR + + +def test_get_element_shape_dtype(): + catalog = _get_element_shape_dtype([[1], [2, 3], [3], 2]) + catalog = _get_element_shape_dtype([['1'], [2, 3]]) + catalog = _get_element_shape_dtype([['1'], [2, 3]]) + catalog = _get_element_shape_dtype([['1'], ['2', '3']]) + catalog = _get_element_shape_dtype([np.zeros(3), np.zeros((2, 1))]) + + +@pytest.mark.parametrize('backend', ['raw', None, 'numpy', 'torch', 'jittor', 'paddle']) +def test_get_padder_run(backend): + if not _NEED_IMPORT_TORCH and backend == 'torch': + pytest.skip("No torch") + if not _NEED_IMPORT_PADDLE and backend == 'paddle': + pytest.skip("No paddle") + if not _NEED_IMPORT_PADDLE and backend == 'jittor': + pytest.skip("No jittor") + batch_field = [1, 2, 3] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + + if backend is not None: + # 不能 pad + batch_field = [[1], [2, 3], [3], 2] + with pytest.raises(InconsistencyError): + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') + + # 不能 pad + batch_field = [['2'], ['2'], ['2', '2']] + with pytest.raises(DtypeError) as exec_info: + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') + + batch_field = [np.zeros(3), np.zeros((3, 1))] + with pytest.raises(InconsistencyError) as exec_info: + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + padder = get_padder(batch_field, pad_val=None, backend=backend, dtype=int, field_name='test') # no pad + + batch_field = [np.zeros((3, 1)), np.zeros((4, 1))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + + +def test_raw_padder(): + backend = 'raw' + batch_field = [1, 2, 3] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert pad_batch == batch_field + + batch_field = [[1], [2, 2], [3, 3, 3]] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert np.shape(pad_batch) == (3, 3) + + batch_field = [[[1]], [[2, 2], [2]], [[3], [3], [3]]] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert np.shape(pad_batch) == (3, 3, 2) + + +def test_numpy_padder(): + backend = 'numpy' + target_type = np.ndarray + batch_field = [1, 2, 3] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert (pad_batch == np.array(batch_field)).sum()==len(batch_field) + + batch_field = [[1], [2, 2], [3, 3, 3]] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert np.shape(pad_batch) == (3, 3) + assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==3 + + batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,3))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert np.shape(pad_batch) == (3, 3, 3) + assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==9 + + batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,0))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert np.shape(pad_batch) == (3, 3, 3) + assert (pad_batch == np.zeros(np.shape(pad_batch))).sum()==12 + + batch_field = [np.ones((3,3)), np.ones((2,3)), np.ones((1,))] + with pytest.raises(InconsistencyError): + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + + +def test_torch_padder(): + if not _NEED_IMPORT_TORCH: + pytest.skip("No torch.") + import torch + backend = 'torch' + target_type = torch.Tensor + batch_field = [1, 2, 3] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert (pad_batch == torch.LongTensor(batch_field)).sum()==len(batch_field) + + batch_field = [[1], [2, 2], [3, 3, 3]] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert pad_batch.shape == (3, 3) + assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==3 + + batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,3))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert pad_batch.shape == (3, 3, 3) + assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==9 + + batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,0))] + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + pad_batch = padder(batch_field) + assert isinstance(pad_batch, target_type) + assert pad_batch.shape == (3, 3, 3) + assert (pad_batch == torch.zeros(pad_batch.shape)).sum()==12 + + batch_field = [torch.ones((3,3)), torch.ones((2,3)), torch.ones((1,))] + with pytest.raises(InconsistencyError): + padder = get_padder(batch_field, pad_val=0, backend=backend, dtype=int, field_name='test') + diff --git a/tests/core/collators/padders/test_numpy_padder.py b/tests/core/collators/padders/test_numpy_padder.py new file mode 100644 index 00000000..42665857 --- /dev/null +++ b/tests/core/collators/padders/test_numpy_padder.py @@ -0,0 +1,81 @@ +import numpy as np +import pytest + +from fastNLP.core.collators.padders.numpy_padder import NumpyTensorPadder, NumpySequencePadder, NumpyNumberPadder +from fastNLP.core.collators.padders.exceptions import DtypeError +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + + +class TestNumpyNumberPadder: + def test_run(self): + padder = NumpyNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [1, 2, 3] + assert isinstance(a, np.ndarray) + assert (padder(a) == np.array(a)).sum() == 3 + + +class TestNumpySequencePadder: + def test_run(self): + padder = NumpySequencePadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [[1, 2, 3], [3]] + a = padder(a) + shape = np.shape(a) + assert isinstance(a, np.ndarray) + assert shape == (2, 3) + b = np.array([[1, 2, 3], [3, -1, -1]]) + assert (a == b).sum().item() == shape[0]*shape[1] + + def test_dtype_check(self): + padder = NumpySequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = NumpySequencePadder(ele_dtype=str, dtype=int, pad_val=-1) + if _NEED_IMPORT_TORCH: + import torch + with pytest.raises(DtypeError): + padder = NumpySequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) + + +class TestNumpyTensorPadder: + def test_run(self): + padder = NumpyTensorPadder(ele_dtype=np.zeros(3).dtype, dtype=int, pad_val=-1) + a = [np.zeros(3), np.zeros(2), np.zeros(0)] + a = padder(a) + shape = np.shape(a) + assert isinstance(a, np.ndarray) + assert shape == (3, 3) + b = np.array([[0, 0, 0], [0, 0, -1], [-1, -1, -1]]) + assert (a == b).sum().item() == shape[0]*shape[1] + + a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 1))] + a = padder(a) + shape = np.shape(a) + assert isinstance(a, np.ndarray) + assert shape == (3, 3, 2) + b = np.array([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[0, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] + a = padder(a) + shape = np.shape(a) + assert isinstance(a, np.ndarray) + assert shape == (3, 3, 2) + b = np.array([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + def test_dtype_check(self): + padder = NumpyTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = NumpyTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) + if _NEED_IMPORT_TORCH: + import torch + with pytest.raises(DtypeError): + padder = NumpyTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = NumpyTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) + + + diff --git a/tests/core/collators/padders/test_raw_padder.py b/tests/core/collators/padders/test_raw_padder.py new file mode 100644 index 00000000..41a9de64 --- /dev/null +++ b/tests/core/collators/padders/test_raw_padder.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest + +from fastNLP.core.collators.padders.raw_padder import RawNumberPadder, RawSequencePadder +from fastNLP.core.collators.padders.exceptions import DtypeError + + +class TestRawNumberPadder: + def test_run(self): + padder = RawNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [1, 2, 3] + assert padder(a) == a + + +class TestRawSequencePadder: + def test_run(self): + padder = RawSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [[1, 2, 3], [3]] + a = padder(a) + shape = np.shape(a) + assert shape == (2, 3) + b = np.array([[1, 2, 3], [3, -1, -1]]) + assert (a == b).sum().item() == shape[0]*shape[1] + + def test_dtype_check(self): + with pytest.raises(DtypeError): + padder = RawSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = RawSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) \ No newline at end of file diff --git a/tests/core/collators/padders/test_torch_padder.py b/tests/core/collators/padders/test_torch_padder.py new file mode 100644 index 00000000..85240b3c --- /dev/null +++ b/tests/core/collators/padders/test_torch_padder.py @@ -0,0 +1,105 @@ +import numpy as np +import pytest + +from fastNLP.core.collators.padders.torch_padder import TorchTensorPadder, TorchSequencePadder, TorchNumberPadder +from fastNLP.core.collators.padders.exceptions import DtypeError +from fastNLP.envs.imports import _NEED_IMPORT_TORCH + +if _NEED_IMPORT_TORCH: + import torch + + +class TestTorchNumberPadder: + def test_run(self): + padder = TorchNumberPadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [1, 2, 3] + t_a = padder(a) + assert isinstance(t_a, torch.Tensor) + assert (t_a == torch.LongTensor(a)).sum() == 3 + + +class TestTorchSequencePadder: + def test_run(self): + padder = TorchSequencePadder(ele_dtype=int, dtype=int, pad_val=-1) + a = [[1, 2, 3], [3]] + a = padder(a) + shape = a.shape + assert isinstance(a, torch.Tensor) + assert tuple(shape) == (2, 3) + b = torch.LongTensor([[1, 2, 3], [3, -1, -1]]) + assert (a == b).sum().item() == shape[0]*shape[1] + + def test_dtype_check(self): + padder = TorchSequencePadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = TorchSequencePadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = TorchSequencePadder(ele_dtype=torch.long, dtype=int, pad_val=-1) + padder = TorchSequencePadder(ele_dtype=np.int8, dtype=None, pad_val=-1) + a = padder([[1], [2, 322]]) + assert (a>67).sum()==0 # 因为int8的范围为-67 - 66 + padder = TorchSequencePadder(ele_dtype=np.zeros(2).dtype, dtype=None, pad_val=-1) + + + +class TestTorchTensorPadder: + def test_run(self): + padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) + a = [torch.zeros(3), torch.zeros(2), torch.zeros(0)] + a = padder(a) + shape = a.shape + assert isinstance(a, torch.Tensor) + assert tuple(shape) == (3, 3) + b = torch.LongTensor([[0, 0, 0], [0, 0, -1], [-1, -1, -1]]) + assert (a == b).sum().item() == shape[0]*shape[1] + + a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 2))] + a = padder(a) + shape = a.shape + assert isinstance(a, torch.Tensor) + assert tuple(shape) == (3, 3, 2) + b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[0, 0], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 1))] + a = padder(a) + shape = a.shape + assert isinstance(a, torch.Tensor) + assert tuple(shape) == (3, 3, 2) + b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[0, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=int, pad_val=-1) + a = [torch.zeros((3, 2)), torch.zeros((2, 2)), torch.zeros((1, 0))] + a = padder(a) + shape = a.shape + assert isinstance(a, torch.Tensor) + assert tuple(shape) == (3, 3, 2) + b = torch.LongTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + padder = TorchTensorPadder(ele_dtype=torch.zeros(3).dtype, dtype=None, pad_val=-1) + a = [np.zeros((3, 2)), np.zeros((2, 2)), np.zeros((1, 0))] + a = padder(a) + shape = a.shape + assert isinstance(a, torch.Tensor) + assert tuple(shape) == (3, 3, 2) + b = torch.FloatTensor([[[0, 0], [0, 0], [0, 0]], + [[0, 0], [0, 0], [-1, -1]], + [[-1, -1], [-1, -1], [-1, -1]]]) + assert (a == b).sum().item() == shape[0]*shape[1]*shape[2] + + def test_dtype_check(self): + padder = TorchTensorPadder(ele_dtype=np.zeros(3, dtype=np.int8).dtype, dtype=int, pad_val=-1) + with pytest.raises(DtypeError): + padder = TorchTensorPadder(ele_dtype=str, dtype=int, pad_val=-1) + padder = TorchTensorPadder(ele_dtype=torch.long, dtype=int, pad_val=-1) + padder = TorchTensorPadder(ele_dtype=int, dtype=torch.long, pad_val=-1) + + + diff --git a/tests/core/collators/padders/test_utils.py b/tests/core/collators/padders/test_utils.py new file mode 100644 index 00000000..4cc70400 --- /dev/null +++ b/tests/core/collators/padders/test_utils.py @@ -0,0 +1,90 @@ +import pytest +import numpy as np + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH +from fastNLP.core.collators.padders.utils import get_shape, get_padded_numpy_array, \ + get_padded_nest_list, is_number_or_numpy_number, is_numpy_number_dtype, is_number + + +def test_get_shape(): + a = [[1, 2, 3], [3]] + assert get_shape(a) == [2, 3] + + a = [[[1], [2], [3, 4]], [[2, 3, 4]]] + assert get_shape(a) == [2, 3, 3] + + a = [[[1], [2], [3, 4]], [[]]] + assert get_shape(a) == [2, 3, 2] + + +def test_get_padded_numpy_array(): + a = [[1, 2, 3], [3]] + a = get_padded_numpy_array(a, dtype=int, pad_val=-1) + assert a.shape == (2, 3) + + a = [[[1], [2], [3, 4]], [[2, 3, 4]]] + a = get_padded_numpy_array(a, dtype=int, pad_val=-1) + assert a.shape == (2, 3, 3) + + a = [[[1], [2], [3, 4]], [[]]] + a = get_padded_numpy_array(a, dtype=int, pad_val=-1) + assert a.shape == (2, 3, 2) + + +def test_get_padded_nest_list(): + a = [[1, 2, 3], [3]] + a = get_padded_nest_list(a, pad_val=-1) + assert np.shape(a) == (2, 3) + + a = [[[1], [2], [3, 4]], [[2, 3, 4]]] + a = get_padded_nest_list(a, pad_val=-1) + assert np.shape(a) == (2, 3, 3) + + a = [[[1], [2], [3, 4]], [[]]] + a = get_padded_nest_list(a, pad_val=-1) + assert np.shape(a) == (2, 3, 2) + + +def test_is_number_or_numpy_number(): + assert is_number_or_numpy_number(type(3)) is True + assert is_number_or_numpy_number(type(3.1)) is True + assert is_number_or_numpy_number(type(True)) is True + assert is_number_or_numpy_number(type('3')) is False + assert is_number_or_numpy_number(np.zeros(3).dtype) is True + assert is_number_or_numpy_number(np.zeros(3, dtype=int).dtype) is True + assert is_number_or_numpy_number(np.zeros(3, dtype=object).dtype) is False + + if _NEED_IMPORT_TORCH: + import torch + dtype = torch.ones(3).dtype + assert is_number_or_numpy_number(dtype) is False + + +def test_is_number(): + assert is_number(type(3)) is True + assert is_number(type(3.1)) is True + assert is_number(type(True)) is True + assert is_number(type('3')) is False + assert is_number(np.zeros(3).dtype) is False + assert is_number(np.zeros(3, dtype=int).dtype) is False + assert is_number(np.zeros(3, dtype=object).dtype) is False + + if _NEED_IMPORT_TORCH: + import torch + dtype = torch.ones(3).dtype + assert is_number(dtype) is False + + +def test_is_numpy_number(): + assert is_numpy_number_dtype(type(3)) is False + assert is_numpy_number_dtype(type(3.1)) is False + assert is_numpy_number_dtype(type(True)) is False + assert is_numpy_number_dtype(type('3')) is False + assert is_numpy_number_dtype(np.zeros(3).dtype) is True + assert is_numpy_number_dtype(np.zeros(3, dtype=int).dtype) is True + assert is_numpy_number_dtype(np.zeros(3, dtype=object).dtype) is False + + if _NEED_IMPORT_TORCH: + import torch + dtype = torch.ones(3).dtype + assert is_numpy_number_dtype(dtype) is False \ No newline at end of file diff --git a/tests/core/collators/test_new_collator.py b/tests/core/collators/test_new_collator.py new file mode 100644 index 00000000..5fc82c91 --- /dev/null +++ b/tests/core/collators/test_new_collator.py @@ -0,0 +1,225 @@ + +import numpy as np +import pytest + +from fastNLP.envs.imports import _NEED_IMPORT_TORCH, _NEED_IMPORT_PADDLE, _NEED_IMPORT_JITTOR + +from fastNLP.core.collators.new_collator import Collator + + +def _assert_equal(d1, d2): + try: + if 'torch' in str(type(d1)): + if 'float64' in str(d2.dtype): + print(d2.dtype) + assert (d1 == d2).all().item() + else: + assert all(d1 == d2) + except TypeError: + assert d1 == d2 + except ValueError: + assert (d1 == d2).all() + + +def findDictDiff(d1, d2, path=""): + for k in d1: + if k in d2: + if isinstance(d1[k], dict): + findDictDiff(d1[k], d2[k], "%s -> %s" % (path, k) if path else k) + else: + _assert_equal(d1[k], d2[k]) + else: + raise RuntimeError("%s%s as key not in d2\n" % ("%s: " % path if path else "", k)) + + +def findListDiff(d1, d2): + assert len(d1)==len(d2) + for _d1, _d2 in zip(d1, d2): + if isinstance(_d1, list): + findListDiff(_d1, _d2) + else: + _assert_equal(_d1, _d2) + + +class TestCollator: + def test_run(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'a': 1, 'b':[1, 2]} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'a': 2, 'b': [1, 2]} + } + ] + + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + + raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} + collator = Collator(backend='raw') + assert raw_pad_batch == collator(dict_batch) + collator = Collator(backend='raw') + raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + collator = Collator(backend='numpy') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': np.array([1, 2]), 'lst_int': np.array([[1, 0], [1, 2]]), + 'nest_lst_int': np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), 'float': np.array([1.1, 2.1]), + 'lst_float': np.array([[1.1], [2.1]]), 'bool': np.array([True, False]), 'numpy': np.array([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': np.array([1, 2]), + 'b': np.array([[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(dict_batch)) + collator = Collator(backend='numpy') + numpy_pad_lst = [['1', '2'], [['1'], ['2', '2']], np.array([1, 2]), np.array([[1, 0], [2, 2]]), + np.array([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + np.array([1.1, 2.1]), np.array([[1.1], [2.1]]), np.array([True, False]), + np.array([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(numpy_pad_lst, collator(list_batch)) + + if _NEED_IMPORT_TORCH: + import torch + collator = Collator(backend='torch') + numpy_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': torch.LongTensor([1, 2]), + 'lst_int': torch.LongTensor([[1, 0], [1, 2]]), + 'nest_lst_int': torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + 'float': torch.FloatTensor([1.1, 2.1]), + 'lst_float': torch.FloatTensor([[1.1], [2.1]]), 'bool': torch.BoolTensor([True, False]), + 'numpy': torch.FloatTensor([[1], [0]]), + 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': torch.LongTensor([1, 2]), + 'b': torch.LongTensor( + [[1, 2], [1, 2]])}} + + findDictDiff(numpy_pad_batch, collator(dict_batch)) + collator = Collator(backend='torch') + torch_pad_lst = [['1', '2'], [['1'], ['2', '2']], torch.LongTensor([1, 2]), torch.LongTensor([[1, 0], [2, 2]]), + torch.LongTensor([[[1, 0], [0, 0]], [[1, 0], [1, 2]]]), + torch.FloatTensor([1.1, 2.1]), torch.FloatTensor([[1.1], [2.1]]), torch.BoolTensor([True, False]), + torch.LongTensor([[1, 0], [1, 1]]), [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(torch_pad_lst, collator(list_batch)) + + def test_pad(self): + dict_batch = [{ + 'str': '1', + 'lst_str': ['1'], + 'int': 1, + 'lst_int': [1], + 'nest_lst_int': [[1]], + 'float': 1.1, + 'lst_float': [1.1], + 'bool': True, + 'numpy': np.ones(1), + 'dict': {'1': '1'}, + 'set': {'1'}, + 'nested_dict': {'a': 1, 'b':[1, 2]} + }, + { + 'str': '2', + 'lst_str': ['2', '2'], + 'int': 2, + 'lst_int': [1, 2], + 'nest_lst_int': [[1], [1, 2]], + 'float': 2.1, + 'lst_float': [2.1], + 'bool': False, + 'numpy': np.zeros(1), + 'dict': {'1': '2'}, + 'set': {'2'}, + 'nested_dict': {'a': 2, 'b': [1, 2]} + } + ] + + raw_pad_batch = {'str': ['1', '2'], 'lst_str': [['1'], ['2', '2']], 'int': [1, 2], 'lst_int': [[1, 0], [1, 2]], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'a': [1, 2], 'b': [[1, 2], [1, 2]]}} + + # 测试 ignore + collator = Collator(backend='raw') + collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + # 测试 set_pad + collator = Collator(backend='raw') + collator.set_pad('str', pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + # 测试设置 pad 值 + collator = Collator(backend='raw') + collator.set_pad('nest_lst_int', pad_val=100) + collator.set_ignore('str', 'int', 'lst_int', 'nested_dict@@a') + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], + 'float': [1.1, 2.1], 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + # 设置 backend 和 type + collator.set_pad('float', pad_val=100, backend='numpy', dtype=int) + raw_pad_batch = {'lst_str': [['1'], ['2', '2']], 'nest_lst_int': [[[1, 100], [100, 100]], [[1, 100], [1, 2]]], + 'float': np.array([1, 2]), 'lst_float': [[1.1], [2.1]], 'bool': [True, False], 'numpy': [np.array([1.]), np.array([0.])], 'dict': {'1': ['1', '2']}, 'set': [{'1'}, {'2'}], 'nested_dict': {'b': [[1, 2], [1, 2]]}} + findDictDiff(raw_pad_batch, collator(dict_batch)) + + + # raw_pad_lst = [['1', '2'], [['1'], ['2', '2']], [1, 2], [[1, 0], [2, 2]], [[[1, 0], [0, 0]], [[1, 0], [1, 2]]], + # [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + # [{'1'}, {'2'}]] + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + collator = Collator(backend='raw') + collator.set_ignore('_0', '_3', '_1') + collator.set_pad('_4', pad_val=None) + raw_pad_lst = [[1, 2], [[[1]], [[1], [1, 2]]], + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + collator = Collator(backend='raw') + collator.set_pad('_0', pad_val=1) + with pytest.raises(BaseException): + collator(dict_batch) + + list_batch = [['1', ['1'], 1, [1], [[1]], 1.1, [1.1], True, np.ones(1), {'1': '1'}, {'1'}], + ['2', ['2', '2'], 2, [2, 2], [[1], [1, 2]], 2.1, [2.1], False, np.ones(2), {'2': '2'}, {'2'}]] + collator = Collator(backend='raw') + collator.set_ignore('_0', '_3', '_1') + collator.set_pad('_2', backend='numpy') + collator.set_pad('_4', backend='numpy', pad_val=100) + raw_pad_lst = [np.array([1, 2]), np.array([[[1, 100], [100, 100]], [[1, 100], [1, 2]]]), + [1.1, 2.1], [[1.1], [2.1]], [True, False], [np.ones(1), np.ones(2)], [{'1': '1'}, {'2': '2'}], + [{'1'}, {'2'}]] + findListDiff(raw_pad_lst, collator(list_batch)) + + # _single + collator = Collator() + collator.set_pad('_single') + findListDiff(list_batch, collator(list_batch)) + + + + + + + diff --git a/tests/core/collators/test_utils.py b/tests/core/collators/test_utils.py new file mode 100644 index 00000000..d56dacc6 --- /dev/null +++ b/tests/core/collators/test_utils.py @@ -0,0 +1,37 @@ + +from fastNLP.core.collators.utils import * + + +def test_unpack_batch_mapping(): + batch = [{'a': [1, 2], 'b': 1}, {'a': [3], 'b': 2}] + assert unpack_batch_mapping(batch)=={'a': [[1, 2], [3]], 'b': [1, 2]} + + +def test_unpack_batch_nested_mapping(): + batch = [{'a': [1, 2], 'b': 1, 'c': {'c': 1}}, {'a': [3], 'b': 2, 'c': {'c': 2}}] + assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c': [1, 2]} + + batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1}}}, {'a': [3], 'b': 2, 'c': {'c': {'c': 2}}}] + assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2]} + + batch = [{'a': [1, 2], 'b': 1, 'c': {'c': {'c': 1, 'd':[1, 1]}, 'd': [1]}}, + {'a': [3], 'b': 2, 'c': {'c': {'c': 2, 'd': [2, 2]}, 'd': [2, 2]}}] + assert unpack_batch_nested_mapping(batch) == {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], + 'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} + + +def test_pack_batch_nested_mapping(): + batch = {'a': [[1, 2], [3]], 'b': [1, 2], 'c@@c@@c': [1, 2], + 'c@@c@@d':[[1, 1], [2, 2]], 'c@@d': [[1], [2, 2]]} + new_batch = pack_batch_nested_mapping(batch) + assert new_batch == {'a': [[1, 2], [3]], 'b': [1, 2], + 'c': {'c':{'c': [1, 2], 'd': [[1, 1], [2, 2]]}, 'd':[[1], [2, 2]]}} + + +def test_unpack_batch_sequence(): + batch = [[1, 2, 3], [2, 4, 6]] + new_batch = unpack_batch_sequence(batch) + assert new_batch == {'_0': [1, 2], '_1': [2, 4], '_2': [3, 6]} + + +