From a5b2ccf7590dc9fb5149a7697ceeae61a4c839b8 Mon Sep 17 00:00:00 2001 From: x54-729 <17307130121@fudan.edu.cn> Date: Tue, 12 Apr 2022 09:36:51 +0000 Subject: [PATCH] update --- fastNLP/core/drivers/paddle_driver/fleet.py | 18 ++++++++---------- .../drivers/paddle_driver/paddle_driver.py | 7 +++++-- .../drivers/paddle_driver/single_device.py | 15 ++++++++++----- fastNLP/core/drivers/paddle_driver/utils.py | 19 ++++++++++--------- 4 files changed, 33 insertions(+), 26 deletions(-) diff --git a/fastNLP/core/drivers/paddle_driver/fleet.py b/fastNLP/core/drivers/paddle_driver/fleet.py index 86198959..582ce542 100644 --- a/fastNLP/core/drivers/paddle_driver/fleet.py +++ b/fastNLP/core/drivers/paddle_driver/fleet.py @@ -104,8 +104,8 @@ class PaddleFleetDriver(PaddleDriver): # 我们就直接将 model_device 置为 None; self._model_device = None - def _running_fn_(batch, step_fn, signature_fn): - if isinstance(batch, Dict): + def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(step_fn, batch, signature_fn=signature_fn) else: return self._validate_step(batch) @@ -116,23 +116,21 @@ class PaddleFleetDriver(PaddleDriver): "Notice your model is a `paddle.DataParallel` model. And your " "model also implements the `train_step` method, which we can not call actually, we will" " call `forward` function instead of `train_step` and you should note that.") - self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) - # self._train_signature_fn = model.forward + self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) if hasattr(model, "validate_step"): logger.warning( "Notice your model is a `paddle.DataParallel` model. And your " "model also implements the `validate_step` method, which we can not call actually, " "we will call `forward` function instead of `validate_step` and you should note that.") - self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) - # self._validate_signature_fn = model.forward + self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) if hasattr(model, "test_step"): logger.warning( "Notice your model is a `paddle.DataParallel` model. And your " "model also implements the `test_step` method, which we can not call actually, we will" " call `forward` function instead of `test_step` and you should note that.") - self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) + self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; self._data_device = kwargs.get("data_device", None) @@ -277,9 +275,9 @@ class PaddleFleetDriver(PaddleDriver): **self._fleet_kwargs ) - self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) - self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) - self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) + self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) + self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) + self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) self._configured = True diff --git a/fastNLP/core/drivers/paddle_driver/paddle_driver.py b/fastNLP/core/drivers/paddle_driver/paddle_driver.py index 89e88aef..a407a7b7 100644 --- a/fastNLP/core/drivers/paddle_driver/paddle_driver.py +++ b/fastNLP/core/drivers/paddle_driver/paddle_driver.py @@ -19,7 +19,7 @@ from fastNLP.envs import ( rank_zero_call, ) from fastNLP.core.log import logger -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler +from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler if _NEED_IMPORT_PADDLE: import paddle @@ -56,6 +56,9 @@ class PaddleDriver(Driver): self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) self.grad_scaler = _grad_scaler() + # 用来设置是否关闭 auto_param_call 中的参数匹配问题; + self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) + def zero_grad(self, set_to_none: bool = False): r""" 实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; @@ -301,7 +304,7 @@ class PaddleDriver(Driver): elif self.is_distributed(): raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") else: - sampler = ReproducibleBatchSampler( + sampler = RandomBatchSampler( batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, batch_size=dataloader_args.batch_size, drop_last=dataloader_args.drop_last diff --git a/fastNLP/core/drivers/paddle_driver/single_device.py b/fastNLP/core/drivers/paddle_driver/single_device.py index 64656124..796f4809 100644 --- a/fastNLP/core/drivers/paddle_driver/single_device.py +++ b/fastNLP/core/drivers/paddle_driver/single_device.py @@ -11,7 +11,12 @@ from fastNLP.core.utils import ( get_paddle_device_id, paddle_move_data_to_device, ) -from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler +from fastNLP.core.samplers import ( + ReproducibleBatchSampler, + RandomBatchSampler, + ReproducibleSampler, + re_instantiate_sampler, +) from fastNLP.core.log import logger if _NEED_IMPORT_PADDLE: @@ -102,7 +107,7 @@ class PaddleSingleDriver(PaddleDriver): def train_step(self, batch) -> Dict: # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; - if isinstance(batch, Dict): + if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) else: return self._train_step(batch) @@ -116,13 +121,13 @@ class PaddleSingleDriver(PaddleDriver): self.grad_scaler.update() def validate_step(self, batch) -> Dict: - if isinstance(batch, Dict): + if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) else: return self._validate_step(batch) def test_step(self, batch) -> Dict: - if isinstance(batch, Dict): + if isinstance(batch, Dict) and not self.wo_auto_param_call: return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) else: return self._test_step(batch) @@ -159,7 +164,7 @@ class PaddleSingleDriver(PaddleDriver): return replace_sampler(dataloader, sampler) if reproducible: - batch_sampler = ReproducibleBatchSampler( + batch_sampler = RandomBatchSampler( batch_sampler=args.batch_sampler, batch_size=args.batch_size, drop_last=args.drop_last diff --git a/fastNLP/core/drivers/paddle_driver/utils.py b/fastNLP/core/drivers/paddle_driver/utils.py index 47c0f1b9..36982b4c 100644 --- a/fastNLP/core/drivers/paddle_driver/utils.py +++ b/fastNLP/core/drivers/paddle_driver/utils.py @@ -85,7 +85,7 @@ class ForwardState(IntEnum): TEST = 2 PREDICT = 3 -_MODE_PARAMETER = "_forward_state" +_MODE_PARAMETER = "forward_state" class _FleetWrappingModel(Layer): """ @@ -151,24 +151,25 @@ class _FleetWrappingModel(Layer): def forward(self, batch, **kwargs) -> Dict: - _forward_state = kwargs.pop(_MODE_PARAMETER) + forward_state = kwargs.pop(_MODE_PARAMETER) + wo_auto_param_call = kwargs.pop("wo_auto_param_call") - if _forward_state == ForwardState.TRAIN: - if isinstance(batch, Dict): + if forward_state == ForwardState.TRAIN: + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) else: return self._train_step(batch) - elif _forward_state == ForwardState.VALIDATE: - if isinstance(batch, Dict): + elif forward_state == ForwardState.VALIDATE: + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) else: return self._validate_step(batch) - elif _forward_state == ForwardState.TEST: - if isinstance(batch, Dict): + elif forward_state == ForwardState.TEST: + if isinstance(batch, Dict) and not wo_auto_param_call: return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) else: return self._test_step(batch) - elif _forward_state == ForwardState.PREDICT: + elif forward_state == ForwardState.PREDICT: raise NotImplementedError("'PREDICT' mode has not been implemented.") else: raise NotImplementedError("You should direct a concrete mode.")