Browse Source

填了了关闭参数匹配的逻辑;添加了 trainer 中获取 driver 参数的接口

tags/v1.0.0alpha
YWMditto 2 years ago
parent
commit
5419b6a042
10 changed files with 125 additions and 35 deletions
  1. +5
    -1
      fastNLP/core/controllers/evaluator.py
  2. +23
    -5
      fastNLP/core/controllers/trainer.py
  3. +9
    -11
      fastNLP/core/drivers/torch_driver/ddp.py
  4. +5
    -5
      fastNLP/core/drivers/torch_driver/single_device.py
  5. +5
    -2
      fastNLP/core/drivers/torch_driver/torch_driver.py
  6. +9
    -8
      fastNLP/core/drivers/torch_driver/utils.py
  7. +2
    -0
      fastNLP/core/utils/utils.py
  8. +1
    -1
      tests/core/callbacks/test_checkpoint_callback_torch.py
  9. +39
    -2
      tests/core/controllers/test_trainer_wo_evaluator_torch.py
  10. +27
    -0
      tests/helpers/models/torch_model.py

+ 5
- 1
fastNLP/core/controllers/evaluator.py View File

@@ -41,6 +41,7 @@ class Evaluator:
mode: str = "validate",
input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None,
model_wo_auto_param_call: bool = False,
fp16: Optional[bool] = False,
verbose: int = 1,
**kwargs
@@ -61,6 +62,9 @@ class Evaluator:
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`;
:param fp16: 是否使用 fp16 。
:param verbose: 是否打印 evaluate 的结果。
:param kwargs:
@@ -83,7 +87,7 @@ class Evaluator:
self.model = model
self.metrics = metrics

self.driver = choose_driver(model, driver, device, fp16=fp16, **kwargs)
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs)

self.device = device
self.verbose = verbose


+ 23
- 5
fastNLP/core/controllers/trainer.py View File

@@ -47,6 +47,7 @@ class Trainer(TrainerEventTrigger):
validate_every: Optional[Union[int, callable]] = -1,
input_mapping: Optional[Union[Callable, Dict]] = None,
output_mapping: Optional[Union[Callable, Dict]] = None,
model_wo_auto_param_call: bool = False,
accumulation_steps: int = 1,
fp16: bool = False,
marker: Optional[str] = None,
@@ -99,7 +100,10 @@ class Trainer(TrainerEventTrigger):
:param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个
函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型,
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value;
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换;
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为;
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`;
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1;
:param fp16: 是否开启混合精度训练;默认为 False;
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None;
@@ -120,9 +124,7 @@ class Trainer(TrainerEventTrigger):

"""

# TODO 是不是可以加一个参数让用户现在关掉参数匹配。
self.marker = marker
self.model = model
self.driver_name = driver
self.device = device
self.fp16 = fp16
@@ -164,6 +166,7 @@ class Trainer(TrainerEventTrigger):
validate_every=validate_every,
input_mapping=input_mapping,
output_mapping=output_mapping,
model_wo_auto_param_call=model_wo_auto_param_call,
accumulation_steps=accumulation_steps,
fp16=fp16,
marker=marker,
@@ -484,8 +487,6 @@ class Trainer(TrainerEventTrigger):

@driver.setter
def driver(self, driver: Driver):
driver.trainer = self
driver.model = self.model
self._driver = driver

@property
@@ -782,4 +783,21 @@ class Trainer(TrainerEventTrigger):
def total_batches(self, total_batches: int):
self.trainer_state.total_batches = total_batches

""" driver property """

@property
def model_device(self):
return self.driver.model_device

@property
def data_device(self):
return self.driver.data_device

@property
def model(self):
# 返回 driver 中的 model,注意该 model 可能被分布式的模型包裹,例如 `DistributedDataParallel`;
return self.driver.model





+ 9
- 11
fastNLP/core/drivers/torch_driver/ddp.py View File

@@ -167,6 +167,7 @@ class TorchDDPDriver(TorchDriver):
不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后,
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉;
"""
# 在加入很多东西后,需要注意这里调用 super 函数的位置;
super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs)

if isinstance(model, torch.nn.DataParallel):
@@ -202,8 +203,8 @@ class TorchDDPDriver(TorchDriver):
# 我们就直接将 model_device 置为 None;
self.model_device = None

