@@ -104,8 +104,8 @@ class PaddleFleetDriver(PaddleDriver): | |||||
# 我们就直接将 model_device 置为 None; | # 我们就直接将 model_device 置为 None; | ||||
self._model_device = None | self._model_device = None | ||||
def _running_fn_(batch, step_fn, signature_fn): | |||||
if isinstance(batch, Dict): | |||||
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | return auto_param_call(step_fn, batch, signature_fn=signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | return self._validate_step(batch) | ||||
@@ -116,23 +116,21 @@ class PaddleFleetDriver(PaddleDriver): | |||||
"Notice your model is a `paddle.DataParallel` model. And your " | "Notice your model is a `paddle.DataParallel` model. And your " | ||||
"model also implements the `train_step` method, which we can not call actually, we will" | "model also implements the `train_step` method, which we can not call actually, we will" | ||||
" call `forward` function instead of `train_step` and you should note that.") | " call `forward` function instead of `train_step` and you should note that.") | ||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
# self._train_signature_fn = model.forward | |||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
if hasattr(model, "validate_step"): | if hasattr(model, "validate_step"): | ||||
logger.warning( | logger.warning( | ||||
"Notice your model is a `paddle.DataParallel` model. And your " | "Notice your model is a `paddle.DataParallel` model. And your " | ||||
"model also implements the `validate_step` method, which we can not call actually, " | "model also implements the `validate_step` method, which we can not call actually, " | ||||
"we will call `forward` function instead of `validate_step` and you should note that.") | "we will call `forward` function instead of `validate_step` and you should note that.") | ||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
# self._validate_signature_fn = model.forward | |||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
logger.warning( | logger.warning( | ||||
"Notice your model is a `paddle.DataParallel` model. And your " | "Notice your model is a `paddle.DataParallel` model. And your " | ||||
"model also implements the `test_step` method, which we can not call actually, we will" | "model also implements the `test_step` method, which we can not call actually, we will" | ||||
" call `forward` function instead of `test_step` and you should note that.") | " call `forward` function instead of `test_step` and you should note that.") | ||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | # 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | ||||
self._data_device = kwargs.get("data_device", None) | self._data_device = kwargs.get("data_device", None) | ||||
@@ -277,9 +275,9 @@ class PaddleFleetDriver(PaddleDriver): | |||||
**self._fleet_kwargs | **self._fleet_kwargs | ||||
) | ) | ||||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) | |||||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._configured = True | self._configured = True | ||||
@@ -19,7 +19,7 @@ from fastNLP.envs import ( | |||||
rank_zero_call, | rank_zero_call, | ||||
) | ) | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
import paddle | import paddle | ||||
@@ -56,6 +56,9 @@ class PaddleDriver(Driver): | |||||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | ||||
self.grad_scaler = _grad_scaler() | self.grad_scaler = _grad_scaler() | ||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||||
def zero_grad(self, set_to_none: bool = False): | def zero_grad(self, set_to_none: bool = False): | ||||
r""" | r""" | ||||
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | 实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | ||||
@@ -301,7 +304,7 @@ class PaddleDriver(Driver): | |||||
elif self.is_distributed(): | elif self.is_distributed(): | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | ||||
else: | else: | ||||
sampler = ReproducibleBatchSampler( | |||||
sampler = RandomBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | ||||
batch_size=dataloader_args.batch_size, | batch_size=dataloader_args.batch_size, | ||||
drop_last=dataloader_args.drop_last | drop_last=dataloader_args.drop_last | ||||
@@ -11,7 +11,12 @@ from fastNLP.core.utils import ( | |||||
get_paddle_device_id, | get_paddle_device_id, | ||||
paddle_move_data_to_device, | paddle_move_data_to_device, | ||||
) | ) | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||||
from fastNLP.core.samplers import ( | |||||
ReproducibleBatchSampler, | |||||
RandomBatchSampler, | |||||
ReproducibleSampler, | |||||
re_instantiate_sampler, | |||||
) | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
if _NEED_IMPORT_PADDLE: | if _NEED_IMPORT_PADDLE: | ||||
@@ -102,7 +107,7 @@ class PaddleSingleDriver(PaddleDriver): | |||||
def train_step(self, batch) -> Dict: | def train_step(self, batch) -> Dict: | ||||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | ||||
else: | else: | ||||
return self._train_step(batch) | return self._train_step(batch) | ||||
@@ -116,13 +121,13 @@ class PaddleSingleDriver(PaddleDriver): | |||||
self.grad_scaler.update() | self.grad_scaler.update() | ||||
def validate_step(self, batch) -> Dict: | def validate_step(self, batch) -> Dict: | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | return self._validate_step(batch) | ||||
def test_step(self, batch) -> Dict: | def test_step(self, batch) -> Dict: | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | ||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
@@ -159,7 +164,7 @@ class PaddleSingleDriver(PaddleDriver): | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
if reproducible: | if reproducible: | ||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=args.batch_sampler, | batch_sampler=args.batch_sampler, | ||||
batch_size=args.batch_size, | batch_size=args.batch_size, | ||||
drop_last=args.drop_last | drop_last=args.drop_last | ||||
@@ -85,7 +85,7 @@ class ForwardState(IntEnum): | |||||
TEST = 2 | TEST = 2 | ||||
PREDICT = 3 | PREDICT = 3 | ||||
_MODE_PARAMETER = "_forward_state" | |||||
_MODE_PARAMETER = "forward_state" | |||||
class _FleetWrappingModel(Layer): | class _FleetWrappingModel(Layer): | ||||
""" | """ | ||||
@@ -151,24 +151,25 @@ class _FleetWrappingModel(Layer): | |||||
def forward(self, batch, **kwargs) -> Dict: | def forward(self, batch, **kwargs) -> Dict: | ||||
_forward_state = kwargs.pop(_MODE_PARAMETER) | |||||
forward_state = kwargs.pop(_MODE_PARAMETER) | |||||
wo_auto_param_call = kwargs.pop("wo_auto_param_call") | |||||
if _forward_state == ForwardState.TRAIN: | |||||
if isinstance(batch, Dict): | |||||
if forward_state == ForwardState.TRAIN: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | ||||
else: | else: | ||||
return self._train_step(batch) | return self._train_step(batch) | ||||
elif _forward_state == ForwardState.VALIDATE: | |||||
if isinstance(batch, Dict): | |||||
elif forward_state == ForwardState.VALIDATE: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | return self._validate_step(batch) | ||||
elif _forward_state == ForwardState.TEST: | |||||
if isinstance(batch, Dict): | |||||
elif forward_state == ForwardState.TEST: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | ||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
elif _forward_state == ForwardState.PREDICT: | |||||
elif forward_state == ForwardState.PREDICT: | |||||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | raise NotImplementedError("'PREDICT' mode has not been implemented.") | ||||
else: | else: | ||||
raise NotImplementedError("You should direct a concrete mode.") | raise NotImplementedError("You should direct a concrete mode.") | ||||