From 5419b6a04295ebd35ab221701757d7b1afeadb9c Mon Sep 17 00:00:00 2001 From: YWMditto Date: Tue, 12 Apr 2022 17:00:07 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A1=AB=E4=BA=86=E4=BA=86=E5=85=B3=E9=97=AD?= =?UTF-8?q?=E5=8F=82=E6=95=B0=E5=8C=B9=E9=85=8D=E7=9A=84=E9=80=BB=E8=BE=91?= =?UTF-8?q?=EF=BC=9B=E6=B7=BB=E5=8A=A0=E4=BA=86=20trainer=20=E4=B8=AD?= =?UTF-8?q?=E8=8E=B7=E5=8F=96=20driver=20=E5=8F=82=E6=95=B0=E7=9A=84?= =?UTF-8?q?=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fastNLP/core/controllers/evaluator.py | 6 ++- fastNLP/core/controllers/trainer.py | 28 ++++++++++--- fastNLP/core/drivers/torch_driver/ddp.py | 20 ++++----- .../drivers/torch_driver/single_device.py | 10 ++--- .../core/drivers/torch_driver/torch_driver.py | 7 +++- fastNLP/core/drivers/torch_driver/utils.py | 17 ++++---- fastNLP/core/utils/utils.py | 2 + .../test_checkpoint_callback_torch.py | 2 +- .../test_trainer_wo_evaluator_torch.py | 41 ++++++++++++++++++- tests/helpers/models/torch_model.py | 27 ++++++++++++ 10 files changed, 125 insertions(+), 35 deletions(-) diff --git a/fastNLP/core/controllers/evaluator.py b/fastNLP/core/controllers/evaluator.py index 865acc89..b193f877 100644 --- a/fastNLP/core/controllers/evaluator.py +++ b/fastNLP/core/controllers/evaluator.py @@ -41,6 +41,7 @@ class Evaluator: mode: str = "validate", input_mapping: Optional[Union[Callable, Dict]] = None, output_mapping: Optional[Union[Callable, Dict]] = None, + model_wo_auto_param_call: bool = False, fp16: Optional[bool] = False, verbose: int = 1, **kwargs @@ -61,6 +62,9 @@ class Evaluator: 没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 + :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; + 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 + 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; :param fp16: 是否使用 fp16 。 :param verbose: 是否打印 evaluate 的结果。 :param kwargs: @@ -83,7 +87,7 @@ class Evaluator: self.model = model self.metrics = metrics - self.driver = choose_driver(model, driver, device, fp16=fp16, **kwargs) + self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) self.device = device self.verbose = verbose diff --git a/fastNLP/core/controllers/trainer.py b/fastNLP/core/controllers/trainer.py index d710f967..a7c38b27 100644 --- a/fastNLP/core/controllers/trainer.py +++ b/fastNLP/core/controllers/trainer.py @@ -47,6 +47,7 @@ class Trainer(TrainerEventTrigger): validate_every: Optional[Union[int, callable]] = -1, input_mapping: Optional[Union[Callable, Dict]] = None, output_mapping: Optional[Union[Callable, Dict]] = None, + model_wo_auto_param_call: bool = False, accumulation_steps: int = 1, fp16: bool = False, marker: Optional[str] = None, @@ -99,7 +100,10 @@ class Trainer(TrainerEventTrigger): :param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; - 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 + 如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; + :param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; + 如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 + 为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; :param fp16: 是否开启混合精度训练;默认为 False; :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; @@ -120,9 +124,7 @@ class Trainer(TrainerEventTrigger): """ - # TODO 是不是可以加一个参数让用户现在关掉参数匹配。 self.marker = marker - self.model = model self.driver_name = driver self.device = device self.fp16 = fp16 @@ -164,6 +166,7 @@ class Trainer(TrainerEventTrigger): validate_every=validate_every, input_mapping=input_mapping, output_mapping=output_mapping, + model_wo_auto_param_call=model_wo_auto_param_call, accumulation_steps=accumulation_steps, fp16=fp16, marker=marker, @@ -484,8 +487,6 @@ class Trainer(TrainerEventTrigger): @driver.setter def driver(self, driver: Driver): - driver.trainer = self - driver.model = self.model self._driver = driver @property @@ -782,4 +783,21 @@ class Trainer(TrainerEventTrigger): def total_batches(self, total_batches: int): self.trainer_state.total_batches = total_batches + """ driver property """ + + @property + def model_device(self): + return self.driver.model_device + + @property + def data_device(self): + return self.driver.data_device + + @property + def model(self): + # 返回 driver 中的 model,注意该 model 可能被分布式的模型包裹,例如 `DistributedDataParallel`; + return self.driver.model + + + diff --git a/fastNLP/core/drivers/torch_driver/ddp.py b/fastNLP/core/drivers/torch_driver/ddp.py index 44cabcf4..4cf207cd 100644 --- a/fastNLP/core/drivers/torch_driver/ddp.py +++ b/fastNLP/core/drivers/torch_driver/ddp.py @@ -167,6 +167,7 @@ class TorchDDPDriver(TorchDriver): 不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; """ + # 在加入很多东西后,需要注意这里调用 super 函数的位置; super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) if isinstance(model, torch.nn.DataParallel): @@ -202,8 +203,8 @@ class TorchDDPDriver(TorchDriver): # 我们就直接将 model_device 置为 None; self.model_device = None - def _running_fn_(batch, step_fn, signature_fn): - if isinstance(batch, Dict): + def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(step_fn, batch, signature_fn=signature_fn) else: return step_fn(batch) @@ -214,7 +215,7 @@ class TorchDDPDriver(TorchDriver): "Notice your model is a `DistributedDataParallel` model. And your " "model also implements the `train_step` method, which we can not call actually, we will" " call `forward` function instead of `train_step` and you should note that.") - self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) + self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) # self._train_signature_fn = model.forward if hasattr(model, "validate_step"): @@ -222,7 +223,7 @@ class TorchDDPDriver(TorchDriver): "Notice your model is a `DistributedDataParallel` model. And your " "model also implements the `validate_step` method, which we can not call actually, " "we will call `forward` function instead of `validate_step` and you should note that.") - self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) + self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) # self._validate_signature_fn = model.forward if hasattr(model, "test_step"): @@ -230,14 +231,11 @@ class TorchDDPDriver(TorchDriver): "Notice your model is a `DistributedDataParallel` model. And your " "model also implements the `test_step` method, which we can not call actually, we will" " call `forward` function instead of `test_step` and you should note that.") - self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) + self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) # self._test_signature_fn = model.forward # 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; self._data_device = kwargs.get("data_device", None) - # if self.outside_ddp and self._data_device is None: - # raise RuntimeError("When you initialize your ddp out of our control, the parameter " - # "`data_device` can not be None.") if isinstance(self._data_device, int): if self._data_device < 0: raise ValueError("Parameter `data_device` can not be smaller than 0.") @@ -349,9 +347,9 @@ class TorchDDPDriver(TorchDriver): **self._ddp_kwargs ) - self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) - self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) - self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) + self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) + self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) + self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) self._configured = True diff --git a/fastNLP/core/drivers/torch_driver/single_device.py b/fastNLP/core/drivers/torch_driver/single_device.py index 19e687b8..8cbb7acd 100644 --- a/fastNLP/core/drivers/torch_driver/single_device.py +++ b/fastNLP/core/drivers/torch_driver/single_device.py @@ -13,7 +13,7 @@ __all__ = [ from .torch_driver import TorchDriver from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler from fastNLP.core.utils import auto_param_call -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler from fastNLP.core.log import logger @@ -102,7 +102,7 @@ class TorchSingleDriver(TorchDriver): def train_step(self, batch) -> Dict: # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; - if isinstance(batch, Dict): + if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) else: return self._train_step(batch) @@ -118,13 +118,13 @@ class TorchSingleDriver(TorchDriver): def validate_step(self, batch) -> Dict: # 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 # 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; - if isinstance(batch, Dict): + if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) else: return self._validate_step(batch) def test_step(self, batch) -> Dict: - if isinstance(batch, Dict): + if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) else: return self._test_step(batch) @@ -148,7 +148,7 @@ class TorchSingleDriver(TorchDriver): return replace_sampler(dataloader, sampler) if reproducible: - batch_sampler = ReproducibleBatchSampler( + batch_sampler = RandomBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, drop_last=args.drop_last diff --git a/fastNLP/core/drivers/torch_driver/torch_driver.py b/fastNLP/core/drivers/torch_driver/torch_driver.py index b200f1fd..d2ffbac1 100644 --- a/fastNLP/core/drivers/torch_driver/torch_driver.py +++ b/fastNLP/core/drivers/torch_driver/torch_driver.py @@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device from fastNLP.envs import rank_zero_call from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler class TorchDriver(Driver): @@ -51,6 +51,9 @@ class TorchDriver(Driver): # 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数; self.non_blocking = kwargs.get("torch_non_blocking", True) + # 用来设置是否关闭 auto_param_call 中的参数匹配问题; + self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) + def zero_grad(self, set_to_none: bool = False): for optimizer in self.optimizers: self._clear_grad(optimizer, set_to_none) @@ -252,7 +255,7 @@ class TorchDriver(Driver): elif self.is_distributed(): raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") else: - sampler = ReproducibleBatchSampler( + sampler = RandomBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last diff --git a/fastNLP/core/drivers/torch_driver/utils.py b/fastNLP/core/drivers/torch_driver/utils.py index 406e030b..4210dac5 100644 --- a/fastNLP/core/drivers/torch_driver/utils.py +++ b/fastNLP/core/drivers/torch_driver/utils.py @@ -140,24 +140,25 @@ class _DDPWrappingModel(Module): pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; """ - _forward_state = kwargs.pop(_MODE_PARAMETER) + forward_state = kwargs.pop(_MODE_PARAMETER) + wo_auto_param_call = kwargs.pop("wo_auto_param_call") - if _forward_state == ForwardState.TRAIN: - if isinstance(batch, Dict): + if forward_state == ForwardState.TRAIN: + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) else: return self._train_step(batch) - elif _forward_state == ForwardState.VALIDATE: - if isinstance(batch, Dict): + elif forward_state == ForwardState.VALIDATE: + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) else: return self._validate_step(batch) - elif _forward_state == ForwardState.TEST: - if isinstance(batch, Dict): + elif forward_state == ForwardState.TEST: + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) else: return self._test_step(batch) - elif _forward_state == ForwardState.PREDICT: + elif forward_state == ForwardState.PREDICT: raise NotImplementedError("'PREDICT' mode has not been implemented.") else: raise NotImplementedError("You should direct a concrete mode.") diff --git a/fastNLP/core/utils/utils.py b/fastNLP/core/utils/utils.py index 0d497bc2..5c497606 100644 --- a/fastNLP/core/utils/utils.py +++ b/fastNLP/core/utils/utils.py @@ -96,6 +96,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None :param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 参数值后,再传给 `fn` 进行实际的运算; :param mapping: 一个字典,用来更改其前面的字典的键值; + :param wo_auto_param_call: 是否关闭默认的参数匹配行为; :return: 返回 `fn` 运行的结果; @@ -113,6 +114,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 """ + if signature_fn is not None: if not callable(signature_fn): raise ValueError(f"Parameter `signature_fn` should be `Callable`.") diff --git a/tests/core/callbacks/test_checkpoint_callback_torch.py b/tests/core/callbacks/test_checkpoint_callback_torch.py index 1f404bb8..557c31b2 100644 --- a/tests/core/callbacks/test_checkpoint_callback_torch.py +++ b/tests/core/callbacks/test_checkpoint_callback_torch.py @@ -10,7 +10,7 @@ import re from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback from fastNLP.core.controllers.trainer import Trainer -from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK +from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK from tests.helpers.utils import magic_argv_env_context from fastNLP.core import synchronize_safe_rm diff --git a/tests/core/controllers/test_trainer_wo_evaluator_torch.py b/tests/core/controllers/test_trainer_wo_evaluator_torch.py index 0a280a0c..0da8c976 100644 --- a/tests/core/controllers/test_trainer_wo_evaluator_torch.py +++ b/tests/core/controllers/test_trainer_wo_evaluator_torch.py @@ -10,7 +10,7 @@ from typing import Any from pathlib import Path from fastNLP.core.controllers.trainer import Trainer -from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 +from tests.helpers.models.torch_model import TorchNormalModel_Classification_1, TorchNormalModel_Classification_3 from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification from tests.helpers.callbacks.helper_callbacks import RecordLossCallback from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch @@ -70,7 +70,7 @@ def model_and_optimizers(request): trainer_params.output_mapping = None # elif request.param == 1: - # model = + return trainer_params @@ -307,10 +307,47 @@ def test_torch_distributed_launch_2(version): subprocess.check_call(command) +@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])]) +@magic_argv_env_context +def test_torch_wo_auto_param_call( + driver, + device, + n_epochs=10, +): + + model = TorchNormalModel_Classification_3( + num_labels=NormalClassificationTrainTorchConfig.num_labels, + feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension + ) + optimizers = SGD(model.parameters(), lr=0.001) + dataset = TorchNormalDataset_Classification( + num_labels=NormalClassificationTrainTorchConfig.num_labels, + feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension, + each_label_data=NormalClassificationTrainTorchConfig.each_label_data, + seed=NormalClassificationTrainTorchConfig.seed + ) + train_dataloader = DataLoader( + dataset=dataset, + batch_size=NormalClassificationTrainTorchConfig.batch_size, + shuffle=True + ) + trainer = Trainer( + model=model, + driver=driver, + device=device, + optimizers=optimizers, + train_dataloader=train_dataloader, + n_epochs=n_epochs, + model_wo_auto_param_call=True, + output_from_new_proc="all" + ) + trainer.run() + if dist.is_initialized(): + dist.destroy_process_group() diff --git a/tests/helpers/models/torch_model.py b/tests/helpers/models/torch_model.py index 2912224f..b949a26f 100644 --- a/tests/helpers/models/torch_model.py +++ b/tests/helpers/models/torch_model.py @@ -37,6 +37,7 @@ class TorchNormalModel_Classification_1(nn.Module): x = torch.max(x, dim=-1)[1] return {"preds": x, "target": y} + class TorchNormalModel_Classification_2(nn.Module): """ 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; @@ -61,5 +62,31 @@ class TorchNormalModel_Classification_2(nn.Module): return {"loss": loss, "preds": x, "target": y} +class TorchNormalModel_Classification_3(nn.Module): + """ + 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; + 关闭 auto_param_call,forward 只有一个 batch 参数; + """ + def __init__(self, num_labels, feature_dimension): + super(TorchNormalModel_Classification_3, self).__init__() + self.num_labels = num_labels + + self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10) + self.ac1 = nn.ReLU() + self.linear2 = nn.Linear(in_features=10, out_features=10) + self.ac2 = nn.ReLU() + self.output = nn.Linear(in_features=10, out_features=num_labels) + self.loss_fn = nn.CrossEntropyLoss() + + def forward(self, batch): + x = batch["x"] + y = batch["y"] + x = self.ac1(self.linear1(x)) + x = self.ac2(self.linear2(x)) + x = self.output(x) + loss = self.loss_fn(x, y) + x = torch.max(x, dim=-1)[1] + return {"loss": loss, "preds": x, "target": y} +