@@ -41,6 +41,7 @@ class Evaluator: | |||
mode: str = "validate", | |||
input_mapping: Optional[Union[Callable, Dict]] = None, | |||
output_mapping: Optional[Union[Callable, Dict]] = None, | |||
model_wo_auto_param_call: bool = False, | |||
fp16: Optional[bool] = False, | |||
verbose: int = 1, | |||
**kwargs | |||
@@ -61,6 +62,9 @@ class Evaluator: | |||
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 | |||
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | |||
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | |||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; | |||
:param fp16: 是否使用 fp16 。 | |||
:param verbose: 是否打印 evaluate 的结果。 | |||
:param kwargs: | |||
@@ -83,7 +87,7 @@ class Evaluator: | |||
self.model = model | |||
self.metrics = metrics | |||
self.driver = choose_driver(model, driver, device, fp16=fp16, **kwargs) | |||
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) | |||
self.device = device | |||
self.verbose = verbose | |||
@@ -47,6 +47,7 @@ class Trainer(TrainerEventTrigger): | |||
validate_every: Optional[Union[int, callable]] = -1, | |||
input_mapping: Optional[Union[Callable, Dict]] = None, | |||
output_mapping: Optional[Union[Callable, Dict]] = None, | |||
model_wo_auto_param_call: bool = False, | |||
accumulation_steps: int = 1, | |||
fp16: bool = False, | |||
marker: Optional[str] = None, | |||
@@ -99,7 +100,10 @@ class Trainer(TrainerEventTrigger): | |||
:param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 | |||
函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, | |||
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; | |||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 | |||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | |||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; | |||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | |||
:param fp16: 是否开启混合精度训练;默认为 False; | |||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | |||
@@ -120,9 +124,7 @@ class Trainer(TrainerEventTrigger): | |||
""" | |||
# TODO 是不是可以加一个参数让用户现在关掉参数匹配。 | |||
self.marker = marker | |||
self.model = model | |||
self.driver_name = driver | |||
self.device = device | |||
self.fp16 = fp16 | |||
@@ -164,6 +166,7 @@ class Trainer(TrainerEventTrigger): | |||
validate_every=validate_every, | |||
input_mapping=input_mapping, | |||
output_mapping=output_mapping, | |||
model_wo_auto_param_call=model_wo_auto_param_call, | |||
accumulation_steps=accumulation_steps, | |||
fp16=fp16, | |||
marker=marker, | |||
@@ -484,8 +487,6 @@ class Trainer(TrainerEventTrigger): | |||
@driver.setter | |||
def driver(self, driver: Driver): | |||
driver.trainer = self | |||
driver.model = self.model | |||
self._driver = driver | |||
@property | |||
@@ -782,4 +783,21 @@ class Trainer(TrainerEventTrigger): | |||
def total_batches(self, total_batches: int): | |||
self.trainer_state.total_batches = total_batches | |||
""" driver property """ | |||
@property | |||
def model_device(self): | |||
return self.driver.model_device | |||
@property | |||
def data_device(self): | |||
return self.driver.data_device | |||
@property | |||
def model(self): | |||
# 返回 driver 中的 model,注意该 model 可能被分布式的模型包裹,例如 `DistributedDataParallel`; | |||
return self.driver.model | |||
@@ -167,6 +167,7 @@ class TorchDDPDriver(TorchDriver): | |||
不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | |||
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | |||
""" | |||
# 在加入很多东西后,需要注意这里调用 super 函数的位置; | |||
super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) | |||
if isinstance(model, torch.nn.DataParallel): | |||
@@ -202,8 +203,8 @@ class TorchDDPDriver(TorchDriver): | |||
# 我们就直接将 model_device 置为 None; | |||
self.model_device = None | |||
def _running_fn_(batch, step_fn, signature_fn): | |||
if isinstance(batch, Dict): | |||
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | |||
else: | |||
return step_fn(batch) | |||
@@ -214,7 +215,7 @@ class TorchDDPDriver(TorchDriver): | |||
"Notice your model is a `DistributedDataParallel` model. And your " | |||
"model also implements the `train_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `train_step` and you should note that.") | |||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
# self._train_signature_fn = model.forward | |||
if hasattr(model, "validate_step"): | |||
@@ -222,7 +223,7 @@ class TorchDDPDriver(TorchDriver): | |||
"Notice your model is a `DistributedDataParallel` model. And your " | |||
"model also implements the `validate_step` method, which we can not call actually, " | |||
"we will call `forward` function instead of `validate_step` and you should note that.") | |||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
# self._validate_signature_fn = model.forward | |||
if hasattr(model, "test_step"): | |||
@@ -230,14 +231,11 @@ class TorchDDPDriver(TorchDriver): | |||
"Notice your model is a `DistributedDataParallel` model. And your " | |||
"model also implements the `test_step` method, which we can not call actually, we will" | |||
" call `forward` function instead of `test_step` and you should note that.") | |||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||
# self._test_signature_fn = model.forward | |||
# 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; | |||
self._data_device = kwargs.get("data_device", None) | |||
# if self.outside_ddp and self._data_device is None: | |||
# raise RuntimeError("When you initialize your ddp out of our control, the parameter " | |||
# "`data_device` can not be None.") | |||
if isinstance(self._data_device, int): | |||
if self._data_device < 0: | |||
raise ValueError("Parameter `data_device` can not be smaller than 0.") | |||
@@ -349,9 +347,9 @@ class TorchDDPDriver(TorchDriver): | |||
**self._ddp_kwargs | |||
) | |||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) | |||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) | |||
self._configured = True | |||
@@ -13,7 +13,7 @@ __all__ = [ | |||
from .torch_driver import TorchDriver | |||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | |||
from fastNLP.core.utils import auto_param_call | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler | |||
from fastNLP.core.log import logger | |||
@@ -102,7 +102,7 @@ class TorchSingleDriver(TorchDriver): | |||
def train_step(self, batch) -> Dict: | |||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | |||
if isinstance(batch, Dict): | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
@@ -118,13 +118,13 @@ class TorchSingleDriver(TorchDriver): | |||
def validate_step(self, batch) -> Dict: | |||
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | |||
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | |||
if isinstance(batch, Dict): | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
def test_step(self, batch) -> Dict: | |||
if isinstance(batch, Dict): | |||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
else: | |||
return self._test_step(batch) | |||
@@ -148,7 +148,7 @@ class TorchSingleDriver(TorchDriver): | |||
return replace_sampler(dataloader, sampler) | |||
if reproducible: | |||
batch_sampler = ReproducibleBatchSampler( | |||
batch_sampler = RandomBatchSampler( | |||
batch_sampler=args.batch_sampler, | |||
batch_size=args.batch_size, | |||
drop_last=args.drop_last | |||
@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||
from fastNLP.envs import rank_zero_call | |||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | |||
from fastNLP.core.log import logger | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||
class TorchDriver(Driver): | |||
@@ -51,6 +51,9 @@ class TorchDriver(Driver): | |||
# 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数; | |||
self.non_blocking = kwargs.get("torch_non_blocking", True) | |||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||
def zero_grad(self, set_to_none: bool = False): | |||
for optimizer in self.optimizers: | |||
self._clear_grad(optimizer, set_to_none) | |||
@@ -252,7 +255,7 @@ class TorchDriver(Driver): | |||
elif self.is_distributed(): | |||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | |||
else: | |||
sampler = ReproducibleBatchSampler( | |||
sampler = RandomBatchSampler( | |||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | |||
batch_size=dataloader_args.batch_size, | |||
drop_last=dataloader_args.drop_last | |||
@@ -140,24 +140,25 @@ class _DDPWrappingModel(Module): | |||
pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; | |||
""" | |||
_forward_state = kwargs.pop(_MODE_PARAMETER) | |||
forward_state = kwargs.pop(_MODE_PARAMETER) | |||
wo_auto_param_call = kwargs.pop("wo_auto_param_call") | |||
if _forward_state == ForwardState.TRAIN: | |||
if isinstance(batch, Dict): | |||
if forward_state == ForwardState.TRAIN: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | |||
else: | |||
return self._train_step(batch) | |||
elif _forward_state == ForwardState.VALIDATE: | |||
if isinstance(batch, Dict): | |||
elif forward_state == ForwardState.VALIDATE: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | |||
else: | |||
return self._validate_step(batch) | |||
elif _forward_state == ForwardState.TEST: | |||
if isinstance(batch, Dict): | |||
elif forward_state == ForwardState.TEST: | |||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | |||
else: | |||
return self._test_step(batch) | |||
elif _forward_state == ForwardState.PREDICT: | |||
elif forward_state == ForwardState.PREDICT: | |||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | |||
else: | |||
raise NotImplementedError("You should direct a concrete mode.") | |||
@@ -96,6 +96,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
:param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | |||
参数值后,再传给 `fn` 进行实际的运算; | |||
:param mapping: 一个字典,用来更改其前面的字典的键值; | |||
:param wo_auto_param_call: 是否关闭默认的参数匹配行为; | |||
:return: 返回 `fn` 运行的结果; | |||
@@ -113,6 +114,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 | |||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 | |||
""" | |||
if signature_fn is not None: | |||
if not callable(signature_fn): | |||
raise ValueError(f"Parameter `signature_fn` should be `Callable`.") | |||
@@ -10,7 +10,7 @@ import re | |||
from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||
from tests.helpers.utils import magic_argv_env_context | |||
from fastNLP.core import synchronize_safe_rm | |||
@@ -10,7 +10,7 @@ from typing import Any | |||
from pathlib import Path | |||
from fastNLP.core.controllers.trainer import Trainer | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1, TorchNormalModel_Classification_3 | |||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | |||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | |||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | |||
@@ -70,7 +70,7 @@ def model_and_optimizers(request): | |||
trainer_params.output_mapping = None | |||
# elif request.param == 1: | |||
# model = | |||
return trainer_params | |||
@@ -307,10 +307,47 @@ def test_torch_distributed_launch_2(version): | |||
subprocess.check_call(command) | |||
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])]) | |||
@magic_argv_env_context | |||
def test_torch_wo_auto_param_call( | |||
driver, | |||
device, | |||
n_epochs=10, | |||
): | |||
model = TorchNormalModel_Classification_3( | |||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension | |||
) | |||
optimizers = SGD(model.parameters(), lr=0.001) | |||
dataset = TorchNormalDataset_Classification( | |||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension, | |||
each_label_data=NormalClassificationTrainTorchConfig.each_label_data, | |||
seed=NormalClassificationTrainTorchConfig.seed | |||
) | |||
train_dataloader = DataLoader( | |||
dataset=dataset, | |||
batch_size=NormalClassificationTrainTorchConfig.batch_size, | |||
shuffle=True | |||
) | |||
trainer = Trainer( | |||
model=model, | |||
driver=driver, | |||
device=device, | |||
optimizers=optimizers, | |||
train_dataloader=train_dataloader, | |||
n_epochs=n_epochs, | |||
model_wo_auto_param_call=True, | |||
output_from_new_proc="all" | |||
) | |||
trainer.run() | |||
if dist.is_initialized(): | |||
dist.destroy_process_group() | |||
@@ -37,6 +37,7 @@ class TorchNormalModel_Classification_1(nn.Module): | |||
x = torch.max(x, dim=-1)[1] | |||
return {"preds": x, "target": y} | |||
class TorchNormalModel_Classification_2(nn.Module): | |||
""" | |||
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | |||
@@ -61,5 +62,31 @@ class TorchNormalModel_Classification_2(nn.Module): | |||
return {"loss": loss, "preds": x, "target": y} | |||
class TorchNormalModel_Classification_3(nn.Module): | |||
""" | |||
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | |||
关闭 auto_param_call,forward 只有一个 batch 参数; | |||
""" | |||
def __init__(self, num_labels, feature_dimension): | |||
super(TorchNormalModel_Classification_3, self).__init__() | |||
self.num_labels = num_labels | |||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10) | |||
self.ac1 = nn.ReLU() | |||
self.linear2 = nn.Linear(in_features=10, out_features=10) | |||
self.ac2 = nn.ReLU() | |||
self.output = nn.Linear(in_features=10, out_features=num_labels) | |||
self.loss_fn = nn.CrossEntropyLoss() | |||
def forward(self, batch): | |||
x = batch["x"] | |||
y = batch["y"] | |||
x = self.ac1(self.linear1(x)) | |||
x = self.ac2(self.linear2(x)) | |||
x = self.output(x) | |||
loss = self.loss_fn(x, y) | |||
x = torch.max(x, dim=-1)[1] | |||
return {"loss": loss, "preds": x, "target": y} | |||