@@ -41,6 +41,7 @@ class Evaluator: | |||||
mode: str = "validate", | mode: str = "validate", | ||||
input_mapping: Optional[Union[Callable, Dict]] = None, | input_mapping: Optional[Union[Callable, Dict]] = None, | ||||
output_mapping: Optional[Union[Callable, Dict]] = None, | output_mapping: Optional[Union[Callable, Dict]] = None, | ||||
model_wo_auto_param_call: bool = False, | |||||
fp16: Optional[bool] = False, | fp16: Optional[bool] = False, | ||||
verbose: int = 1, | verbose: int = 1, | ||||
**kwargs | **kwargs | ||||
@@ -61,6 +62,9 @@ class Evaluator: | |||||
没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 | 没有的话尝试 "validate_step" 函数,都没找到则使用 model 的前向运算函数。 | ||||
:param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | :param input_mapping: 对 dataloader 中输出的内容将通过 input_mapping 处理之后再输入到 model 以及 metric 中 | ||||
:param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | :param output_mapping: 对 model 输出的内容,将通过 output_mapping 处理之后再输入到 metric 中。 | ||||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||||
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; | |||||
:param fp16: 是否使用 fp16 。 | :param fp16: 是否使用 fp16 。 | ||||
:param verbose: 是否打印 evaluate 的结果。 | :param verbose: 是否打印 evaluate 的结果。 | ||||
:param kwargs: | :param kwargs: | ||||
@@ -83,7 +87,7 @@ class Evaluator: | |||||
self.model = model | self.model = model | ||||
self.metrics = metrics | self.metrics = metrics | ||||
self.driver = choose_driver(model, driver, device, fp16=fp16, **kwargs) | |||||
self.driver = choose_driver(model, driver, device, fp16=fp16, model_wo_auto_param_call=model_wo_auto_param_call, **kwargs) | |||||
self.device = device | self.device = device | ||||
self.verbose = verbose | self.verbose = verbose | ||||
@@ -47,6 +47,7 @@ class Trainer(TrainerEventTrigger): | |||||
validate_every: Optional[Union[int, callable]] = -1, | validate_every: Optional[Union[int, callable]] = -1, | ||||
input_mapping: Optional[Union[Callable, Dict]] = None, | input_mapping: Optional[Union[Callable, Dict]] = None, | ||||
output_mapping: Optional[Union[Callable, Dict]] = None, | output_mapping: Optional[Union[Callable, Dict]] = None, | ||||
model_wo_auto_param_call: bool = False, | |||||
accumulation_steps: int = 1, | accumulation_steps: int = 1, | ||||
fp16: bool = False, | fp16: bool = False, | ||||
marker: Optional[str] = None, | marker: Optional[str] = None, | ||||
@@ -99,7 +100,10 @@ class Trainer(TrainerEventTrigger): | |||||
:param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 | :param output_mapping: 应当为一个字典或者函数。作用和 input_mapping 类似,区别在于其用于转换输出;如果 output_mapping 是一个 | ||||
函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, | 函数,那么我们将会直接将模型的输出传给该函数;如果其是一个 `Dict`,那么我们需要 batch 必须是 `Dict` 或者 `dataclass` 类型, | ||||
如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; | 如果 batch 是一个 `Dict`,那么我们会把 batch 中同样在 output_mapping 中的 key 修改为 output_mapping 的对应 key 的 value; | ||||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换 | |||||
如果 batch 是一个 `dataclass`,那么我们会先将该 dataclass 转换为一个 Dict,然后再进行上述转换; | |||||
:param model_wo_auto_param_call: 是否关闭在训练时调用我们的 auto_param_call 来自动匹配 batch 和 forward 函数的参数的行为; | |||||
如果该值为 True,并且当 batch 为字典时,我们会根据 forward 所需要的参数从 batch 中提取对应的对象,传入到 forward 函数中;如果该值 | |||||
为 False,那么我们会将 batch 直接透传给 forward 函数。注意上述逻辑同样应用于 `train_step`, `validate_step` 和 `test_step`; | |||||
:param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | :param accumulation_steps: 梯度累积的步数,表示每隔几个 batch 优化器迭代一次;默认为 1; | ||||
:param fp16: 是否开启混合精度训练;默认为 False; | :param fp16: 是否开启混合精度训练;默认为 False; | ||||
:param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | :param marker: 用于标记一个 Trainer 实例,从而在用户调用 `Trainer.on` 函数时,标记该 callback 函数属于哪一个具体的 'trainer' 实例;默认为 None; | ||||
@@ -120,9 +124,7 @@ class Trainer(TrainerEventTrigger): | |||||
""" | """ | ||||
# TODO 是不是可以加一个参数让用户现在关掉参数匹配。 | |||||
self.marker = marker | self.marker = marker | ||||
self.model = model | |||||
self.driver_name = driver | self.driver_name = driver | ||||
self.device = device | self.device = device | ||||
self.fp16 = fp16 | self.fp16 = fp16 | ||||
@@ -164,6 +166,7 @@ class Trainer(TrainerEventTrigger): | |||||
validate_every=validate_every, | validate_every=validate_every, | ||||
input_mapping=input_mapping, | input_mapping=input_mapping, | ||||
output_mapping=output_mapping, | output_mapping=output_mapping, | ||||
model_wo_auto_param_call=model_wo_auto_param_call, | |||||
accumulation_steps=accumulation_steps, | accumulation_steps=accumulation_steps, | ||||
fp16=fp16, | fp16=fp16, | ||||
marker=marker, | marker=marker, | ||||
@@ -484,8 +487,6 @@ class Trainer(TrainerEventTrigger): | |||||
@driver.setter | @driver.setter | ||||
def driver(self, driver: Driver): | def driver(self, driver: Driver): | ||||
driver.trainer = self | |||||
driver.model = self.model | |||||
self._driver = driver | self._driver = driver | ||||
@property | @property | ||||
@@ -782,4 +783,21 @@ class Trainer(TrainerEventTrigger): | |||||
def total_batches(self, total_batches: int): | def total_batches(self, total_batches: int): | ||||
self.trainer_state.total_batches = total_batches | self.trainer_state.total_batches = total_batches | ||||
""" driver property """ | |||||
@property | |||||
def model_device(self): | |||||
return self.driver.model_device | |||||
@property | |||||
def data_device(self): | |||||
return self.driver.data_device | |||||
@property | |||||
def model(self): | |||||
# 返回 driver 中的 model,注意该 model 可能被分布式的模型包裹,例如 `DistributedDataParallel`; | |||||
return self.driver.model | |||||
@@ -167,6 +167,7 @@ class TorchDDPDriver(TorchDriver): | |||||
不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | 不管是什么情况,`TorchDDPDriver` 在 `setup` 函数的最后,都会将所有进程的 pid 主动记录下来,这样当一个进程出现 exception 后, | ||||
driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | driver 的 on_exception 函数就会被 trainer 调用,其会调用 os.kill 指令将其它进程 kill 掉; | ||||
""" | """ | ||||
# 在加入很多东西后,需要注意这里调用 super 函数的位置; | |||||
super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) | super(TorchDDPDriver, self).__init__(model, fp16=fp16, **kwargs) | ||||
if isinstance(model, torch.nn.DataParallel): | if isinstance(model, torch.nn.DataParallel): | ||||
@@ -202,8 +203,8 @@ class TorchDDPDriver(TorchDriver): | |||||
# 我们就直接将 model_device 置为 None; | # 我们就直接将 model_device 置为 None; | ||||
self.model_device = None | self.model_device = None | ||||
def _running_fn_(batch, step_fn, signature_fn): | |||||
if isinstance(batch, Dict): | |||||
def _running_fn_(batch, step_fn, signature_fn, wo_auto_param_call): | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(step_fn, batch, signature_fn=signature_fn) | return auto_param_call(step_fn, batch, signature_fn=signature_fn) | ||||
else: | else: | ||||
return step_fn(batch) | return step_fn(batch) | ||||
@@ -214,7 +215,7 @@ class TorchDDPDriver(TorchDriver): | |||||
"Notice your model is a `DistributedDataParallel` model. And your " | "Notice your model is a `DistributedDataParallel` model. And your " | ||||
"model also implements the `train_step` method, which we can not call actually, we will" | "model also implements the `train_step` method, which we can not call actually, we will" | ||||
" call `forward` function instead of `train_step` and you should note that.") | " call `forward` function instead of `train_step` and you should note that.") | ||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
self._train_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
# self._train_signature_fn = model.forward | # self._train_signature_fn = model.forward | ||||
if hasattr(model, "validate_step"): | if hasattr(model, "validate_step"): | ||||
@@ -222,7 +223,7 @@ class TorchDDPDriver(TorchDriver): | |||||
"Notice your model is a `DistributedDataParallel` model. And your " | "Notice your model is a `DistributedDataParallel` model. And your " | ||||
"model also implements the `validate_step` method, which we can not call actually, " | "model also implements the `validate_step` method, which we can not call actually, " | ||||
"we will call `forward` function instead of `validate_step` and you should note that.") | "we will call `forward` function instead of `validate_step` and you should note that.") | ||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
self._validate_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
# self._validate_signature_fn = model.forward | # self._validate_signature_fn = model.forward | ||||
if hasattr(model, "test_step"): | if hasattr(model, "test_step"): | ||||
@@ -230,14 +231,11 @@ class TorchDDPDriver(TorchDriver): | |||||
"Notice your model is a `DistributedDataParallel` model. And your " | "Notice your model is a `DistributedDataParallel` model. And your " | ||||
"model also implements the `test_step` method, which we can not call actually, we will" | "model also implements the `test_step` method, which we can not call actually, we will" | ||||
" call `forward` function instead of `test_step` and you should note that.") | " call `forward` function instead of `test_step` and you should note that.") | ||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward) | |||||
self._test_step = partial(_running_fn_, step_fn=self.model, signature_fn=model.forward, wo_auto_param_call=self.wo_auto_param_call) | |||||
# self._test_signature_fn = model.forward | # self._test_signature_fn = model.forward | ||||
# 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; | # 当用户自己在外面初始化 DDP 时我们会将 model_device 置为 None,这是用户可以通过 `data_device` 将对应的数据移到指定的机器上; | ||||
self._data_device = kwargs.get("data_device", None) | self._data_device = kwargs.get("data_device", None) | ||||
# if self.outside_ddp and self._data_device is None: | |||||
# raise RuntimeError("When you initialize your ddp out of our control, the parameter " | |||||
# "`data_device` can not be None.") | |||||
if isinstance(self._data_device, int): | if isinstance(self._data_device, int): | ||||
if self._data_device < 0: | if self._data_device < 0: | ||||
raise ValueError("Parameter `data_device` can not be smaller than 0.") | raise ValueError("Parameter `data_device` can not be smaller than 0.") | ||||
@@ -349,9 +347,9 @@ class TorchDDPDriver(TorchDriver): | |||||
**self._ddp_kwargs | **self._ddp_kwargs | ||||
) | ) | ||||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}) | |||||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}) | |||||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}) | |||||
self._train_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TRAIN}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._validate_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.VALIDATE}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._test_step = partial(self.model, **{_MODE_PARAMETER: ForwardState.TEST}, wo_auto_param_call=self.wo_auto_param_call) | |||||
self._configured = True | self._configured = True | ||||
@@ -13,7 +13,7 @@ __all__ = [ | |||||
from .torch_driver import TorchDriver | from .torch_driver import TorchDriver | ||||
from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | from fastNLP.core.drivers.torch_driver.utils import replace_sampler, replace_batch_sampler | ||||
from fastNLP.core.utils import auto_param_call | from fastNLP.core.utils import auto_param_call | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, re_instantiate_sampler, RandomBatchSampler | |||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
@@ -102,7 +102,7 @@ class TorchSingleDriver(TorchDriver): | |||||
def train_step(self, batch) -> Dict: | def train_step(self, batch) -> Dict: | ||||
# 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | # 如果 batch 是一个 Dict,我们就默认帮其做参数匹配,否则就直接传入到 `train_step` 函数中,让用户自己处理; | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | ||||
else: | else: | ||||
return self._train_step(batch) | return self._train_step(batch) | ||||
@@ -118,13 +118,13 @@ class TorchSingleDriver(TorchDriver): | |||||
def validate_step(self, batch) -> Dict: | def validate_step(self, batch) -> Dict: | ||||
# 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | # 因为我们 Tester 的逻辑就是将所有的 metric 传给 tester,然后 tester 控制具体 metric 的 update 和 compute;因此不管用户是否 | ||||
# 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | # 实现 validate_step 函数,其都应该返回一个字典,具体使用哪些东西则是在 validate_batch_loop 中每一个具体的 metric 自己去拿的; | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | return self._validate_step(batch) | ||||
def test_step(self, batch) -> Dict: | def test_step(self, batch) -> Dict: | ||||
if isinstance(batch, Dict): | |||||
if isinstance(batch, Dict) and not self.wo_auto_param_call: | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | ||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
@@ -148,7 +148,7 @@ class TorchSingleDriver(TorchDriver): | |||||
return replace_sampler(dataloader, sampler) | return replace_sampler(dataloader, sampler) | ||||
if reproducible: | if reproducible: | ||||
batch_sampler = ReproducibleBatchSampler( | |||||
batch_sampler = RandomBatchSampler( | |||||
batch_sampler=args.batch_sampler, | batch_sampler=args.batch_sampler, | ||||
batch_size=args.batch_size, | batch_size=args.batch_size, | ||||
drop_last=args.drop_last | drop_last=args.drop_last | ||||
@@ -30,7 +30,7 @@ from fastNLP.core.utils import apply_to_collection, torch_move_data_to_device | |||||
from fastNLP.envs import rank_zero_call | from fastNLP.envs import rank_zero_call | ||||
from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | from fastNLP.envs import FASTNLP_SEED_WORKERS, FASTNLP_GLOBAL_RANK, FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME | ||||
from fastNLP.core.log import logger | from fastNLP.core.log import logger | ||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler | |||||
from fastNLP.core.samplers import ReproducibleBatchSampler, ReproducibleSampler, RandomBatchSampler | |||||
class TorchDriver(Driver): | class TorchDriver(Driver): | ||||
@@ -51,6 +51,9 @@ class TorchDriver(Driver): | |||||
# 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数; | # 用来设置 `torch_move_data_to_device` 中的 `non_blocking` 参数; | ||||
self.non_blocking = kwargs.get("torch_non_blocking", True) | self.non_blocking = kwargs.get("torch_non_blocking", True) | ||||
# 用来设置是否关闭 auto_param_call 中的参数匹配问题; | |||||
self.wo_auto_param_call = kwargs.get("model_wo_auto_param_call", False) | |||||
def zero_grad(self, set_to_none: bool = False): | def zero_grad(self, set_to_none: bool = False): | ||||
for optimizer in self.optimizers: | for optimizer in self.optimizers: | ||||
self._clear_grad(optimizer, set_to_none) | self._clear_grad(optimizer, set_to_none) | ||||
@@ -252,7 +255,7 @@ class TorchDriver(Driver): | |||||
elif self.is_distributed(): | elif self.is_distributed(): | ||||
raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | raise RuntimeError("It is not allowed to use checkpoint retraining when you do not use our or `ReproducibleSampler`.") | ||||
else: | else: | ||||
sampler = ReproducibleBatchSampler( | |||||
sampler = RandomBatchSampler( | |||||
batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | batch_sampler=dataloader_args.batch_sampler if dataloader_args.batch_sampler is not None else dataloader_args.sampler, | ||||
batch_size=dataloader_args.batch_size, | batch_size=dataloader_args.batch_size, | ||||
drop_last=dataloader_args.drop_last | drop_last=dataloader_args.drop_last | ||||
@@ -140,24 +140,25 @@ class _DDPWrappingModel(Module): | |||||
pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; | pytorch lightning 实现了先 unwrapping_model 的操作,但是感觉对于我们来说没有什么必须要,先写个注释放这里,之后有需求了再看; | ||||
""" | """ | ||||
_forward_state = kwargs.pop(_MODE_PARAMETER) | |||||
forward_state = kwargs.pop(_MODE_PARAMETER) | |||||
wo_auto_param_call = kwargs.pop("wo_auto_param_call") | |||||
if _forward_state == ForwardState.TRAIN: | |||||
if isinstance(batch, Dict): | |||||
if forward_state == ForwardState.TRAIN: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | return auto_param_call(self._train_step, batch, signature_fn=self._train_signature_fn) | ||||
else: | else: | ||||
return self._train_step(batch) | return self._train_step(batch) | ||||
elif _forward_state == ForwardState.VALIDATE: | |||||
if isinstance(batch, Dict): | |||||
elif forward_state == ForwardState.VALIDATE: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | return auto_param_call(self._validate_step, batch, signature_fn=self._validate_signature_fn) | ||||
else: | else: | ||||
return self._validate_step(batch) | return self._validate_step(batch) | ||||
elif _forward_state == ForwardState.TEST: | |||||
if isinstance(batch, Dict): | |||||
elif forward_state == ForwardState.TEST: | |||||
if isinstance(batch, Dict) and not wo_auto_param_call: | |||||
return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | return auto_param_call(self._test_step, batch, signature_fn=self._test_signature_fn) | ||||
else: | else: | ||||
return self._test_step(batch) | return self._test_step(batch) | ||||
elif _forward_state == ForwardState.PREDICT: | |||||
elif forward_state == ForwardState.PREDICT: | |||||
raise NotImplementedError("'PREDICT' mode has not been implemented.") | raise NotImplementedError("'PREDICT' mode has not been implemented.") | ||||
else: | else: | ||||
raise NotImplementedError("You should direct a concrete mode.") | raise NotImplementedError("You should direct a concrete mode.") | ||||
@@ -96,6 +96,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
:param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | :param signature_fn: 函数,用来替换 `fn` 的函数签名,如果该参数不为 None,那么我们首先会从该函数中提取函数签名,然后通过该函数签名提取 | ||||
参数值后,再传给 `fn` 进行实际的运算; | 参数值后,再传给 `fn` 进行实际的运算; | ||||
:param mapping: 一个字典,用来更改其前面的字典的键值; | :param mapping: 一个字典,用来更改其前面的字典的键值; | ||||
:param wo_auto_param_call: 是否关闭默认的参数匹配行为; | |||||
:return: 返回 `fn` 运行的结果; | :return: 返回 `fn` 运行的结果; | ||||
@@ -113,6 +114,7 @@ def auto_param_call(fn: Callable, *args, signature_fn: Optional[Callable] = None | |||||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 | >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20})) # res: 140 | ||||
>>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 | >>> print(auto_param_call(partial(test_fn, a=100), {"x": 10}, {"y": 20, "a": 200})) # res: 240 | ||||
""" | """ | ||||
if signature_fn is not None: | if signature_fn is not None: | ||||
if not callable(signature_fn): | if not callable(signature_fn): | ||||
raise ValueError(f"Parameter `signature_fn` should be `Callable`.") | raise ValueError(f"Parameter `signature_fn` should be `Callable`.") | ||||
@@ -10,7 +10,7 @@ import re | |||||
from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | from fastNLP.core.callbacks.checkpoint_callback import ModelCheckpointCallback, TrainerCheckpointCallback | ||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from fastNLP.envs import FASTNLP_MODEL_FILENAME, FASTNLP_CHECKPOINT_FILENAME, FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||||
from fastNLP.envs import FASTNLP_LAUNCH_TIME, FASTNLP_DISTRIBUTED_CHECK | |||||
from tests.helpers.utils import magic_argv_env_context | from tests.helpers.utils import magic_argv_env_context | ||||
from fastNLP.core import synchronize_safe_rm | from fastNLP.core import synchronize_safe_rm | ||||
@@ -10,7 +10,7 @@ from typing import Any | |||||
from pathlib import Path | from pathlib import Path | ||||
from fastNLP.core.controllers.trainer import Trainer | from fastNLP.core.controllers.trainer import Trainer | ||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1 | |||||
from tests.helpers.models.torch_model import TorchNormalModel_Classification_1, TorchNormalModel_Classification_3 | |||||
from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | from tests.helpers.datasets.torch_data import TorchNormalDataset_Classification | ||||
from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | from tests.helpers.callbacks.helper_callbacks import RecordLossCallback | ||||
from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | from tests.helpers.callbacks.helper_callbacks_torch import RecordAccumulationStepsCallback_Torch | ||||
@@ -70,7 +70,7 @@ def model_and_optimizers(request): | |||||
trainer_params.output_mapping = None | trainer_params.output_mapping = None | ||||
# elif request.param == 1: | # elif request.param == 1: | ||||
# model = | |||||
return trainer_params | return trainer_params | ||||
@@ -307,10 +307,47 @@ def test_torch_distributed_launch_2(version): | |||||
subprocess.check_call(command) | subprocess.check_call(command) | ||||
@pytest.mark.parametrize("driver,device", [("torch", 0), ("torch_ddp", [0, 1])]) | |||||
@magic_argv_env_context | |||||
def test_torch_wo_auto_param_call( | |||||
driver, | |||||
device, | |||||
n_epochs=10, | |||||
): | |||||
model = TorchNormalModel_Classification_3( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension | |||||
) | |||||
optimizers = SGD(model.parameters(), lr=0.001) | |||||
dataset = TorchNormalDataset_Classification( | |||||
num_labels=NormalClassificationTrainTorchConfig.num_labels, | |||||
feature_dimension=NormalClassificationTrainTorchConfig.feature_dimension, | |||||
each_label_data=NormalClassificationTrainTorchConfig.each_label_data, | |||||
seed=NormalClassificationTrainTorchConfig.seed | |||||
) | |||||
train_dataloader = DataLoader( | |||||
dataset=dataset, | |||||
batch_size=NormalClassificationTrainTorchConfig.batch_size, | |||||
shuffle=True | |||||
) | |||||
trainer = Trainer( | |||||
model=model, | |||||
driver=driver, | |||||
device=device, | |||||
optimizers=optimizers, | |||||
train_dataloader=train_dataloader, | |||||
n_epochs=n_epochs, | |||||
model_wo_auto_param_call=True, | |||||
output_from_new_proc="all" | |||||
) | |||||
trainer.run() | |||||
if dist.is_initialized(): | |||||
dist.destroy_process_group() | |||||
@@ -37,6 +37,7 @@ class TorchNormalModel_Classification_1(nn.Module): | |||||
x = torch.max(x, dim=-1)[1] | x = torch.max(x, dim=-1)[1] | ||||
return {"preds": x, "target": y} | return {"preds": x, "target": y} | ||||
class TorchNormalModel_Classification_2(nn.Module): | class TorchNormalModel_Classification_2(nn.Module): | ||||
""" | """ | ||||
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | 只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | ||||
@@ -61,5 +62,31 @@ class TorchNormalModel_Classification_2(nn.Module): | |||||
return {"loss": loss, "preds": x, "target": y} | return {"loss": loss, "preds": x, "target": y} | ||||
class TorchNormalModel_Classification_3(nn.Module): | |||||
""" | |||||
只实现一个 forward 函数,来测试用户自己在外面初始化 DDP 的场景; | |||||
关闭 auto_param_call,forward 只有一个 batch 参数; | |||||
""" | |||||
def __init__(self, num_labels, feature_dimension): | |||||
super(TorchNormalModel_Classification_3, self).__init__() | |||||
self.num_labels = num_labels | |||||
self.linear1 = nn.Linear(in_features=feature_dimension, out_features=10) | |||||
self.ac1 = nn.ReLU() | |||||
self.linear2 = nn.Linear(in_features=10, out_features=10) | |||||
self.ac2 = nn.ReLU() | |||||
self.output = nn.Linear(in_features=10, out_features=num_labels) | |||||
self.loss_fn = nn.CrossEntropyLoss() | |||||
def forward(self, batch): | |||||
x = batch["x"] | |||||
y = batch["y"] | |||||
x = self.ac1(self.linear1(x)) | |||||
x = self.ac2(self.linear2(x)) | |||||
x = self.output(x) | |||||
loss = self.loss_fn(x, y) | |||||
x = torch.max(x, dim=-1)[1] | |||||
return {"loss": loss, "preds": x, "target": y} | |||||