def _running_fn_(batch, step_fn, signature_fn):
if isinstance(batch, Dict):
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call):
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(step_fn, batch, signature_fn=signature_fn)
else:
return step_fn(batch)
@@ -214,7 +215,7 @@ class TorchDDPDriver(TorchDriver):
"Notice your model is a `DistributedDataParallel` model. And your "
"model also implements the `train_step` method, which we can not call actually, we will"
" call `forward` function instead of `train_step` and you should note that.")
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward)
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
# self._train_signature_fn = model.forward

if hasattr(model, "validate_step"):
@@ -222,7 +223,7 @@ class TorchDDPDriver(TorchDriver):
"Notice your model is a `DistributedDataParallel` model. And your "
"model also implements the `validate_step` method, which we can not call actually, "
"we will call `forward` function instead of `validate_step` and you should note that.")
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward)
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
# self._validate_signature_fn = model.forward

if hasattr(model, "test_step"):
@@ -230,14 +231,11 @@ class TorchDDPDriver(TorchDriver):
"Notice your model is a `DistributedDataParallel` model. And your "
"model also implements the `test_step` method, which we can not call actually, we will"
" call `forward` function instead of `test_step` and you should note that.")
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward)
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call)
# self._test_signature_fn = model.forward

# 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上;
self._data_device = kwargs.get("data_device", None)
# if self.outside_ddp and self._data_device is None:
# raise RuntimeError("When you initialize your ddp out of our control, the parameter "
# "`data_device` can not be None.")
if isinstance(self._data_device, int):
if self._data_device < 0:
raise ValueError("Parameter `data_device` can not be smaller than 0.")
@@ -349,9 +347,9 @@ class TorchDDPDriver(TorchDriver):
**self._ddp_kwargs
)

self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN})
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE})
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST})
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call)
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call)
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call)

self._configured = True



+ 5
- 5
fastNLP/core/drivers/torch_driver/single_device.py View File

@@ -13,7 +13,7 @@ __all__ = [
from .torch_driver import TorchDriver
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler
from fastNLP.core.utils import auto_param_call
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler
from fastNLP.core.log import logger


@@ -102,7 +102,7 @@ class TorchSingleDriver(TorchDriver):

def train_step(self, batch) -> Dict:
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理;
if isinstance(batch, Dict):
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)
@@ -118,13 +118,13 @@ class TorchSingleDriver(TorchDriver):
def validate_step(self, batch) -> Dict:
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的;
if isinstance(batch, Dict):
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)

def test_step(self, batch) -> Dict:
if isinstance(batch, Dict):
if isinstance(batch, Dict) and not self.wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
@@ -148,7 +148,7 @@ class TorchSingleDriver(TorchDriver):
return replace_sampler(dataloader, sampler)

