Browse Source

update

tags/v1.0.0alpha
x54-729 3 years ago
parent
commit
a5b2ccf759
4 changed files with 33 additions and 26 deletions
  1. +8
    -10
      fastNLP/core/drivers/paddle_driver/fleet.py
  2. +5
    -2
      fastNLP/core/drivers/paddle_driver/paddle_driver.py
  3. +10
    -5
      fastNLP/core/drivers/paddle_driver/single_device.py
  4. +10
    -9
      fastNLP/core/drivers/paddle_driver/utils.py

+ 8
- 10
fastNLP/core/drivers/paddle_driver/fleet.py View File

@@ -104,8 +104,8 @@ class PaddleFleetDriver(PaddleDriver):
# 我们就直接将 model_device 置为 None;
self._model_device = None

def _running_fn_(batch, step_fn, signature_fn):
if isinstance(batch, Dict):
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call):
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(step_fn, batch, signature_fn=signature_fn)
else:
return self._validate_step(batch)
@@ -116,23 +116,21 @@ class PaddleFleetDriver(PaddleDriver):
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward)
# self._train_signature_fn = model.forward
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

if hasattr(model, "validate_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `validate_step` method, which we can not call actually, "
"we will call `forward` function instead of `validate_step` and you should note that.")
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward)
# self._validate_signature_fn = model.forward
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

if hasattr(model, "test_step"):
logger.warning(
"Notice your model is a `paddle.DataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward)
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)

# 当参数 `device` 为 None 时并且该参数不为 None,表示将对应的数据移到指定的机器上;
self._data_device = kwargs.get("data_device", None)
@@ -277,9 +275,9 @@ class PaddleFleetDriver(PaddleDriver):
**self._fleet_kwargs
)

self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN})
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE})
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST})
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call)
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call)
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call)

self._configured = True



+ 5
- 2
fastNLP/core/drivers/paddle_driver/paddle_driver.py View File

@@ -19,7 +19,7 @@ from fastNLP.envs import (
rank_zero_call,
)
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler

if _NEED_IMPORT_PADDLE:
import paddle
@@ -56,6 +56,9 @@ class PaddleDriver(Driver):
self.auto_cast, _grad_scaler = _build_fp16_env(dummy=not fp16)
self.grad_scaler = _grad_scaler()

# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)

def zero_grad(self, set_to_none: bool = False):
r"""
实现深度学习中的梯度的置零操作,应当直接通过优化器 optimizers 来将梯度置零;
@@ -301,7 +304,7 @@ class PaddleDriver(Driver):
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
else:
sampler = ReproducibleBatchSampler(
sampler = RandomBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last


+ 10
- 5
fastNLP/core/drivers/paddle_driver/single_device.py View File

@@ -11,7 +11,12 @@ from fastNLP.core.utils import (
get_paddle_device_id,
paddle_move_data_to_device,
)
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler
from fastNLP.core.samplers import (
ReproducibleBatchSampler,
RandomBatchSampler,
ReproducibleSampler,
re_instantiate_sampler,
)
from fastNLP.core.log import logger

if _NEED_IMPORT_PADDLE:
@@ -102,7 +107,7 @@ class PaddleSingleDriver(PaddleDriver):

def train_step(self, batch) -> Dict:
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理;
if isinstance(batch, Dict):
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)
@@ -116,13 +121,13 @@ class PaddleSingleDriver(PaddleDriver):
self.grad_scaler.update()

def validate_step(self, batch) -> Dict:
if isinstance(batch, Dict):
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)

def test_step(self, batch) -> Dict:
if isinstance(batch, Dict):
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
@@ -159,7 +164,7 @@ class PaddleSingleDriver(PaddleDriver):
return replace_sampler(dataloader, sampler)

if reproducible:
batch_sampler = ReproducibleBatchSampler(
batch_sampler = RandomBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last


+ 10
- 9
fastNLP/core/drivers/paddle_driver/utils.py View File

@@ -85,7 +85,7 @@ class ForwardState(IntEnum):
TEST = 2
PREDICT = 3

_MODE_PARAMETER = "_forward_state"
_MODE_PARAMETER = "forward_state"

class _FleetWrappingModel(Layer):
"""
@@ -151,24 +151,25 @@ class _FleetWrappingModel(Layer):

def forward(self, batch, **kwargs) -> Dict:

_forward_state = kwargs.pop(_MODE_PARAMETER)
forward_state = kwargs.pop(_MODE_PARAMETER)
wo_auto_param_call = kwargs.pop("wo_auto_param_call")

if _forward_state == ForwardState.TRAIN:
if isinstance(batch, Dict):
if forward_state == ForwardState.TRAIN:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)
elif _forward_state == ForwardState.VALIDATE:
if isinstance(batch, Dict):
elif forward_state == ForwardState.VALIDATE:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)
elif _forward_state == ForwardState.TEST:
if isinstance(batch, Dict):
elif forward_state == ForwardState.TEST:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
elif _forward_state == ForwardState.PREDICT:
elif forward_state == ForwardState.PREDICT:
raise NotImplementedError("'PREDICT' mode has not been implemented.")
else:
raise NotImplementedError("You should direct a concrete mode.")


Loading…
Cancel
Save