@@ -104,8 +104,8 @@ class PaddleFleetDriver(PaddleDriver): | |||
# 我们就直接将 model_device 置为 None; | |||
self._model_device = None | |||
def _running_fn_(batch, step_fn, signature_fn): | |||
if isinstance(batch, Dict): | |||
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
@@ -116,23 +116,21 @@ class PaddleFleetDriver(PaddleDriver): | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `train_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `train_step` and you should note that.") | |||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
# self._train_signature_fn = model.forward | |||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
if hasattr(model, "validate_step"): | |||
logger.warning( | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `validate_step` method, which we can not call actually, " | |||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
# self._validate_signature_fn = model.forward | |||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
if hasattr(model, "test_step"): | |||
logger.warning( | |||
"Notice your model is a `paddle.DataParallel` model. And your " | |||
"model also implements the `test_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `test_step` and you should note that.") | |||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上; | |||
self._data_device = kwargs.get("data_device", None) | |||
@@ -277,9 +275,9 @@ class PaddleFleetDriver(PaddleDriver): | |||
**self._fleet_kwargs | |||
) | |||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) | |||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._configured = True | |||
@@ -19,7 +19,7 @@ from fastNLP.envs import ( | |||
rank_zero_call, | |||
) | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||
if _NEED_IMPORT_PADDLE: | |||
import paddle | |||
@@ -56,6 +56,9 @@ class PaddleDriver(Driver): | |||
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16) | |||
self.grad_scaler = _grad_scaler() | |||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||
def zero_grad(self, set_to_none: bool = False): | |||
r""" | |||
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零; | |||
@@ -301,7 +304,7 @@ class PaddleDriver(Driver): | |||
elif self.is_distributed(): | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||
else: | |||
sampler = ReproducibleBatchSampler( | |||
sampler = RandomBatchSampler( | |||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||
batch_size=dataloader_args.batch_size, | |||
drop_last=dataloader_args.drop_last | |||
@@ -11,7 +11,12 @@ from fastNLP.core.utils import ( | |||
get_paddle_device_id, | |||
paddle_move_data_to_device, | |||
) | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||
from fastNLP.core.samplers import ( | |||
ReproducibleBatchSampler, | |||
RandomBatchSampler, | |||
ReproducibleSampler, | |||
re_instantiate_sampler, | |||
) | |||
from fastNLP.core.log import logger | |||
if _NEED_IMPORT_PADDLE: | |||
@@ -102,7 +107,7 @@ class PaddleSingleDriver(PaddleDriver): | |||
def train_step(self, batch) -> Dict: | |||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | |||
if isinstance(batch, Dict): | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
@@ -116,13 +121,13 @@ class PaddleSingleDriver(PaddleDriver): | |||
self.grad_scaler.update() | |||
def validate_step(self, batch) -> Dict: | |||
if isinstance(batch, Dict): | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
def test_step(self, batch) -> Dict: | |||
if isinstance(batch, Dict): | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
else: | |||
return self._test_step(batch) | |||
@@ -159,7 +164,7 @@ class PaddleSingleDriver(PaddleDriver): | |||
return replace_sampler(dataloader, sampler) | |||
if reproducible: | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler = RandomBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
drop_last=args.drop_last | |||
@@ -85,7 +85,7 @@ class ForwardState(IntEnum): | |||
TEST = 2 | |||
PREDICT = 3 | |||
_MODE_PARAMETER = "_forward_state" | |||
_MODE_PARAMETER = "forward_state" | |||
class _FleetWrappingModel(Layer): | |||
""" | |||
@@ -151,24 +151,25 @@ class _FleetWrappingModel(Layer): | |||
def forward(self, batch, **kwargs) -> Dict: | |||
_forward_state = kwargs.pop(_MODE_PARAMETER) | |||
forward_state = kwargs.pop(_MODE_PARAMETER) | |||
wo_auto_param_call = kwargs.pop("wo_auto_param_call") | |||
if _forward_state == ForwardState.TRAIN: | |||
if isinstance(batch, Dict): | |||
if forward_state == ForwardState.TRAIN: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
elif _forward_state == ForwardState.VALIDATE: | |||
if isinstance(batch, Dict): | |||
elif forward_state == ForwardState.VALIDATE: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
elif _forward_state == ForwardState.TEST: | |||
if isinstance(batch, Dict): | |||
elif forward_state == ForwardState.TEST: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
else: | |||
return self._test_step(batch) | |||
elif _forward_state == ForwardState.PREDICT: | |||
elif forward_state == ForwardState.PREDICT: | |||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | |||
else: | |||
raise NotImplementedError("You should direct a concrete mode.") | |||