if reproducible:
batch_sampler = ReproducibleBatchSampler(
batch_sampler = RandomBatchSampler(
batch_sampler=args.batch_sampler,
batch_size=args.batch_size,
drop_last=args.drop_last


+ 5
- 2
fastNLP/core/drivers/torch_driver/torch_driver.py View File

@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device
from fastNLP.envs import rank_zero_call
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME
from fastNLP.core.log import logger
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler


class TorchDriver(Driver):
@@ -51,6 +51,9 @@ class TorchDriver(Driver):
# 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数;
self.non_blocking = kwargs.get("torch_non_blocking", True)

# 用来设置是否关闭 auto_param_call 中的参数匹配问题;
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False)

def zero_grad(self, set_to_none: bool = False):
for optimizer in self.optimizers:
self._clear_grad(optimizer, set_to_none)
@@ -252,7 +255,7 @@ class TorchDriver(Driver):
elif self.is_distributed():
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.")
else:
sampler = ReproducibleBatchSampler(
sampler = RandomBatchSampler(
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler,
batch_size=dataloader_args.batch_size,
drop_last=dataloader_args.drop_last


+ 9
- 8
fastNLP/core/drivers/torch_driver/utils.py View File

@@ -140,24 +140,25 @@ class _DDPWrappingModel(Module):
pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看;
"""

_forward_state = kwargs.pop(_MODE_PARAMETER)
forward_state = kwargs.pop(_MODE_PARAMETER)
wo_auto_param_call = kwargs.pop("wo_auto_param_call")

if _forward_state == ForwardState.TRAIN:
if isinstance(batch, Dict):
if forward_state == ForwardState.TRAIN:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn)
else:
return self._train_step(batch)
elif _forward_state == ForwardState.VALIDATE:
if isinstance(batch, Dict):
elif forward_state == ForwardState.VALIDATE:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn)
else:
return self._validate_step(batch)
elif _forward_state == ForwardState.TEST:
if isinstance(batch, Dict):
elif forward_state == ForwardState.TEST:
if isinstance(batch, Dict) and not wo_auto_param_call:
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn)
else:
return self._test_step(batch)
elif _forward_state == ForwardState.PREDICT:
elif forward_state == ForwardState.PREDICT:
raise NotImplementedError("'PREDICT' mode has not been implemented.")
else:
raise NotImplementedError("You should direct a concrete mode.")


+ 2
- 0
fastNLP/core/utils/utils.py View File

@@ -96,6 +96,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None
:param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取
参数值后,再传给 `fn` 进行实际的运算;
:param mapping: 一个字典,用来更改其前面的字典的键值;
:param wo_auto_param_call: 是否关闭默认的参数匹配行为;

:return: 返回 `fn` 运行的结果;

@@ -113,6 +114,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240
"""

if signature_fn is not None:
if not callable(signature_fn):
raise ValueError(f"Parameter `signature_fn` should be `Callable`.")


+ 1
- 1
tests/core/callbacks/test_checkpoint_callback_torch.py View File

@@ -10,7 +10,7 @@ import re

from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback
from fastNLP.core.controllers.trainer import Trainer
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK

from tests.helpers.utils import magic_argv_env_context
from fastNLP.core import synchronize_safe_rm


+ 39
- 2
tests/core/controllers/test_trainer_wo_evaluator_torch.py View File

@@ -10,7 +10,7 @@ from typing import Any
from pathlib import Path

from fastNLP.core.controllers.trainer import Trainer
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1, TorchNormalModel_Classification_3
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch
@@ -70,7 +70,7 @@ def model_and_optimizers(request):
trainer_params.output_mapping = None

# elif request.param == 1:
# model =

return trainer_params

@@ -307,10 +307,47 @@ def test_torch_distributed_launch_2(version):
subprocess.check_call(command)


@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])])
@magic_argv_env_context
def test_torch_wo_auto_param_call(
driver,
device,
n_epochs=10,
):

model = TorchNormalModel_Classification_3(
num_labels=NormalClassificationTrainTorchConfig.num_labels,
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension
)
optimizers = SGD(model.parameters(), lr=0.001)
dataset = TorchNormalDataset_Classification(
num_labels=NormalClassificationTrainTorchConfig.num_labels,
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension,
each_label_data=NormalClassificationTrainTorchConfig.each_label_data,
seed=NormalClassificationTrainTorchConfig.seed
)
train_dataloader = DataLoader(
dataset=dataset,
batch_size=NormalClassificationTrainTorchConfig.batch_size,
shuffle=True
)

trainer = Trainer(
model=model,
driver=driver,
device=device,
optimizers=optimizers,
train_dataloader=train_dataloader,
n_epochs=n_epochs,

model_wo_auto_param_call=True,
output_from_new_proc="all"
)

trainer.run()

if dist.is_initialized():
dist.destroy_process_group()





+ 27
- 0
tests/helpers/models/torch_model.py View File

@@ -37,6 +37,7 @@ class TorchNormalModel_Classification_1(nn.Module):
x = torch.max(x, dim=-1)[1]
return {"preds": x, "target": y}


class TorchNormalModel_Classification_2(nn.Module):
"""
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
@@ -61,5 +62,31 @@ class TorchNormalModel_Classification_2(nn.Module):
return {"loss": loss, "preds": x, "target": y}


class TorchNormalModel_Classification_3(nn.Module):
"""
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景;
关闭 auto_param_call,forward 只有一个 batch 参数;
"""
def __init__(self, num_labels, feature_dimension):
super(TorchNormalModel_Classification_3, self).__init__()
self.num_labels = num_labels

self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10)
self.ac1 = nn.ReLU()
self.linear2 = nn.Linear(in_features=10, out_features=10)
self.ac2 = nn.ReLU()
self.output = nn.Linear(in_features=10, out_features=num_labels)
self.loss_fn = nn.CrossEntropyLoss()

def forward(self, batch):
x = batch["x"]
y = batch["y"]
x = self.ac1(self.linear1(x))
x = self.ac2(self.linear2(x))
x = self.output(x)
loss = self.loss_fn(x, y)
x = torch.max(x, dim=-1)[1]
return {"loss": loss, "preds": x, "target": y}




Loading…
Cancel
Save