From 2924e2117f564b8e85bee989624575d26a4db369 Mon Sep 17 00:00:00 2001 From: YWMditto Date: Thu, 14 Apr 2022 23:34:35 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E5=88=A0=E9=99=A4=E4=BA=86=20driver=20?= =?UTF-8?q?=E4=B8=AD=E7=9A=84=20**=5Fstep=EF=BC=8C=E4=BD=BF=E7=94=A8=20mod?= =?UTF-8?q?el=5Fcall=20=E5=92=8C=20get=5Fmodel=5Fcall=5Ffn=20=E6=9D=A5?= =?UTF-8?q?=E4=BB=A3=E6=9B=BF=EF=BC=9B=E5=88=A0=E9=99=A4=E4=BA=86=20driver?= =?UTF-8?q?=20=E4=B8=AD=E7=9A=84=E6=89=80=E6=9C=89=20dataloaders?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/checkpoint_callback.py | 2 +- fastNLP/core/controllers/evaluator.py | 49 +++---- fastNLP/core/controllers/trainer.py | 97 ++++++++------ fastNLP/core/controllers/utils/utils.py | 2 +- fastNLP/core/drivers/driver.py | 122 +++++------------- .../drivers/jittor_driver/jittor_driver.py | 14 +- .../drivers/jittor_driver/single_device.py | 10 +- fastNLP/core/drivers/paddle_driver/fleet.py | 6 +- .../drivers/paddle_driver/paddle_driver.py | 16 +-- .../drivers/paddle_driver/single_device.py | 16 +-- fastNLP/core/drivers/paddle_driver/utils.py | 14 +- fastNLP/core/drivers/torch_driver/ddp.py | 101 ++++++--------- .../drivers/torch_driver/single_device.py | 96 ++++---------- .../core/drivers/torch_driver/torch_driver.py | 19 +-- fastNLP/core/drivers/torch_driver/utils.py | 57 +------- .../torch_paddle_driver.py | 10 +- fastNLP/core/log/logger.py | 6 +- .../test_checkpoint_callback_torch.py | 16 +-- .../test_load_best_model_callback_torch.py | 2 +- .../_test_distributed_launch_torch_1.py | 2 +- .../_test_distributed_launch_torch_2.py | 2 +- .../controllers/test_trainer_event_trigger.py | 2 +- tests/core/controllers/test_trainer_fleet.py | 4 +- .../controllers/test_trainer_fleet_outside.py | 4 +- tests/core/controllers/test_trainer_paddle.py | 14 +- .../test_trainer_w_evaluator_torch.py | 16 +-- .../test_trainer_wo_evaluator_torch.py | 10 +- .../paddle_driver/test_single_device.py | 14 +- tests/helpers/models/torch_model.py | 2 +- 29 files changed, 273 insertions(+), 452 deletions(-) diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index a5be2b4c..7bbdb2fe 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -95,7 +95,7 @@ class CheckpointCallback(HasMonitorCallback): if self.save_topk is not None: super().on_after_trainer_initialized(trainer, driver) if self.save_topk is not None and trainer.evaluator is None: - logger.warning("You set `save_topk`, but `validate_dataloaders` is not set in Trainer.") + logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") def on_validate_end(self, trainer, results): if len(results) == 0: diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 5196f8c7..3013c316 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -39,7 +39,7 @@ class Evaluator: driver: Union[str, Driver] = 'single', device: Optional[Union[int, List[int], str]] = None, batch_step_fn: Optional[callable] = None, - mode: Optional[Union[str, callable]] = 'validate', # 首先尝试找 evaluate_step, 找不到 forward, callable + evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable input_mapping: Optional[Union[Callable, Dict]] = None, output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, @@ -58,14 +58,13 @@ class Evaluator: :param batch_step_fn: callable的对象,接受 (evaluator, batch) 作为参数,其中 evaluator 为 Evaluator 对象,batch 为 DataLoader 中返回的对象。一个 batch_step_fn 的例子可参考 fastNLP.core.controller.loops.evaluate_batch_loop 的 batch_step_fn 函数。 - :param mode: 可选 ["validate", "test"], 当为 "validate" 时将首先尝试寻找 model 是否有 validate_step 函数,没有的话则尝试 - 寻找 test_step 函数,都没找到则使用 model 的前向运算函数。当为 "test" 是将首先尝试寻找 model 是否有 test_step 函数, - 没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 + :param evaluate_fn: 用来控制 `Evaluator` 在评测的前向传播过程中是调用哪一个函数,例如是 `model.evaluate_step` 还是 `model.forward`; + 默认为 None,如果该值是 None,那么我们会默认使用 `evaluate_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数; :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 - 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; + 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `evaluate_step` 和 `test_step`; :param fp16: 是否使用 fp16 。 :param verbose: 是否打印 evaluate 的结果。 :param kwargs: @@ -87,9 +86,11 @@ class Evaluator: self.model = model self.metrics = metrics - self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) + if dataloaders is None: + raise ValueError("Parameter `dataloaders` can not be None.") + self.dataloaders = dataloaders self.device = device self.verbose = verbose @@ -97,21 +98,12 @@ class Evaluator: _check_valid_parameters_number(batch_step_fn, ['trainer', 'batch'], fn_name='batch_step_fn') self.batch_step_fn = batch_step_fn - self.mode = mode - assert mode in {'validate', 'test'}, "Parameter `mode` should only be 'validate' or 'test'." - self.input_mapping = input_mapping self.output_mapping = output_mapping if not isinstance(dataloaders, dict): dataloaders = {None: dataloaders} - if mode == "validate": - self._evaluate_step = self.driver.validate_step - self.driver.set_dataloader(validate_dataloaders=dataloaders) - else: - self._evaluate_step = self.driver.test_step - self.driver.set_dataloader(test_dataloaders=dataloaders) - self.mode = mode + self.evaluate_batch_loop = EvaluateBatchLoop(batch_step_fn=batch_step_fn) self.separator = kwargs.get('separator', '#') self.model_use_eval_mode = kwargs.get('model_use_eval_mode', True) @@ -123,10 +115,14 @@ class Evaluator: self._metric_wrapper = None _ = self.metrics_wrapper # 触发检查 - assert self.driver.has_validate_dataloaders() or self.driver.has_test_dataloaders() self.driver.setup() self.driver.barrier() + if evaluate_fn is not None and not isinstance(evaluate_fn, str): + raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") + self._evaluate_step, self._evaluate_step_signature_fn = self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) + self.evaluate_fn = evaluate_fn + self.dataloaders = {} for name, dl in dataloaders.items(): # 替换为正确的 sampler dl = self.driver.set_dist_repro_dataloader(dataloader=dl, dist=self._dist_sampler, reproducible=False) @@ -136,9 +132,10 @@ class Evaluator: if self.progress_bar == 'auto': self.progress_bar = 'rich' if (sys.stdin and sys.stdin.isatty()) else 'raw' - self.driver.check_evaluator_mode(self.mode) self.driver.barrier() + self.driver.check_dataloader_legality(self.dataloaders, "dataloaders", is_train=False) + def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: """ 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 @@ -156,11 +153,6 @@ class Evaluator: assert isinstance(num_eval_batch_per_dl, int), "num_eval_batch_per_dl must be of int type." assert num_eval_batch_per_dl > 0 or num_eval_batch_per_dl == -1, "num_eval_batch_per_dl must be -1 or larger than 0." - if self.mode == 'validate': - assert self.driver.has_validate_dataloaders() - else: - assert self.driver.has_test_dataloaders() - metric_results = {} self.reset() evaluate_context = self.driver.get_evaluate_context() @@ -235,13 +227,6 @@ class Evaluator: f_rich_progress.destroy_task(self._rich_task_id) delattr(self, '_rich_task_id') - @property - def eval_dataloaders(self): - if self.mode == "validate": - return self.driver.validate_dataloaders - else: - return self.driver.test_dataloaders - @property def evaluate_batch_loop(self): return self._evaluate_batch_loop @@ -296,13 +281,13 @@ class Evaluator: def evaluate_step(self, batch): """ - 将 batch 传递到model中进行处理,根据当前 mode 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再 + 将 batch 传递到model中进行处理,根据当前 evaluate_fn 选择进行 evaluate 还是 test 。会将返回结果经过 output_mapping 处理后再 返回。 :param batch: :return: """ - outputs = self._evaluate_step(batch) + outputs = self.driver.model_call(batch, self._evaluate_step, self._evaluate_step_signature_fn) outputs = match_and_substitute_params(self.output_mapping, outputs) return outputs diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 66e88827..5154c580 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -41,19 +41,20 @@ class Trainer(TrainerEventTrigger): optimizers, device: Optional[Union[int, List[int], str]] = "cpu", n_epochs: int = 20, - validate_dataloaders=None, + evaluate_dataloaders=None, batch_step_fn: Optional[Callable] = None, - validate_batch_step_fn: Optional[Callable] = None, - validate_mode: Union[str, callable] = 'validate', + evaluate_batch_step_fn: Optional[Callable] = None, + train_fn: Optional[str] = None, + evaluate_fn: Optional[str] = None, callbacks: Union[List[Callback], Callback, None] = None, metrics: Optional[dict] = None, - validate_every: Optional[Union[int, callable]] = -1, + evaluate_every: Optional[Union[int, Callable]] = -1, input_mapping: Optional[Union[Callable, Dict]] = None, output_mapping: Optional[Union[Callable, Dict]] = None, model_wo_auto_param_call: bool = False, accumulation_steps: int = 1, fp16: bool = False, - monitor: Union[str, callable] = None, + monitor: Union[str, Callable] = None, larger_better: bool = True, marker: Optional[str] = None, **kwargs @@ -79,19 +80,19 @@ class Trainer(TrainerEventTrigger): 4. list(int):如果多于1个device,应当通过该种方式进行设定;当 `device` 为一个 list 时,我们默认使用 `TorchDDPDriver`; 5. None: 为None则不对模型进行任何处理; :param n_epochs: 训练总共的 epoch 的数量,默认为 20; - :param validate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 + :param evaluate_dataloaders: 验证数据集,其可以是单独的一个数据集,也可以是多个数据集;当为多个数据集时,注意其必须是 Dict;默认 为 None; :param batch_step_fn: 用来替换 `TrainBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的两个参数必须为 `trainer` 和 `batch`;默认为 None; - :param validate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的 + :param evaluate_batch_step_fn: 用来替换 'Evaluator' 中的 `EvaluateBatchLoop` 中的 `batch_step_fn` 函数,注意该函数的 两个参数必须为 `evaluator` 和 `batch`;默认为 None; - :param validate_mode: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,其值应当为以下之一:["validate", "test"]; - 默认为 "validate";当为 "validate" 时将首先尝试寻找 model 是否有 validate_step 函数,没有的话则尝试 - 寻找 test_step 函数,都没找到则使用 model 的前向运算函数。当为 "test" 是将首先尝试寻找 model 是否有 test_step 函数, - 没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 + :param train_fn: 用来控制 `Trainer` 在训练的前向传播过程中是调用哪一个函数,例如是 `model.train_step` 还是 `model.forward`; + 默认为 None,如果该值是 None,那么我们会默认使用 `train_step` 当做前向传播的函数,如果在模型中没有找到该方法,则使用 `model.forward` 函数; + :param evaluate_fn: 用来控制 `Trainer` 中内置的 `Evaluator` 的模式,应当为 None 或者一个字符串;其使用方式和 train_fn 类似; + 注意该参数我们会直接传给 Trainer 中内置的 Evaluator(如果不为 None); :param callbacks: 训练当中触发的 callback 类,该参数应当为一个列表,其中的每一个元素都应当继承 `Callback` 类; :param metrics: 应当为一个字典,其中 key 表示 monitor,例如 {"acc1": AccMetric(), "acc2": AccMetric()}; - :param validate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; + :param evaluate_every: 可以为负数、正数或者函数;为负数时表示每隔几个 epoch validate 一次;为正数则表示每隔几个 batch validate 一次; 为函数时表示用户自己传入的用于控制 Trainer 中的 validate 的频率的函数,该函数的应该接受当前 trainer 对象作为参数,并 返回一个 bool 值,返回为 True 说明需要进行 validate ;将在每个 batch 结束后调用该函数判断是否需要 validate 。 :param input_mapping: 应当为一个字典或者一个函数,表示在当前 step 拿到一个 batch 的训练数据后,应当做怎样的映射处理;如果其是 @@ -105,10 +106,10 @@ class Trainer(TrainerEventTrigger): 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; 如果该值为 False,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 - 为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `validate_step` 和 `test_step`; + 为 True,那么我们会将 batch 直接透传给模型。注意该参数应用于 `train_step`, `evaluate_step` 和 `test_step`; :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; :param fp16: 是否开启混合精度训练;默认为 False; - :param monitor: 当存在 validate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 + :param monitor: 当存在 evaluate_dataloaders 时,默认的 monitor metric 的名字。传入的 callback 如果有 monitor 参数且没有 在 callback 初始化设定的,将采取这个值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 的那个作为 monitor 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: monitor 的值是否是越大越好。 @@ -136,10 +137,15 @@ class Trainer(TrainerEventTrigger): else: self.driver_name = driver.__class__.__name__ self.device = device + if train_dataloader is None: + raise ValueError("Parameter `train_dataloader` can not be None.") + self.train_dataloader = train_dataloader + self.evaluate_dataloaders = evaluate_dataloaders self.optimizers = optimizers self.fp16 = fp16 self.input_mapping = input_mapping self.output_mapping = output_mapping + self.evaluate_fn = evaluate_fn self.batch_step_fn = batch_step_fn if batch_step_fn is not None: @@ -168,13 +174,13 @@ class Trainer(TrainerEventTrigger): optimizers=optimizers, device=device, n_epochs=n_epochs, - validate_dataloaders=validate_dataloaders, + validate_dataloaders=evaluate_dataloaders, batch_step_fn=batch_step_fn, - validate_batch_step_fn=validate_batch_step_fn, - validate_mode=validate_mode, + validate_batch_step_fn=evaluate_batch_step_fn, + evaluate_fn=evaluate_fn, callbacks=callbacks, metrics=metrics, - validate_every=validate_every, + validate_every=evaluate_every, input_mapping=input_mapping, output_mapping=output_mapping, model_wo_auto_param_call=model_wo_auto_param_call, @@ -185,9 +191,6 @@ class Trainer(TrainerEventTrigger): ) self.driver.set_optimizers(optimizers=optimizers) - if train_dataloader is not None: - self.driver.set_dataloader(train_dataloader=train_dataloader) - # 初始化 callback manager; self.callback_manager = CallbackManager(callbacks, kwargs.get('progress_bar', 'auto')) # 添加所有的函数式 callbacks; @@ -213,25 +216,25 @@ class Trainer(TrainerEventTrigger): _dist_sampler = None """ 设置内部的 Evaluator """ - if metrics is None and validate_dataloaders is not None: + if metrics is None and evaluate_dataloaders is not None: raise ValueError("You have set 'validate_dataloader' but forget to set 'metrics'.") - if metrics is not None and validate_dataloaders is None: + if metrics is not None and evaluate_dataloaders is None: raise ValueError("You have set 'metrics' but forget to set 'validate_dataloader'.") self.evaluator = None self.monitor = monitor self.larger_better = larger_better - if metrics is not None and validate_dataloaders is not None: - check_validate_every(validate_every) + if metrics is not None and evaluate_dataloaders is not None: + check_validate_every(evaluate_every) self.evaluator = Evaluator( model=model, - dataloaders=validate_dataloaders, + dataloaders=evaluate_dataloaders, metrics=metrics, driver=self.driver, device=device, - batch_step_fn=validate_batch_step_fn, - mode=validate_mode, + batch_step_fn=evaluate_batch_step_fn, + evaluate_fn=evaluate_fn, input_mapping=input_mapping, output_mapping=output_mapping, fp16=fp16, @@ -241,12 +244,16 @@ class Trainer(TrainerEventTrigger): ) self.metrics = metrics - self.validate_every = validate_every + self.validate_every = evaluate_every - assert self.driver.has_train_dataloader() self.driver.setup() self.driver.barrier() + if train_fn is not None and not isinstance(train_fn, str): + raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") + self._train_step, self._train_step_signature_fn = self.driver.get_model_call_fn("train_step" if train_fn is None else train_fn) + self.train_fn = train_fn + self.dataloader = self.train_dataloader self.driver.set_deterministic_dataloader(self.dataloader) @@ -257,6 +264,7 @@ class Trainer(TrainerEventTrigger): self.on_after_trainer_initialized(self.driver) self.driver.barrier() + self.driver.check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True) def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, @@ -273,6 +281,7 @@ class Trainer(TrainerEventTrigger): 行。默认如果非 distributed 的 driver 会 catch ,distributed 不会 catch (无法 catch ) :return: """ + if catch_KeyboardInterrupt is None: catch_KeyboardInterrupt = not self.driver.is_distributed() else: @@ -343,7 +352,8 @@ class Trainer(TrainerEventTrigger): _validate_res: dict = validate_fn() trainer.on_validate_end(_validate_res) - self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) + if self.evaluator is not None: + self.run_evaluate = partial(_validate_fn, self, partial(self.evaluator.run, num_eval_batch_per_dl)) def step_validate(self): """ @@ -489,11 +499,6 @@ class Trainer(TrainerEventTrigger): self.has_checked_train_batch_loop = True """ Trainer 需要的一些 property """ - - @property - def train_dataloader(self): - return self.driver.train_dataloader - @property def driver(self): return self._driver @@ -684,7 +689,7 @@ class Trainer(TrainerEventTrigger): def train_step(self, batch): with self.driver.auto_cast(): - outputs = self.driver.train_step(batch) + outputs = self.driver.model_call(batch, self._train_step, self._train_step_signature_fn) outputs = match_and_substitute_params(self.output_mapping, outputs) return outputs @@ -814,6 +819,24 @@ class Trainer(TrainerEventTrigger): def data_device(self): return self.driver.data_device + """ dataloader property """ + + @property + def train_dataloader(self): + return self._train_dataloader + + @train_dataloader.setter + def train_dataloader(self, train_dataloader): + self._train_dataloader = train_dataloader + + @property + def evaluate_dataloaders(self): + return self._evaluate_dataloaders + + @evaluate_dataloaders.setter + def evaluate_dataloaders(self, evaluate_dataloaders): + self._evaluate_dataloaders = evaluate_dataloaders + diff --git a/fastNLP/core/controllers/utils/utils.py b/fastNLP/core/controllers/utils/utils.py index 6e0824a1..3d25fd6b 100644 --- a/fastNLP/core/controllers/utils/utils.py +++ b/fastNLP/core/controllers/utils/utils.py @@ -128,6 +128,6 @@ class _TruncatedDataLoader: def check_validate_every(validate_every): if not callable(validate_every) and (not isinstance(validate_every, int) or validate_every == 0): - raise ValueError("Parameter 'validate_every' should be set to 'int' type and either < 0 or > 0.") + raise ValueError("Parameter 'evaluate_every' should be set to 'int' type and either < 0 or > 0.") if callable(validate_every): _check_valid_parameters_number(validate_every, expected_params=['trainer']) diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index 019e6fad..b1015b47 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -1,7 +1,7 @@ import os import signal import sys -from typing import Any, Sequence, List, Optional, Callable, Dict, Union +from typing import Any, Sequence, List, Optional, Callable, Dict, Union, Tuple from abc import ABC, abstractmethod from datetime import datetime from pathlib import Path @@ -79,41 +79,44 @@ class Driver(ABC): """ @abstractmethod - def train_step(self, batch): + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: """ - 通过调用模型自带的 `train_step` 或者 `forward` 方法来实现训练的前向过程; - 如果检测到用户模型实现了 train_step + 通过调用 `fn` 来实现训练时的前向传播过程; + 注意 Trainer 和 Evaluator 会调用该函数来实现网络的前向传播过程,其中传入该函数的参数 `fn` 是函数 `get_model_call_fn` 所返回的 + 函数; :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; - :return: 返回由模型的 `train_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); + :param fn: 由 Trainer 传入的用于网络前向传播一次的函数; + :param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call + 函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; + :return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); """ - raise NotImplementedError("Each specific driver should implemented its own `train_step` function.") + raise NotImplementedError("Each specific driver should implemented its own `model_call` function.") - def validate_step(self, batch): + @abstractmethod + def get_model_call_fn(self, fn: str) -> Tuple: """ - 通过调用模型自带的 `validate_step` 或者 `forward` 方法来实现模型评测的前向过程; + 该函数会接受 Trainer 的 train_fn 或者 Evaluator 的 evaluate_fn,返回一个实际用于调用 driver.model_call 时传入的函数参数; + 该函数会在 Trainer 和 Evaluator 在 driver.setup 函数之后调用; - :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; - :return: 返回由模型的 `validate_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); - """ - raise NotImplementedError("Each specific driver should implemented its own `validate_step` function.") + 之所以设置该函数的目的在于希望将具体的 model_call function 从 driver 中抽离出来,然后将其附着在 Trainer 或者 Evaluator 身上; + 这样是因为在新版的设计中,使用 model 的哪种方法来进行 `train step` 或者 `evaluate step` 是通过额外的参数 `train_fn` 和 + `evaluate_fn` 来确定的,而二者又分别是通过 Trainer 和 Evaluator 来控制的;因此不能将确定具体的 `train step fn` 和 + `evaluate step fn` 的逻辑放在每一个 driver 的初始化的时候(因此在 Trainer 初始化第一个 driver 时,Evaluator 还没有初始化,但是 + `evaluate step fn` 的确定却需要 Evaluator 的初始化),因此我们将这一逻辑抽象到这一函数当中; - def test_step(self, batch): - """ - 通过调用模型自带的 `test_step` 或者 `forward` 方法来实现模型评测的前向过程; + 这一函数应当通过参数 `fn` 来判断应当返回的实际的调用的函数,具体逻辑如下所示: + 1. 如果 fn == "train_step" or "evaluate_step",那么对传入的模型进行检测,如果模型没有定义方法 `fn`,则默认调用模型的 `forward` + 函数,然后给出 warning; + 2. 如果 fn 是其他字符串,那么如果模型没有定义方法 `fn` 则直接报错; + 注意不同的 driver 需要做额外的检测处理,例如在 DDPDriver 中,当传入的模型本身就是 DistributedDataParallel 中,我们只能调用模型的 + forward 函数,因此需要额外的 warning;这一点特别需要注意的问题在于 driver 自己在 setup 时也会对模型进行改变(DDPDriver),因此 + 可能需要额外标记最初传入 driver 的模型是哪种形式的; - :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; - :return: 返回由模型的 `test_step` 或者 `forward` 方法返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); + :param fn: 应当为一个字符串,该函数通过该字符串判断要返回模型的哪种方法; + :return: 返回一个元组,包含两个函数,用于在调用 driver.model_call 时传入; """ - raise NotImplementedError("Each specific driver should implemented its own `test_step` function.") - - def check_evaluator_mode(self, mode: str): - r""" - 因为我们在具体的 driver 的 validate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; - 因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么 - 我们应当提醒用户这一行为; - """ - raise NotImplementedError("Each specific driver should implemented its own `check_evaluator_mode` function.") + raise NotImplementedError("Each specific driver should implemented its own `get_model_call_fn` function.") @property def model(self): @@ -123,59 +126,8 @@ class Driver(ABC): def model(self, model): self._model = model - @property - def train_dataloader(self): - return self._train_dataloader - - @train_dataloader.setter - def train_dataloader(self, train_dataloader: Any): - self._train_dataloader = train_dataloader - - @property - def validate_dataloaders(self): - return self._validate_dataloaders - - @validate_dataloaders.setter - def validate_dataloaders(self, validate_dataloaders: Any): - self._validate_dataloaders = validate_dataloaders - - @property - def test_dataloaders(self): - return self._test_dataloaders - - @test_dataloaders.setter - def test_dataloaders(self, test_dataloaders: Any): - self._test_dataloaders = test_dataloaders - - @property - def predict_dataloaders(self): - return self._predict_dataloaders - - @predict_dataloaders.setter - def predict_dataloaders(self, predict_dataloaders: Any): - self._predict_dataloaders = predict_dataloaders - - def set_dataloader(self, **kwargs): - r""" - 设置训练或者检验过程中的数据;用于在 trainer 和 evaluator 中将数据 dataloader 挂载到每一个具体的 driver 上; - - :param kwargs: 输入的数据,应当使用 'keyword-only' 的参数进行设置; - """ - if "train_dataloader" in kwargs: - self.train_dataloader = kwargs["train_dataloader"] - self._check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True) - if "validate_dataloaders" in kwargs: - self.validate_dataloaders = kwargs["validate_dataloaders"] - self._check_dataloader_legality(self.validate_dataloaders, "validate_dataloaders", is_train=False) - if "test_dataloaders" in kwargs: - self.test_dataloaders = kwargs["test_dataloaders"] - self._check_dataloader_legality(self.test_dataloaders, "test_dataloaders", is_train=False) - if "predict_dataloaders" in kwargs: - self.predict_dataloaders = kwargs["predict_dataloaders"] - self._check_dataloader_legality(self.predict_dataloaders, "predict_dataloaders", is_train=False) - @staticmethod - def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): + def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): r""" 该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的 行为是不相同的; @@ -183,19 +135,7 @@ class Driver(ABC): :param dataloader: 需要检测的输入的 `dataloader`; :param dataloader_name: """ - raise NotImplementedError("Each specific driver should implemented its own `_check_dataloader_legality` function.") - - def has_train_dataloader(self): - return "_train_dataloader" in self.__dict__ - - def has_validate_dataloaders(self): - return "_validate_dataloaders" in self.__dict__ - - def has_test_dataloaders(self): - return "_test_dataloaders" in self.__dict__ - - def has_predict_dataloaders(self): - return "_predict_dataloaders" in self.__dict__ + raise NotImplementedError("Each specific driver should implemented its own `check_dataloader_legality` function.") @property def optimizers(self) -> List: diff --git a/fastNLP/core/drivers/jittor_driver/jittor_driver.py b/fastNLP/core/drivers/jittor_driver/jittor_driver.py index 411fdf69..84e3f002 100644 --- a/fastNLP/core/drivers/jittor_driver/jittor_driver.py +++ b/fastNLP/core/drivers/jittor_driver/jittor_driver.py @@ -39,7 +39,7 @@ class JittorDriver(Driver): self.grad_scaler = _grad_scaler() @staticmethod - def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): + def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): # 在fastnlp中实现了JittorDataLoader # TODO: 是否允许传入Dataset? if is_train: @@ -64,18 +64,18 @@ class JittorDriver(Driver): def check_evaluator_mode(self, mode: str): model = self.unwrap_model() if mode == "validate": - if not hasattr(model, "validate_step"): + if not hasattr(model, "evaluate_step"): if hasattr(model, "test_step"): logger.warning_once( - "Your model does not have 'validate_step' method but has 'test_step' method, but you" - "are using 'mode=validate', we are going to use 'test_step' to substitute for" - "'validate_step'.") + "Your model does not have 'evaluate_step' method but has 'test_step' method, but you" + "are using 'evaluate_fn=validate', we are going to use 'test_step' to substitute for" + "'evaluate_step'.") else: if not hasattr(model, "test_step"): - if hasattr(model, "validate_step"): + if hasattr(model, "evaluate_step"): logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" - "are using 'mode=test', we are going to use 'validate_step' to substitute for" + "are using 'evaluate_fn=test', we are going to use 'evaluate_step' to substitute for" "'test_step'.") def save_model(self, filepath: str, only_state_dict: bool = False, model_save_fn: Optional[Callable]=None): diff --git a/fastNLP/core/drivers/jittor_driver/single_device.py b/fastNLP/core/drivers/jittor_driver/single_device.py index 4c99a2f5..84bdb28b 100644 --- a/fastNLP/core/drivers/jittor_driver/single_device.py +++ b/fastNLP/core/drivers/jittor_driver/single_device.py @@ -35,8 +35,8 @@ class JittorSingleDriver(JittorDriver): model = self.unwrap_model() self._train_signature_fn = model.execute - if hasattr(self.model, "validate_step"): - self._validate_step = self.model.validate_step + if hasattr(self.model, "evaluate_step"): + self._validate_step = self.model.evaluate_step self._validate_signature_fn = None elif hasattr(self.model, "test_step"): self._validate_step = self.model.test_step @@ -49,9 +49,9 @@ class JittorSingleDriver(JittorDriver): if hasattr(self.model, "test_step"): self._test_step = self.model.test_step self._test_signature_fn = None - elif hasattr(self.model, "validate_step"): - self._test_step = self.model.validate_step - self._test_signature_fn = self.model.validate_step + elif hasattr(self.model, "evaluate_step"): + self._test_step = self.model.evaluate_step + self._test_signature_fn = self.model.evaluate_step else: self._test_step = self.model model = self.unwrap_model() diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 3f29e4dd..1b29fd07 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -118,11 +118,11 @@ class PaddleFleetDriver(PaddleDriver): " call `forward` function instead of `train_step` and you should note that.") self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) - if hasattr(model, "validate_step"): + if hasattr(model, "evaluate_step"): logger.warning( "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `validate_step` method, which we can not call actually, " - "we will call `forward` function instead of `validate_step` and you should note that.") + "model also implements the `evaluate_step` method, which we can not call actually, " + "we will call `forward` function instead of `evaluate_step` and you should note that.") self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) if hasattr(model, "test_step"): diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 22f28743..de0af7f2 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -72,7 +72,7 @@ class PaddleDriver(Driver): optimizer.clear_grad() @staticmethod - def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): + def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): r""" 该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性。 要求传入的 dataloader 必须为 `paddle.io.DataLoader` 或包含该类型的字典。 @@ -117,24 +117,24 @@ class PaddleDriver(Driver): def check_evaluator_mode(self, mode: str): r""" - 因为我们在具体的 driver 的 validate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; - 因此如果用户的 evaluator mode 是 validate,但是传入的 model 却没有实现 validate_step 函数,而是实现了 test_step 函数,那么 + 因为我们在具体的 driver 的 evaluate_step 和 test_step 的逻辑是如果模型没有实现本函数,那么就去检测模型是否实现了另一个函数; + 因此如果用户的 evaluator evaluate_fn 是 validate,但是传入的 model 却没有实现 evaluate_step 函数,而是实现了 test_step 函数,那么 我们应当提醒用户这一行为; """ model = self.unwrap_model() if mode == "validate": - if not hasattr(model, "validate_step"): + if not hasattr(model, "evaluate_step"): if hasattr(model, "test_step"): logger.warning( - "Your model does not have 'validate_step' method but has 'test_step' method, but you" + "Your model does not have 'evaluate_step' method but has 'test_step' method, but you" "are using 'Evaluator.validate', we are going to use 'test_step' to substitute for" - "'validate_step'.") + "'evaluate_step'.") else: if not hasattr(model, "test_step"): - if hasattr(model, "validate_step"): + if hasattr(model, "evaluate_step"): logger.warning_once("Your model does not have 'test_step' method but has 'validate' method, but you" - "are using 'Evaluator.test', we are going to use 'validate_step' to substitute for" + "are using 'Evaluator.test', we are going to use 'evaluate_step' to substitute for" "'test_step'.") @staticmethod diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 796f4809..f11cb49a 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -50,10 +50,10 @@ class PaddleSingleDriver(PaddleDriver): self._train_step = self.model self._train_signature_fn = model.forward - if hasattr(model, "validate_step"): + if hasattr(model, "evaluate_step"): logger.warning("Notice your model is a `paddle.DataParallel` model. And your model also " - "implements the `validate_step` method, which we can not call actually, we " - "will call `forward` function instead of `validate_step` and you should note that.") + "implements the `evaluate_step` method, which we can not call actually, we " + "will call `forward` function instead of `evaluate_step` and you should note that.") self._validate_step = self.model self._validate_signature_fn = model.forward @@ -73,8 +73,8 @@ class PaddleSingleDriver(PaddleDriver): model = self.unwrap_model() self._train_signature_fn = model.forward - if hasattr(self.model, "validate_step"): - self._validate_step = self.model.validate_step + if hasattr(self.model, "evaluate_step"): + self._validate_step = self.model.evaluate_step self._validate_signature_fn = None elif hasattr(self.model, "test_step"): self._validate_step = self.model.test_step @@ -87,9 +87,9 @@ class PaddleSingleDriver(PaddleDriver): if hasattr(self.model, "test_step"): self._test_step = self.model.test_step self._test_signature_fn = None - elif hasattr(self.model, "validate_step"): - self._test_step = self.model.validate_step - self._test_signature_fn = self.model.validate_step + elif hasattr(self.model, "evaluate_step"): + self._test_step = self.model.evaluate_step + self._test_signature_fn = self.model.evaluate_step else: self._test_step = self.model model = self.unwrap_model() diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 895ec703..2f74cc65 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -108,11 +108,11 @@ class _FleetWrappingModel(Layer): self._train_step = self.model self._train_signature_fn = model.forward - if hasattr(model, "validate_step"): + if hasattr(model, "evaluate_step"): logger.warning( "Notice your model is a `paddle.DataParallel` model. And your " - "model also implements the `validate_step` method, which we can not call actually, " - "we will call `forward` function instead of `validate_step` and you should note that.") + "model also implements the `evaluate_step` method, which we can not call actually, " + "we will call `forward` function instead of `evaluate_step` and you should note that.") self._validate_step = self.model self._validate_signature_fn = model.forward @@ -131,7 +131,7 @@ class _FleetWrappingModel(Layer): self._train_step = model self._train_signature_fn = model.forward - if hasattr(model, "validate_step"): + if hasattr(model, "evaluate_step"): self._validate_step = model.validate_step self._validate_signature_fn = None elif hasattr(model, "test_step"): @@ -144,7 +144,7 @@ class _FleetWrappingModel(Layer): if hasattr(model, "test_step"): self._test_step = model.test_step self._test_signature_fn = None - elif hasattr(model, "validate_step"): + elif hasattr(model, "evaluate_step"): self._test_step = model.validate_step self._test_signature_fn = None else: @@ -172,9 +172,9 @@ class _FleetWrappingModel(Layer): else: return self._test_step(batch) elif forward_state == ForwardState.PREDICT: - raise NotImplementedError("'PREDICT' mode has not been implemented.") + raise NotImplementedError("'PREDICT' evaluate_fn has not been implemented.") else: - raise NotImplementedError("You should direct a concrete mode.") + raise NotImplementedError("You should direct a concrete evaluate_fn.") class DummyGradScaler: """ diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index c673fe62..55af3367 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -4,7 +4,7 @@ import __main__ import socket import numpy as np from time import sleep -from typing import List, Optional, Union, Dict +from typing import List, Optional, Union, Dict, Tuple, Callable from functools import partial from fastNLP.envs.imports import _NEED_IMPORT_TORCH @@ -21,8 +21,6 @@ __all__ = [ from .torch_driver import TorchDriver from fastNLP.core.drivers.torch_driver.utils import ( _DDPWrappingModel, - ForwardState, - _MODE_PARAMETER, reset_seed, replace_sampler, replace_batch_sampler @@ -158,10 +156,10 @@ class TorchDDPDriver(TorchDriver): ———————————————————————————————————————————————————————————————————————————————————————————————————————— 3. _DDPWrappingModel 的作用; - 因为我们即需要调用模型的 `train_step`、`validate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的 + 因为我们即需要调用模型的 `train_step`、`evaluate_step`、`test_step` 方法,又需要通过 `DistributedDataParallel` 的 forward 函数来帮助我们同步各个设备上的梯度,因此我们需要先将模型单独包裹一层,然后在 forward 的时候,其先经过 `DistributedDataParallel` 的 forward 方法,然后再经过 `_DDPWrappingModel` 的 forward 方法,我们会在该 forward 函数中进行判断,确定调用的是模型自己的 - forward 函数,还是 `train_step`、`validate_step`、`test_step` 方法。 + forward 函数,还是 `train_step`、`evaluate_step`、`test_step` 方法。 4. 当某一个进程出现 exception 后,`TorchDDPDriver` 的处理; @@ -204,37 +202,6 @@ class TorchDDPDriver(TorchDriver): # 我们就直接将 model_device 置为 None; self.model_device = None - def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): - if isinstance(batch, Dict) and not wo_auto_param_call: - return auto_param_call(step_fn, batch, signature_fn=signature_fn) - else: - return step_fn(batch) - - model = model.module - if hasattr(model, "train_step"): - logger.warning( - "Notice your model is a `DistributedDataParallel` model. And your " - "model also implements the `train_step` method, which we can not call actually, we will" - " call `forward` function instead of `train_step` and you should note that.") - self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) - # self._train_signature_fn = model.forward - - if hasattr(model, "validate_step"): - logger.warning( - "Notice your model is a `DistributedDataParallel` model. And your " - "model also implements the `validate_step` method, which we can not call actually, " - "we will call `forward` function instead of `validate_step` and you should note that.") - self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) - # self._validate_signature_fn = model.forward - - if hasattr(model, "test_step"): - logger.warning( - "Notice your model is a `DistributedDataParallel` model. And your " - "model also implements the `test_step` method, which we can not call actually, we will" - " call `forward` function instead of `test_step` and you should note that.") - self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) - # self._test_signature_fn = model.forward - # 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; self._data_device = kwargs.get("data_device", None) if isinstance(self._data_device, int): @@ -253,7 +220,6 @@ class TorchDDPDriver(TorchDriver): # world_size 表示的就是全局的显卡的数量; self.world_size = None # int(os.environ.get("WORLD_SIZE")) len(self.parallel_device) self.global_rank = 0 - self._configured = False # 防止重复调用 configure_ddp() 函数使用的 self._ddp_kwargs = kwargs.get("torch_ddp_kwargs", {}) check_user_specific_params(self._ddp_kwargs, DistributedDataParallel.__init__) @@ -268,8 +234,8 @@ class TorchDDPDriver(TorchDriver): os.makedirs(name=self.output_from_new_proc, exist_ok=True) self.output_from_new_proc = os.path.abspath(self.output_from_new_proc) - # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; - self._has_setup = False + self._has_setup = False # 设置这一参数是因为 evaluator 中也会进行 setup 操作,但是显然是不需要的也不应该的; + self._has_ddpwrapped = False # 判断传入的模型是否经过 _has_ddpwrapped 包裹; def setup(self): if self._has_setup: @@ -341,24 +307,16 @@ class TorchDDPDriver(TorchDriver): self._pids = self.tensor_to_numeric(self._pids) def configure_ddp(self): - if not self._configured and not isinstance(self.model, DistributedDataParallel): + if not isinstance(self.model, DistributedDataParallel): self.model = DistributedDataParallel( # 注意这里的 self.model_device 是 `torch.device` type,因此 self.model_device.index; _DDPWrappingModel(self.model), device_ids=[self.model_device.index], **self._ddp_kwargs ) - - self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) - self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) - self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) - - self._configured = True + self._has_ddpwrapped = True def open_subprocess(self): if self.local_rank == 0: - # self._consensus_file = Path(tempfile.mkstemp()[1]) - # self._consensus_file.unlink() - # Script called as `python a/b/c.py` if __main__.__spec__ is None: # pragma: no-cover # pull out the commands used to run the script and resolve the abs file path @@ -432,18 +390,39 @@ class TorchDDPDriver(TorchDriver): return self._data_device return self.model_device - def train_step(self, batch): - # 注意这里的 self.model 已经是 'fastNLP.drivers.utils._DDPWrappingModel'; - # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TRAIN}) - return self._train_step(batch) - - def validate_step(self, batch): - # return self.model(batch, **{_MODE_PARAMETER: ForwardState.VALIDATE}) - return self._validate_step(batch) - - def test_step(self, batch): - # return self.model(batch, **{_MODE_PARAMETER: ForwardState.TEST}) - return self._test_step(batch) + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: + if self._has_ddpwrapped: + return self.model(batch, fastnlp_fn=fn, fastnlp_signature_fn=signature_fn, + wo_auto_param_call=self.wo_auto_param_call) + else: + if isinstance(batch, Dict) and not self.wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) + else: + return fn(batch) + + def get_model_call_fn(self, fn: str) -> Tuple: + model = self.unwrap_model() + if self._has_ddpwrapped: + if hasattr(model, fn): + fn = getattr(model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + return fn, None + elif fn in {"train_step", "evaluate_step"}: + return model, model.forward + else: + raise RuntimeError(f"There is no `{fn}` method in your model.") + else: + if hasattr(model, fn): + logger.warning("Notice your model is a `DistributedDataParallel` model. And your model also implements " + f"the `{fn}` method, which we can not call actually, we will" + " call `forward` function instead of `train_step` and you should note that.") + elif fn not in {"train_step", "evaluate_step"}: + raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a " + "`DistributedDataParallel` model, which means that we will only call model.forward " + "function when we are in forward propagation.") + + return self.model, model.forward def set_dist_repro_dataloader(self, dataloader, dist: Optional[Union[str, ReproducibleSampler, ReproducibleBatchSampler]]=None, reproducible: bool = False): diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index eda438d7..b16bb309 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -1,5 +1,5 @@ import os -from typing import Dict, Union +from typing import Dict, Union, Callable, Tuple, Optional from fastNLP.envs.imports import _NEED_IMPORT_TORCH if _NEED_IMPORT_TORCH: import torch @@ -42,84 +42,40 @@ class TorchSingleDriver(TorchDriver): self.global_rank = 0 self.world_size = 1 - if isinstance(model, DataParallel): - model = self.unwrap_model() - if hasattr(model, "train_step"): - logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your " - "model also implements the `train_step` method, which we can not call actually, we will" - " call `forward` function instead of `train_step` and you should note that.") - self._train_step = self.model - self._train_signature_fn = model.forward - - if hasattr(model, "validate_step"): - logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your " - "model also implements the `validate_step` method, which we can not call actually, " - "we will call `forward` function instead of `validate_step` and you should note that.") - self._validate_step = self.model - self._validate_signature_fn = model.forward - - if hasattr(model, "test_step"): - logger.warning("Notice your model is a `DataParallel` or `DistributedDataParallel` model. And your " - "model also implements the `test_step` method, which we can not call actually, we will" - " call `forward` function instead of `test_step` and you should note that.") - self._test_step = self.model - self._test_signature_fn = model.forward - else: - if hasattr(self.model, "train_step"): - self._train_step = self.model.train_step - self._train_signature_fn = None - else: - self._train_step = self.model - # 输入的模型是 `DataParallel` 或者 `DistributedDataParallel`,我们需要保证其 signature_fn 是正确的; - model = self.unwrap_model() - self._train_signature_fn = model.forward - - if hasattr(self.model, "validate_step"): - self._validate_step = self.model.validate_step - self._validate_signature_fn = None - elif hasattr(self.model, "test_step"): - self._validate_step = self.model.test_step - self._validate_signature_fn = self.model.test_step - else: - self._validate_step = self.model - model = self.unwrap_model() - self._validate_signature_fn = model.forward - - if hasattr(self.model, "test_step"): - self._test_step = self.model.test_step - self._test_signature_fn = None - elif hasattr(self.model, "validate_step"): - self._test_step = self.model.validate_step - self._test_signature_fn = self.model.validate_step - else: - self._test_step = self.model - model = self.unwrap_model() - self._test_signature_fn = model.forward - def setup(self): if self.model_device is not None: self.model.to(self.model_device) - def train_step(self, batch) -> Dict: - # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; + def model_call(self, batch, fn: Callable, signature_fn: Optional[Callable]) -> Dict: if isinstance(batch, Dict) and not self.wo_auto_param_call: - return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) + return auto_param_call(fn, batch, signature_fn=signature_fn) else: - return self._train_step(batch) + return fn(batch) - def validate_step(self, batch) -> Dict: - # 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 - # 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; - if isinstance(batch, Dict) and not self.wo_auto_param_call: - return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) - else: - return self._validate_step(batch) + def get_model_call_fn(self, fn: str) -> Tuple: + if isinstance(self.model, DataParallel): + model = self.unwrap_model() + if hasattr(model, fn): + logger.warning("Notice your model is a `DataParallel` model. And your model also implements the " + f"`{fn}` method, which we can not call actually, we will" + " call `forward` function instead of `train_step` and you should note that.") - def test_step(self, batch) -> Dict: - if isinstance(batch, Dict) and not self.wo_auto_param_call: - return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) + elif fn not in {"train_step", "evaluate_step"}: + raise RuntimeError(f"There is no `{fn}` method in your model. And also notice that your model is a " + f"`DataParallel` model, which means that we will only call model.forward function " + f"when we are in forward propagation.") + + return self.model, model.forward else: - return self._test_step(batch) + if hasattr(self.model, fn): + fn = getattr(self.model, fn) + if not callable(fn): + raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + return fn, None + elif fn in {"train_step", "evaluate_step"}: + return self.model, self.model.forward + else: + raise RuntimeError(f"There is no `{fn}` method in your model.") def set_dist_repro_dataloader(self, dataloader, dist: Union[str, ReproducibleBatchSampler, ReproducibleSampler]=None, reproducible: bool = False): diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index f1e33d5e..b7aebec8 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -81,7 +81,7 @@ class TorchDriver(Driver): self.grad_scaler.update() @staticmethod - def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): + def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): if is_train: if not isinstance(dataloader, DataLoader): raise ValueError(f"Parameter `{dataloader_name}` should be 'DataLoader' type, not {type(dataloader)}.") @@ -108,23 +108,6 @@ class TorchDriver(Driver): raise ValueError(f"Each optimizer of parameter `optimizers` should be 'Optimizer' type, " f"not {type(each_optimizer)}.") - def check_evaluator_mode(self, mode: str): - model = self.unwrap_model() - if mode == "validate": - if not hasattr(model, "validate_step"): - if hasattr(model, "test_step"): - logger.warning_once( - "Your model does not have 'validate_step' method but has 'test_step' method, but you" - "are using 'mode=validate', we are going to use 'test_step' to substitute for" - "'validate_step'.") - - else: - if not hasattr(model, "test_step"): - if hasattr(model, "validate_step"): - logger.warning("Your model does not have 'test_step' method but has 'validate' method, but you" - "are using 'mode=test', we are going to use 'validate_step' to substitute for" - "'test_step'.") - @staticmethod def tensor_to_numeric(tensor, reduce=None): if tensor is None: diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index cdc6cea9..941e4445 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -90,14 +90,11 @@ class ForwardState(IntEnum): PREDICT = 3 -_MODE_PARAMETER = "_forward_state" - - class _DDPWrappingModel(Module): """ 该函数用于 DDP 训练时处理用户自己定制的 train_step 等函数; 之所以要使用这一额外的包裹模型,是因为在使用 DDP 时,必须使用 DistributedDataParallel 的 forward 函数才能实现正常的运行; - 另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'validate_step' 等; + 另一方面,我们要求用户在使用我们的框架时,需要针对不用的模式实现不同的处理函数,例如 'train_step', 'evaluate_step' 等; 然而,当使用 DistributedDataParallel 包裹 model 后,模型看不见其除了 forward 之外的方法;并且当我们尝试在训练过程中主动提取 `model = model.module`,这同样会导致错误,会使得每一个gpu上的模型参数不同; @@ -109,60 +106,18 @@ class _DDPWrappingModel(Module): super(_DDPWrappingModel, self).__init__() self.model = model - if hasattr(model, "train_step"): - self._train_step = model.train_step - self._train_signature_fn = None - else: - self._train_step = model - self._train_signature_fn = model.forward - - if hasattr(model, "validate_step"): - self._validate_step = model.validate_step - self._validate_signature_fn = None - elif hasattr(model, "test_step"): - self._validate_step = model.test_step - self._validate_signature_fn = None - else: - self._validate_step = model - self._validate_signature_fn = model.forward - - if hasattr(model, "test_step"): - self._test_step = model.test_step - self._test_signature_fn = None - elif hasattr(model, "validate_step"): - self._test_step = model.validate_step - self._test_signature_fn = None - else: - self._test_step = model - self._test_signature_fn = model.forward - def forward(self, batch, **kwargs) -> Dict: """ pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; """ - - forward_state = kwargs.pop(_MODE_PARAMETER) + fn = kwargs.pop("fastnlp_fn") + signature_fn = kwargs.pop("fastnlp_signature_fn") wo_auto_param_call = kwargs.pop("wo_auto_param_call") - if forward_state == ForwardState.TRAIN: - if isinstance(batch, Dict) and not wo_auto_param_call: - return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) - else: - return self._train_step(batch) - elif forward_state == ForwardState.VALIDATE: - if isinstance(batch, Dict) and not wo_auto_param_call: - return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) - else: - return self._validate_step(batch) - elif forward_state == ForwardState.TEST: - if isinstance(batch, Dict) and not wo_auto_param_call: - return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) - else: - return self._test_step(batch) - elif forward_state == ForwardState.PREDICT: - raise NotImplementedError("'PREDICT' mode has not been implemented.") + if isinstance(batch, Dict) and not wo_auto_param_call: + return auto_param_call(fn, batch, signature_fn=signature_fn) else: - raise NotImplementedError("You should direct a concrete mode.") + return fn(batch) class DummyGradScaler: diff --git a/fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py b/fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py index 59fde526..2f4526ac 100644 --- a/fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py +++ b/fastNLP/core/drivers/torch_paddle_driver/torch_paddle_driver.py @@ -55,8 +55,8 @@ class TorchPaddleDriver(Driver): self._train_step = self.model self._train_signature_fn = self.model.forward - if hasattr(self.model, "validate_step"): - self._validate_step = self.model.validate_step + if hasattr(self.model, "evaluate_step"): + self._validate_step = self.model.evaluate_step self._validate_signature_fn = None elif hasattr(self.model, "test_step"): self._validate_step = self.model.test_step @@ -68,8 +68,8 @@ class TorchPaddleDriver(Driver): if hasattr(self.model, "test_step"): self._test_step = self.model.test_step self._test_signature_fn = None - elif hasattr(self.model, "validate_step"): - self._test_step = self.model.validate_step + elif hasattr(self.model, "evaluate_step"): + self._test_step = self.model.evaluate_step self._test_signature_fn = self.model.forward else: self._test_step = self.model @@ -81,7 +81,7 @@ class TorchPaddleDriver(Driver): self.model.to(self.model_device) @staticmethod - def _check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): + def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): if is_train: if not isinstance(dataloader, (TorchDataLoader, PaddleDataLoader)): raise ValueError(f"Parameter `{dataloader_name}` should be 'torch.util.data.DataLoader' or `paddle.io.dataloader` type, not {type(dataloader)}.") diff --git a/fastNLP/core/log/logger.py b/fastNLP/core/log/logger.py index 004bfb16..086089ea 100644 --- a/fastNLP/core/log/logger.py +++ b/fastNLP/core/log/logger.py @@ -211,9 +211,9 @@ def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]] raise TypeError("Parameter `remove_other_handlers` can only be `bool` type.") if not isinstance(mode, str): - raise TypeError("Parameter 'mode' can only be `str` type.") + raise TypeError("Parameter 'evaluate_fn' can only be `str` type.") if mode not in {"w", "a"}: - raise ValueError("Parameter `mode` can only be one of these values: ('w', 'a').") + raise ValueError("Parameter `evaluate_fn` can only be one of these values: ('w', 'a').") for h in _logger.handlers: if isinstance(h, logging.FileHandler): @@ -230,7 +230,7 @@ def _add_file_handler(_logger: logging.Logger, path: Optional[Union[str, Path]] dirname = os.path.abspath(os.path.dirname(path)) os.makedirs(dirname, exist_ok=True) - # 这里只要检测到是分布式训练,我们就将 mode 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新 + # 这里只要检测到是分布式训练,我们就将 evaluate_fn 改为 "a";这样会导致的一个问题在于,如果第二次训练也是分布式训练,logger记录的log不会重新 # 覆盖掉原文件,而是会接着上一次的 log 继续添加; # 这样做主要是为了解决这样的情形所导致的问题:在分布式训练中,进程 1 比 进程 0 先运行到这里,然后使得进程 0 将进程 1 的 log 覆盖掉; if is_cur_env_distributed():# and int(os.environ.get(FASTNLP_GLOBAL_RANK, 0)) != 0: diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index fe0a3582..98987181 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -124,7 +124,7 @@ def test_model_checkpoint_callback_1( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, @@ -204,7 +204,7 @@ def test_model_checkpoint_callback_1( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, @@ -264,7 +264,7 @@ def test_model_checkpoint_callback_2( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, @@ -302,7 +302,7 @@ def test_model_checkpoint_callback_2( device=4, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, @@ -370,7 +370,7 @@ def test_trainer_checkpoint_callback_1( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, @@ -448,7 +448,7 @@ def test_trainer_checkpoint_callback_1( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, @@ -626,7 +626,7 @@ def test_trainer_checkpoint_callback_2( train_dataloader=test_bert_dataloader_train, optimizers=test_bert_optimizers, - validate_dataloaders=test_bert_dataloader_validate, + evaluate_dataloaders=test_bert_dataloader_validate, input_mapping=bert_input_mapping, output_mapping=bert_output_mapping, metrics={"acc": acc}, @@ -700,7 +700,7 @@ def test_trainer_checkpoint_callback_2( train_dataloader=test_bert_dataloader_train, optimizers=test_bert_optimizers, - validate_dataloaders=test_bert_dataloader_validate, + evaluate_dataloaders=test_bert_dataloader_validate, input_mapping=bert_input_mapping, output_mapping=bert_output_mapping, metrics={"acc": acc}, diff --git a/tests/core/callbacks/test_load_best_model_callback_torch.py b/tests/core/callbacks/test_load_best_model_callback_torch.py index 1d82361f..91ddc2da 100644 --- a/tests/core/callbacks/test_load_best_model_callback_torch.py +++ b/tests/core/callbacks/test_load_best_model_callback_torch.py @@ -92,7 +92,7 @@ def test_load_best_model_callback( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=lambda output: output if ('loss' in output) else {'pred':output['preds'], 'target': output['target']}, metrics=model_and_optimizers.metrics, diff --git a/tests/core/controllers/_test_distributed_launch_torch_1.py b/tests/core/controllers/_test_distributed_launch_torch_1.py index 56261922..f9b3312c 100644 --- a/tests/core/controllers/_test_distributed_launch_torch_1.py +++ b/tests/core/controllers/_test_distributed_launch_torch_1.py @@ -89,7 +89,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( device=None, optimizers=optimizers, train_dataloader=train_dataloader, - validate_dataloaders=validate_dataloaders, + evaluate_dataloaders=validate_dataloaders, metrics=metrics, n_epochs=2, diff --git a/tests/core/controllers/_test_distributed_launch_torch_2.py b/tests/core/controllers/_test_distributed_launch_torch_2.py index 13d88248..c61b6d48 100644 --- a/tests/core/controllers/_test_distributed_launch_torch_2.py +++ b/tests/core/controllers/_test_distributed_launch_torch_2.py @@ -77,7 +77,7 @@ def _test_trainer_torch_with_evaluator_fp16_accumulation_steps( device=None, optimizers=optimizers, train_dataloader=train_dataloader, - validate_dataloaders=validate_dataloaders, + evaluate_dataloaders=validate_dataloaders, metrics=metrics, n_epochs=2, diff --git a/tests/core/controllers/test_trainer_event_trigger.py b/tests/core/controllers/test_trainer_event_trigger.py index 6ee0054f..2a3c60dc 100644 --- a/tests/core/controllers/test_trainer_event_trigger.py +++ b/tests/core/controllers/test_trainer_event_trigger.py @@ -82,7 +82,7 @@ def test_trainer_event_trigger( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, diff --git a/tests/core/controllers/test_trainer_fleet.py b/tests/core/controllers/test_trainer_fleet.py index a294ad1f..46201c67 100644 --- a/tests/core/controllers/test_trainer_fleet.py +++ b/tests/core/controllers/test_trainer_fleet.py @@ -64,8 +64,8 @@ def test_trainer_fleet( device=device, optimizers=optimizers, train_dataloader=train_dataloader, - validate_dataloaders=validate_dataloaders, - validate_every=validate_every, + evaluate_dataloaders=validate_dataloaders, + evaluate_every=validate_every, input_mapping=None, output_mapping=None, metrics=metrics, diff --git a/tests/core/controllers/test_trainer_fleet_outside.py b/tests/core/controllers/test_trainer_fleet_outside.py index d461e211..9f58d599 100644 --- a/tests/core/controllers/test_trainer_fleet_outside.py +++ b/tests/core/controllers/test_trainer_fleet_outside.py @@ -70,8 +70,8 @@ def test_trainer_fleet( device=device, optimizers=optimizers, train_dataloader=train_dataloader, - validate_dataloaders=validate_dataloaders, - validate_every=validate_every, + evaluate_dataloaders=validate_dataloaders, + evaluate_every=validate_every, input_mapping=None, output_mapping=None, metrics=metrics, diff --git a/tests/core/controllers/test_trainer_paddle.py b/tests/core/controllers/test_trainer_paddle.py index ed102c99..0f8657b2 100644 --- a/tests/core/controllers/test_trainer_paddle.py +++ b/tests/core/controllers/test_trainer_paddle.py @@ -68,13 +68,13 @@ class TrainerParameters: # shuffle=True # ) # val_dataloader = DataLoader( -# dataset=PaddleDataset_MNIST(mode="test"), +# dataset=PaddleDataset_MNIST(evaluate_fn="test"), # batch_size=MNISTTrainPaddleConfig.batch_size, # shuffle=True # ) # trainer_params.train_dataloader = train_dataloader -# trainer_params.validate_dataloaders = val_dataloader -# trainer_params.validate_every = MNISTTrainPaddleConfig.validate_every +# trainer_params.evaluate_dataloaders = val_dataloader +# trainer_params.evaluate_every = MNISTTrainPaddleConfig.evaluate_every # trainer_params.metrics = {"acc": Accuracy()} # return trainer_params @@ -121,8 +121,8 @@ def test_trainer_paddle( device=device, optimizers=trainer_params.optimizers, train_dataloader=trainer_params.train_dataloader, - validate_dataloaders=trainer_params.validate_dataloaders, - validate_every=trainer_params.validate_every, + evaluate_dataloaders=trainer_params.validate_dataloaders, + evaluate_every=trainer_params.validate_every, input_mapping=trainer_params.input_mapping, output_mapping=trainer_params.output_mapping, metrics=trainer_params.metrics, @@ -139,8 +139,8 @@ def test_trainer_paddle( device=device, optimizers=trainer_params.optimizers, train_dataloader=trainer_params.train_dataloader, - validate_dataloaders=trainer_params.validate_dataloaders, - validate_every=trainer_params.validate_every, + evaluate_dataloaders=trainer_params.validate_dataloaders, + evaluate_every=trainer_params.validate_every, input_mapping=trainer_params.input_mapping, output_mapping=trainer_params.output_mapping, metrics=trainer_params.metrics, diff --git a/tests/core/controllers/test_trainer_w_evaluator_torch.py b/tests/core/controllers/test_trainer_w_evaluator_torch.py index 70d03f8c..2f7b522c 100644 --- a/tests/core/controllers/test_trainer_w_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_w_evaluator_torch.py @@ -98,16 +98,16 @@ def model_and_optimizers(request): # 测试一下普通的情况; -@pytest.mark.parametrize("driver,device", [("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) +@pytest.mark.parametrize("driver,device", [("torch", "cpu"), ("torch", 1), ("torch", [0, 1])]) # ("torch", "cpu"), ("torch", 1), ("torch", [0, 1]) @pytest.mark.parametrize("callbacks", [[RecordMetricCallback(monitor="acc", metric_threshold=0.2, larger_better=True)]]) -@pytest.mark.parametrize("validate_every", [-3]) +@pytest.mark.parametrize("evaluate_every", [-3, -1, 100]) @magic_argv_env_context def test_trainer_torch_with_evaluator( model_and_optimizers: TrainerParameters, driver, device, callbacks, - validate_every, + evaluate_every, n_epochs=10, ): trainer = Trainer( @@ -116,11 +116,11 @@ def test_trainer_torch_with_evaluator( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, - validate_every=validate_every, + evaluate_every=evaluate_every, n_epochs=n_epochs, callbacks=callbacks, @@ -152,7 +152,7 @@ def test_trainer_torch_with_evaluator_fp16_accumulation_steps( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, @@ -193,14 +193,14 @@ def test_trainer_validate_every( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, n_epochs=n_epochs, output_from_new_proc="all", - validate_every=validate_every + evaluate_every=validate_every ) trainer.run() diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 82fa3af0..8aa76eb2 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -91,7 +91,7 @@ def test_trainer_torch_without_evaluator( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, @@ -126,7 +126,7 @@ def test_trainer_torch_without_evaluator_fp16_accumulation_steps( device=device, optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, @@ -163,7 +163,7 @@ def test_trainer_torch_without_evaluator_accumulation_steps( optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, @@ -202,7 +202,7 @@ def test_trainer_output_from_new_proc( optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, @@ -267,7 +267,7 @@ def test_trainer_on_exception( optimizers=model_and_optimizers.optimizers, train_dataloader=model_and_optimizers.train_dataloader, - validate_dataloaders=model_and_optimizers.validate_dataloaders, + evaluate_dataloaders=model_and_optimizers.validate_dataloaders, input_mapping=model_and_optimizers.input_mapping, output_mapping=model_and_optimizers.output_mapping, metrics=model_and_optimizers.metrics, diff --git a/tests/core/drivers/paddle_driver/test_single_device.py b/tests/core/drivers/paddle_driver/test_single_device.py index b9681121..ec5bb846 100644 --- a/tests/core/drivers/paddle_driver/test_single_device.py +++ b/tests/core/drivers/paddle_driver/test_single_device.py @@ -401,12 +401,12 @@ class TestPaddleDriverFunctions: 测试is_train参数为True时,_check_dataloader_legality函数的表现 """ dataloader = paddle.io.DataLoader(PaddleNormalDataset()) - PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) # batch_size 和 batch_sampler 均为 None 的情形 dataloader = paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) with pytest.raises(ValueError): - PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) # 创建torch的dataloader dataloader = torch.utils.data.DataLoader( @@ -414,7 +414,7 @@ class TestPaddleDriverFunctions: batch_size=32, shuffle=True ) with pytest.raises(ValueError): - PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", True) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", True) def test_check_dataloader_legality_in_test(self): """ @@ -425,7 +425,7 @@ class TestPaddleDriverFunctions: "train": paddle.io.DataLoader(PaddleNormalDataset()), "test":paddle.io.DataLoader(PaddleNormalDataset()) } - PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) # batch_size 和 batch_sampler 均为 None 的情形 dataloader = { @@ -433,12 +433,12 @@ class TestPaddleDriverFunctions: "test":paddle.io.DataLoader(PaddleNormalDataset(), batch_size=None) } with pytest.raises(ValueError): - PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) # 传入的不是dict,应该报错 dataloader = paddle.io.DataLoader(PaddleNormalDataset()) with pytest.raises(ValueError): - PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) # 创建torch的dataloader train_loader = torch.utils.data.DataLoader( @@ -451,7 +451,7 @@ class TestPaddleDriverFunctions: ) dataloader = {"train": train_loader, "test": test_loader} with pytest.raises(ValueError): - PaddleSingleDriver._check_dataloader_legality(dataloader, "dataloader", False) + PaddleSingleDriver.check_dataloader_legality(dataloader, "dataloader", False) def test_tensor_to_numeric(self): """ diff --git a/tests/helpers/models/torch_model.py b/tests/helpers/models/torch_model.py index b949a26f..236ffda5 100644 --- a/tests/helpers/models/torch_model.py +++ b/tests/helpers/models/torch_model.py @@ -28,7 +28,7 @@ class TorchNormalModel_Classification_1(nn.Module): x = self(x) return {"loss": self.loss_fn(x, y)} - def validate_step(self, x, y): + def evaluate_step(self, x, y): """ 如果不加参数 y,那么应该在 trainer 中设置 output_mapping = {"y": "target"}; """ From a4b2e0fac57aff040b0cfe2f18487d72d8c9971a Mon Sep 17 00:00:00 2001 From: yh_cc Date: Fri, 15 Apr 2022 00:01:29 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E8=8B=A5=E5=B9=B2bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/callbacks/callback.py | 21 +++++- fastNLP/core/callbacks/checkpoint_callback.py | 68 +++++++++++++------ fastNLP/core/callbacks/early_stop_callback.py | 14 ++-- .../callbacks/load_best_model_callback.py | 5 +- fastNLP/core/callbacks/progress_callback.py | 14 ++-- fastNLP/core/controllers/evaluator.py | 9 ++- .../controllers/loops/train_batch_loop.py | 6 +- fastNLP/core/controllers/trainer.py | 10 ++- fastNLP/core/controllers/utils/state.py | 6 +- fastNLP/core/drivers/driver.py | 13 +--- fastNLP/core/drivers/torch_driver/ddp.py | 2 +- .../core/drivers/torch_driver/torch_driver.py | 1 + tests/core/utils/test_utils.py | 1 + tests/helpers/datasets/normal_data.py | 16 ++++- 14 files changed, 121 insertions(+), 65 deletions(-) diff --git a/fastNLP/core/callbacks/callback.py b/fastNLP/core/callbacks/callback.py index 902421c8..b37eda63 100644 --- a/fastNLP/core/callbacks/callback.py +++ b/fastNLP/core/callbacks/callback.py @@ -390,4 +390,23 @@ class HasMonitorCallback(Callback): if (self.larger_better and monitor_value1 > monitor_value2) or \ (not self.larger_better and monitor_value1 < monitor_value2): better = True - return better \ No newline at end of file + return better + + @property + def monitor_name(self): + """ + 返回 monitor 的名字,如果 monitor 是个 callable 的函数,则返回该函数的名称。 + + :return: + """ + if callable(self.monitor): + try: + monitor_name = self.monitor.__qualname__ + except: + monitor_name = self.monitor.__name__ + elif self.monitor is None: + return None + else: + # 这里是能是monitor,而不能是real_monitor,因为用户再次运行的时候real_monitor被初始化为monitor了 + monitor_name = str(self.monitor) + return monitor_name diff --git a/fastNLP/core/callbacks/checkpoint_callback.py b/fastNLP/core/callbacks/checkpoint_callback.py index 7bbdb2fe..d2d97294 100644 --- a/fastNLP/core/callbacks/checkpoint_callback.py +++ b/fastNLP/core/callbacks/checkpoint_callback.py @@ -19,11 +19,11 @@ from fastNLP.core.utils import synchronize_safe_rm, synchronize_mkdir class CheckpointCallback(HasMonitorCallback): def __init__( self, - monitor, + monitor:Optional[Union[str, Callable]]=None, save_folder: Optional[Union[str, Path]] = None, save_every_n_epochs: Optional[int] = None, save_every_n_batches: Optional[int] = None, - save_last: bool = True, + save_last: bool = False, save_topk: Optional[int] = None, save_on_exception: Optional[Union[BaseException, Sequence[BaseException]]] = None, larger_better: bool = True, @@ -31,12 +31,32 @@ class CheckpointCallback(HasMonitorCallback): model_save_fn: Optional[Callable] = None, **kwargs, ): + """ + 请使用 ModelCheckpointCallback 与 TrainerCheckpointCallback 。 + + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 + :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 + 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 + :param save_every_n_epochs: 多少个 epoch 保存一次。 + :param save_every_n_batches: 多少个 batch 保存一次。 + :param save_last: 如果为 True ,将在每次 epoch 运行结束都保存一次,会覆盖之前的保存。 + :param save_topk: 保存 monitor 结果 topK 个。 + :param save_on_exception: 在出异常信息时,是否保存。传入需要捕获的异常的类。 + :param larger_better: monitor 的值是否时越大越好。 + :param only_state_dict: 保存模型时是否只保存 state_dict 。当 model_save_fn 不为 None 时,该参数无效。 + :param model_save_fn: 个性化的保存函数,当触发保存操作时,就调用这个函数,这个函数应当接受一个文件夹作为参数,不返回任何东西。 + 如果传入了 model_save_fn 函数,fastNLP 将不再进行模型相关的保存。在多卡场景下,我们只在 rank 0 上会运行该函数。 + :param kwargs: + """ super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=save_topk is not None) if save_folder is None: logger.warning( "Parameter `path` is None, and we will use the current work directory to find and load your model.") save_folder = Path.cwd() + save_folder = Path(save_folder) if not save_folder.exists(): raise NotADirectoryError(f"Path '{save_folder.absolute()}' is not existed!") elif save_folder.is_file(): @@ -71,7 +91,7 @@ class CheckpointCallback(HasMonitorCallback): else: save_on_exception = [] - self.save_folder = Path(save_folder) + self.save_folder = save_folder self.save_every_n_epochs = save_every_n_epochs self.save_every_n_batches = save_every_n_batches self.save_last = save_last @@ -88,8 +108,7 @@ class CheckpointCallback(HasMonitorCallback): # 注意这里应当保证只有进程 0 在执行这个操作,因为当用户使用 python -m torch.distributed.launch 来拉起进程的时候, # FASTNLP_LAUNCH_TIME 在每一个进程上的值是不一样的; self.timestamp_path = self.save_folder.joinpath(os.environ[FASTNLP_LAUNCH_TIME]) - # 我们只需要保证这个创建文件夹的操作只在进程 0 上进行即可;因为后续的实际的保存操作,其它进程实际并不会去执行; - synchronize_mkdir(self.timestamp_path) + # 该 folder 只在保存真的要发生的时候再创建。 def on_after_trainer_initialized(self, trainer, driver): if self.save_topk is not None: @@ -98,8 +117,6 @@ class CheckpointCallback(HasMonitorCallback): logger.warning("You set `save_topk`, but `evaluate_dataloaders` is not set in Trainer.") def on_validate_end(self, trainer, results): - if len(results) == 0: - return self._save_topk(trainer, results) def on_train_epoch_end(self, trainer: "fastNLP.Trainer"): @@ -136,16 +153,17 @@ class CheckpointCallback(HasMonitorCallback): states['timestamp_path'] = str(self.timestamp_path.absolute()) states['_topk_model'] = deepcopy(self._topk_model) states['save_topk'] = 0 if self.save_topk is None else self.save_topk - states['_real_monitor'] = self._real_monitor + if isinstance(self._real_monitor, str): + states['_real_monitor'] = self._real_monitor return states def on_load_checkpoint(self, trainer, states: Optional[Dict]): timestamp_path = states['timestamp_path'] if not os.path.exists(timestamp_path): - logger.info(f"The resuming save folder {timestamp_path} is not exists, will checkpoint save to " + logger.info(f"The resuming checkpoint folder {timestamp_path} is not exists, will checkpoint save to " f" {self.timestamp_path.absolute()}.") else: - logger.info(f"Resume to save in path: {timestamp_path}.") + logger.info(f"Resume to checkpoint in path: {timestamp_path}.") self.timestamp_path = Path(timestamp_path) _topk_model = states['_topk_model'] save_topk = None if int(states['save_topk']) == 0 else int(states['save_topk']) @@ -153,7 +171,8 @@ class CheckpointCallback(HasMonitorCallback): assert self.save_topk == save_topk, f"The checkpoint set save_topk={save_topk}, while this callback set it " \ f"as {save_topk}." self._topk_model.update(self._topk_model) - self._real_monitor = states["real_monitor"] + + self._real_monitor = states["_real_monitor"] def _save_topk(self, trainer: "fastNLP.Trainer", results: Dict): """ @@ -231,9 +250,9 @@ class ModelCheckpointCallback(CheckpointCallback): model_save_fn 为 None ,则以上每个 folder 中,将生成 fastnlp_model.pkl.tar 文件。 若 model_save_fn 不为 None,则 fastNLP 将 folder 绝对路径传递给该函数,fastNLP 不在该 folder 下创建任何文件。 - :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), - 返回一个 float 值作为 monitor 的结果。 + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 @@ -249,6 +268,11 @@ class ModelCheckpointCallback(CheckpointCallback): """ @property def save_fn_name(self): + """ + 调用 Trainer 中的哪个函数。 + + :return: + """ return 'save_model' @property @@ -257,7 +281,7 @@ class ModelCheckpointCallback(CheckpointCallback): 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; :return: """ - return f"model_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" + return f"model_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" @property def folder_prefix(self): @@ -279,9 +303,9 @@ class TrainerCheckpointCallback(CheckpointCallback): model_save_fn 为 None ,则以上每个 folder 中,将生成两个文件:fastnlp_trainer.pkl.tar 以及 fastnlp_model.pkl.tar 。 若 model_save_fn 不为 None,则 fastNLP 只会在每个 folder 下生成 fastnlp_trainer.pkl.tar 文件。 - :param monitor: 监控的 metric 的名称。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 - 的那个作为 monitor 。如果为 None 将尝试从 Trainer 中获取该值。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型), - 返回一个 float 值作为 monitor 的结果。 + :param monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 :param save_folder: 保存的文件夹,fastNLP 将在该文件下以时间戳创建子文件夹,并在里面保存。因此不同次运行可以将被保存到不同的 时间戳文件夹中。如果为 None ,默认使用当前文件夹。 :param save_every_n_epochs: 多少个 epoch 保存一次。 @@ -297,6 +321,11 @@ class TrainerCheckpointCallback(CheckpointCallback): """ @property def save_fn_name(self): + """ + 调用 Trainer 中的哪个函数。 + + :return: + """ return 'save' @property @@ -305,7 +334,8 @@ class TrainerCheckpointCallback(CheckpointCallback): 通过该值决定两个 CheckpointCallback 实例是否可以共用断点重训的状态; :return: """ - return f"trainer_checkpoint#monitor-{self.monitor}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" + + return f"trainer_checkpoint#monitor-{self.monitor_name}#topK-{self.save_topk}#only_state_dict-{self.only_state_dict}" @property def folder_prefix(self): diff --git a/fastNLP/core/callbacks/early_stop_callback.py b/fastNLP/core/callbacks/early_stop_callback.py index b1842d43..c679ad7e 100644 --- a/fastNLP/core/callbacks/early_stop_callback.py +++ b/fastNLP/core/callbacks/early_stop_callback.py @@ -12,8 +12,9 @@ class EarlyStopCallback(HasMonitorCallback): def __init__(self, monitor:Union[str, Callable]=None, larger_better:bool=True, patience:int=10): """ - :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 - evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 + :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: monitor 的值是否是越大越好。 :param patience: 多少次 validate 不没有提升就停止。 """ @@ -46,17 +47,20 @@ class EarlyStopCallback(HasMonitorCallback): states = { 'patience': self.patience, 'wait': self.wait, - 'monitor': self.monitor, 'monitor_value': self.monitor_value } + if not callable(self._real_monitor): + states['_real_monitor'] = self._real_monitor return states def on_load_checkpoint(self, trainer, states): self.patience = states['patience'] self.wait = states['wait'] - self.monitor = states['monitor'] self.monitor_value = float(states['monitor_value']) + if '_real_monitor' in states: + self._real_monitor = states['_real_monitor'] + @property def callback_name(self): - return f'EarlyStopCallback#monitor-{self.monitor}#patience-{self.patience}' + return f'EarlyStopCallback#monitor-{self.monitor_name}#patience-{self.patience}' diff --git a/fastNLP/core/callbacks/load_best_model_callback.py b/fastNLP/core/callbacks/load_best_model_callback.py index e068326b..09f85d01 100644 --- a/fastNLP/core/callbacks/load_best_model_callback.py +++ b/fastNLP/core/callbacks/load_best_model_callback.py @@ -21,8 +21,9 @@ class LoadBestModelCallback(HasMonitorCallback): """ 保存最佳的 monitor 值最佳的模型,并在训练结束的时候重新加载模型。仅在训练正常结束的时候才能加载最好的模型。 - :param str monitor: 监控的 metric 值。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 - evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 + :param str monitor: 监控的 metric 值。如果在 evaluation 结果中没有找到完全一致的名称,将使用 最短公共字符串算法 找到最匹配 + 的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor 。也可以传入一个函数,接受参数为 evaluation 的结 + 果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: 该 metric 值是否是越大越好。 :param save_folder: 保存的文件夹,如果为空,则保存在内存中。不为空,则保存一份权重到文件中,当为多机训练,且本值不为空时,请确保 不同的机器均可访问当该路径。当 model_save_fn 不为 None 时该值一定不能为空。 diff --git a/fastNLP/core/callbacks/progress_callback.py b/fastNLP/core/callbacks/progress_callback.py index 67176387..f351f204 100644 --- a/fastNLP/core/callbacks/progress_callback.py +++ b/fastNLP/core/callbacks/progress_callback.py @@ -44,10 +44,11 @@ class RichCallback(ProgressCallback): :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 - :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。如果为 None ,会尝试使用 trainer 中设置的 monitor 。 - 也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 - :param larger_better: 是否是monitor的结果越大越好。 - :param format_json: 是否format json再打印 + :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 + 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor + 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 + :param larger_better: 是否是 monitor 的结果越大越好。 + :param format_json: 是否格式化 json 再打印 """ super().__init__(monitor=monitor, larger_better=larger_better, must_have_monitor=False) self.print_every = print_every @@ -136,8 +137,9 @@ class RawTextCallback(ProgressCallback): :param print_every: 多少个 batch 更新一次显示。 :param loss_round_ndigit: 显示的 loss 保留多少位有效数字 - :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。也可以传入一个函数,接受参数为 evaluation 的结果( - 字典类型),返回一个 float 值作为 monitor 的结果。 + :param monitor: 当检测到这个key的结果更好时,会打印出不同的颜色进行提示。监控的 metric 值。如果在 evaluation 结果中没有找到 + 完全一致的名称,将使用 最短公共字符串算法 找到最匹配的那个作为 monitor 。如果为 None,将尝试使用 Trainer 设置的 monitor + 。也可以传入一个函数,接受参数为 evaluation 的结果(字典类型),返回一个 float 值作为 monitor 的结果。 :param larger_better: 是否是monitor的结果越大越好。 :param format_json: 是否format json再打印 """ diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 3013c316..d447a0f2 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -36,7 +36,7 @@ class Evaluator: model, dataloaders, metrics: Optional[Union[Dict, Metric]] = None, - driver: Union[str, Driver] = 'single', + driver: Union[str, Driver] = 'torch', device: Optional[Union[int, List[int], str]] = None, batch_step_fn: Optional[callable] = None, evaluate_fn: Optional[str] = None, # 首先尝试找 evaluate_step, 找不到 forward, callable @@ -49,8 +49,8 @@ class Evaluator: ): """ - :param dataloaders: :param model: + :param dataloaders: :param metrics: 使用的 metric 。必须为 dict 类型,其中 key 为 metric 的名称,value 为一个 Metric 对象。支持 fastNLP 的 metric ,torchmetrics,allennlpmetrics等。 :param driver: 使用 driver 。 @@ -120,7 +120,8 @@ class Evaluator: if evaluate_fn is not None and not isinstance(evaluate_fn, str): raise TypeError("Parameter `train_fn` can only be `str` type when it is not None.") - self._evaluate_step, self._evaluate_step_signature_fn = self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) + self._evaluate_step, self._evaluate_step_signature_fn = \ + self.driver.get_model_call_fn("evaluate_step" if evaluate_fn is None else evaluate_fn) self.evaluate_fn = evaluate_fn self.dataloaders = {} @@ -134,8 +135,6 @@ class Evaluator: self.driver.barrier() - self.driver.check_dataloader_legality(self.dataloaders, "dataloaders", is_train=False) - def run(self, num_eval_batch_per_dl: int = -1, **kwargs) -> Dict: """ 返回一个字典类型的数据,其中key为metric的名字,value为对应metric的结果。 diff --git a/fastNLP/core/controllers/loops/train_batch_loop.py b/fastNLP/core/controllers/loops/train_batch_loop.py index a3219e6d..7dbe9775 100644 --- a/fastNLP/core/controllers/loops/train_batch_loop.py +++ b/fastNLP/core/controllers/loops/train_batch_loop.py @@ -20,7 +20,7 @@ class TrainBatchLoop(Loop): else lambda *args, **kwargs: None dataloader = iter(dataloader) indices = None - while True: + while trainer.batch_idx_in_epoch<=trainer.num_batches_per_epoch: try: trainer.on_fetch_data_begin() batch = next(dataloader) @@ -30,10 +30,8 @@ class TrainBatchLoop(Loop): batch = trainer.move_data_to_device(batch) except StopIteration: break - except EarlyStopException: # 在 Trainer 处理 earlystop 的 exception - break except BaseException as e: - if indices: + if indices and not isinstance(e, EarlyStopException): logger.debug(f"The following exception happens when running on samples: {indices}") raise e diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index 5154c580..b4695c00 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -174,9 +174,8 @@ class Trainer(TrainerEventTrigger): optimizers=optimizers, device=device, n_epochs=n_epochs, - validate_dataloaders=evaluate_dataloaders, batch_step_fn=batch_step_fn, - validate_batch_step_fn=evaluate_batch_step_fn, + z=evaluate_batch_step_fn, evaluate_fn=evaluate_fn, callbacks=callbacks, metrics=metrics, @@ -264,7 +263,6 @@ class Trainer(TrainerEventTrigger): self.on_after_trainer_initialized(self.driver) self.driver.barrier() - self.driver.check_dataloader_legality(self.train_dataloader, "train_dataloader", is_train=True) def run(self, num_train_batch_per_epoch: int = -1, num_eval_batch_per_dl: int = -1, num_eval_sanity_batch: int = 2, resume_from: str = None, resume_training: bool = True, @@ -310,7 +308,7 @@ class Trainer(TrainerEventTrigger): self.num_batches_per_epoch = len(self.dataloader) self.total_batches = self.num_batches_per_epoch * self.n_epochs - + self.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + self.batch_idx_in_epoch self.on_train_begin() self.driver.barrier() self.driver.zero_grad(self.set_grad_to_none) @@ -637,6 +635,8 @@ class Trainer(TrainerEventTrigger): :param folder: 保存断点重训 states 的文件地址; :param resume_training: 是否从上次的 batch 开始训练,或者只从最近的 epoch 开始训练;注意如果 resume_training=True,那么我们 只会加载 model 和 optimizers 的状态;而其余的对象的值则根据用户的 Trainer 的初始化直接重置; + :param only_state_dict: 保存的 model 是否只包含了权重。 + :param model_load_fn: 使用的模型加载函数,参数应为一个 文件夹,不返回任何内容。 """ self.driver.barrier() if isinstance(folder, str): @@ -675,8 +675,6 @@ class Trainer(TrainerEventTrigger): # 这里的原则就是应当使得 '还会产生的batch数量' + 'batch_idx_in_epoch' = '原来不断点训练的batch的总数'。其中由于 # '还会产生的batch数量' 是由还剩多少 sample 决定的,因此只能通过调整 'batch_idx_in_epoch' 使得等式成立 self.trainer_state.batch_idx_in_epoch = states.pop('batch_idx_in_epoch') - self.trainer_state.global_forward_batches = self.num_batches_per_epoch * self.cur_epoch_idx + \ - self.batch_idx_in_epoch # 这个是防止用户在 Trainer.load 之后还没结束当前 epoch 又继续 save self.start_batch_idx_in_epoch = self.trainer_state.batch_idx_in_epoch diff --git a/fastNLP/core/controllers/utils/state.py b/fastNLP/core/controllers/utils/state.py index 2327c1e5..496533d2 100644 --- a/fastNLP/core/controllers/utils/state.py +++ b/fastNLP/core/controllers/utils/state.py @@ -65,10 +65,10 @@ class TrainerState: """ n_epochs: Optional[int] = None # 无论如何重新算 - cur_epoch_idx: Optional[int] = None # 断点重训; 仅当 resume=False 时为0; - global_forward_batches: Optional[int] = None # 断点重训 + cur_epoch_idx: Optional[int] = 0 # 断点重训; 仅当 resume=False 时为0; + global_forward_batches: Optional[int] = 0 # 断点重训 - batch_idx_in_epoch: Optional[int] = None # 断点重训 + batch_idx_in_epoch: Optional[int] = 0 # 断点重训 num_batches_per_epoch: Optional[int] = None # 无论如何重新算 diff --git a/fastNLP/core/drivers/driver.py b/fastNLP/core/drivers/driver.py index b1015b47..0ef7f053 100644 --- a/fastNLP/core/drivers/driver.py +++ b/fastNLP/core/drivers/driver.py @@ -86,7 +86,7 @@ class Driver(ABC): 函数; :param batch: 当前的一个 batch 的数据;可以为字典或者其它类型; - :param fn: 由 Trainer 传入的用于网络前向传播一次的函数; + :param fn: 调用该函数进行一次计算。 :param signature_fn: 由 Trainer 传入的用于网络前向传播一次的签名函数,因为当 batch 是一个 Dict 的时候,我们会自动调用 auto_param_call 函数,而一些被包裹的模型需要暴露其真正的函数签名,例如 DistributedDataParallel 的调用函数是 forward,但是需要其函数签名为 model.module.forward; :return: 返回由 `fn` 返回的结果(应当为一个 dict 或者 dataclass,但是不需要我们去检查); @@ -126,17 +126,6 @@ class Driver(ABC): def model(self, model): self._model = model - @staticmethod - def check_dataloader_legality(dataloader, dataloader_name, is_train: bool = False): - r""" - 该函数会在 trainer 或者 evaluator 设置 dataloader 后检测 dataloader 的合法性,因为不同的深度学习的框架需要的 dataloader 的 - 行为是不相同的; - - :param dataloader: 需要检测的输入的 `dataloader`; - :param dataloader_name: - """ - raise NotImplementedError("Each specific driver should implemented its own `check_dataloader_legality` function.") - @property def optimizers(self) -> List: r""" diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 55af3367..a37525f4 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -406,7 +406,7 @@ class TorchDDPDriver(TorchDriver): if hasattr(model, fn): fn = getattr(model, fn) if not callable(fn): - raise RuntimeError(f"The `{fn}` attribute is not `Callable`.") + raise RuntimeError(f"The `{fn}` attribute of model is not `Callable`.") return fn, None elif fn in {"train_step", "evaluate_step"}: return model, model.forward diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index b7aebec8..233d7040 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -199,6 +199,7 @@ class TorchDriver(Driver): num_consumed_batches = sampler_states['num_consumed_samples'] sampler_states['num_consumed_samples'] = num_consumed_samples_array[num_consumed_batches] assert sampler_states['num_consumed_samples'] != -1, "This is a bug, please report." + states['sampler_states'] = sampler_states else: raise RuntimeError( 'The sampler has no `state_dict()` method, it will fail to recover to the specific batch.') diff --git a/tests/core/utils/test_utils.py b/tests/core/utils/test_utils.py index a7aeffb1..556f85ff 100644 --- a/tests/core/utils/test_utils.py +++ b/tests/core/utils/test_utils.py @@ -181,6 +181,7 @@ class TestCheckNumberOfParameters: def test_get_fun_msg(): + # 测试运行 def demo(x): pass diff --git a/tests/helpers/datasets/normal_data.py b/tests/helpers/datasets/normal_data.py index ba1af370..714ec676 100644 --- a/tests/helpers/datasets/normal_data.py +++ b/tests/helpers/datasets/normal_data.py @@ -1,3 +1,6 @@ +import numpy as np + + class NormalIterator: def __init__(self, num_of_data=1000): self._num_of_data = num_of_data @@ -15,4 +18,15 @@ class NormalIterator: return self._data def __len__(self): - return self._num_of_data \ No newline at end of file + return self._num_of_data + + +class RandomDataset: + def __init__(self, num_data=10): + self.data = np.random.rand(num_data) + + def __len__(self): + return len(self.data) + + def __getitem__(self, item): + return self.data[item] \ No newline at end of file